diff --git a/.gitattributes b/.gitattributes
index f8d2f31cf3853000a83449a2a5c3ac5bdc2b219f..375c7e662659a9b787545e4aeb0ec894bba5e3ed 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -53,3 +53,5 @@ docs/resources/grpo_geoqa.png filter=lfs diff=lfs merge=lfs -text
docs/resources/grpo_openr1_multimodal.png filter=lfs diff=lfs merge=lfs -text
docs/resources/web-ui-en.jpg filter=lfs diff=lfs merge=lfs -text
docs/resources/kto_data.png filter=lfs diff=lfs merge=lfs -text
+docs/resources/grpo_countdown_1.png filter=lfs diff=lfs merge=lfs -text
+docs/resources/grpo_clevr_count.png filter=lfs diff=lfs merge=lfs -text
diff --git a/docs/resources/grpo_clevr_count.png b/docs/resources/grpo_clevr_count.png
new file mode 100644
index 0000000000000000000000000000000000000000..2c814e28e31908a7250220fde18b5bd024d82b35
--- /dev/null
+++ b/docs/resources/grpo_clevr_count.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7192dc4f04801dbdff30bed098a16a7e21212a773ba7b6dc1424b261feca366f
+size 671176
diff --git a/docs/resources/grpo_countdown_1.png b/docs/resources/grpo_countdown_1.png
new file mode 100644
index 0000000000000000000000000000000000000000..819ab3d992619b077d75e6946d4637b030b8d213
--- /dev/null
+++ b/docs/resources/grpo_countdown_1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b78dc3ce1cd541e76f2c557dea3aff06b278bb3b5413946a92c584cf42c1369f
+size 785044
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/efficientformer/configuration_efficientformer.py b/docs/transformers/build/lib/transformers/models/deprecated/efficientformer/configuration_efficientformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e2dca8bcd01c853025ec4364c10b4250421cf2d
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/efficientformer/configuration_efficientformer.py
@@ -0,0 +1,172 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""EfficientFormer model configuration"""
+
+from typing import List
+
+from ....configuration_utils import PretrainedConfig
+from ....utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class EfficientFormerConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of an [`EfficientFormerModel`]. It is used to
+ instantiate an EfficientFormer model according to the specified arguments, defining the model architecture.
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the EfficientFormer
+ [snap-research/efficientformer-l1](https://huggingface.co/snap-research/efficientformer-l1) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ depths (`List(int)`, *optional*, defaults to `[3, 2, 6, 4]`)
+ Depth of each stage.
+ hidden_sizes (`List(int)`, *optional*, defaults to `[48, 96, 224, 448]`)
+ Dimensionality of each stage.
+ downsamples (`List(bool)`, *optional*, defaults to `[True, True, True, True]`)
+ Whether or not to downsample inputs between two stages.
+ dim (`int`, *optional*, defaults to 448):
+ Number of channels in Meta3D layers
+ key_dim (`int`, *optional*, defaults to 32):
+ The size of the key in meta3D block.
+ attention_ratio (`int`, *optional*, defaults to 4):
+ Ratio of the dimension of the query and value to the dimension of the key in MSHA block
+ resolution (`int`, *optional*, defaults to 7)
+ Size of each patch
+ num_hidden_layers (`int`, *optional*, defaults to 5):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 8):
+ Number of attention heads for each attention layer in the 3D MetaBlock.
+ mlp_expansion_ratio (`int`, *optional*, defaults to 4):
+ Ratio of size of the hidden dimensionality of an MLP to the dimensionality of its input.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout probability for all fully connected layers in the embeddings and encoder.
+ patch_size (`int`, *optional*, defaults to 16):
+ The size (resolution) of each patch.
+ num_channels (`int`, *optional*, defaults to 3):
+ The number of input channels.
+ pool_size (`int`, *optional*, defaults to 3):
+ Kernel size of pooling layers.
+ downsample_patch_size (`int`, *optional*, defaults to 3):
+ The size of patches in downsampling layers.
+ downsample_stride (`int`, *optional*, defaults to 2):
+ The stride of convolution kernels in downsampling layers.
+ downsample_pad (`int`, *optional*, defaults to 1):
+ Padding in downsampling layers.
+ drop_path_rate (`int`, *optional*, defaults to 0):
+ Rate at which to increase dropout probability in DropPath.
+ num_meta3d_blocks (`int`, *optional*, defaults to 1):
+ The number of 3D MetaBlocks in the last stage.
+ distillation (`bool`, *optional*, defaults to `True`):
+ Whether to add a distillation head.
+ use_layer_scale (`bool`, *optional*, defaults to `True`):
+ Whether to scale outputs from token mixers.
+ layer_scale_init_value (`float`, *optional*, defaults to 1e-5):
+ Factor by which outputs from token mixers are scaled.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
+ The epsilon used by the layer normalization layers.
+ image_size (`int`, *optional*, defaults to `224`):
+ The size (resolution) of each image.
+
+ Example:
+
+ ```python
+ >>> from transformers import EfficientFormerConfig, EfficientFormerModel
+
+ >>> # Initializing a EfficientFormer efficientformer-l1 style configuration
+ >>> configuration = EfficientFormerConfig()
+
+ >>> # Initializing a EfficientFormerModel (with random weights) from the efficientformer-l3 style configuration
+ >>> model = EfficientFormerModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "efficientformer"
+
+ def __init__(
+ self,
+ depths: List[int] = [3, 2, 6, 4],
+ hidden_sizes: List[int] = [48, 96, 224, 448],
+ downsamples: List[bool] = [True, True, True, True],
+ dim: int = 448,
+ key_dim: int = 32,
+ attention_ratio: int = 4,
+ resolution: int = 7,
+ num_hidden_layers: int = 5,
+ num_attention_heads: int = 8,
+ mlp_expansion_ratio: int = 4,
+ hidden_dropout_prob: float = 0.0,
+ patch_size: int = 16,
+ num_channels: int = 3,
+ pool_size: int = 3,
+ downsample_patch_size: int = 3,
+ downsample_stride: int = 2,
+ downsample_pad: int = 1,
+ drop_path_rate: float = 0.0,
+ num_meta3d_blocks: int = 1,
+ distillation: bool = True,
+ use_layer_scale: bool = True,
+ layer_scale_init_value: float = 1e-5,
+ hidden_act: str = "gelu",
+ initializer_range: float = 0.02,
+ layer_norm_eps: float = 1e-12,
+ image_size: int = 224,
+ batch_norm_eps: float = 1e-05,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.hidden_sizes = hidden_sizes
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.depths = depths
+ self.mlp_expansion_ratio = mlp_expansion_ratio
+ self.downsamples = downsamples
+ self.dim = dim
+ self.key_dim = key_dim
+ self.attention_ratio = attention_ratio
+ self.resolution = resolution
+ self.pool_size = pool_size
+ self.downsample_patch_size = downsample_patch_size
+ self.downsample_stride = downsample_stride
+ self.downsample_pad = downsample_pad
+ self.drop_path_rate = drop_path_rate
+ self.num_meta3d_blocks = num_meta3d_blocks
+ self.distillation = distillation
+ self.use_layer_scale = use_layer_scale
+ self.layer_scale_init_value = layer_scale_init_value
+ self.image_size = image_size
+ self.batch_norm_eps = batch_norm_eps
+
+
+__all__ = [
+ "EfficientFormerConfig",
+]
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/efficientformer/image_processing_efficientformer.py b/docs/transformers/build/lib/transformers/models/deprecated/efficientformer/image_processing_efficientformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..74d16a048de1e59a6d6e0f7fe001bb7194273abd
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/efficientformer/image_processing_efficientformer.py
@@ -0,0 +1,324 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Image processor class for EfficientFormer."""
+
+from typing import Dict, List, Optional, Union
+
+import numpy as np
+
+from ....image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
+from ....image_transforms import (
+ get_resize_output_image_size,
+ resize,
+ to_channel_dimension_format,
+)
+from ....image_utils import (
+ IMAGENET_DEFAULT_MEAN,
+ IMAGENET_DEFAULT_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ infer_channel_dimension_format,
+ is_batched,
+ is_scaled_image,
+ to_numpy_array,
+ valid_images,
+ validate_kwargs,
+ validate_preprocess_arguments,
+)
+from ....utils import TensorType, logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class EfficientFormerImageProcessor(BaseImageProcessor):
+ r"""
+ Constructs a EfficientFormer image processor.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the image's (height, width) dimensions to the specified `(size["height"],
+ size["width"])`. Can be overridden by the `do_resize` parameter in the `preprocess` method.
+ size (`dict`, *optional*, defaults to `{"height": 224, "width": 224}`):
+ Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
+ method.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
+ Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
+ `preprocess` method.
+ do_center_crop (`bool`, *optional*, defaults to `True`):
+ Whether to center crop the image to the specified `crop_size`. Can be overridden by `do_center_crop` in the
+ `preprocess` method.
+ crop_size (`Dict[str, int]` *optional*, defaults to 224):
+ Size of the output image after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess`
+ method.
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
+ parameter in the `preprocess` method.
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
+ Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
+ `preprocess` method.
+ do_normalize:
+ Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
+ method.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
+ image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
+ """
+
+ model_input_names = ["pixel_values"]
+
+ def __init__(
+ self,
+ do_resize: bool = True,
+ size: Optional[Dict[str, int]] = None,
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ do_center_crop: bool = True,
+ do_rescale: bool = True,
+ rescale_factor: Union[int, float] = 1 / 255,
+ crop_size: Dict[str, int] = None,
+ do_normalize: bool = True,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ size = size if size is not None else {"height": 224, "width": 224}
+ size = get_size_dict(size)
+ crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
+ crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size")
+
+ self.do_resize = do_resize
+ self.do_rescale = do_rescale
+ self.do_normalize = do_normalize
+ self.do_center_crop = do_center_crop
+ self.crop_size = crop_size
+ self.size = size
+ self.resample = resample
+ self.rescale_factor = rescale_factor
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
+ self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
+ self._valid_processor_keys = [
+ "images",
+ "do_resize",
+ "size",
+ "resample",
+ "do_center_crop",
+ "crop_size",
+ "do_rescale",
+ "rescale_factor",
+ "do_normalize",
+ "image_mean",
+ "image_std",
+ "return_tensors",
+ "data_format",
+ "input_data_format",
+ ]
+
+ def resize(
+ self,
+ image: np.ndarray,
+ size: Dict[str, int],
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ) -> np.ndarray:
+ """
+ Resize an image to `(size["height"], size["width"])`.
+
+ Args:
+ image (`np.ndarray`):
+ Image to resize.
+ size (`Dict[str, int]`):
+ Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
+ resample:
+ `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`.
+ data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
+ image is used. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred.
+
+ Returns:
+ `np.ndarray`: The resized image.
+ """
+ size = get_size_dict(size)
+
+ if "shortest_edge" in size:
+ size = get_resize_output_image_size(
+ image, size=size["shortest_edge"], default_to_square=False, input_data_format=input_data_format
+ )
+ # size = get_resize_output_image_size(image, size["shortest_edge"], size["longest_edge"])
+ elif "height" in size and "width" in size:
+ size = (size["height"], size["width"])
+ else:
+ raise ValueError(f"Size must contain 'height' and 'width' keys or 'shortest_edge' key. Got {size.keys()}")
+ return resize(
+ image, size=size, resample=resample, data_format=data_format, input_data_format=input_data_format, **kwargs
+ )
+
+ def preprocess(
+ self,
+ images: ImageInput,
+ do_resize: Optional[bool] = None,
+ size: Dict[str, int] = None,
+ resample: PILImageResampling = None,
+ do_center_crop: Optional[bool] = None,
+ crop_size: Optional[int] = None,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ) -> BatchFeature:
+ """
+ Preprocess an image or batch of images.
+
+ Args:
+ images (`ImageInput`):
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
+ Dictionary in the format `{"height": h, "width": w}` specifying the size of the output image after
+ resizing.
+ resample (`PILImageResampling` filter, *optional*, defaults to `self.resample`):
+ `PILImageResampling` filter to use if resizing the image e.g. `PILImageResampling.BILINEAR`. Only has
+ an effect if `do_resize` is set to `True`.
+ do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
+ Whether to center crop the image.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image values between [0 - 1].
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+ crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`):
+ Size of the center crop. Only has an effect if `do_center_crop` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
+ Image mean to use if `do_normalize` is set to `True`.
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
+ Image standard deviation to use if `do_normalize` is set to `True`.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - Unset: Use the channel dimension format of the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ """
+ do_resize = do_resize if do_resize is not None else self.do_resize
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
+ crop_size = crop_size if crop_size is not None else self.crop_size
+ crop_size = get_size_dict(crop_size, param_name="crop_size", default_to_square=True)
+ resample = resample if resample is not None else self.resample
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+
+ size = size if size is not None else self.size
+ size_dict = get_size_dict(size)
+
+ validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys)
+
+ if not is_batched(images):
+ images = [images]
+
+ if not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+ validate_preprocess_arguments(
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ do_center_crop=do_center_crop,
+ crop_size=crop_size,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ )
+ # All transformations expect numpy arrays.
+ images = [to_numpy_array(image) for image in images]
+
+ if do_rescale and is_scaled_image(images[0]):
+ logger.warning_once(
+ "It looks like you are trying to rescale already rescaled images. If the input"
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
+ )
+
+ if input_data_format is None:
+ # We assume that all images have the same channel dimension format.
+ input_data_format = infer_channel_dimension_format(images[0])
+
+ if do_resize:
+ images = [
+ self.resize(image=image, size=size_dict, resample=resample, input_data_format=input_data_format)
+ for image in images
+ ]
+
+ if do_center_crop:
+ images = [
+ self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) for image in images
+ ]
+
+ if do_rescale:
+ images = [
+ self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
+ for image in images
+ ]
+
+ if do_normalize:
+ images = [
+ self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
+ for image in images
+ ]
+
+ images = [
+ to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
+ ]
+
+ data = {"pixel_values": images}
+ return BatchFeature(data=data, tensor_type=return_tensors)
+
+
+__all__ = ["EfficientFormerImageProcessor"]
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/efficientformer/modeling_efficientformer.py b/docs/transformers/build/lib/transformers/models/deprecated/efficientformer/modeling_efficientformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..a45fe7da5da3f3077b48b3aa23ca557cc60dd5e9
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/efficientformer/modeling_efficientformer.py
@@ -0,0 +1,807 @@
+# coding=utf-8
+# Copyright 2022 Snapchat Research and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch EfficientFormer model."""
+
+import itertools
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from ....activations import ACT2FN
+from ....modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
+from ....modeling_utils import PreTrainedModel
+from ....utils import (
+ ModelOutput,
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+)
+from .configuration_efficientformer import EfficientFormerConfig
+
+
+logger = logging.get_logger(__name__)
+
+# General docstring
+_CONFIG_FOR_DOC = "EfficientFormerConfig"
+
+# Base docstring
+_CHECKPOINT_FOR_DOC = "snap-research/efficientformer-l1-300"
+_EXPECTED_OUTPUT_SHAPE = [1, 49, 448]
+
+# Image classification docstring
+_IMAGE_CLASS_CHECKPOINT = "snap-research/efficientformer-l1-300"
+_IMAGE_CLASS_EXPECTED_OUTPUT = "Egyptian cat"
+
+
+class EfficientFormerPatchEmbeddings(nn.Module):
+ """
+ This class performs downsampling between two stages. For the input tensor with the shape [batch_size, num_channels,
+ height, width] it produces output tensor with the shape [batch_size, num_channels, height/stride, width/stride]
+ """
+
+ def __init__(self, config: EfficientFormerConfig, num_channels: int, embed_dim: int, apply_norm: bool = True):
+ super().__init__()
+ self.num_channels = num_channels
+
+ self.projection = nn.Conv2d(
+ num_channels,
+ embed_dim,
+ kernel_size=config.downsample_patch_size,
+ stride=config.downsample_stride,
+ padding=config.downsample_pad,
+ )
+ self.norm = nn.BatchNorm2d(embed_dim, eps=config.batch_norm_eps) if apply_norm else nn.Identity()
+
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
+ batch_size, num_channels, height, width = pixel_values.shape
+ if num_channels != self.num_channels:
+ raise ValueError(
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
+ )
+
+ embeddings = self.projection(pixel_values)
+ embeddings = self.norm(embeddings)
+
+ return embeddings
+
+
+class EfficientFormerSelfAttention(nn.Module):
+ def __init__(self, dim: int, key_dim: int, num_heads: int, attention_ratio: int, resolution: int):
+ super().__init__()
+
+ self.num_heads = num_heads
+ self.key_dim = key_dim
+ self.attention_ratio = attention_ratio
+ self.scale = key_dim**-0.5
+ self.total_key_dim = key_dim * num_heads
+ self.expanded_key_dim = int(attention_ratio * key_dim)
+ self.total_expanded_key_dim = int(self.expanded_key_dim * num_heads)
+ hidden_size = self.total_expanded_key_dim + self.total_key_dim * 2
+ self.qkv = nn.Linear(dim, hidden_size)
+ self.projection = nn.Linear(self.total_expanded_key_dim, dim)
+ points = list(itertools.product(range(resolution), range(resolution)))
+ num_points = len(points)
+ attention_offsets = {}
+ idxs = []
+ for point_1 in points:
+ for point_2 in points:
+ offset = (abs(point_1[0] - point_2[0]), abs(point_1[1] - point_2[1]))
+ if offset not in attention_offsets:
+ attention_offsets[offset] = len(attention_offsets)
+ idxs.append(attention_offsets[offset])
+ self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets)))
+ self.register_buffer("attention_bias_idxs", torch.LongTensor(idxs).view(num_points, num_points))
+
+ @torch.no_grad()
+ def train(self, mode=True):
+ super().train(mode)
+ if mode and hasattr(self, "ab"):
+ del self.ab
+ else:
+ self.ab = self.attention_biases[:, self.attention_bias_idxs]
+
+ def forward(self, hidden_states: torch.Tensor, output_attentions: bool = False) -> Tuple[torch.Tensor]:
+ batch_size, sequence_length, num_channels = hidden_states.shape
+ qkv = self.qkv(hidden_states)
+ query_layer, key_layer, value_layer = qkv.reshape(batch_size, sequence_length, self.num_heads, -1).split(
+ [self.key_dim, self.key_dim, self.expanded_key_dim], dim=3
+ )
+ query_layer = query_layer.permute(0, 2, 1, 3)
+ key_layer = key_layer.permute(0, 2, 1, 3)
+ value_layer = value_layer.permute(0, 2, 1, 3)
+
+ # set `model.to(torch_device)` won't change `self.ab.device`, if there is no follow-up `train` or `eval` call.
+ # Let's do it manually here, so users won't have to do this everytime.
+ if not self.training:
+ self.ab = self.ab.to(self.attention_biases.device)
+ attention_probs = (torch.matmul(query_layer, key_layer.transpose(-2, -1))) * self.scale + (
+ self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab
+ )
+
+ attention_probs = attention_probs.softmax(dim=-1)
+
+ context_layer = torch.matmul(attention_probs, value_layer).transpose(1, 2)
+ context_layer = context_layer.reshape(batch_size, sequence_length, self.total_expanded_key_dim)
+ context_layer = self.projection(context_layer)
+
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+ return outputs
+
+
+class EfficientFormerConvStem(nn.Module):
+ def __init__(self, config: EfficientFormerConfig, out_channels: int):
+ super().__init__()
+
+ self.convolution1 = nn.Conv2d(config.num_channels, out_channels // 2, kernel_size=3, stride=2, padding=1)
+ self.batchnorm_before = nn.BatchNorm2d(out_channels // 2, eps=config.batch_norm_eps)
+
+ self.convolution2 = nn.Conv2d(out_channels // 2, out_channels, kernel_size=3, stride=2, padding=1)
+ self.batchnorm_after = nn.BatchNorm2d(out_channels, eps=config.batch_norm_eps)
+
+ self.activation = nn.ReLU()
+
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
+ features = self.batchnorm_before(self.convolution1(pixel_values))
+ features = self.activation(features)
+ features = self.batchnorm_after(self.convolution2(features))
+ features = self.activation(features)
+
+ return features
+
+
+class EfficientFormerPooling(nn.Module):
+ def __init__(self, pool_size: int):
+ super().__init__()
+ self.pool = nn.AvgPool2d(pool_size, stride=1, padding=pool_size // 2, count_include_pad=False)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ output = self.pool(hidden_states) - hidden_states
+ return output
+
+
+class EfficientFormerDenseMlp(nn.Module):
+ def __init__(
+ self,
+ config: EfficientFormerConfig,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ ):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+
+ self.linear_in = nn.Linear(in_features, hidden_features)
+ self.activation = ACT2FN[config.hidden_act]
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ self.linear_out = nn.Linear(hidden_features, out_features)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.linear_in(hidden_states)
+ hidden_states = self.activation(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.linear_out(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+
+ return hidden_states
+
+
+class EfficientFormerConvMlp(nn.Module):
+ def __init__(
+ self,
+ config: EfficientFormerConfig,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ drop: float = 0.0,
+ ):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+
+ self.convolution1 = nn.Conv2d(in_features, hidden_features, 1)
+ self.activation = ACT2FN[config.hidden_act]
+ self.convolution2 = nn.Conv2d(hidden_features, out_features, 1)
+ self.dropout = nn.Dropout(drop)
+
+ self.batchnorm_before = nn.BatchNorm2d(hidden_features, eps=config.batch_norm_eps)
+ self.batchnorm_after = nn.BatchNorm2d(out_features, eps=config.batch_norm_eps)
+
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
+ hidden_state = self.convolution1(hidden_state)
+ hidden_state = self.batchnorm_before(hidden_state)
+
+ hidden_state = self.activation(hidden_state)
+ hidden_state = self.dropout(hidden_state)
+ hidden_state = self.convolution2(hidden_state)
+
+ hidden_state = self.batchnorm_after(hidden_state)
+ hidden_state = self.dropout(hidden_state)
+
+ return hidden_state
+
+
+def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
+ """
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+ Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
+ however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
+ layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
+ argument.
+ """
+ if drop_prob == 0.0 or not training:
+ return input
+ keep_prob = 1 - drop_prob
+ shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
+ random_tensor.floor_() # binarize
+ output = input.div(keep_prob) * random_tensor
+ return output
+
+
+class EfficientFormerDropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob: Optional[float] = None) -> None:
+ super().__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ return drop_path(hidden_states, self.drop_prob, self.training)
+
+ def extra_repr(self) -> str:
+ return "p={}".format(self.drop_prob)
+
+
+class EfficientFormerFlat(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor]:
+ hidden_states = hidden_states.flatten(2).transpose(1, 2)
+ return hidden_states
+
+
+class EfficientFormerMeta3D(nn.Module):
+ def __init__(self, config: EfficientFormerConfig, dim: int, drop_path: float = 0.0):
+ super().__init__()
+
+ self.token_mixer = EfficientFormerSelfAttention(
+ dim=config.dim,
+ key_dim=config.key_dim,
+ num_heads=config.num_attention_heads,
+ attention_ratio=config.attention_ratio,
+ resolution=config.resolution,
+ )
+
+ self.layernorm1 = nn.LayerNorm(dim, eps=config.layer_norm_eps)
+ self.layernorm2 = nn.LayerNorm(dim, eps=config.layer_norm_eps)
+
+ mlp_hidden_dim = int(dim * config.mlp_expansion_ratio)
+ self.mlp = EfficientFormerDenseMlp(config, in_features=dim, hidden_features=mlp_hidden_dim)
+
+ self.drop_path = EfficientFormerDropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+ self.use_layer_scale = config.use_layer_scale
+ if config.use_layer_scale:
+ self.layer_scale_1 = nn.Parameter(config.layer_scale_init_value * torch.ones((dim)), requires_grad=True)
+ self.layer_scale_2 = nn.Parameter(config.layer_scale_init_value * torch.ones((dim)), requires_grad=True)
+
+ def forward(self, hidden_states: torch.Tensor, output_attentions: bool = False) -> Tuple[torch.Tensor]:
+ self_attention_outputs = self.token_mixer(self.layernorm1(hidden_states), output_attentions)
+ attention_output = self_attention_outputs[0]
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
+
+ if self.use_layer_scale:
+ layer_output = hidden_states + self.drop_path(
+ self.layer_scale_1.unsqueeze(0).unsqueeze(0) * attention_output
+ )
+ layer_output = layer_output + self.drop_path(
+ self.layer_scale_2.unsqueeze(0).unsqueeze(0) * self.mlp(self.layernorm2(layer_output))
+ )
+ else:
+ layer_output = hidden_states + self.drop_path(attention_output)
+ layer_output = layer_output + self.drop_path(self.mlp(self.layernorm2(layer_output)))
+
+ outputs = (layer_output,) + outputs
+
+ return outputs
+
+
+class EfficientFormerMeta3DLayers(nn.Module):
+ def __init__(self, config: EfficientFormerConfig):
+ super().__init__()
+ drop_paths = [
+ config.drop_path_rate * (block_idx + sum(config.depths[:-1]))
+ for block_idx in range(config.num_meta3d_blocks)
+ ]
+ self.blocks = nn.ModuleList(
+ [EfficientFormerMeta3D(config, config.hidden_sizes[-1], drop_path=drop_path) for drop_path in drop_paths]
+ )
+
+ def forward(self, hidden_states: torch.Tensor, output_attentions: bool = False) -> Tuple[torch.Tensor]:
+ all_attention_outputs = () if output_attentions else None
+
+ for layer_module in self.blocks:
+ if isinstance(hidden_states, tuple):
+ hidden_states = hidden_states[0]
+
+ hidden_states = layer_module(hidden_states, output_attentions)
+
+ if output_attentions:
+ all_attention_outputs = all_attention_outputs + (hidden_states[1],)
+
+ if output_attentions:
+ outputs = (hidden_states[0],) + all_attention_outputs
+ return outputs
+
+ return hidden_states
+
+
+class EfficientFormerMeta4D(nn.Module):
+ def __init__(self, config: EfficientFormerConfig, dim: int, drop_path: float = 0.0):
+ super().__init__()
+ pool_size = config.pool_size if config.pool_size is not None else 3
+ self.token_mixer = EfficientFormerPooling(pool_size=pool_size)
+ mlp_hidden_dim = int(dim * config.mlp_expansion_ratio)
+ self.mlp = EfficientFormerConvMlp(
+ config, in_features=dim, hidden_features=mlp_hidden_dim, drop=config.hidden_dropout_prob
+ )
+
+ self.drop_path = EfficientFormerDropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+ self.use_layer_scale = config.use_layer_scale
+ if config.use_layer_scale:
+ self.layer_scale_1 = nn.Parameter(config.layer_scale_init_value * torch.ones((dim)), requires_grad=True)
+ self.layer_scale_2 = nn.Parameter(config.layer_scale_init_value * torch.ones((dim)), requires_grad=True)
+
+ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor]:
+ outputs = self.token_mixer(hidden_states)
+
+ if self.use_layer_scale:
+ layer_output = hidden_states + self.drop_path(self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * outputs)
+
+ layer_output = layer_output + self.drop_path(
+ self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * self.mlp(layer_output)
+ )
+ else:
+ layer_output = hidden_states + self.drop_path(outputs)
+ layer_output = layer_output + self.drop_path(self.mlp(layer_output))
+
+ return layer_output
+
+
+class EfficientFormerMeta4DLayers(nn.Module):
+ def __init__(self, config: EfficientFormerConfig, stage_idx: int):
+ super().__init__()
+ num_layers = (
+ config.depths[stage_idx] if stage_idx != -1 else config.depths[stage_idx] - config.num_meta3d_blocks
+ )
+ drop_paths = [
+ config.drop_path_rate * (block_idx + sum(config.depths[:stage_idx])) for block_idx in range(num_layers)
+ ]
+
+ self.blocks = nn.ModuleList(
+ [
+ EfficientFormerMeta4D(config, config.hidden_sizes[stage_idx], drop_path=drop_path)
+ for drop_path in drop_paths
+ ]
+ )
+
+ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor]:
+ for layer_module in self.blocks:
+ hidden_states = layer_module(hidden_states)
+ return hidden_states
+
+
+class EfficientFormerIntermediateStage(nn.Module):
+ def __init__(self, config: EfficientFormerConfig, index: int):
+ super().__init__()
+ self.meta4D_layers = EfficientFormerMeta4DLayers(config, index)
+
+ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor]:
+ hidden_states = self.meta4D_layers(hidden_states)
+ return hidden_states
+
+
+class EfficientFormerLastStage(nn.Module):
+ def __init__(self, config: EfficientFormerConfig):
+ super().__init__()
+ self.meta4D_layers = EfficientFormerMeta4DLayers(config, -1)
+ self.flat = EfficientFormerFlat()
+ self.meta3D_layers = EfficientFormerMeta3DLayers(config)
+
+ def forward(self, hidden_states: torch.Tensor, output_attentions: bool = False) -> Tuple[torch.Tensor]:
+ hidden_states = self.meta4D_layers(hidden_states)
+ hidden_states = self.flat(hidden_states)
+ hidden_states = self.meta3D_layers(hidden_states, output_attentions)
+
+ return hidden_states
+
+
+class EfficientFormerEncoder(nn.Module):
+ def __init__(self, config: EfficientFormerConfig):
+ super().__init__()
+ self.config = config
+ num_intermediate_stages = len(config.depths) - 1
+ downsamples = [
+ config.downsamples[i] or config.hidden_sizes[i] != config.hidden_sizes[i + 1]
+ for i in range(num_intermediate_stages)
+ ]
+ intermediate_stages = []
+
+ for i in range(num_intermediate_stages):
+ intermediate_stages.append(EfficientFormerIntermediateStage(config, i))
+ if downsamples[i]:
+ intermediate_stages.append(
+ EfficientFormerPatchEmbeddings(config, config.hidden_sizes[i], config.hidden_sizes[i + 1])
+ )
+
+ self.intermediate_stages = nn.ModuleList(intermediate_stages)
+ self.last_stage = EfficientFormerLastStage(config)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ output_hidden_states: bool = False,
+ output_attentions: bool = False,
+ return_dict: bool = True,
+ ) -> BaseModelOutput:
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ for layer_module in self.intermediate_stages:
+ hidden_states = layer_module(hidden_states)
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_output = self.last_stage(hidden_states, output_attentions=output_attentions)
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + layer_output[1:]
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (layer_output[0],)
+
+ if not return_dict:
+ return tuple(v for v in [layer_output[0], all_hidden_states, all_self_attentions] if v is not None)
+
+ return BaseModelOutput(
+ last_hidden_state=layer_output[0],
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+
+class EfficientFormerPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = EfficientFormerConfig
+ base_model_prefix = "efficientformer"
+ main_input_name = "pixel_values"
+ supports_gradient_checkpointing = False
+
+ def _init_weights(self, module: nn.Module):
+ """Initialize the weights"""
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+
+EFFICIENTFORMER_START_DOCSTRING = r"""
+ This model is a PyTorch [nn.Module](https://pytorch.org/docs/stable/nn.html#nn.Module) subclass. Use it as a
+ regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.
+
+ Parameters:
+ config ([`EfficientFormerConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+EFFICIENTFORMER_INPUTS_DOCSTRING = r"""
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Pixel values can be obtained using [`ViTImageProcessor`]. See
+ [`ViTImageProcessor.preprocess`] for details.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare EfficientFormer Model transformer outputting raw hidden-states without any specific head on top.",
+ EFFICIENTFORMER_START_DOCSTRING,
+)
+class EfficientFormerModel(EfficientFormerPreTrainedModel):
+ def __init__(self, config: EfficientFormerConfig):
+ super().__init__(config)
+ self.config = config
+ _no_split_modules = ["EfficientFormerMeta4D"]
+
+ self.patch_embed = EfficientFormerConvStem(config, config.hidden_sizes[0])
+ self.encoder = EfficientFormerEncoder(config)
+ self.layernorm = nn.LayerNorm(config.hidden_sizes[-1], eps=config.layer_norm_eps)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(EFFICIENTFORMER_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=BaseModelOutputWithPooling,
+ config_class=_CONFIG_FOR_DOC,
+ modality="vision",
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
+ )
+ def forward(
+ self,
+ pixel_values: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, BaseModelOutput]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ embedding_output = self.patch_embed(pixel_values)
+ encoder_outputs = self.encoder(
+ embedding_output, output_attentions=output_attentions, output_hidden_states=output_hidden_states
+ )
+
+ sequence_output = encoder_outputs[0]
+ sequence_output = self.layernorm(sequence_output)
+
+ if not return_dict:
+ head_outputs = (sequence_output,)
+ return head_outputs + encoder_outputs[1:]
+
+ return BaseModelOutput(
+ last_hidden_state=sequence_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ EfficientFormer Model transformer with an image classification head on top (a linear layer on top of the final
+ hidden state of the [CLS] token) e.g. for ImageNet.
+ """,
+ EFFICIENTFORMER_START_DOCSTRING,
+)
+class EfficientFormerForImageClassification(EfficientFormerPreTrainedModel):
+ def __init__(self, config: EfficientFormerConfig):
+ super().__init__(config)
+
+ self.num_labels = config.num_labels
+ self.efficientformer = EfficientFormerModel(config)
+
+ # Classifier head
+ self.classifier = (
+ nn.Linear(config.hidden_sizes[-1], config.num_labels) if config.num_labels > 0 else nn.Identity()
+ )
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(EFFICIENTFORMER_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_IMAGE_CLASS_CHECKPOINT,
+ output_type=ImageClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
+ )
+ def forward(
+ self,
+ pixel_values: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, ImageClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.efficientformer(
+ pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ logits = self.classifier(sequence_output.mean(-2))
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return ImageClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@dataclass
+class EfficientFormerForImageClassificationWithTeacherOutput(ModelOutput):
+ """
+ Output type of [`EfficientFormerForImageClassificationWithTeacher`].
+
+ Args:
+ logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
+ Prediction scores as the average of the cls_logits and distillation logits.
+ cls_logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
+ Prediction scores of the classification head (i.e. the linear layer on top of the final hidden state of the
+ class token).
+ distillation_logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
+ Prediction scores of the distillation head (i.e. the linear layer on top of the final hidden state of the
+ distillation token).
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
+ shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
+ plus the initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
+ the self-attention heads.
+ """
+
+ logits: Optional[torch.FloatTensor] = None
+ cls_logits: Optional[torch.FloatTensor] = None
+ distillation_logits: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+
+
+@add_start_docstrings(
+ """
+ EfficientFormer Model transformer with image classification heads on top (a linear layer on top of the final hidden
+ state of the [CLS] token and a linear layer on top of the final hidden state of the distillation token) e.g. for
+ ImageNet.
+
+
+
+ This model supports inference-only. Fine-tuning with distillation (i.e. with a teacher) is not yet
+ supported.
+
+
+ """,
+ EFFICIENTFORMER_START_DOCSTRING,
+)
+class EfficientFormerForImageClassificationWithTeacher(EfficientFormerPreTrainedModel):
+ def __init__(self, config: EfficientFormerConfig):
+ super().__init__(config)
+
+ self.num_labels = config.num_labels
+ self.efficientformer = EfficientFormerModel(config)
+
+ # Classifier head
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
+ # Distillation head
+ self.distillation_classifier = (
+ nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
+ )
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(EFFICIENTFORMER_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_IMAGE_CLASS_CHECKPOINT,
+ output_type=EfficientFormerForImageClassificationWithTeacherOutput,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
+ )
+ def forward(
+ self,
+ pixel_values: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, EfficientFormerForImageClassificationWithTeacherOutput]:
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ outputs = self.efficientformer(
+ pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ cls_logits = self.classifier(sequence_output.mean(-2))
+ distillation_logits = self.distillation_classifier(sequence_output.mean(-2))
+
+ # during inference, return the average of both classifier predictions
+ logits = (cls_logits + distillation_logits) / 2
+
+ if not return_dict:
+ output = (logits, cls_logits, distillation_logits) + outputs[1:]
+ return output
+
+ return EfficientFormerForImageClassificationWithTeacherOutput(
+ logits=logits,
+ cls_logits=cls_logits,
+ distillation_logits=distillation_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+__all__ = [
+ "EfficientFormerForImageClassification",
+ "EfficientFormerForImageClassificationWithTeacher",
+ "EfficientFormerModel",
+ "EfficientFormerPreTrainedModel",
+]
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/efficientformer/modeling_tf_efficientformer.py b/docs/transformers/build/lib/transformers/models/deprecated/efficientformer/modeling_tf_efficientformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..e11fa1edf9d08f6ae773a025a9f3fbf7039cb1c9
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/efficientformer/modeling_tf_efficientformer.py
@@ -0,0 +1,1198 @@
+# coding=utf-8
+# Copyright 2023 Snapchat Research and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""TensorFlow EfficientFormer model."""
+
+import itertools
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import tensorflow as tf
+
+from ....activations_tf import ACT2FN
+from ....modeling_tf_outputs import (
+ TFBaseModelOutput,
+ TFBaseModelOutputWithPooling,
+ TFImageClassifierOutput,
+)
+from ....modeling_tf_utils import (
+ TFPreTrainedModel,
+ TFSequenceClassificationLoss,
+ get_initializer,
+ keras,
+ keras_serializable,
+ unpack_inputs,
+)
+from ....tf_utils import shape_list, stable_softmax
+from ....utils import (
+ ModelOutput,
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+)
+from .configuration_efficientformer import EfficientFormerConfig
+
+
+logger = logging.get_logger(__name__)
+
+# General docstring
+_CONFIG_FOR_DOC = "EfficientFormerConfig"
+
+# Base docstring
+_CHECKPOINT_FOR_DOC = "snap-research/efficientformer-l1-300"
+_EXPECTED_OUTPUT_SHAPE = [1, 49, 448]
+
+# Image classification docstring
+_IMAGE_CLASS_CHECKPOINT = "snap-research/efficientformer-l1-300"
+_IMAGE_CLASS_EXPECTED_OUTPUT = "LABEL_281"
+
+
+class TFEfficientFormerPatchEmbeddings(keras.layers.Layer):
+ """
+ This class performs downsampling between two stages. For the input tensor with the shape [batch_size, num_channels,
+ height, width] it produces output tensor with the shape [batch_size, num_channels, height/stride, width/stride]
+ """
+
+ def __init__(
+ self, config: EfficientFormerConfig, num_channels: int, embed_dim: int, apply_norm: bool = True, **kwargs
+ ) -> None:
+ super().__init__(**kwargs)
+ self.num_channels = num_channels
+
+ self.padding = keras.layers.ZeroPadding2D(padding=config.downsample_pad)
+ self.projection = keras.layers.Conv2D(
+ filters=embed_dim,
+ kernel_size=config.downsample_patch_size,
+ strides=config.downsample_stride,
+ padding="valid",
+ name="projection",
+ )
+ # Use same default momentum and epsilon as PyTorch equivalent for BatchNormalization
+ self.norm = (
+ keras.layers.BatchNormalization(axis=-1, epsilon=config.batch_norm_eps, momentum=0.9, name="norm")
+ if apply_norm
+ else tf.identity
+ )
+ self.embed_dim = embed_dim
+
+ def call(self, pixel_values: tf.Tensor, training: bool = False) -> tf.Tensor:
+ tf.debugging.assert_shapes(
+ [(pixel_values, (..., None, None, self.num_channels))],
+ message="Make sure that the channel dimension of the pixel values match with the one set in the configuration.",
+ )
+ embeddings = self.projection(self.padding(pixel_values))
+ embeddings = self.norm(embeddings, training=training)
+ return embeddings
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "projection", None) is not None:
+ with tf.name_scope(self.projection.name):
+ self.projection.build([None, None, None, self.num_channels])
+ if getattr(self, "norm", None) is not None:
+ if hasattr(self.norm, "name"):
+ with tf.name_scope(self.norm.name):
+ self.norm.build([None, None, None, self.embed_dim])
+
+
+class TFEfficientFormerSelfAttention(keras.layers.Layer):
+ def __init__(
+ self,
+ dim: int,
+ key_dim: int,
+ num_heads: int,
+ attention_ratio: int,
+ resolution: int,
+ config: EfficientFormerConfig,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.num_heads = num_heads
+ self.key_dim = key_dim
+ self.attention_ratio = attention_ratio
+ self.scale = key_dim**-0.5
+ self.total_key_dim = key_dim * num_heads
+ self.expanded_key_dim = int(attention_ratio * key_dim)
+ self.total_expanded_key_dim = int(self.expanded_key_dim * num_heads)
+ hidden_size = self.total_expanded_key_dim + self.total_key_dim * 2
+
+ self.qkv = keras.layers.Dense(
+ units=hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="qkv"
+ )
+ self.projection = keras.layers.Dense(
+ units=dim, kernel_initializer=get_initializer(config.initializer_range), name="projection"
+ )
+ self.resolution = resolution
+ self.dim = dim
+
+ def build(self, input_shape: tf.TensorShape) -> None:
+ points = list(itertools.product(range(self.resolution), range(self.resolution)))
+ num_points = len(points)
+ attention_offsets = {}
+
+ idxs = []
+
+ for point_1 in points:
+ for point_2 in points:
+ offset = (abs(point_1[0] - point_2[0]), abs(point_1[1] - point_2[1]))
+ if offset not in attention_offsets:
+ attention_offsets[offset] = len(attention_offsets)
+ idxs.append(attention_offsets[offset])
+
+ self.attention_biases = self.add_weight(
+ shape=(self.num_heads, len(attention_offsets)),
+ initializer=keras.initializers.zeros(),
+ trainable=True,
+ name="attention_biases",
+ )
+ self.attention_bias_idxs = self.add_weight(
+ shape=(num_points, num_points),
+ trainable=False,
+ dtype=tf.int32,
+ name="attention_bias_idxs",
+ )
+
+ self.attention_bias_idxs.assign(tf.reshape(tf.cast(idxs, dtype=tf.int32), (num_points, num_points)))
+
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "qkv", None) is not None:
+ with tf.name_scope(self.qkv.name):
+ self.qkv.build([None, None, self.dim])
+ if getattr(self, "projection", None) is not None:
+ with tf.name_scope(self.projection.name):
+ self.projection.build([None, None, self.total_expanded_key_dim])
+
+ def call(
+ self, hidden_states: tf.Tensor, output_attentions: bool = False, training: bool = False
+ ) -> Tuple[tf.Tensor]:
+ batch_size, sequence_length, *_ = shape_list(hidden_states)
+ qkv = self.qkv(inputs=hidden_states)
+
+ query_layer, key_layer, value_layer = tf.split(
+ tf.reshape(tensor=qkv, shape=(batch_size, sequence_length, self.num_heads, -1)),
+ num_or_size_splits=[self.key_dim, self.key_dim, self.expanded_key_dim],
+ axis=3,
+ )
+
+ query_layer = tf.transpose(query_layer, perm=[0, 2, 1, 3])
+ key_layer = tf.transpose(key_layer, perm=[0, 2, 1, 3])
+ value_layer = tf.transpose(value_layer, perm=[0, 2, 1, 3])
+
+ attention_probs = tf.matmul(query_layer, tf.transpose(key_layer, perm=[0, 1, 3, 2]))
+ scale = tf.cast(self.scale, dtype=attention_probs.dtype)
+ attention_probs = tf.multiply(attention_probs, scale)
+
+ attention_biases = tf.gather(params=self.attention_biases, indices=self.attention_bias_idxs, axis=1)
+ attention_probs = attention_probs + attention_biases
+ attention_probs = stable_softmax(logits=attention_probs, axis=-1)
+
+ context_layer = tf.matmul(attention_probs, value_layer)
+ context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3])
+
+ context_layer = tf.reshape(
+ tensor=context_layer, shape=(batch_size, sequence_length, self.total_expanded_key_dim)
+ )
+ context_layer = self.projection(context_layer)
+
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+ return outputs
+
+
+class TFEfficientFormerConvStem(keras.layers.Layer):
+ def __init__(self, config: EfficientFormerConfig, out_channels: int, **kwargs):
+ super().__init__(**kwargs)
+
+ self.padding = keras.layers.ZeroPadding2D(padding=1)
+ self.convolution1 = keras.layers.Conv2D(
+ filters=out_channels // 2, kernel_size=3, strides=2, padding="valid", name="convolution1"
+ )
+ # Use same default momentum and epsilon as PyTorch equivalent for BatchNormalization
+ self.batchnorm_before = keras.layers.BatchNormalization(
+ axis=-1, epsilon=config.batch_norm_eps, momentum=0.9, name="batchnorm_before"
+ )
+
+ self.convolution2 = keras.layers.Conv2D(
+ filters=out_channels,
+ kernel_size=3,
+ strides=2,
+ padding="valid",
+ name="convolution2",
+ )
+ # Use same default momentum and epsilon as PyTorch equivalent for BatchNormalization
+ self.batchnorm_after = keras.layers.BatchNormalization(
+ axis=-1, epsilon=config.batch_norm_eps, momentum=0.9, name="batchnorm_after"
+ )
+
+ self.activation = keras.layers.Activation(activation=keras.activations.relu, name="activation")
+ self.out_channels = out_channels
+ self.config = config
+
+ def call(self, pixel_values: tf.Tensor, training: bool = False) -> tf.Tensor:
+ features = self.batchnorm_before(self.convolution1(self.padding(pixel_values)), training=training)
+ features = self.activation(features)
+ features = self.batchnorm_after(self.convolution2(self.padding(features)), training=training)
+ features = self.activation(features)
+ return features
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "convolution1", None) is not None:
+ with tf.name_scope(self.convolution1.name):
+ self.convolution1.build([None, None, None, self.config.num_channels])
+ if getattr(self, "batchnorm_before", None) is not None:
+ with tf.name_scope(self.batchnorm_before.name):
+ self.batchnorm_before.build([None, None, None, self.out_channels // 2])
+ if getattr(self, "convolution2", None) is not None:
+ with tf.name_scope(self.convolution2.name):
+ self.convolution2.build([None, None, None, self.out_channels // 2])
+ if getattr(self, "batchnorm_after", None) is not None:
+ with tf.name_scope(self.batchnorm_after.name):
+ self.batchnorm_after.build([None, None, None, self.out_channels])
+ if getattr(self, "activation", None) is not None:
+ with tf.name_scope(self.activation.name):
+ self.activation.build(None)
+
+
+class TFEfficientFormerPooling(keras.layers.Layer):
+ def __init__(self, pool_size: int, **kwargs):
+ super().__init__(**kwargs)
+ self.pool = keras.layers.AveragePooling2D(pool_size=pool_size, strides=1, padding="same")
+
+ def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
+ output = self.pool(hidden_states)
+ output = output - hidden_states
+ return output
+
+
+class TFEfficientFormerDenseMlp(keras.layers.Layer):
+ def __init__(
+ self,
+ config: EfficientFormerConfig,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+
+ self.linear_in = keras.layers.Dense(
+ units=hidden_features, kernel_initializer=get_initializer(config.initializer_range), name="linear_in"
+ )
+ self.activation = ACT2FN[config.hidden_act]
+ self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
+
+ self.linear_out = keras.layers.Dense(
+ units=out_features, kernel_initializer=get_initializer(config.initializer_range), name="linear_out"
+ )
+ self.hidden_features = hidden_features
+ self.in_features = in_features
+
+ def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
+ hidden_states = self.linear_in(inputs=hidden_states)
+ hidden_states = self.activation(hidden_states)
+ hidden_states = self.dropout(inputs=hidden_states, training=training)
+ hidden_states = self.linear_out(inputs=hidden_states)
+ hidden_states = self.dropout(inputs=hidden_states, training=training)
+
+ return hidden_states
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "linear_in", None) is not None:
+ with tf.name_scope(self.linear_in.name):
+ self.linear_in.build([None, None, self.in_features])
+ if getattr(self, "linear_out", None) is not None:
+ with tf.name_scope(self.linear_out.name):
+ self.linear_out.build([None, None, self.hidden_features])
+
+
+class TFEfficientFormerConvMlp(keras.layers.Layer):
+ def __init__(
+ self,
+ config: EfficientFormerConfig,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ drop: float = 0.0,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+
+ self.convolution1 = keras.layers.Conv2D(
+ filters=hidden_features,
+ kernel_size=1,
+ name="convolution1",
+ padding="valid",
+ )
+
+ self.activation = ACT2FN[config.hidden_act]
+
+ self.convolution2 = keras.layers.Conv2D(
+ filters=out_features,
+ kernel_size=1,
+ name="convolution2",
+ padding="valid",
+ )
+
+ self.dropout = keras.layers.Dropout(rate=drop)
+
+ # Use same default momentum and epsilon as PyTorch equivalent for BatchNormalization
+ self.batchnorm_before = keras.layers.BatchNormalization(
+ axis=-1, epsilon=config.batch_norm_eps, momentum=0.9, name="batchnorm_before"
+ )
+ # Use same default momentum and epsilon as PyTorch equivalent for BatchNormalization
+ self.batchnorm_after = keras.layers.BatchNormalization(
+ axis=-1, epsilon=config.batch_norm_eps, momentum=0.9, name="batchnorm_after"
+ )
+ self.hidden_features = hidden_features
+ self.in_features = in_features
+ self.out_features = out_features
+
+ def call(self, hidden_state: tf.Tensor, training: bool = False) -> tf.Tensor:
+ hidden_state = self.convolution1(hidden_state)
+ hidden_state = self.batchnorm_before(hidden_state, training=training)
+ hidden_state = self.activation(hidden_state)
+ hidden_state = self.dropout(hidden_state, training=training)
+ hidden_state = self.convolution2(hidden_state)
+ hidden_state = self.batchnorm_after(hidden_state, training=training)
+ hidden_state = self.dropout(hidden_state, training=training)
+ return hidden_state
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "convolution1", None) is not None:
+ with tf.name_scope(self.convolution1.name):
+ self.convolution1.build([None, None, None, self.in_features])
+ if getattr(self, "convolution2", None) is not None:
+ with tf.name_scope(self.convolution2.name):
+ self.convolution2.build([None, None, None, self.hidden_features])
+ if getattr(self, "batchnorm_before", None) is not None:
+ with tf.name_scope(self.batchnorm_before.name):
+ self.batchnorm_before.build([None, None, None, self.hidden_features])
+ if getattr(self, "batchnorm_after", None) is not None:
+ with tf.name_scope(self.batchnorm_after.name):
+ self.batchnorm_after.build([None, None, None, self.out_features])
+
+
+# Copied from transformers.models.convnext.modeling_tf_convnext.TFConvNextDropPath with ConvNext->EfficientFormer
+class TFEfficientFormerDropPath(keras.layers.Layer):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+ References:
+ (1) github.com:rwightman/pytorch-image-models
+ """
+
+ def __init__(self, drop_path: float, **kwargs):
+ super().__init__(**kwargs)
+ self.drop_path = drop_path
+
+ def call(self, x: tf.Tensor, training=None):
+ if training:
+ keep_prob = 1 - self.drop_path
+ shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)
+ random_tensor = keep_prob + tf.random.uniform(shape, 0, 1)
+ random_tensor = tf.floor(random_tensor)
+ return (x / keep_prob) * random_tensor
+ return x
+
+
+class TFEfficientFormerFlat(keras.layers.Layer):
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+
+ def call(self, hidden_states: tf.Tensor) -> Tuple[tf.Tensor]:
+ batch_size, _, _, in_channels = shape_list(hidden_states)
+ hidden_states = tf.reshape(hidden_states, shape=[batch_size, -1, in_channels])
+ return hidden_states
+
+
+class TFEfficientFormerMeta3D(keras.layers.Layer):
+ def __init__(self, config: EfficientFormerConfig, dim: int, drop_path: float = 0.0, **kwargs):
+ super().__init__(**kwargs)
+
+ self.token_mixer = TFEfficientFormerSelfAttention(
+ dim=config.dim,
+ key_dim=config.key_dim,
+ num_heads=config.num_attention_heads,
+ attention_ratio=config.attention_ratio,
+ resolution=config.resolution,
+ name="token_mixer",
+ config=config,
+ )
+ self.dim = dim
+ self.config = config
+
+ self.layernorm1 = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm1")
+ self.layernorm2 = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm2")
+ mlp_hidden_dim = int(dim * config.mlp_expansion_ratio)
+ self.mlp = TFEfficientFormerDenseMlp(config, in_features=dim, hidden_features=mlp_hidden_dim, name="mlp")
+
+ # Using `layers.Activation` instead of `tf.identity` to better control `training' behavior.
+ self.drop_path = (
+ TFEfficientFormerDropPath(drop_path)
+ if drop_path > 0.0
+ else keras.layers.Activation("linear", name="drop_path")
+ )
+ self.config = config
+
+ def build(self, input_shape=None):
+ self.layer_scale_1 = None
+ self.layer_scale_2 = None
+
+ if self.config.use_layer_scale:
+ self.layer_scale_1 = self.add_weight(
+ shape=(self.dim,),
+ initializer=keras.initializers.Constant(value=self.config.layer_scale_init_value),
+ trainable=True,
+ name="layer_scale_1",
+ )
+ self.layer_scale_2 = self.add_weight(
+ shape=(self.dim,),
+ initializer=keras.initializers.Constant(value=self.config.layer_scale_init_value),
+ trainable=True,
+ name="layer_scale_2",
+ )
+
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "token_mixer", None) is not None:
+ with tf.name_scope(self.token_mixer.name):
+ self.token_mixer.build(None)
+ if getattr(self, "layernorm1", None) is not None:
+ with tf.name_scope(self.layernorm1.name):
+ self.layernorm1.build([None, None, self.dim])
+ if getattr(self, "layernorm2", None) is not None:
+ with tf.name_scope(self.layernorm2.name):
+ self.layernorm2.build([None, None, self.dim])
+ if getattr(self, "mlp", None) is not None:
+ with tf.name_scope(self.mlp.name):
+ self.mlp.build(None)
+ if getattr(self, "drop_path", None) is not None:
+ with tf.name_scope(self.drop_path.name):
+ self.drop_path.build(None)
+
+ def call(
+ self, hidden_states: tf.Tensor, output_attentions: bool = False, training: bool = False
+ ) -> Tuple[tf.Tensor]:
+ self_attention_outputs = self.token_mixer(
+ hidden_states=self.layernorm1(hidden_states, training=training),
+ output_attentions=output_attentions,
+ training=training,
+ )
+
+ attention_output = self_attention_outputs[0]
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
+
+ if self.config.use_layer_scale:
+ layer_output = hidden_states + self.drop_path(
+ tf.expand_dims(tf.expand_dims(self.layer_scale_1, 0), 0) * attention_output,
+ training=training,
+ )
+ layer_output = layer_output + self.drop_path(
+ tf.expand_dims(tf.expand_dims(self.layer_scale_2, 0), 0)
+ * self.mlp(hidden_states=self.layernorm2(inputs=layer_output, training=training), training=training),
+ training=training,
+ )
+ else:
+ layer_output = hidden_states + self.drop_path(attention_output, training=training)
+ layer_output = layer_output + self.drop_path(
+ self.mlp(hidden_states=self.layernorm2(inputs=layer_output, training=training), training=training),
+ training=training,
+ )
+
+ outputs = (layer_output,) + outputs
+
+ return outputs
+
+
+class TFEfficientFormerMeta3DLayers(keras.layers.Layer):
+ def __init__(self, config: EfficientFormerConfig, **kwargs):
+ super().__init__(**kwargs)
+ drop_paths = [
+ config.drop_path_rate * (block_idx + sum(config.depths[:-1]))
+ for block_idx in range(config.num_meta3d_blocks)
+ ]
+ self.blocks = [
+ TFEfficientFormerMeta3D(config, config.hidden_sizes[-1], drop_path=drop_path, name=f"blocks.{i}")
+ for i, drop_path in enumerate(drop_paths)
+ ]
+
+ def call(
+ self, hidden_states: tf.Tensor, output_attentions: bool = False, training: bool = False
+ ) -> Tuple[tf.Tensor]:
+ all_attention_outputs = () if output_attentions else None
+
+ for i, layer_module in enumerate(self.blocks):
+ if isinstance(hidden_states, tuple):
+ hidden_states = hidden_states[0]
+
+ hidden_states = layer_module(
+ hidden_states=hidden_states, output_attentions=output_attentions, training=training
+ )
+ if output_attentions:
+ all_attention_outputs = all_attention_outputs + (hidden_states[1],)
+
+ if output_attentions:
+ outputs = (hidden_states[0],) + all_attention_outputs
+ return outputs
+
+ return hidden_states
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "blocks", None) is not None:
+ for layer in self.blocks:
+ with tf.name_scope(layer.name):
+ layer.build(None)
+
+
+class TFEfficientFormerMeta4D(keras.layers.Layer):
+ def __init__(self, config: EfficientFormerConfig, dim: int, drop_path: float = 0.0, **kwargs):
+ super().__init__(**kwargs)
+ pool_size = config.pool_size if config.pool_size is not None else 3
+ self.token_mixer = TFEfficientFormerPooling(pool_size=pool_size, name="token_mixer")
+ self.dim = dim
+ mlp_hidden_dim = int(dim * config.mlp_expansion_ratio)
+ self.mlp = TFEfficientFormerConvMlp(
+ config=config, in_features=dim, hidden_features=mlp_hidden_dim, drop=config.hidden_dropout_prob, name="mlp"
+ )
+
+ self.drop_path = (
+ TFEfficientFormerDropPath(drop_path, name="drop_path")
+ if drop_path > 0.0
+ else keras.layers.Activation("linear", name="drop_path")
+ )
+ self.config = config
+
+ def build(self, input_shape=None):
+ self.layer_scale_1 = None
+ self.layer_scale_2 = None
+
+ if self.config.use_layer_scale:
+ self.layer_scale_1 = self.add_weight(
+ shape=(self.dim),
+ initializer=keras.initializers.Constant(value=self.config.layer_scale_init_value),
+ trainable=True,
+ name="layer_scale_1",
+ )
+ self.layer_scale_2 = self.add_weight(
+ shape=(self.dim),
+ initializer=keras.initializers.Constant(value=self.config.layer_scale_init_value),
+ trainable=True,
+ name="layer_scale_2",
+ )
+
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "token_mixer", None) is not None:
+ with tf.name_scope(self.token_mixer.name):
+ self.token_mixer.build(None)
+ if getattr(self, "mlp", None) is not None:
+ with tf.name_scope(self.mlp.name):
+ self.mlp.build(None)
+ if getattr(self, "drop_path", None) is not None:
+ with tf.name_scope(self.drop_path.name):
+ self.drop_path.build(None)
+
+ def call(self, hidden_states: tf.Tensor, training: bool = False) -> Tuple[tf.Tensor]:
+ outputs = self.token_mixer(hidden_states)
+
+ if self.config.use_layer_scale:
+ layer_output = hidden_states + self.drop_path(
+ tf.expand_dims(tf.expand_dims(self.layer_scale_1, 0), 0) * outputs,
+ training=training,
+ )
+
+ layer_output = layer_output + self.drop_path(
+ tf.expand_dims(tf.expand_dims(self.layer_scale_2, 0), 0)
+ * self.mlp(hidden_state=layer_output, training=training),
+ training=training,
+ )
+
+ else:
+ layer_output = hidden_states + self.drop_path(outputs, training=training)
+ layer_output = layer_output + self.drop_path(
+ self.mlp(hidden_state=layer_output, training=training), training=training
+ )
+
+ return layer_output
+
+
+class TFEfficientFormerMeta4DLayers(keras.layers.Layer):
+ def __init__(self, config: EfficientFormerConfig, stage_idx: int, **kwargs):
+ super().__init__(**kwargs)
+ num_layers = (
+ config.depths[stage_idx] if stage_idx != -1 else config.depths[stage_idx] - config.num_meta3d_blocks
+ )
+ drop_paths = [
+ config.drop_path_rate * (block_idx + sum(config.depths[:stage_idx])) for block_idx in range(num_layers)
+ ]
+
+ self.blocks = [
+ TFEfficientFormerMeta4D(
+ config=config, dim=config.hidden_sizes[stage_idx], drop_path=drop_paths[i], name=f"blocks.{i}"
+ )
+ for i in range(len(drop_paths))
+ ]
+
+ def call(self, hidden_states: tf.Tensor, training: bool = False) -> Tuple[tf.Tensor]:
+ for layer_module in self.blocks:
+ hidden_states = layer_module(hidden_states=hidden_states, training=training)
+ return hidden_states
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "blocks", None) is not None:
+ for layer in self.blocks:
+ with tf.name_scope(layer.name):
+ layer.build(None)
+
+
+class TFEfficientFormerIntermediateStage(keras.layers.Layer):
+ def __init__(self, config: EfficientFormerConfig, index: int, **kwargs):
+ super().__init__(**kwargs)
+ self.meta4D_layers = TFEfficientFormerMeta4DLayers(config=config, stage_idx=index, name="meta4D_layers")
+
+ def call(self, hidden_states: tf.Tensor, training: bool = False) -> Tuple[tf.Tensor]:
+ hidden_states = self.meta4D_layers(hidden_states=hidden_states, training=training)
+ return hidden_states
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "meta4D_layers", None) is not None:
+ with tf.name_scope(self.meta4D_layers.name):
+ self.meta4D_layers.build(None)
+
+
+class TFEfficientFormerLastStage(keras.layers.Layer):
+ def __init__(self, config: EfficientFormerConfig, **kwargs):
+ super().__init__(**kwargs)
+ self.meta4D_layers = TFEfficientFormerMeta4DLayers(config=config, stage_idx=-1, name="meta4D_layers")
+ self.flat = TFEfficientFormerFlat(name="flat")
+ self.meta3D_layers = TFEfficientFormerMeta3DLayers(config, name="meta3D_layers")
+
+ def call(
+ self, hidden_states: tf.Tensor, output_attentions: bool = False, training: bool = False
+ ) -> Tuple[tf.Tensor]:
+ hidden_states = self.meta4D_layers(hidden_states=hidden_states, training=training)
+ hidden_states = self.flat(hidden_states=hidden_states)
+ hidden_states = self.meta3D_layers(
+ hidden_states=hidden_states, output_attentions=output_attentions, training=training
+ )
+
+ return hidden_states
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "meta4D_layers", None) is not None:
+ with tf.name_scope(self.meta4D_layers.name):
+ self.meta4D_layers.build(None)
+ if getattr(self, "flat", None) is not None:
+ with tf.name_scope(self.flat.name):
+ self.flat.build(None)
+ if getattr(self, "meta3D_layers", None) is not None:
+ with tf.name_scope(self.meta3D_layers.name):
+ self.meta3D_layers.build(None)
+
+
+class TFEfficientFormerEncoder(keras.layers.Layer):
+ def __init__(self, config: EfficientFormerConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.config = config
+ num_intermediate_stages = len(config.depths) - 1
+ downsamples = [
+ config.downsamples[i] or config.hidden_sizes[i] != config.hidden_sizes[i + 1]
+ for i in range(num_intermediate_stages)
+ ]
+
+ intermediate_stages = []
+ layer_count = -1
+ for i in range(num_intermediate_stages):
+ layer_count += 1
+ intermediate_stages.append(
+ TFEfficientFormerIntermediateStage(config, i, name=f"intermediate_stages.{layer_count}")
+ )
+ if downsamples[i]:
+ layer_count += 1
+ intermediate_stages.append(
+ TFEfficientFormerPatchEmbeddings(
+ config,
+ config.hidden_sizes[i],
+ config.hidden_sizes[i + 1],
+ name=f"intermediate_stages.{layer_count}",
+ )
+ )
+ self.intermediate_stages = intermediate_stages
+ self.last_stage = TFEfficientFormerLastStage(config, name="last_stage")
+
+ def call(
+ self,
+ hidden_states: tf.Tensor,
+ output_hidden_states: bool,
+ output_attentions: bool,
+ return_dict: bool,
+ training: bool = False,
+ ) -> TFBaseModelOutput:
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ for layer_module in self.intermediate_stages:
+ hidden_states = layer_module(hidden_states, training=training)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_output = self.last_stage(hidden_states, output_attentions=output_attentions, training=training)
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + layer_output[1:]
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (layer_output[0],)
+
+ if not return_dict:
+ return tuple(v for v in [layer_output[0], all_hidden_states, all_self_attentions] if v is not None)
+
+ return TFBaseModelOutput(
+ last_hidden_state=layer_output[0],
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "last_stage", None) is not None:
+ with tf.name_scope(self.last_stage.name):
+ self.last_stage.build(None)
+ for layer in self.intermediate_stages:
+ with tf.name_scope(layer.name):
+ layer.build(None)
+
+
+@keras_serializable
+class TFEfficientFormerMainLayer(keras.layers.Layer):
+ config_class = EfficientFormerConfig
+
+ def __init__(self, config: EfficientFormerConfig, **kwargs) -> None:
+ super().__init__(**kwargs)
+ self.config = config
+
+ self.patch_embed = TFEfficientFormerConvStem(config, config.hidden_sizes[0], name="patch_embed")
+ self.encoder = TFEfficientFormerEncoder(config, name="encoder")
+ self.layernorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm")
+
+ @unpack_inputs
+ def call(
+ self,
+ pixel_values: Optional[tf.Tensor] = None,
+ output_attentions: Optional[tf.Tensor] = None,
+ output_hidden_states: Optional[tf.Tensor] = None,
+ return_dict: Optional[bool] = None,
+ training: bool = False,
+ ) -> Union[TFBaseModelOutput, Tuple[tf.Tensor, ...]]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ # When running on CPU, keras.layers.Conv2D and keras.layers.AveragePool2D do not
+ # support channels first NCHW format. A number of blocks contain both.
+ # So change the input format from (batch_size, num_channels, height, width) to
+ # (batch_size, height, width, num_channels) here.
+ # shape = (batch_size, in_height, in_width, in_channels=num_channels)
+ pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1))
+ embedding_output = self.patch_embed(pixel_values, training=training)
+
+ encoder_outputs = self.encoder(
+ hidden_states=embedding_output,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ sequence_output = encoder_outputs[0]
+ sequence_output = self.layernorm(sequence_output, training=training)
+
+ # Change the hidden states from (batch_size, height, width, num_channels) to
+ # (batch_size, num_channels, height, width).
+ # The hidden states are in (batch_size, height, width, num_channels)
+ # shape after all stages except the MB3D blocks.
+ if output_hidden_states:
+ hidden_states = tuple([tf.transpose(h, perm=(0, 3, 1, 2)) for h in encoder_outputs[1][:-1]]) + (
+ encoder_outputs[1][-1],
+ )
+
+ if not return_dict:
+ head_outputs = (sequence_output,)
+ return head_outputs + encoder_outputs[1:]
+
+ return TFBaseModelOutput(
+ last_hidden_state=sequence_output,
+ hidden_states=hidden_states if output_hidden_states else encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "patch_embed", None) is not None:
+ with tf.name_scope(self.patch_embed.name):
+ self.patch_embed.build(None)
+ if getattr(self, "encoder", None) is not None:
+ with tf.name_scope(self.encoder.name):
+ self.encoder.build(None)
+ if getattr(self, "layernorm", None) is not None:
+ with tf.name_scope(self.layernorm.name):
+ self.layernorm.build([None, None, self.config.hidden_sizes[-1]])
+
+
+class TFEfficientFormerPreTrainedModel(TFPreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = EfficientFormerConfig
+ base_model_prefix = "efficientformer"
+ main_input_name = "pixel_values"
+
+
+EFFICIENTFORMER_START_DOCSTRING = r"""
+ This model is a TensorFlow
+ [keras.layers.Layer](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer). Use it as a regular
+ TensorFlow Module and refer to the TensorFlow documentation for all matter related to general usage and behavior.
+
+
+ Parameters:
+ config ([`EfficientFormerConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+EFFICIENTFORMER_INPUTS_DOCSTRING = r"""
+ Args:
+ pixel_values ((`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
+ [`EfficientFormerImageProcessor.__call__`] for details.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare EfficientFormer Model transformer outputting raw hidden-states without any specific head on top.",
+ EFFICIENTFORMER_START_DOCSTRING,
+)
+class TFEfficientFormerModel(TFEfficientFormerPreTrainedModel):
+ def __init__(self, config: EfficientFormerConfig, **kwargs) -> None:
+ super().__init__(config, **kwargs)
+
+ self.efficientformer = TFEfficientFormerMainLayer(config, name="efficientformer")
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(EFFICIENTFORMER_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TFBaseModelOutputWithPooling,
+ config_class=_CONFIG_FOR_DOC,
+ modality="vision",
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
+ )
+ def call(
+ self,
+ pixel_values: Optional[tf.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ training: bool = False,
+ ) -> Union[Tuple, TFBaseModelOutput]:
+ outputs = self.efficientformer(
+ pixel_values=pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ return outputs
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "efficientformer", None) is not None:
+ with tf.name_scope(self.efficientformer.name):
+ self.efficientformer.build(None)
+
+
+@add_start_docstrings(
+ """
+ EfficientFormer Model transformer with an image classification head on top of pooled last hidden state, e.g. for
+ ImageNet.
+ """,
+ EFFICIENTFORMER_START_DOCSTRING,
+)
+class TFEfficientFormerForImageClassification(TFEfficientFormerPreTrainedModel, TFSequenceClassificationLoss):
+ def __init__(self, config: EfficientFormerConfig):
+ super().__init__(config)
+
+ self.num_labels = config.num_labels
+ self.efficientformer = TFEfficientFormerMainLayer(config, name="efficientformer")
+
+ # Classifier head
+ self.classifier = (
+ keras.layers.Dense(config.num_labels, name="classifier")
+ if config.num_labels > 0
+ else keras.layers.Activation("linear", name="classifier")
+ )
+ self.config = config
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(EFFICIENTFORMER_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_IMAGE_CLASS_CHECKPOINT,
+ output_type=TFImageClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
+ )
+ def call(
+ self,
+ pixel_values: Optional[tf.Tensor] = None,
+ labels: Optional[tf.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ training: bool = False,
+ ) -> Union[tf.Tensor, TFImageClassifierOutput]:
+ r"""
+ labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.efficientformer(
+ pixel_values=pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ sequence_output = outputs[0]
+
+ logits = self.classifier(tf.reduce_mean(sequence_output, axis=-2))
+
+ loss = None if labels is None else self.hf_compute_loss(labels, logits)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TFImageClassifierOutput(
+ loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "efficientformer", None) is not None:
+ with tf.name_scope(self.efficientformer.name):
+ self.efficientformer.build(None)
+ if getattr(self, "classifier", None) is not None:
+ if hasattr(self.classifier, "name"):
+ with tf.name_scope(self.classifier.name):
+ self.classifier.build([None, None, self.config.hidden_sizes[-1]])
+
+
+@dataclass
+class TFEfficientFormerForImageClassificationWithTeacherOutput(ModelOutput):
+ """
+ Args:
+ Output type of [`EfficientFormerForImageClassificationWithTeacher`].
+ logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`):
+ Prediction scores as the average of the cls_logits and distillation logits.
+ cls_logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`):
+ Prediction scores of the classification head (i.e. the linear layer on top of the final hidden state of the
+ class token).
+ distillation_logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`):
+ Prediction scores of the distillation head (i.e. the linear layer on top of the final hidden state of the
+ distillation token).
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when
+ `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus
+ the initial embedding outputs.
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when
+ `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
+ the self-attention heads.
+ """
+
+ logits: Optional[tf.Tensor] = None
+ cls_logits: Optional[tf.Tensor] = None
+ distillation_logits: Optional[tf.Tensor] = None
+ hidden_states: Optional[Tuple[tf.Tensor]] = None
+ attentions: Optional[Tuple[tf.Tensor]] = None
+
+
+@add_start_docstrings(
+ """
+ EfficientFormer Model transformer with image classification heads on top (a linear layer on top of the final hidden
+ state and a linear layer on top of the final hidden state of the distillation token) e.g. for ImageNet.
+
+ .. warning::
+ This model supports inference-only. Fine-tuning with distillation (i.e. with a teacher) is not yet
+ supported.
+ """,
+ EFFICIENTFORMER_START_DOCSTRING,
+)
+class TFEfficientFormerForImageClassificationWithTeacher(TFEfficientFormerPreTrainedModel):
+ def __init__(self, config: EfficientFormerConfig) -> None:
+ super().__init__(config)
+
+ self.num_labels = config.num_labels
+ self.efficientformer = TFEfficientFormerMainLayer(config, name="efficientformer")
+
+ # Classifier heads
+ self.classifier = (
+ keras.layers.Dense(config.num_labels, name="classifier")
+ if config.num_labels > 0
+ else keras.layers.Activation("linear", name="classifier")
+ )
+ self.distillation_classifier = (
+ keras.layers.Dense(config.num_labels, name="distillation_classifier")
+ if config.num_labels > 0
+ else keras.layers.Activation("linear", name="distillation_classifier")
+ )
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(EFFICIENTFORMER_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_IMAGE_CLASS_CHECKPOINT,
+ output_type=TFEfficientFormerForImageClassificationWithTeacherOutput,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
+ )
+ def call(
+ self,
+ pixel_values: Optional[tf.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ training: bool = False,
+ ) -> Union[tuple, TFEfficientFormerForImageClassificationWithTeacherOutput]:
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if training:
+ raise Exception(
+ "This model supports inference-only. Fine-tuning with distillation (i.e. with a teacher) is not yet supported."
+ )
+
+ outputs = self.efficientformer(
+ pixel_values=pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ sequence_output = outputs[0]
+
+ cls_logits = self.classifier(tf.reduce_mean(sequence_output, axis=-2))
+ distillation_logits = self.distillation_classifier(tf.reduce_mean(sequence_output, axis=-2))
+ logits = (cls_logits + distillation_logits) / 2
+
+ if not return_dict:
+ output = (logits, cls_logits, distillation_logits) + outputs[1:]
+ return output
+
+ return TFEfficientFormerForImageClassificationWithTeacherOutput(
+ logits=logits,
+ cls_logits=cls_logits,
+ distillation_logits=distillation_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "efficientformer", None) is not None:
+ with tf.name_scope(self.efficientformer.name):
+ self.efficientformer.build(None)
+ if getattr(self, "classifier", None) is not None:
+ if hasattr(self.classifier, "name"):
+ with tf.name_scope(self.classifier.name):
+ self.classifier.build([None, None, self.config.hidden_sizes[-1]])
+ if getattr(self, "distillation_classifier", None) is not None:
+ if hasattr(self.distillation_classifier, "name"):
+ with tf.name_scope(self.distillation_classifier.name):
+ self.distillation_classifier.build([None, None, self.config.hidden_sizes[-1]])
+
+
+__all__ = [
+ "TFEfficientFormerForImageClassification",
+ "TFEfficientFormerForImageClassificationWithTeacher",
+ "TFEfficientFormerModel",
+ "TFEfficientFormerPreTrainedModel",
+]
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/ernie_m/__init__.py b/docs/transformers/build/lib/transformers/models/deprecated/ernie_m/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2beb8f463ff10af33ea980599ea4d5fc05888acb
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/ernie_m/__init__.py
@@ -0,0 +1,28 @@
+# Copyright 2023 The HuggingFace and Baidu Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ....utils import _LazyModule
+from ....utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_ernie_m import *
+ from .modeling_ernie_m import *
+ from .tokenization_ernie_m import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/ernie_m/configuration_ernie_m.py b/docs/transformers/build/lib/transformers/models/deprecated/ernie_m/configuration_ernie_m.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a45106131850a82cb169fb01cecf96ff5f3e1a2
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/ernie_m/configuration_ernie_m.py
@@ -0,0 +1,114 @@
+# coding=utf-8
+# Copyright 2023 Xuan Ouyang, Shuohuan Wang, Chao Pang, Yu Sun, Hao Tian, Hua Wu, Haifeng Wang and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""ErnieM model configuration"""
+# Adapted from original paddlenlp repository.(https://github.com/PaddlePaddle/PaddleNLP/blob/develop/paddlenlp/transformers/ernie_m/configuration.py)
+
+from __future__ import annotations
+
+from typing import Dict
+
+from ....configuration_utils import PretrainedConfig
+
+
+class ErnieMConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`ErnieMModel`]. It is used to instantiate a
+ Ernie-M model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the `Ernie-M`
+ [susnato/ernie-m-base_pytorch](https://huggingface.co/susnato/ernie-m-base_pytorch) architecture.
+
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 250002):
+ Vocabulary size of `inputs_ids` in [`ErnieMModel`]. Also is the vocab size of token embedding matrix.
+ Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling
+ [`ErnieMModel`].
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the embedding layer, encoder layers and pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ intermediate_size (`int`, *optional*, defaults to 3072):
+ Dimensionality of the feed-forward (ff) layer in the encoder. Input tensors to feed-forward layers are
+ firstly projected from hidden_size to intermediate_size, and then projected back to hidden_size. Typically
+ intermediate_size is larger than hidden_size.
+ hidden_act (`str`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function in the feed-forward layer. `"gelu"`, `"relu"` and any other torch
+ supported activation functions are supported.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout probability for all fully connected layers in the embeddings and encoder.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout probability used in `MultiHeadAttention` in all encoder layers to drop some attention target.
+ max_position_embeddings (`int`, *optional*, defaults to 514):
+ The maximum value of the dimensionality of position encoding, which dictates the maximum supported length
+ of an input sequence.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the normal initializer for initializing all weight matrices. The index of padding
+ token in the token vocabulary.
+ pad_token_id (`int`, *optional*, defaults to 1):
+ Padding token id.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the layer normalization layers.
+ classifier_dropout (`float`, *optional*):
+ The dropout ratio for the classification head.
+ act_dropout (`float`, *optional*, defaults to 0.0):
+ This dropout probability is used in `ErnieMEncoderLayer` after activation.
+
+ A normal_initializer initializes weight matrices as normal distributions. See
+ `ErnieMPretrainedModel._init_weights()` for how weights are initialized in `ErnieMModel`.
+ """
+
+ model_type = "ernie_m"
+ attribute_map: Dict[str, str] = {"dropout": "classifier_dropout", "num_classes": "num_labels"}
+
+ def __init__(
+ self,
+ vocab_size: int = 250002,
+ hidden_size: int = 768,
+ num_hidden_layers: int = 12,
+ num_attention_heads: int = 12,
+ intermediate_size: int = 3072,
+ hidden_act: str = "gelu",
+ hidden_dropout_prob: float = 0.1,
+ attention_probs_dropout_prob: float = 0.1,
+ max_position_embeddings: int = 514,
+ initializer_range: float = 0.02,
+ pad_token_id: int = 1,
+ layer_norm_eps: float = 1e-05,
+ classifier_dropout=None,
+ act_dropout=0.0,
+ **kwargs,
+ ):
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.max_position_embeddings = max_position_embeddings
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.classifier_dropout = classifier_dropout
+ self.act_dropout = act_dropout
+
+
+__all__ = ["ErnieMConfig"]
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/ernie_m/modeling_ernie_m.py b/docs/transformers/build/lib/transformers/models/deprecated/ernie_m/modeling_ernie_m.py
new file mode 100644
index 0000000000000000000000000000000000000000..28c17afa3f7a2ca10d340794ac0b0a6b28e27915
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/ernie_m/modeling_ernie_m.py
@@ -0,0 +1,1058 @@
+# coding=utf-8
+# Copyright 2023 Xuan Ouyang, Shuohuan Wang, Chao Pang, Yu Sun, Hao Tian, Hua Wu, Haifeng Wang The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch ErnieM model."""
+
+import math
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn, tensor
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from ....activations import ACT2FN
+from ....modeling_outputs import (
+ BaseModelOutputWithPastAndCrossAttentions,
+ BaseModelOutputWithPoolingAndCrossAttentions,
+ MultipleChoiceModelOutput,
+ QuestionAnsweringModelOutput,
+ SequenceClassifierOutput,
+ TokenClassifierOutput,
+)
+from ....modeling_utils import PreTrainedModel
+from ....pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
+from ....utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
+from .configuration_ernie_m import ErnieMConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "susnato/ernie-m-base_pytorch"
+_CONFIG_FOR_DOC = "ErnieMConfig"
+_TOKENIZER_FOR_DOC = "ErnieMTokenizer"
+
+
+# Adapted from paddlenlp.transformers.ernie_m.modeling.ErnieEmbeddings
+class ErnieMEmbeddings(nn.Module):
+ """Construct the embeddings from word and position embeddings."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
+ self.position_embeddings = nn.Embedding(
+ config.max_position_embeddings, config.hidden_size, padding_idx=config.pad_token_id
+ )
+ self.layer_norm = nn.LayerNorm(normalized_shape=config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(p=config.hidden_dropout_prob)
+ self.padding_idx = config.pad_token_id
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.LongTensor] = None,
+ past_key_values_length: int = 0,
+ ) -> torch.Tensor:
+ if inputs_embeds is None:
+ inputs_embeds = self.word_embeddings(input_ids)
+ if position_ids is None:
+ input_shape = inputs_embeds.size()[:-1]
+ ones = torch.ones(input_shape, dtype=torch.int64, device=inputs_embeds.device)
+ seq_length = torch.cumsum(ones, dim=1)
+ position_ids = seq_length - ones
+
+ if past_key_values_length > 0:
+ position_ids = position_ids + past_key_values_length
+ # to mimic paddlenlp implementation
+ position_ids += 2
+ position_embeddings = self.position_embeddings(position_ids)
+ embeddings = inputs_embeds + position_embeddings
+ embeddings = self.layer_norm(embeddings)
+ embeddings = self.dropout(embeddings)
+
+ return embeddings
+
+
+class ErnieMSelfAttention(nn.Module):
+ def __init__(self, config, position_embedding_type=None):
+ super().__init__()
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+ raise ValueError(
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
+ f"heads ({config.num_attention_heads})"
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.q_proj = nn.Linear(config.hidden_size, self.all_head_size)
+ self.k_proj = nn.Linear(config.hidden_size, self.all_head_size)
+ self.v_proj = nn.Linear(config.hidden_size, self.all_head_size)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+ self.position_embedding_type = position_embedding_type or getattr(
+ config, "position_embedding_type", "absolute"
+ )
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+ self.max_position_embeddings = config.max_position_embeddings
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
+
+ self.is_decoder = config.is_decoder
+
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
+ x = x.view(new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor]:
+ mixed_query_layer = self.q_proj(hidden_states)
+
+ # If this is instantiated as a cross-attention module, the keys
+ # and values come from an encoder; the attention mask needs to be
+ # such that the encoder's padding tokens are not attended to.
+ is_cross_attention = encoder_hidden_states is not None
+
+ if is_cross_attention and past_key_value is not None:
+ # reuse k,v, cross_attentions
+ key_layer = past_key_value[0]
+ value_layer = past_key_value[1]
+ attention_mask = encoder_attention_mask
+ elif is_cross_attention:
+ key_layer = self.transpose_for_scores(self.k_proj(encoder_hidden_states))
+ value_layer = self.transpose_for_scores(self.v_proj(encoder_hidden_states))
+ attention_mask = encoder_attention_mask
+ elif past_key_value is not None:
+ key_layer = self.transpose_for_scores(self.k_proj(hidden_states))
+ value_layer = self.transpose_for_scores(self.v_proj(hidden_states))
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
+ else:
+ key_layer = self.transpose_for_scores(self.k_proj(hidden_states))
+ value_layer = self.transpose_for_scores(self.v_proj(hidden_states))
+
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+
+ use_cache = past_key_value is not None
+ if self.is_decoder:
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_layer, value_layer)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+ query_length, key_length = query_layer.shape[2], key_layer.shape[2]
+ if use_cache:
+ position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
+ -1, 1
+ )
+ else:
+ position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
+ position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
+ distance = position_ids_l - position_ids_r
+
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
+
+ if self.position_embedding_type == "relative_key":
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+ attention_scores = attention_scores + relative_position_scores
+ elif self.position_embedding_type == "relative_key_query":
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
+
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+ if attention_mask is not None:
+ # Apply the attention mask is (precomputed for all layers in ErnieMModel forward() function)
+ attention_scores = attention_scores + attention_mask
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs = attention_probs * head_mask
+
+ context_layer = torch.matmul(attention_probs, value_layer)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(new_context_layer_shape)
+
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+ if self.is_decoder:
+ outputs = outputs + (past_key_value,)
+ return outputs
+
+
+class ErnieMAttention(nn.Module):
+ def __init__(self, config, position_embedding_type=None):
+ super().__init__()
+ self.self_attn = ErnieMSelfAttention(config, position_embedding_type=position_embedding_type)
+ self.out_proj = nn.Linear(config.hidden_size, config.hidden_size)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.self_attn.num_attention_heads, self.self_attn.attention_head_size, self.pruned_heads
+ )
+
+ # Prune linear layers
+ self.self_attn.q_proj = prune_linear_layer(self.self_attn.q_proj, index)
+ self.self_attn.k_proj = prune_linear_layer(self.self_attn.k_proj, index)
+ self.self_attn.v_proj = prune_linear_layer(self.self_attn.v_proj, index)
+ self.out_proj = prune_linear_layer(self.out_proj, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.self_attn.num_attention_heads = self.self_attn.num_attention_heads - len(heads)
+ self.self_attn.all_head_size = self.self_attn.attention_head_size * self.self_attn.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor]:
+ self_outputs = self.self_attn(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ )
+ attention_output = self.out_proj(self_outputs[0])
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+ return outputs
+
+
+class ErnieMEncoderLayer(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ # to mimic paddlenlp implementation
+ dropout = 0.1 if config.hidden_dropout_prob is None else config.hidden_dropout_prob
+ act_dropout = config.hidden_dropout_prob if config.act_dropout is None else config.act_dropout
+
+ self.self_attn = ErnieMAttention(config)
+ self.linear1 = nn.Linear(config.hidden_size, config.intermediate_size)
+ self.dropout = nn.Dropout(act_dropout)
+ self.linear2 = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+ if isinstance(config.hidden_act, str):
+ self.activation = ACT2FN[config.hidden_act]
+ else:
+ self.activation = config.hidden_act
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ output_attentions: Optional[bool] = True,
+ ):
+ residual = hidden_states
+ if output_attentions:
+ hidden_states, attention_opt_weights = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ )
+
+ else:
+ hidden_states = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ )
+ hidden_states = residual + self.dropout1(hidden_states)
+ hidden_states = self.norm1(hidden_states)
+ residual = hidden_states
+
+ hidden_states = self.linear1(hidden_states)
+ hidden_states = self.activation(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.linear2(hidden_states)
+ hidden_states = residual + self.dropout2(hidden_states)
+ hidden_states = self.norm2(hidden_states)
+
+ if output_attentions:
+ return hidden_states, attention_opt_weights
+ else:
+ return hidden_states
+
+
+class ErnieMEncoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.layers = nn.ModuleList([ErnieMEncoderLayer(config) for _ in range(config.num_hidden_layers)])
+
+ def forward(
+ self,
+ input_embeds: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ output_attentions: Optional[bool] = False,
+ output_hidden_states: Optional[bool] = False,
+ return_dict: Optional[bool] = True,
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
+ hidden_states = () if output_hidden_states else None
+ attentions = () if output_attentions else None
+
+ output = input_embeds
+ if output_hidden_states:
+ hidden_states = hidden_states + (output,)
+ for i, layer in enumerate(self.layers):
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+ past_key_value = past_key_values[i] if past_key_values is not None else None
+
+ output, opt_attn_weights = layer(
+ hidden_states=output,
+ attention_mask=attention_mask,
+ head_mask=layer_head_mask,
+ past_key_value=past_key_value,
+ )
+
+ if output_hidden_states:
+ hidden_states = hidden_states + (output,)
+ if output_attentions:
+ attentions = attentions + (opt_attn_weights,)
+
+ last_hidden_state = output
+ if not return_dict:
+ return tuple(v for v in [last_hidden_state, hidden_states, attentions] if v is not None)
+
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=last_hidden_state, hidden_states=hidden_states, attentions=attentions
+ )
+
+
+class ErnieMPooler(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.activation = nn.Tanh()
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ # We "pool" the model by simply taking the hidden state corresponding
+ # to the first token.
+ first_token_tensor = hidden_states[:, 0]
+ pooled_output = self.dense(first_token_tensor)
+ pooled_output = self.activation(pooled_output)
+ return pooled_output
+
+
+class ErnieMPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = ErnieMConfig
+ base_model_prefix = "ernie_m"
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, nn.Linear):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+
+ERNIE_M_START_DOCSTRING = r"""
+
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
+ it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+ behavior.
+
+ Parameters:
+ config ([`ErnieMConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+ERNIE_M_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `({0})`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`ErnieMTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare ErnieM Model transformer outputting raw hidden-states without any specific head on top.",
+ ERNIE_M_START_DOCSTRING,
+)
+class ErnieMModel(ErnieMPreTrainedModel):
+ def __init__(self, config, add_pooling_layer=True):
+ super(ErnieMModel, self).__init__(config)
+ self.initializer_range = config.initializer_range
+ self.embeddings = ErnieMEmbeddings(config)
+ self.encoder = ErnieMEncoder(config)
+ self.pooler = ErnieMPooler(config) if add_pooling_layer else None
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embeddings.word_embeddings
+
+ def set_input_embeddings(self, value):
+ self.embeddings.word_embeddings = value
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layers[layer].self_attn.prune_heads(heads)
+
+ @add_start_docstrings_to_model_forward(ERNIE_M_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ processor_class=_TOKENIZER_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=BaseModelOutputWithPastAndCrossAttentions,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[tensor] = None,
+ position_ids: Optional[tensor] = None,
+ attention_mask: Optional[tensor] = None,
+ head_mask: Optional[tensor] = None,
+ inputs_embeds: Optional[tensor] = None,
+ past_key_values: Optional[Tuple[Tuple[tensor]]] = None,
+ use_cache: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.FloatTensor], BaseModelOutputWithPoolingAndCrossAttentions]:
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time.")
+
+ # init the default bool value
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
+
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ past_key_values_length = 0
+ if past_key_values is not None:
+ past_key_values_length = past_key_values[0][0].shape[2]
+
+ # Adapted from paddlenlp.transformers.ernie_m.ErnieMModel
+ if attention_mask is None:
+ attention_mask = (input_ids == self.config.pad_token_id).to(torch.float32)
+ attention_mask *= torch.finfo(attention_mask.dtype).min
+ if past_key_values is not None:
+ batch_size = past_key_values[0][0].shape[0]
+ past_mask = torch.zeros([batch_size, 1, 1, past_key_values_length], dtype=attention_mask.dtype)
+ attention_mask = torch.concat([past_mask, attention_mask], dim=-1)
+ # For 2D attention_mask from tokenizer
+ elif attention_mask.ndim == 2:
+ attention_mask = attention_mask.to(torch.float32)
+ attention_mask = 1.0 - attention_mask
+ attention_mask *= torch.finfo(attention_mask.dtype).min
+
+ extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(1)
+
+ embedding_output = self.embeddings(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ inputs_embeds=inputs_embeds,
+ past_key_values_length=past_key_values_length,
+ )
+ encoder_outputs = self.encoder(
+ embedding_output,
+ attention_mask=extended_attention_mask,
+ head_mask=head_mask,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ if not return_dict:
+ sequence_output = encoder_outputs[0]
+ pooler_output = self.pooler(sequence_output) if self.pooler is not None else None
+ return (sequence_output, pooler_output) + encoder_outputs[1:]
+
+ sequence_output = encoder_outputs["last_hidden_state"]
+ pooler_output = self.pooler(sequence_output) if self.pooler is not None else None
+ hidden_states = None if not output_hidden_states else encoder_outputs["hidden_states"]
+ attentions = None if not output_attentions else encoder_outputs["attentions"]
+
+ return BaseModelOutputWithPoolingAndCrossAttentions(
+ last_hidden_state=sequence_output,
+ pooler_output=pooler_output,
+ hidden_states=hidden_states,
+ attentions=attentions,
+ )
+
+
+@add_start_docstrings(
+ """ErnieM Model transformer with a sequence classification/regression head on top (a linear layer on top of
+ the pooled output) e.g. for GLUE tasks.""",
+ ERNIE_M_START_DOCSTRING,
+)
+class ErnieMForSequenceClassification(ErnieMPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.config = config
+
+ self.ernie_m = ErnieMModel(config)
+ classifier_dropout = (
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
+ )
+ self.dropout = nn.Dropout(classifier_dropout)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(ERNIE_M_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ processor_class=_TOKENIZER_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=SequenceClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ past_key_values: Optional[List[torch.Tensor]] = None,
+ use_cache: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ return_dict: Optional[bool] = True,
+ labels: Optional[torch.Tensor] = None,
+ ) -> Union[Tuple[torch.FloatTensor], SequenceClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.ernie_m(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ past_key_values=past_key_values,
+ output_hidden_states=output_hidden_states,
+ output_attentions=output_attentions,
+ return_dict=return_dict,
+ )
+
+ pooled_output = outputs[1]
+
+ pooled_output = self.dropout(pooled_output)
+ logits = self.classifier(pooled_output)
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """ErnieM Model with a multiple choice classification head on top (a linear layer on top of
+ the pooled output and a softmax) e.g. for RocStories/SWAG tasks.""",
+ ERNIE_M_START_DOCSTRING,
+)
+class ErnieMForMultipleChoice(ErnieMPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.ernie_m = ErnieMModel(config)
+ classifier_dropout = (
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
+ )
+ self.dropout = nn.Dropout(classifier_dropout)
+ self.classifier = nn.Linear(config.hidden_size, 1)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(ERNIE_M_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=MultipleChoiceModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = True,
+ ) -> Union[Tuple[torch.FloatTensor], MultipleChoiceModelOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
+ num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
+ `input_ids` above)
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
+
+ input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
+ position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
+ inputs_embeds = (
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
+ if inputs_embeds is not None
+ else None
+ )
+
+ outputs = self.ernie_m(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ pooled_output = outputs[1]
+
+ pooled_output = self.dropout(pooled_output)
+ logits = self.classifier(pooled_output)
+ reshaped_logits = logits.view(-1, num_choices)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(reshaped_logits, labels)
+
+ if not return_dict:
+ output = (reshaped_logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return MultipleChoiceModelOutput(
+ loss=loss,
+ logits=reshaped_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """ErnieM Model with a token classification head on top (a linear layer on top of
+ the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks.""",
+ ERNIE_M_START_DOCSTRING,
+)
+class ErnieMForTokenClassification(ErnieMPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.ernie_m = ErnieMModel(config, add_pooling_layer=False)
+ classifier_dropout = (
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
+ )
+ self.dropout = nn.Dropout(classifier_dropout)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(ERNIE_M_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ processor_class=_TOKENIZER_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TokenClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ past_key_values: Optional[List[torch.Tensor]] = None,
+ output_hidden_states: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ return_dict: Optional[bool] = True,
+ labels: Optional[torch.Tensor] = None,
+ ) -> Union[Tuple[torch.FloatTensor], TokenClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.ernie_m(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ sequence_output = self.dropout(sequence_output)
+ logits = self.classifier(sequence_output)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """ErnieM Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
+ layers on top of the hidden-states output to compute `span start logits` and `span end logits`).""",
+ ERNIE_M_START_DOCSTRING,
+)
+class ErnieMForQuestionAnswering(ErnieMPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.ernie_m = ErnieMModel(config, add_pooling_layer=False)
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(ERNIE_M_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ processor_class=_TOKENIZER_FOR_DOC,
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=QuestionAnsweringModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ start_positions: Optional[torch.Tensor] = None,
+ end_positions: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = True,
+ ) -> Union[Tuple[torch.FloatTensor], QuestionAnsweringModelOutput]:
+ r"""
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
+ are not taken into account for computing the loss.
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
+ are not taken into account for computing the loss.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.ernie_m(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ logits = self.qa_outputs(sequence_output)
+ start_logits, end_logits = logits.split(1, dim=-1)
+ start_logits = start_logits.squeeze(-1).contiguous()
+ end_logits = end_logits.squeeze(-1).contiguous()
+
+ total_loss = None
+ if start_positions is not None and end_positions is not None:
+ # If we are on multi-GPU, split add a dimension
+ if len(start_positions.size()) > 1:
+ start_positions = start_positions.squeeze(-1)
+ if len(end_positions.size()) > 1:
+ end_positions = end_positions.squeeze(-1)
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
+ ignored_index = start_logits.size(1)
+ start_positions = start_positions.clamp(0, ignored_index)
+ end_positions = end_positions.clamp(0, ignored_index)
+
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
+ start_loss = loss_fct(start_logits, start_positions)
+ end_loss = loss_fct(end_logits, end_positions)
+ total_loss = (start_loss + end_loss) / 2
+
+ if not return_dict:
+ output = (start_logits, end_logits) + outputs[2:]
+ return ((total_loss,) + output) if total_loss is not None else output
+
+ return QuestionAnsweringModelOutput(
+ loss=total_loss,
+ start_logits=start_logits,
+ end_logits=end_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """ErnieMForInformationExtraction is a Ernie-M Model with two linear layer on top of the hidden-states output to
+ compute `start_prob` and `end_prob`, designed for Universal Information Extraction.""",
+ ERNIE_M_START_DOCSTRING,
+)
+class ErnieMForInformationExtraction(ErnieMPreTrainedModel):
+ def __init__(self, config):
+ super(ErnieMForInformationExtraction, self).__init__(config)
+ self.ernie_m = ErnieMModel(config)
+ self.linear_start = nn.Linear(config.hidden_size, 1)
+ self.linear_end = nn.Linear(config.hidden_size, 1)
+ self.sigmoid = nn.Sigmoid()
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(ERNIE_M_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ start_positions: Optional[torch.Tensor] = None,
+ end_positions: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = True,
+ ) -> Union[Tuple[torch.FloatTensor], QuestionAnsweringModelOutput]:
+ r"""
+ start_positions (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for position (index) for computing the start_positions loss. Position outside of the sequence are
+ not taken into account for computing the loss.
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) for computing the end_positions loss. Position outside of the sequence are not
+ taken into account for computing the loss.
+ """
+
+ result = self.ernie_m(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ if return_dict:
+ sequence_output = result.last_hidden_state
+ elif not return_dict:
+ sequence_output = result[0]
+
+ start_logits = self.linear_start(sequence_output)
+ start_logits = start_logits.squeeze(-1)
+ end_logits = self.linear_end(sequence_output)
+ end_logits = end_logits.squeeze(-1)
+
+ total_loss = None
+ if start_positions is not None and end_positions is not None:
+ # If we are on multi-GPU, split add a dimension
+ if len(start_positions.size()) > 1:
+ start_positions = start_positions.squeeze(-1)
+ if len(end_positions.size()) > 1:
+ end_positions = end_positions.squeeze(-1)
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
+ ignored_index = start_logits.size(1)
+ start_positions = start_positions.clamp(0, ignored_index)
+ end_positions = end_positions.clamp(0, ignored_index)
+
+ loss_fct = BCEWithLogitsLoss()
+ start_loss = loss_fct(start_logits, start_positions)
+ end_loss = loss_fct(end_logits, end_positions)
+ total_loss = (start_loss + end_loss) / 2
+
+ if not return_dict:
+ return tuple(
+ i
+ for i in [total_loss, start_logits, end_logits, result.hidden_states, result.attentions]
+ if i is not None
+ )
+
+ return QuestionAnsweringModelOutput(
+ loss=total_loss,
+ start_logits=start_logits,
+ end_logits=end_logits,
+ hidden_states=result.hidden_states,
+ attentions=result.attentions,
+ )
+
+
+__all__ = [
+ "ErnieMForMultipleChoice",
+ "ErnieMForQuestionAnswering",
+ "ErnieMForSequenceClassification",
+ "ErnieMForTokenClassification",
+ "ErnieMModel",
+ "ErnieMPreTrainedModel",
+ "ErnieMForInformationExtraction",
+]
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/ernie_m/tokenization_ernie_m.py b/docs/transformers/build/lib/transformers/models/deprecated/ernie_m/tokenization_ernie_m.py
new file mode 100644
index 0000000000000000000000000000000000000000..44bc197a4f7c463c71eb64dd941dacce51148675
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/ernie_m/tokenization_ernie_m.py
@@ -0,0 +1,410 @@
+# coding=utf-8
+# Copyright 2023 Xuan Ouyang, Shuohuan Wang, Chao Pang, Yu Sun, Hao Tian, Hua Wu, Haifeng Wang and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization classes for Ernie-M."""
+
+import io
+import os
+import unicodedata
+from typing import Any, Dict, List, Optional, Tuple
+
+import sentencepiece as spm
+
+from ....tokenization_utils import PreTrainedTokenizer
+from ....utils import logging
+from ....utils.import_utils import requires
+
+
+logger = logging.get_logger(__name__)
+
+SPIECE_UNDERLINE = "▁"
+
+VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "sentencepiece_model_ckpt": "sentencepiece.bpe.model"}
+
+RESOURCE_FILES_NAMES = {
+ "sentencepiece_model_file": "sentencepiece.bpe.model",
+ "vocab_file": "vocab.txt",
+}
+
+
+# Adapted from paddlenlp.transformers.ernie_m.tokenizer.ErnieMTokenizer
+@requires(backends=("sentencepiece",))
+class ErnieMTokenizer(PreTrainedTokenizer):
+ r"""
+ Constructs a Ernie-M tokenizer. It uses the `sentencepiece` tools to cut the words to sub-words.
+
+ Args:
+ sentencepiece_model_file (`str`):
+ The file path of sentencepiece model.
+ vocab_file (`str`, *optional*):
+ The file path of the vocabulary.
+ do_lower_case (`str`, *optional*, defaults to `True`):
+ Whether or not to lowercase the input when tokenizing.
+ unk_token (`str`, *optional*, defaults to `"[UNK]"`):
+ A special token representing the `unknown (out-of-vocabulary)` token. An unknown token is set to be
+ `unk_token` inorder to be converted to an ID.
+ sep_token (`str`, *optional*, defaults to `"[SEP]"`):
+ A special token separating two different sentences in the same input.
+ pad_token (`str`, *optional*, defaults to `"[PAD]"`):
+ A special token used to make arrays of tokens the same size for batching purposes.
+ cls_token (`str`, *optional*, defaults to `"[CLS]"`):
+ A special token used for sequence classification. It is the last token of the sequence when built with
+ special tokens.
+ mask_token (`str`, *optional*, defaults to `"[MASK]"`):
+ A special token representing a masked token. This is the token used in the masked language modeling task
+ which the model tries to predict the original unmasked ones.
+ """
+
+ # Ernie-M model doesn't have token_type embedding.
+ model_input_names: List[str] = ["input_ids"]
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ resource_files_names = RESOURCE_FILES_NAMES
+
+ def __init__(
+ self,
+ sentencepiece_model_ckpt,
+ vocab_file=None,
+ do_lower_case=False,
+ encoding="utf8",
+ unk_token="[UNK]",
+ sep_token="[SEP]",
+ pad_token="[PAD]",
+ cls_token="[CLS]",
+ mask_token="[MASK]",
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
+ **kwargs,
+ ) -> None:
+ # Mask token behave like a normal word, i.e. include the space before it and
+ # is included in the raw text, there should be a match in a non-normalized sentence.
+
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
+
+ self.do_lower_case = do_lower_case
+ self.sentencepiece_model_ckpt = sentencepiece_model_ckpt
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
+ self.sp_model.Load(sentencepiece_model_ckpt)
+
+ # to mimic paddlenlp.transformers.ernie_m.tokenizer.ErnieMTokenizer functioning
+ if vocab_file is not None:
+ self.vocab = self.load_vocab(filepath=vocab_file)
+ else:
+ self.vocab = {self.sp_model.id_to_piece(id): id for id in range(self.sp_model.get_piece_size())}
+ self.reverse_vocab = {v: k for k, v in self.vocab.items()}
+
+ super().__init__(
+ do_lower_case=do_lower_case,
+ unk_token=unk_token,
+ sep_token=sep_token,
+ pad_token=pad_token,
+ cls_token=cls_token,
+ mask_token=mask_token,
+ vocab_file=vocab_file,
+ encoding=encoding,
+ sp_model_kwargs=self.sp_model_kwargs,
+ **kwargs,
+ )
+
+ def get_offset_mapping(self, text):
+ if text is None:
+ return None
+
+ split_tokens = self.tokenize(text)
+ normalized_text, char_mapping = "", []
+
+ for i, ch in enumerate(text):
+ if ch in self.SP_CHAR_MAPPING:
+ ch = self.SP_CHAR_MAPPING.get(ch)
+ else:
+ ch = unicodedata.normalize("NFKC", ch)
+ if self.is_whitespace(ch):
+ continue
+ normalized_text += ch
+ char_mapping.extend([i] * len(ch))
+
+ text, token_mapping, offset = normalized_text, [], 0
+
+ if self.do_lower_case:
+ text = text.lower()
+
+ for token in split_tokens:
+ if token[:1] == "▁":
+ token = token[1:]
+ start = text[offset:].index(token) + offset
+ end = start + len(token)
+
+ token_mapping.append((char_mapping[start], char_mapping[end - 1] + 1))
+ offset = end
+ return token_mapping
+
+ @property
+ def vocab_size(self):
+ return len(self.vocab)
+
+ def get_vocab(self):
+ return dict(self.vocab, **self.added_tokens_encoder)
+
+ def __getstate__(self):
+ state = self.__dict__.copy()
+ state["sp_model"] = None
+ return state
+
+ def __setstate__(self, d):
+ self.__dict__ = d
+
+ # for backward compatibility
+ if not hasattr(self, "sp_model_kwargs"):
+ self.sp_model_kwargs = {}
+
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
+ self.sp_model.Load(self.sentencepiece_model_ckpt)
+
+ def clean_text(self, text):
+ """Performs invalid character removal and whitespace cleanup on text."""
+ return "".join((self.SP_CHAR_MAPPING.get(c, c) for c in text))
+
+ def _tokenize(self, text, enable_sampling=False, nbest_size=64, alpha=0.1):
+ """Tokenize a string."""
+
+ if self.sp_model_kwargs.get("enable_sampling") is True:
+ enable_sampling = True
+ if self.sp_model_kwargs.get("alpha") is not None:
+ alpha = self.sp_model_kwargs.get("alpha")
+ if self.sp_model_kwargs.get("nbest_size") is not None:
+ nbest_size = self.sp_model_kwargs.get("nbest_size")
+
+ if not enable_sampling:
+ pieces = self.sp_model.EncodeAsPieces(text)
+ else:
+ pieces = self.sp_model.SampleEncodeAsPieces(text, nbest_size, alpha)
+ new_pieces = []
+ for pi, piece in enumerate(pieces):
+ if piece == SPIECE_UNDERLINE:
+ if not pieces[pi + 1].startswith(SPIECE_UNDERLINE) and pi != 0:
+ new_pieces.append(SPIECE_UNDERLINE)
+ continue
+ else:
+ continue
+ lst_i = 0
+ for i, chunk in enumerate(piece):
+ if chunk == SPIECE_UNDERLINE:
+ continue
+ if self.is_ch_char(chunk) or self.is_punct(chunk):
+ if i > lst_i and piece[lst_i:i] != SPIECE_UNDERLINE:
+ new_pieces.append(piece[lst_i:i])
+ new_pieces.append(chunk)
+ lst_i = i + 1
+ elif chunk.isdigit() and i > 0 and not piece[i - 1].isdigit():
+ if i > lst_i and piece[lst_i:i] != SPIECE_UNDERLINE:
+ new_pieces.append(piece[lst_i:i])
+ lst_i = i
+ elif not chunk.isdigit() and i > 0 and piece[i - 1].isdigit():
+ if i > lst_i and piece[lst_i:i] != SPIECE_UNDERLINE:
+ new_pieces.append(piece[lst_i:i])
+ lst_i = i
+ if len(piece) > lst_i:
+ new_pieces.append(piece[lst_i:])
+ return new_pieces
+
+ def convert_tokens_to_string(self, tokens):
+ """Converts a sequence of tokens (strings for sub-words) in a single string."""
+ out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
+ return out_string
+
+ def convert_ids_to_string(self, ids):
+ """
+ Converts a sequence of tokens (strings for sub-words) in a single string.
+ """
+ tokens = self.convert_ids_to_tokens(ids)
+ out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
+ return out_string
+
+ # to mimic paddlenlp.transformers.ernie_m.tokenizer.ErnieMTokenizer functioning
+ def _convert_token_to_id(self, token):
+ return self.vocab.get(token, self.vocab.get(self.unk_token))
+
+ # to mimic paddlenlp.transformers.ernie_m.tokenizer.ErnieMTokenizer functioning
+ def _convert_id_to_token(self, index):
+ """Converts an index (integer) in a token (str) using the vocab."""
+ return self.reverse_vocab.get(index, self.unk_token)
+
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
+ r"""
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
+ adding special tokens. An ErnieM sequence has the following format:
+
+ - single sequence: `[CLS] X [SEP]`
+ - pair of sequences: `[CLS] A [SEP] [SEP] B [SEP]`
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs to which the special tokens will be added.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+ Returns:
+ `List[int]`: List of input_id with the appropriate special tokens.
+ """
+ if token_ids_1 is None:
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
+ _cls = [self.cls_token_id]
+ _sep = [self.sep_token_id]
+ return _cls + token_ids_0 + _sep + _sep + token_ids_1 + _sep
+
+ def build_offset_mapping_with_special_tokens(self, offset_mapping_0, offset_mapping_1=None):
+ r"""
+ Build offset map from a pair of offset map by concatenating and adding offsets of special tokens. An Ernie-M
+ offset_mapping has the following format:
+
+ - single sequence: `(0,0) X (0,0)`
+ - pair of sequences: `(0,0) A (0,0) (0,0) B (0,0)`
+
+ Args:
+ offset_mapping_ids_0 (`List[tuple]`):
+ List of char offsets to which the special tokens will be added.
+ offset_mapping_ids_1 (`List[tuple]`, *optional*):
+ Optional second list of wordpiece offsets for offset mapping pairs.
+ Returns:
+ `List[tuple]`: List of wordpiece offsets with the appropriate offsets of special tokens.
+ """
+ if offset_mapping_1 is None:
+ return [(0, 0)] + offset_mapping_0 + [(0, 0)]
+
+ return [(0, 0)] + offset_mapping_0 + [(0, 0), (0, 0)] + offset_mapping_1 + [(0, 0)]
+
+ def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False):
+ r"""
+ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
+ special tokens using the tokenizer `encode` method.
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of ids of the first sequence.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+ already_has_special_tokens (`str`, *optional*, defaults to `False`):
+ Whether or not the token list is already formatted with special tokens for the model.
+ Returns:
+ `List[int]`:
+ The list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
+ """
+
+ if already_has_special_tokens:
+ if token_ids_1 is not None:
+ raise ValueError(
+ "You should not supply a second sequence if the provided sequence of "
+ "ids is already formatted with special tokens for the model."
+ )
+ return [1 if x in [self.sep_token_id, self.cls_token_id] else 0 for x in token_ids_0]
+
+ if token_ids_1 is not None:
+ return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
+ return [1] + ([0] * len(token_ids_0)) + [1]
+
+ def create_token_type_ids_from_sequences(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Create the token type IDs corresponding to the sequences passed. [What are token type
+ IDs?](../glossary#token-type-ids) Should be overridden in a subclass if the model has a special way of
+ building: those.
+
+ Args:
+ token_ids_0 (`List[int]`):
+ The first tokenized sequence.
+ token_ids_1 (`List[int]`, *optional*):
+ The second tokenized sequence.
+ Returns:
+ `List[int]`: The token type ids.
+ """
+ # called when `add_special_tokens` is True, so align with `build_inputs_with_special_tokens` method
+ if token_ids_1 is None:
+ # [CLS] X [SEP]
+ return (len(token_ids_0) + 2) * [0]
+
+ # [CLS] A [SEP] [SEP] B [SEP]
+ return [0] * (len(token_ids_0) + 1) + [1] * (len(token_ids_1) + 3)
+
+ def is_ch_char(self, char):
+ """
+ is_ch_char
+ """
+ if "\u4e00" <= char <= "\u9fff":
+ return True
+ return False
+
+ def is_alpha(self, char):
+ """
+ is_alpha
+ """
+ if ("a" <= char <= "z") or ("A" <= char <= "Z"):
+ return True
+ return False
+
+ def is_punct(self, char):
+ """
+ is_punct
+ """
+ if char in ",;:.?!~,;:。?!《》【】":
+ return True
+ return False
+
+ def is_whitespace(self, char):
+ """
+ is whitespace
+ """
+ if char == " " or char == "\t" or char == "\n" or char == "\r":
+ return True
+ if len(char) == 1:
+ cat = unicodedata.category(char)
+ if cat == "Zs":
+ return True
+ return False
+
+ def load_vocab(self, filepath):
+ token_to_idx = {}
+ with io.open(filepath, "r", encoding="utf-8") as f:
+ for index, line in enumerate(f):
+ token = line.rstrip("\n")
+ token_to_idx[token] = int(index)
+
+ return token_to_idx
+
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
+ index = 0
+ if os.path.isdir(save_directory):
+ vocab_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
+ )
+ else:
+ vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory
+ with open(vocab_file, "w", encoding="utf-8") as writer:
+ for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
+ if index != token_index:
+ logger.warning(
+ f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
+ " Please check that the vocabulary is not corrupted!"
+ )
+ index = token_index
+ writer.write(token + "\n")
+ index += 1
+
+ tokenizer_model_file = os.path.join(save_directory, "sentencepiece.bpe.model")
+ with open(tokenizer_model_file, "wb") as fi:
+ content_spiece_model = self.sp_model.serialized_model_proto()
+ fi.write(content_spiece_model)
+
+ return (vocab_file,)
+
+
+__all__ = ["ErnieMTokenizer"]
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/gptsan_japanese/__init__.py b/docs/transformers/build/lib/transformers/models/deprecated/gptsan_japanese/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c23b58f35012f1ddc00e2275c484810287de6f8
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/gptsan_japanese/__init__.py
@@ -0,0 +1,28 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ....utils import _LazyModule
+from ....utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_gptsan_japanese import *
+ from .modeling_gptsan_japanese import *
+ from .tokenization_gptsan_japanese import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/gptsan_japanese/configuration_gptsan_japanese.py b/docs/transformers/build/lib/transformers/models/deprecated/gptsan_japanese/configuration_gptsan_japanese.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd565810095955c1d7a6e199db93d2ef1a499173
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/gptsan_japanese/configuration_gptsan_japanese.py
@@ -0,0 +1,157 @@
+# coding=utf-8
+# Copyright 2023, HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""GPTSAN-japanese model configuration"""
+
+from ....configuration_utils import PretrainedConfig
+from ....utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class GPTSanJapaneseConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`GPTSanJapaneseModel`]. It is used to instantiate
+ a GPTSANJapanese model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the GPTSANJapanese
+ [Tanrei/GPTSAN-japanese](https://huggingface.co/Tanrei/GPTSAN-japanese) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Arguments:
+ vocab_size (`int`, *optional*, defaults to 36000):
+ Vocabulary size of the GPTSANJapanese model. Defines the number of different tokens that can be represented
+ by the `inputs_ids` passed when calling [`GPTSanJapaneseModel`].
+ max_position_embeddings (`int`, *optional*, defaults to 1280):
+ The maximum sequence length that this model might ever be used with. Defaults set this to 1280.
+ d_model (`int`, *optional*, defaults to 1024):
+ Size of the encoder layers and the pooler layer.
+ d_ff (`int`, *optional*, defaults to 8192):
+ Size of the intermediate feed forward layer in each `SwitchTransformersBlock`.
+ d_ext (`int`, *optional*, defaults to 4096):
+ Size of the intermediate feed forward layer in each Extra-layers.
+ d_spout (`int`, *optional*, defaults to 128):
+ Size of the `spout` vector.
+ num_switch_layers (`int`, *optional*, defaults to 10):
+ Number of layers in the Switch Transformer layer.
+ num_ext_layers (`int`, *optional*, defaults to 0):
+ Number of layers in the Extra-layers.
+ num_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ num_experts (`int`, *optional*, defaults to 16):
+ Number of experts for each SwitchTransformer layer.
+ expert_capacity (`int`, *optional*, defaults to 128):
+ Number of tokens that can be stored in each expert. If set to 1, the model will behave like a regular
+ Transformer.
+ dropout_rate (`float`, *optional*, defaults to 0.0):
+ The ratio for all dropout layers.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-5):
+ The epsilon used by the layer normalization layers.
+ router_bias (`bool`, *optional*, defaults to `False`):
+ Whether to add a bias to the router.
+ router_jitter_noise (`float`, *optional*, defaults to 0.0):
+ Amount of noise to add to the router. Set it to 0.0 during prediction or set small value (usually 1e-2)
+ during training.
+ router_dtype (`str`, *optional*, default to `"float32"`):
+ The `dtype` used for the routers. It is preferable to keep the `dtype` to `"float32"` as specified in the
+ *selective precision* discussion in [the paper](https://arxiv.org/abs/2101.03961).
+ router_ignore_padding_tokens (`bool`, *optional*, defaults to `False`):
+ Whether to ignore padding tokens when routing.
+ output_hidden_states (`bool`, *optional*, default to `False`):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ output_attentions (`bool`, *optional*, defaults to `False`):
+ Whether or not to return the attentions tensors of all attention layers.
+ initializer_factor (`float`, *optional*, defaults to 0.002):
+ A factor for initializing all weight matrices.
+ output_router_logits (`bool`, *optional*, default to `False`):
+ Whether or not to return the router logits of all experts.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models)
+ """
+
+ model_type = "gptsan-japanese"
+ keys_to_ignore_at_inference = [
+ "past_key_values",
+ ]
+ attribute_map = {
+ "hidden_size": "d_model",
+ "num_attention_heads": "num_heads",
+ "num_hidden_layers": "num_layers",
+ }
+
+ def __init__(
+ self,
+ vocab_size=36000,
+ max_position_embeddings=1280,
+ d_model=1024,
+ d_ff=8192,
+ d_ext=4096,
+ d_spout=128,
+ num_switch_layers=10,
+ num_ext_layers=0,
+ num_heads=16,
+ num_experts=16,
+ expert_capacity=128,
+ dropout_rate=0.0,
+ layer_norm_epsilon=1e-5,
+ router_bias=False,
+ router_jitter_noise=0.0,
+ router_dtype="float32",
+ router_ignore_padding_tokens=False,
+ output_hidden_states=False,
+ output_attentions=False,
+ initializer_factor=0.002,
+ output_router_logits=False,
+ use_cache=True,
+ separator_token_id=35998,
+ pad_token_id=35995,
+ eos_token_id=35999,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.d_model = d_model
+ self.d_ff = d_ff
+ self.d_ext = d_ext
+ self.d_spout = d_spout
+ self.num_switch_layers = num_switch_layers
+ self.num_ext_layers = num_ext_layers
+ self.num_layers = num_switch_layers + num_ext_layers
+ self.num_heads = num_heads
+ self.num_experts = num_experts
+ self.expert_capacity = expert_capacity
+ self.dropout_rate = dropout_rate
+ self.layer_norm_epsilon = layer_norm_epsilon
+ self.router_bias = router_bias
+ self.router_jitter_noise = router_jitter_noise
+ self.router_dtype = router_dtype
+ self.router_ignore_padding_tokens = router_ignore_padding_tokens
+ self.output_hidden_states = output_hidden_states
+ self.output_attentions = output_attentions
+ self.initializer_factor = initializer_factor
+ self.output_router_logits = output_router_logits
+ self.use_cache = use_cache
+
+ super().__init__(
+ separator_token_id=separator_token_id,
+ pad_token_id=pad_token_id,
+ eos_token_id=eos_token_id,
+ **kwargs,
+ )
+
+
+__all__ = ["GPTSanJapaneseConfig"]
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/gptsan_japanese/convert_gptsan_tf_checkpoint_to_pytorch.py b/docs/transformers/build/lib/transformers/models/deprecated/gptsan_japanese/convert_gptsan_tf_checkpoint_to_pytorch.py
new file mode 100644
index 0000000000000000000000000000000000000000..a84d000d44390fe6ae821fb1cdfba968d40a2b93
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/gptsan_japanese/convert_gptsan_tf_checkpoint_to_pytorch.py
@@ -0,0 +1,181 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Convert GPTSANJapanese checkpoints from the original repository to pytorch model."""
+
+import argparse
+import json
+import os
+from collections import OrderedDict
+
+import numpy as np
+import tensorflow as tf
+import torch
+
+
+def convert_tf_gptsan_to_pt(args):
+ parameter_file = os.path.join(args.tf_model_dir, "parameters.json")
+ params = json.loads(open(parameter_file).read())
+ if not params:
+ raise ValueError(
+ f"It seems that the json file at {parameter_file} is empty. Make sure you have a correct json file."
+ )
+ if not args.output.endswith(".pt"):
+ args.output = args.output + ".pt"
+ new_state = OrderedDict()
+ with tf.device("/CPU:0"):
+ reader = tf.train.load_checkpoint(args.tf_model_dir)
+ shapes = reader.get_variable_to_shape_map()
+ for key_name in shapes.keys():
+ vnp = reader.get_tensor(key_name).astype(np.float16)
+ if key_name.endswith("/adam_m") or key_name.endswith("/adam_v"):
+ continue
+ if key_name.startswith("pasts/"):
+ if key_name.startswith("pasts/mlp"):
+ player = int(key_name[9])
+ elif key_name.startswith("pasts/out"):
+ player = 8
+ name = "model.sqout.%d.weight" % (player * 2) # enter to nn.Sequencial with Tanh, so 2 at a time
+ state = vnp.transpose([1, 0]).copy() # Mesh-Tensorflow is a diagonal matrix
+ new_state[name] = torch.tensor(state)
+ elif key_name.startswith("model/moe"):
+ player = int(key_name[9:].split("/")[0])
+ if key_name.endswith("/switch_gating/kernel"):
+ name = "model.blocks.%d.feed_forward.mlp.router.classifier.weight" % player
+ state = vnp.transpose([1, 0]).copy() # Mesh-Tensorflow is a diagonal matrix
+ new_state[name] = torch.tensor(state)
+ elif key_name.endswith("/softmlp/kernel"):
+ name = "model.blocks.%d.feed_forward.soft_bypass_mlp.weight" % player
+ state = vnp.transpose([1, 0]).copy() # Mesh-Tensorflow is a diagonal matrix
+ new_state[name] = torch.tensor(state)
+ elif key_name.endswith("/wo/kernel") or key_name.endswith("/wi/kernel"):
+ nlayer = key_name[-9:-7]
+ for i in range(16):
+ name = "model.blocks.%d.feed_forward.mlp.experts.expert_%d.%s.weight" % (player, i, nlayer)
+ state = (
+ vnp[i].transpose([1, 0]).copy()
+ ) # In Mesh-Tensorflow, it is one array, so it is divided
+ new_state[name] = torch.tensor(state)
+ elif key_name.startswith("model/mlp"):
+ player = int(key_name[9:].split("/")[0])
+ if key_name.endswith("/p1/kernel"):
+ name = "model.blocks.%d.feed_forward.mlp.wi.weight" % player
+ state = vnp.transpose([1, 0]).copy() # Mesh-Tensorflow is a diagonal matrix
+ new_state[name] = torch.tensor(state)
+ elif key_name.endswith("/p1/bias"):
+ name = "model.blocks.%d.feed_forward.mlp.wi.bias" % player
+ state = vnp.copy() # same because it is one dimensional
+ new_state[name] = torch.tensor(state)
+ elif key_name.endswith("/p2/kernel"):
+ name = "model.blocks.%d.feed_forward.mlp.wo.weight" % player
+ state = vnp.transpose([1, 0]).copy() # Mesh-Tensorflow is a diagonal matrix
+ new_state[name] = torch.tensor(state)
+ elif key_name.endswith("/p2/bias"):
+ name = "model.blocks.%d.feed_forward.mlp.wo.bias" % player
+ state = vnp.copy() # same because it is one dimensional
+ new_state[name] = torch.tensor(state)
+ elif key_name.startswith("model/ln"):
+ player = int(key_name[8:].split("/")[0])
+ if key_name.endswith("/b"):
+ name = "model.blocks.%d.feed_forward.norm.bias" % player
+ state = vnp.copy() # same because it is one dimensional
+ new_state[name] = torch.tensor(state)
+ elif key_name.endswith("/g"):
+ name = "model.blocks.%d.feed_forward.norm.weight" % player
+ state = vnp.copy() # same because it is one dimensional
+ new_state[name] = torch.tensor(state)
+ elif key_name.startswith("model/att"):
+ player = int(key_name[9:].split("/")[0])
+ if key_name.endswith("/qkv/kernel"):
+ state = vnp.copy() # Compute same dimension as Mesh-tensorflow using einsum
+ state_q = state[:, 0, :, :]
+ state_k = state[:, 1, :, :]
+ state_v = state[:, 2, :, :]
+ state_q = (
+ state_q.reshape([state_q.shape[0], state_q.shape[1] * state_q.shape[2]])
+ .transpose([1, 0])
+ .copy()
+ ) # Mesh-Tensorflow is a diagonal matrix
+ state_k = (
+ state_k.reshape([state_k.shape[0], state_k.shape[1] * state_k.shape[2]])
+ .transpose([1, 0])
+ .copy()
+ ) # Mesh-Tensorflow is a diagonal matrix
+ state_v = (
+ state_v.reshape([state_v.shape[0], state_v.shape[1] * state_v.shape[2]])
+ .transpose([1, 0])
+ .copy()
+ ) # Mesh-Tensorflow is a diagonal matrix
+ name = "model.blocks.%d.self_attn.self_attn.q_proj.weight" % player
+ new_state[name] = torch.tensor(state_q)
+ name = "model.blocks.%d.self_attn.self_attn.k_proj.weight" % player
+ new_state[name] = torch.tensor(state_k)
+ name = "model.blocks.%d.self_attn.self_attn.v_proj.weight" % player
+ new_state[name] = torch.tensor(state_v)
+ elif key_name.endswith("/o/kernel"):
+ name = "model.blocks.%d.self_attn.self_attn.out_proj.weight" % player
+ state = (
+ vnp.reshape([vnp.shape[0] * vnp.shape[1], vnp.shape[2]]).transpose([1, 0]).copy()
+ ) # Mesh-Tensorflow is a diagonal matrix
+ new_state[name] = torch.tensor(state)
+ elif key_name.startswith("model/an"):
+ player = int(key_name[8:].split("/")[0])
+ if key_name.endswith("/b"):
+ name = "model.blocks.%d.self_attn.norm.bias" % player
+ state = vnp.copy() # same because it is one dimensional
+ new_state[name] = torch.tensor(state)
+ elif key_name.endswith("/g"):
+ name = "model.blocks.%d.self_attn.norm.weight" % player
+ state = vnp.copy() # same because it is one dimensional
+ new_state[name] = torch.tensor(state)
+ elif (
+ key_name.startswith("model/wte")
+ or key_name.startswith("model/wpe")
+ or key_name.startswith("model/ete")
+ ):
+ nlayer = {"wte": "embed_tokens", "wpe": "position_embeddings", "ete": "extra_position_embeddings"}[
+ key_name[-3:]
+ ]
+ name = "model.%s.weight" % nlayer
+ state = vnp.copy() # same in embedded
+ new_state[name] = torch.tensor(state)
+ if key_name.startswith("model/wte"):
+ name = "lm_head.weight"
+ state = vnp.copy() # same in embedded
+ new_state[name] = torch.tensor(state)
+ elif key_name.startswith("model/wob"):
+ name = "final_logits_bias"
+ state = vnp.copy() # same in embedded
+ state = state.reshape((1, -1))
+ new_state[name] = torch.tensor(state)
+ elif key_name == "model/dense/kernel":
+ name = "model.last_project.weight"
+ state = vnp.transpose([1, 0]).copy() # Mesh-Tensorflow is a diagonal matrix
+ new_state[name] = torch.tensor(state)
+ elif key_name == "model/dense_1/bias":
+ name = "model.last_project.bias"
+ state = vnp.copy() # same because it is one dimensional
+ new_state[name] = torch.tensor(state)
+ torch.save(new_state, args.output)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(
+ description="model converter.", formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+ parser.add_argument("--tf_model_dir", metavar="PATH", type=str, required=True, help="import model")
+ parser.add_argument("--output", metavar="PATH", type=str, required=True, help="output model")
+ args = parser.parse_args()
+ convert_tf_gptsan_to_pt(args)
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/gptsan_japanese/modeling_gptsan_japanese.py b/docs/transformers/build/lib/transformers/models/deprecated/gptsan_japanese/modeling_gptsan_japanese.py
new file mode 100644
index 0000000000000000000000000000000000000000..a35ea4a31199887a32e6912bfb2d3d6036635b38
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/gptsan_japanese/modeling_gptsan_japanese.py
@@ -0,0 +1,1337 @@
+# coding=utf-8
+# Copyright 2023 Toshiyuki Sakamoto(tanreinama) and HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch GPTSANJapanese model."""
+
+import copy
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+
+from ....activations import ACT2FN
+from ....modeling_outputs import MoECausalLMOutputWithPast, MoEModelOutputWithPastAndCrossAttentions
+from ....modeling_utils import PreTrainedModel
+from ....utils import (
+ DUMMY_INPUTS,
+ DUMMY_MASK,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ is_torch_fx_proxy,
+ logging,
+)
+from .configuration_gptsan_japanese import GPTSanJapaneseConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "GPTSanJapaneseConfig"
+_CHECKPOINT_FOR_DOC = "Tanrei/GPTSAN-japanese"
+
+####################################################
+# This dict contains ids and associated url
+# for the pretrained weights provided with the models
+####################################################
+
+
+def router_z_loss_func(router_logits: torch.Tensor) -> float:
+ r"""
+ Compute the router z-loss implemented in PyTorch.
+
+ The router z-loss was introduced in [Designing Effective Sparse Expert Models](https://arxiv.org/abs/2202.08906).
+ It encourages router logits to remain small in an effort to improve stability.
+
+ Args:
+ router_logits (`float`):
+ Input logits of shape [batch_size, sequence_length, num_experts]
+
+ Returns:
+ Scalar router z-loss.
+ """
+ num_groups, tokens_per_group, _ = router_logits.shape
+ log_z = torch.logsumexp(router_logits, dim=-1)
+ z_loss = log_z**2
+ return torch.sum(z_loss) / (num_groups * tokens_per_group)
+
+
+def load_balancing_loss_func(router_probs: torch.Tensor, expert_indices: torch.Tensor) -> float:
+ r"""
+ Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
+
+ See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
+ function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
+ experts is too unbalanced.
+
+ Args:
+ router_probs (`torch.Tensor`):
+ Probability assigned to each expert per token. Shape: [batch_size, seqeunce_length, num_experts].
+ expert_indices (`torch.Tensor`):
+ Indices tensor of shape [batch_size, seqeunce_length] identifying the selected expert for a given token.
+
+ Returns:
+ The auxiliary loss.
+ """
+ num_experts = router_probs.shape[-1]
+
+ # cast the expert indices to int64, otherwise one-hot encoding will fail
+ if expert_indices.dtype != torch.int64:
+ expert_indices = expert_indices.to(torch.int64)
+
+ if len(expert_indices.shape) == 2:
+ expert_indices = expert_indices.unsqueeze(2)
+
+ expert_mask = torch.nn.functional.one_hot(expert_indices, num_experts)
+
+ # For a given token, determine if it was routed to a given expert.
+ expert_mask = torch.max(expert_mask, axis=-2).values
+
+ # cast to float32 otherwise mean will fail
+ expert_mask = expert_mask.to(torch.float32)
+ tokens_per_group_and_expert = torch.mean(expert_mask, axis=-2)
+
+ router_prob_per_group_and_expert = torch.mean(router_probs, axis=-2)
+ return torch.mean(tokens_per_group_and_expert * router_prob_per_group_and_expert) * (num_experts**2)
+
+
+class GPTSanJapaneseDenseActDense(nn.Module):
+ """
+ FFN Layer for Switch Transformer and Extra layers
+
+ GPTSAN can mix Switch Transformer layers and normal Transformer layers This class is used as Expert in Switch
+ Transformer layers and as FFN in regular Transformer layers. RELU is used in the Switch Transformer layer, and
+ Swish is used in the normal Transformer layer, so there is a choice of which is used in the argument.
+
+ """
+
+ def __init__(self, config: GPTSanJapaneseConfig, ext_layer=False):
+ super().__init__()
+ d_inter = config.d_ext if ext_layer else config.d_ff
+ self.wi = nn.Linear(config.d_model, d_inter, bias=ext_layer)
+ self.wo = nn.Linear(d_inter, config.d_model, bias=ext_layer)
+ self.dropout = nn.Identity() if ext_layer else nn.Dropout(config.dropout_rate)
+ self.act = ACT2FN["swish" if ext_layer else "relu"]
+
+ def forward(self, hidden_states):
+ r"""
+ Args:
+ hidden_states (`torch.Tensor`) :
+ [num_groups, tokens_per_group, hidden_dim] inputs to send to experts.
+ Returns:
+ torch.Tensor[num_groups, tokens_per_group, hidden_dim]
+
+ """
+ hidden_states = self.wi(hidden_states)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.wo(hidden_states)
+ return hidden_states
+
+
+class GPTSanJapaneseTop1Router(nn.Module):
+ """
+ Router using tokens choose top-1 experts assignment.
+
+ This router uses the same mechanism as in Switch Transformer (https://arxiv.org/abs/2101.03961) and V-MoE
+ (https://arxiv.org/abs/2106.05974): tokens choose their top experts. Items are sorted by router_probs and then
+ routed to their choice of expert until the expert's expert_capacity is reached. **There is no guarantee that each
+ token is processed by an expert**, or that each expert receives at least one token.
+
+ """
+
+ def __init__(self, config: GPTSanJapaneseConfig):
+ super().__init__()
+ self.num_experts = config.num_experts
+ self.expert_capacity = config.expert_capacity
+ self.classifier = nn.Linear(config.hidden_size, self.num_experts, bias=config.router_bias)
+ self.jitter_noise = config.router_jitter_noise
+ self.ignore_padding_tokens = config.router_ignore_padding_tokens
+ self.dtype = getattr(torch, config.router_dtype)
+
+ def _compute_router_probabilities(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ r"""
+ Computes router probabilities from input hidden states.
+
+ Args:
+ hidden_states (`torch.Tensor`):
+ (batch_size, sequence_length, hidden_dim) from which router probabilities are computed.
+ Returns:
+ router_probabilities (`torch.Tensor`):
+ Tensor of shape (batch_size, sequence_length, num_experts) corresponding to the probabilities for each
+ token and expert. Used for routing tokens to experts.
+ router_logits (`torch.Tensor`):
+ Logits tensor of shape (batch_size, sequence_length, num_experts) corresponding to raw router logits.
+ This is used later for computing router z-loss.
+ """
+ # float32 is used to ensure stability. See the discussion of "selective precision" in
+ # https://arxiv.org/abs/2101.03961.
+ # We also store the previous dtype to cast back the output to the previous dtype
+ self.input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(self.dtype)
+
+ if self.training and self.jitter_noise > 0:
+ # Multiply the token inputs by the uniform distribution - adding some noise
+ hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
+
+ # Shape: [num_groups, tokens_per_group, num_experts]
+ self._cast_classifier()
+ router_logits = self.classifier(hidden_states)
+
+ # Apply Softmax and cast back to the original `dtype`
+ router_probabilities = nn.functional.softmax(router_logits, dim=-1, dtype=self.dtype).to(self.input_dtype)
+ return router_probabilities, router_logits
+
+ def _cast_classifier(self):
+ r"""
+ `bitsandbytes` `Linear8bitLt` layers does not support manual casting Therefore we need to check if they are an
+ instance of the `Linear8bitLt` class by checking special attributes.
+ """
+ if not (hasattr(self.classifier, "SCB") or hasattr(self.classifier, "CB")):
+ self.classifier = self.classifier.to(self.dtype)
+
+ def forward(self, hidden_states: torch.Tensor) -> Tuple:
+ r"""
+ Generic forward function for every Router class. Each Router expects to have the same input hidden states
+ (`hidden_states`) corresponding to the hidden states for each token, the `expert_capacity` corresponding to the
+ number of tokens the Router will send to each expert, some Routers can send up to few tokens to each expert.
+
+ Each Router works as the following: it expects the hidden states for each token, gets the `router_probs` and
+ `router_logits` from the `router_weights`. This will assign for each token, the raw probability to be assigned
+ to an expert. Then each Router class will have to define its own `_compute_routing_instructions`.
+
+ Args:
+ hidden_states (`torch.Tensor`) :
+ [num_groups, tokens_per_group, hidden_dim] inputs to send to experts.
+ Returns:
+ Tuple[`torch.Tensor`, `torch.Tensor`, `torch.Tensor`] Tuple containing the expert index, the router probs
+ and the router logits. The router probabilities and logits are required to compute the loss.
+ """
+ router_probs, router_logits = self._compute_router_probabilities(hidden_states)
+
+ expert_index = torch.argmax(router_probs, dim=-1)
+ expert_index = torch.nn.functional.one_hot(expert_index, num_classes=self.num_experts)
+
+ # Mask tokens outside expert capacity. Sum over each sequence
+ token_priority = torch.cumsum(expert_index, dim=-2)
+ # mask if the token routed to to the expert will overflow
+ expert_capacity_mask = token_priority <= self.expert_capacity
+ expert_index = expert_index * expert_capacity_mask
+
+ router_probs = torch.max(router_probs, dim=-1).values.unsqueeze(-1)
+ return expert_index, router_probs, router_logits
+
+
+class GPTSanJapaneseSparseMLP(nn.Module):
+ r"""
+ Implementation of the Switch Transformers Sparse MLP module.
+ """
+
+ def __init__(self, config: GPTSanJapaneseConfig, expert_class: nn.Module = GPTSanJapaneseDenseActDense):
+ super().__init__()
+ # Step 1: Get the correct router according to its class
+ self.router = GPTSanJapaneseTop1Router(config)
+
+ # Step 2: Get the experts
+ self.experts = nn.ModuleDict()
+ for idx in range(config.num_experts):
+ self.experts[f"expert_{idx}"] = expert_class(config)
+
+ def forward(self, hidden_states):
+ r"""
+ Hold on, this will be slightly tricky to understand In the correct order, a MoE layer does the following:
+
+ 1- Gets the `router_mask` from the router. The shape of the mask is `(batch_size, sequence_length, num_expert)`
+ and corresponds to the argmax of the `router_probs`. The probabilities are needed in the computation of the
+ hidden states : they are broadcasted to the hidden states values (can be interpreted as a scaling factor).
+
+ 2- Dispatch the tokens to its associated experts. We do a classic for loop over the experts and assign for each
+ expert the corresponding hidden states.
+
+ """
+ # Step 1: Get the router_mask from the router as wel as the probabilities
+ router_mask, router_probs, router_logits = self.router(hidden_states)
+ expert_index = torch.argmax(router_mask, dim=-1)
+
+ # The routers introduced might not always map all the tokens, to a router, which means that some hidden states
+ # can be unchanged from one layer to another. That is why the hidden states are cloned before updating only the seleced ones.
+
+ next_states = hidden_states.clone()
+ for idx, expert in enumerate(self.experts.values()):
+ token_indices = router_mask[:, :, idx].bool()
+ next_states[token_indices] = expert(hidden_states[token_indices]).to(next_states.dtype)
+
+ hidden_states = router_probs * next_states
+ return hidden_states, (router_logits, expert_index)
+
+
+class GPTSanJapaneseLayerSparseFF(nn.Module):
+ r"""
+ Switch Transformers Feed Forward layer module. This is a wrapper around the Mixture of Experts module.
+
+ Parameters:
+ config : ([`GPTSanJapaneseConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+ """
+
+ def __init__(self, config: GPTSanJapaneseConfig):
+ super().__init__()
+ self.mlp = GPTSanJapaneseSparseMLP(config)
+ self.soft_bypass_mlp = nn.Linear(config.d_model, config.d_model, bias=False)
+ self.norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
+
+ def forward(self, hidden_states, output_router_logits):
+ r"""
+ Args:
+ hidden_states (`torch.Tensor`) :
+ [num_groups, tokens_per_group, hidden_dim] inputs to send to experts.
+ output_router_logits (`bool`) :
+ output experts router output.
+ Returns:
+ torch.Tensor[num_groups, tokens_per_group, hidden_dim]
+
+ """
+ forwarded_states, router_tuple = self.mlp(hidden_states)
+ forwarded_states += torch.tanh(self.soft_bypass_mlp(hidden_states))
+ output = hidden_states + self.norm(forwarded_states)
+
+ if output_router_logits and router_tuple is not None:
+ return output, router_tuple
+ else:
+ return output
+
+
+class GPTSanJapaneseLayerDenseFF(nn.Module):
+ r"""
+ Extra Transformers Feed Forward layer module.
+
+ Parameters:
+ config : ([`GPTSanJapaneseConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+ """
+
+ def __init__(self, config: GPTSanJapaneseConfig):
+ super().__init__()
+ # Check if it is a sparse layer, if not then it is a dense layer
+ self.mlp = GPTSanJapaneseDenseActDense(config, ext_layer=True)
+ self.norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
+
+ def forward(self, hidden_states):
+ r"""
+ Args:
+ hidden_states (`torch.Tensor`) :
+ [num_groups, tokens_per_group, hidden_dim] inputs to send to experts.
+ Returns:
+ torch.Tensor[num_groups, tokens_per_group, hidden_dim]
+
+ """
+ forwarded_states = self.mlp(hidden_states)
+ output = hidden_states + self.norm(forwarded_states)
+ return output
+
+
+class GPTSanJapaneseAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ dropout: float = 0.0,
+ is_decoder: bool = False,
+ bias: bool = True,
+ is_causal: bool = False,
+ config: Optional[GPTSanJapaneseConfig] = None,
+ ):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.dropout = dropout
+ self.head_dim = embed_dim // num_heads
+ self.config = config
+
+ if (self.head_dim * num_heads) != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
+ f" and `num_heads`: {num_heads})."
+ )
+ self.scaling = self.head_dim**-0.5
+ self.is_decoder = is_decoder
+ self.is_causal = is_causal
+
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ key_value_states: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+
+ bsz, tgt_len, _ = hidden_states.size()
+
+ # get query proj
+ query_states = self.q_proj(hidden_states) * self.scaling
+ # get key, value proj
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
+ # the provided `key_value_states` to support prefix tuning
+ if (
+ is_cross_attention
+ and past_key_value is not None
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
+ ):
+ # reuse k,v, cross_attentions
+ key_states = past_key_value[0]
+ value_states = past_key_value[1]
+ elif is_cross_attention:
+ # cross_attentions
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
+ elif past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+ else:
+ # self_attention
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+
+ if self.is_decoder:
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_states, value_states)
+
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
+ key_states = key_states.reshape(*proj_shape)
+ value_states = value_states.reshape(*proj_shape)
+
+ src_len = key_states.size(1)
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
+
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
+ )
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ if layer_head_mask is not None:
+ if layer_head_mask.size() != (self.num_heads,):
+ raise ValueError(
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
+ f" {layer_head_mask.size()}"
+ )
+ attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ if output_attentions:
+ # this operation is a bit awkward, but it's required to
+ # make sure that attn_weights keeps its gradient.
+ # In order to do so, attn_weights have to be reshaped
+ # twice and have to be reused in the following
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
+ else:
+ attn_weights_reshaped = None
+
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
+
+ attn_output = torch.bmm(attn_probs, value_states)
+
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
+ attn_output = attn_output.transpose(1, 2)
+
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
+ # partitioned across GPUs when using tensor-parallelism.
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
+
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, attn_weights_reshaped, past_key_value
+
+
+class GPTSanJapaneseLayerSelfAttention(nn.Module):
+ """
+ Self Attention and Normalization Unit
+ """
+
+ def __init__(self, config, has_relative_attention_bias=False):
+ super().__init__()
+ self.self_attn = GPTSanJapaneseAttention(
+ embed_dim=config.d_model,
+ num_heads=config.num_heads,
+ is_decoder=True,
+ bias=has_relative_attention_bias,
+ )
+ self.norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
+
+ def forward(
+ self,
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = False,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
+ r"""
+ Self-attention and normalize block.
+
+ Args:
+ hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
+ if the model is configured as a decoder.
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up
+ decoding. If `past_key_values` are used, the user can optionally input only the last
+ `decoder_input_ids` (those that don't have their past key value states given to this model) of shape
+ `(batch_size, 1)` instead of all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
+ in the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ head_mask (`numpy.ndarray` of shape `({0})`, `optional):
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ Returns:
+ Tuple[torch.Tensor[num_groups, tokens_per_group, hidden_dim],...]
+ """
+ # Self Attention
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
+ # add present self-attn cache to positions 1,2 of present_key_value tuple
+ atten_out = self.self_attn(
+ hidden_states=hidden_states,
+ past_key_value=self_attn_past_key_value,
+ attention_mask=(1 - attention_mask) * torch.finfo(hidden_states.dtype).min,
+ layer_head_mask=head_mask,
+ output_attentions=output_attentions,
+ )
+ if output_attentions:
+ attn_weights = (atten_out[1],)
+ else:
+ attn_weights = ()
+
+ attention_output = atten_out[0]
+
+ hidden = hidden_states + self.norm(attention_output)
+
+ if use_cache:
+ outputs = (hidden, atten_out[2]) # hidden, present, (attentions)
+ else:
+ outputs = (hidden,) # hidden, (attentions)
+
+ return outputs + attn_weights
+
+
+class GPTSanJapaneseBlock(nn.Module):
+ """
+ Self Attention and FFN Unit
+ """
+
+ def __init__(self, config, ext_layer=False):
+ super().__init__()
+ self.self_attn = GPTSanJapaneseLayerSelfAttention(config)
+ self.feed_forward = GPTSanJapaneseLayerDenseFF(config) if ext_layer else GPTSanJapaneseLayerSparseFF(config)
+
+ def forward(
+ self,
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = False,
+ output_attentions: Optional[bool] = False,
+ output_router_tuple: Optional[bool] = False,
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
+ r"""
+ GPTSAN transformer block.
+
+ Args:
+ hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
+ if the model is configured as a decoder.
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up
+ decoding. If `past_key_values` are used, the user can optionally input only the last
+ `decoder_input_ids` (those that don't have their past key value states given to this model) of shape
+ `(batch_size, 1)` instead of all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
+ in the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ head_mask (`numpy.ndarray` of shape `({0})`, `optional):
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ output_attentions (`bool`) :
+ output attention probabirities.
+ output_router_tuple:
+ output experts router logits and expert id.
+ Returns:
+ Tuple[torch.Tensor[num_groups, tokens_per_group, hidden_dim],...]
+ """
+ atten_out = self.self_attn(
+ hidden_states=hidden_states,
+ past_key_value=past_key_value,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ )
+ attention_output = atten_out[0]
+
+ if isinstance(self.feed_forward, GPTSanJapaneseLayerSparseFF):
+ sparse_out = self.feed_forward(attention_output, output_router_tuple)
+ if output_router_tuple:
+ hidden, router_tuple = sparse_out
+ else:
+ hidden = sparse_out
+ else:
+ hidden = self.feed_forward(attention_output)
+
+ outputs = (hidden,) + atten_out[1:]
+
+ if isinstance(self.feed_forward, GPTSanJapaneseLayerSparseFF) and output_router_tuple:
+ outputs += (router_tuple,)
+
+ return outputs
+
+
+class GPTSanJapanesePreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = GPTSanJapaneseConfig
+ base_model_prefix = "gptsan_japanese"
+ supports_gradient_checkpointing = False
+ _no_split_modules = ["GPTSanJapaneseBlock"]
+ _skip_keys_device_placement = "past_key_values"
+
+ @property
+ def dummy_inputs(self):
+ input_ids = torch.tensor(DUMMY_INPUTS)
+ input_mask = torch.tensor(DUMMY_MASK)
+ dummy_inputs = {
+ "input_ids": input_ids,
+ "attention_mask": input_mask,
+ }
+ return dummy_inputs
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ factor = self.config.initializer_factor # Used for testing weights initialization
+ if isinstance(module, nn.LayerNorm):
+ module.weight.data.fill_(factor * 1.0)
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
+ if hasattr(module, "bias") and module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=factor * 1.0)
+ elif isinstance(module, GPTSanJapaneseModel):
+ # Mesh TensorFlow embeddings initialization
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
+ module.embed_tokens.weight.data.normal_(mean=0.0, std=factor * 1.0)
+ module.position_embeddings.weight.data.normal_(mean=0.0, std=factor * 1.0)
+ if hasattr(module, "extra_position_embeddings") and module.extra_position_embeddings is not None:
+ module.extra_position_embeddings.weight.data.normal_(mean=0.0, std=factor * 1.0)
+ elif isinstance(module, (GPTSanJapaneseModel, GPTSanJapaneseForConditionalGeneration)):
+ # Mesh TensorFlow embeddings initialization
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
+ module.final_logits_bias.data.normal_(mean=0.0, std=factor * 1.0)
+ if hasattr(module, "lm_head") and not self.config.tie_word_embeddings:
+ module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0)
+ elif isinstance(module, GPTSanJapaneseDenseActDense):
+ # Mesh TensorFlow FF initialization
+ # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56
+ # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89
+ module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
+ if hasattr(module.wi, "bias") and module.wi.bias is not None:
+ module.wi.bias.data.zero_()
+ module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
+ if hasattr(module.wo, "bias") and module.wo.bias is not None:
+ module.wo.bias.data.zero_()
+ elif isinstance(module, GPTSanJapaneseAttention):
+ # Multi-headed attention
+ d_model = self.config.d_model
+ key_value_proj_dim = self.config.d_model
+ n_heads = self.config.num_heads
+ module.k_proj.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5))
+ module.v_proj.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5))
+ module.q_proj.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5))
+ module.out_proj.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5))
+ elif isinstance(module, GPTSanJapaneseSparseMLP):
+ # Mesh TensorFlow attention initialization to avoid scaling before softmax
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136
+ d_model = self.config.d_model
+ key_value_proj_dim = self.config.d_model
+ n_heads = self.config.num_heads
+ module.router.classifier.weight.data.normal_(mean=0.0, std=factor * 1)
+ for idx in range(self.config.num_experts):
+ module.experts[f"expert_{idx}"].wi.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
+ module.experts[f"expert_{idx}"].wo.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
+
+ def _shift_right(self, input_ids):
+ decoder_start_token_id = self.config.decoder_start_token_id
+ pad_token_id = self.config.pad_token_id
+
+ if decoder_start_token_id is None:
+ raise ValueError(
+ "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. "
+ "See T5 docs for more information."
+ )
+
+ # shift inputs to the right
+ if is_torch_fx_proxy(input_ids):
+ # Item assignment is not supported natively for proxies.
+ shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id)
+ shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
+ else:
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
+ shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
+ shifted_input_ids[..., 0] = decoder_start_token_id
+
+ if pad_token_id is None:
+ raise ValueError("self.model.config.pad_token_id has to be defined.")
+ # replace possible -100 values in labels by `pad_token_id`
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
+
+ return shifted_input_ids
+
+
+GPTSAN_JAPANESE_START_DOCSTRING = r"""
+
+ The [GPTSAN-japanese](https://github.com/tanreinama/GPTSAN) model was proposed in General-purpose Swich transformer
+ based Japanese language model
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`GPTSanJapaneseConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+GPTSAN_JAPANESE_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. GPTSAN-japanese is a model that generates sentence
+ continuations or predicts tokens at mask positions. Special tokens required for inputs to the model are
+ automatically appended.
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ token_type_ids (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ An input that masks the Prefix part in the Prefix-LM input. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **prefix** input,
+ - 0 for tokens that are **not-prefix** input.
+ spout (`torch.Tensor` of shape `(batch_size, config.d_spout)`):
+ This vector is transformed through an 8-layer FFN and can be used instead of `past_key_values`.
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
+ representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be
+ input (see `past_key_values`). This is useful if you want more control over how to convert
+ `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`.
+ Router logits of the decoder model, useful to compute the auxiliary loss for Mixture of Experts models.
+"""
+
+
+@add_start_docstrings(
+ "The bare GPTSAN-japanese Model transformer outputting raw hidden-states without any specific head on top.",
+ GPTSAN_JAPANESE_START_DOCSTRING,
+)
+class GPTSanJapaneseModel(GPTSanJapanesePreTrainedModel):
+ def __init__(self, config: GPTSanJapaneseConfig):
+ super().__init__(config)
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.d_model)
+ self.config = copy.deepcopy(config)
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model)
+ self.last_project = nn.Linear(config.d_model, config.d_model, bias=True)
+ self.act = ACT2FN["swish"]
+
+ self.blocks = torch.nn.ModuleList([])
+ for _ in range(config.num_switch_layers):
+ self.blocks.append(GPTSanJapaneseBlock(config))
+ for _ in range(config.num_ext_layers):
+ self.blocks.append(GPTSanJapaneseBlock(config, ext_layer=True))
+
+ if config.num_ext_layers > 0:
+ self.extra_position_embeddings = nn.Embedding(config.max_position_embeddings, config.d_model)
+
+ if config.d_spout:
+ spouts = []
+ for _ in range(8):
+ spouts.append(nn.Linear(config.d_spout, config.d_spout, bias=False))
+ spouts.append(nn.Tanh())
+ spouts.append(nn.Linear(config.d_spout, config.num_layers * 2 * config.d_model, bias=False))
+ self.spout = nn.Sequential(*spouts)
+
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, new_embeddings):
+ self.embed_tokens = new_embeddings
+
+ @add_start_docstrings_to_model_forward(GPTSAN_JAPANESE_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.FloatTensor] = None,
+ spout: Optional[torch.FloatTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = False,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ output_router_logits: Optional[bool] = None,
+ num_precontext: Optional[torch.LongTensor] = None,
+ ) -> Union[MoEModelOutputWithPastAndCrossAttentions, Tuple[torch.FloatTensor]]:
+ r"""
+ num_precontext (`torch.LongTensor` of shape `(batch_size,1)`):
+ length of `hybrid` input tokens in the input. Tokens up to this length refer to both front and back like
+ BERT, tokens after that refer only to front like GPT. see also:
+ https://github.com/tanreinama/GPTSAN/blob/main/report/model.md
+
+ Returns:
+ `MoEModelOutputWithPastAndCrossAttentions` or `tuple` if `return_dict` returns
+ MoEModelOutputWithPastAndCrossAttentions insted of tuple
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ device = self.position_embeddings.weight.device
+ if input_ids is None:
+ input_ids = torch.zeros([1, 1]).int().to(device) # dummy for input_ids was None
+ if inputs_embeds is not None:
+ raise NotImplementedError(
+ "GPTSanJapaneseModel does not use `inputs_embeds`. Make sure to pass in `input_ids` instead."
+ )
+ num_pasts_contexts = 0
+ num_batch = input_ids.shape[0]
+ pasts_or_spout_value = None
+ if past_key_values is not None:
+ num_pasts_contexts = past_key_values[0][0].shape[2]
+ elif self.config.d_spout and spout is not None:
+ # `spout` is a special input vector specific to GPTSAN
+ # This controls the output by projecting embedded information such as the class of sentences during learning.
+ # It should passed instead of the first past_key_value.
+ # See the original GPTSAN repository for details
+ num_pasts_contexts += 1
+
+ # If there is an attention_mask, increase first one for spout
+ if self.config.d_spout and spout is not None and attention_mask is not None:
+ attention_mask_with_spout = torch.ones(num_batch, attention_mask.shape[1] + 1, device=device)
+ attention_mask_with_spout[:, 1:] -= 1 - attention_mask # 1st token should be spout
+ attention_mask = attention_mask_with_spout # update attention_mask
+
+ if num_precontext is not None:
+ # `num_precontext` is the number of tokens that refer to each other in prefix-lm
+ # created per batch, so dimension of num_precontext should be [batch, 1]
+ if not (
+ len(num_precontext.shape) == 2 and num_precontext.shape[1] == 1
+ ): # num_precontext Should be [batch,1]
+ raise ValueError("num_precontext should be [batch, 1] size.")
+ num_precontext = torch.reshape(num_precontext, [-1])
+ else:
+ num_precontext = torch.zeros([num_batch]).int().to(device)
+
+ num_input_contexts = input_ids.shape[1]
+ num_output_contexts = num_input_contexts + num_pasts_contexts
+
+ hidden_states = self.embed_tokens(input_ids)
+
+ if past_key_values is not None:
+ pasts_or_spout_value = past_key_values
+ elif self.config.d_spout and spout is not None:
+ # Make vector from `spout` of GPTSAN to the same shape as past_key_values
+ pasts_or_spout_value = self.spout(spout) # projecting `spout` vector
+ pasts_or_spout_value = torch.reshape(
+ pasts_or_spout_value,
+ [
+ num_batch,
+ self.config.num_layers,
+ 2,
+ self.config.num_heads,
+ num_pasts_contexts,
+ self.config.d_model // self.config.num_heads,
+ ],
+ )
+ pasts_or_spout_value = torch.split(pasts_or_spout_value, [1] * self.config.num_layers, dim=1)
+ # make same shape as past_key_values
+ pasts_or_spout_value = tuple(
+ tuple([b.squeeze(1) for b in torch.split(a.squeeze(1), [1, 1], dim=1)]) for a in pasts_or_spout_value
+ )
+ else:
+ pasts_or_spout_value = [None] * self.config.num_layers
+
+ # Token position considering spout and pasts
+ token_position = torch.arange(num_input_contexts).to(device) + num_pasts_contexts
+
+ if attention_mask is None:
+ attention_mask = torch.ones(num_batch, num_input_contexts, device=device)
+
+ # positions for get position_embeddings
+ gather_position = (
+ (
+ torch.zeros((num_batch, self.config.d_model, num_input_contexts)).to(device)
+ + token_position.unsqueeze(0)
+ )
+ .transpose(1, 2)
+ .long()
+ )
+ # When padding with padding_side="left", zeros line up on the left side of attention_mask, so position_embeddings is shifted accordingly
+ gather_position -= (1 - attention_mask).argmin(dim=-1).unsqueeze(1).unsqueeze(2)
+ gather_position = torch.clip(gather_position, num_pasts_contexts, self.config.max_position_embeddings - 1)
+
+ # attention_mask is applied per batch
+ for i in range(num_batch):
+ hidden_states[i] += torch.gather(self.position_embeddings.weight, dim=0, index=gather_position[i])
+
+ # Create a mask to be used when making the prefix Input length of Prefix-LM variable
+ causal_mask = (
+ torch.tril(torch.ones((num_output_contexts, num_output_contexts), dtype=torch.uint8))
+ .view(1, 1, num_output_contexts, num_output_contexts)
+ .to(device)
+ )
+ prefix_lm_mask = causal_mask[:, :, -num_input_contexts:, :]
+ if token_type_ids is not None:
+ token_type_ids = token_type_ids.unsqueeze(1).unsqueeze(2)
+ prefix_lm_mask = ((prefix_lm_mask + token_type_ids) > 0).float()
+ # Marge prefix_lm_mask and attention_mask
+ extended_attention_mask = prefix_lm_mask * attention_mask.unsqueeze(1).unsqueeze(2)
+
+ # Prepare head mask if needed
+ if head_mask is not None:
+ head_mask = self.get_head_mask(
+ head_mask, self.config.num_switch_layers + self.config.num_ext_layers
+ ) # n_layer x batch x n_heads x N x N
+
+ # outputs
+ present_key_value_states = () if self.config.use_cache or use_cache else None
+ all_hidden_states = () if self.config.output_hidden_states or output_hidden_states else None
+ all_attentions = () if self.config.output_attentions or output_attentions else None
+ all_router_probs = () if self.config.output_router_logits or output_router_logits else None
+
+ for layer, past in enumerate(pasts_or_spout_value):
+ if layer == self.config.num_switch_layers:
+ if self.config.num_ext_layers > 0:
+ # extra_position_embeddings are extra position embeddings that are only created when extending the model with code from the original GPTSAN repository. Not used in the default model.
+ # However, it is created when you create an additional layer and partially train only that location.
+ # Therefore, convert_gptsan_tf_checkpoint_to_pytorch.py is used when converting and loading models created in the original GPTSAN repository.
+ for i in range(num_batch):
+ hidden_states[i] += torch.gather(
+ self.extra_position_embeddings.weight, dim=0, index=gather_position[i]
+ )
+
+ output_router_tuple = (
+ self.config.output_router_logits or output_router_logits
+ ) and layer < self.config.num_switch_layers
+ block_output = self.blocks[layer](
+ hidden_states=hidden_states,
+ past_key_value=past,
+ attention_mask=extended_attention_mask,
+ head_mask=head_mask,
+ use_cache=self.config.use_cache or use_cache,
+ output_attentions=self.config.output_attentions or output_attentions,
+ output_router_tuple=output_router_tuple,
+ )
+
+ outpos = 0
+ hidden_states = block_output[outpos]
+ if self.config.output_hidden_states or output_hidden_states:
+ all_hidden_states += (hidden_states,)
+ if self.config.use_cache or use_cache:
+ outpos += 1
+ present = block_output[outpos]
+ present_key_value_states += (present,)
+ if self.config.output_attentions or output_attentions:
+ outpos += 1
+ attention_probs = block_output[outpos]
+ all_attentions += (attention_probs,)
+ if output_router_tuple:
+ outpos += 1
+ router_tuple = block_output[outpos]
+ all_router_probs.append(router_tuple[0])
+
+ hidden_states = self.last_project(hidden_states)
+ hidden_states = self.act(hidden_states)
+
+ if self.config.output_hidden_states or output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [
+ hidden_states,
+ present_key_value_states,
+ all_hidden_states,
+ all_attentions,
+ all_router_probs,
+ ]
+ if v is not None
+ )
+
+ return MoEModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=present_key_value_states,
+ hidden_states=all_hidden_states,
+ attentions=all_attentions,
+ router_probs=all_router_probs,
+ )
+
+
+@add_start_docstrings(
+ "The bare GPTSAN-japanese Model with a language modeling head.",
+ GPTSAN_JAPANESE_START_DOCSTRING,
+)
+class GPTSanJapaneseForConditionalGeneration(GPTSanJapanesePreTrainedModel):
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config: GPTSanJapaneseConfig):
+ super().__init__(config)
+ self.model = GPTSanJapaneseModel(config)
+ self.register_buffer("final_logits_bias", torch.zeros([1, config.vocab_size]))
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
+ if not self.config.torchscript:
+ self.lm_head.weight = self.model.embed_tokens.weight
+
+ @add_start_docstrings_to_model_forward(GPTSAN_JAPANESE_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.FloatTensor] = None,
+ spout: Optional[torch.FloatTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = False,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ output_router_logits: Optional[bool] = None,
+ labels: Optional[torch.LongTensor] = None,
+ ) -> Union[Tuple[torch.FloatTensor], MoECausalLMOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification loss. Indices should be in `[-100, 0, ...,
+ config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
+ labels in `[0, ..., config.vocab_size]`
+
+ Returns:
+ `MoECausalLMOutputWithPast` or `tuple` if `return_dict` returns MoECausalLMOutputWithPast insted of tuple
+
+ Example:
+
+ Text Generation with regular LM Model
+ ```python
+ >>> from transformers import AutoModel, AutoTokenizer, trainer_utils
+
+ >>> device = "cuda"
+ >>> model = AutoModel.from_pretrained("Tanrei/GPTSAN-japanese").to(device)
+ >>> tokenizer = AutoTokenizer.from_pretrained("Tanrei/GPTSAN-japanese")
+ >>> x_token = tokenizer("織田信長は、", return_tensors="pt")
+ >>> trainer_utils.set_seed(30)
+ >>> input_ids = x_token.input_ids.to(device)
+ >>> gen_token = model.generate(input_ids, max_new_tokens=50)
+ >>> tokenizer.decode(gen_token[0])
+ "織田信長は、政治・軍事の中枢まで掌握した政治家であり、日本史上類を見ない驚異的な軍事侵攻を続け..."
+ ```
+
+ Text Generation with Prefix-LM Model
+ ```python
+ >>> from transformers import AutoModel, AutoTokenizer, trainer_utils
+
+ >>> device = "cuda"
+ >>> model = AutoModel.from_pretrained("Tanrei/GPTSAN-japanese").to(device)
+ >>> tokenizer = AutoTokenizer.from_pretrained("Tanrei/GPTSAN-japanese")
+ >>> x_token = tokenizer("", prefix_text="織田信長は、", return_tensors="pt")
+ >>> trainer_utils.set_seed(30)
+ >>> input_ids = x_token.input_ids.to(device)
+ >>> token_type_ids = x_token.token_type_ids.to(device)
+ >>> gen_token = model.generate(input_ids, token_type_ids=token_type_ids, max_new_tokens=50)
+ >>> tokenizer.decode(gen_token[0])
+ "織田信長は、政治・外交で数々の戦果を上げるが、1568年からは、いわゆる本能寺の変で細川晴元に暗殺される..."
+ ```
+
+ Simultaneously Text Generation And Masked Language Model
+ ```python
+ >>> from transformers import AutoModel, AutoTokenizer, trainer_utils
+
+ >>> device = "cuda"
+ >>> model = AutoModel.from_pretrained("Tanrei/GPTSAN-japanese").to(device)
+ >>> tokenizer = AutoTokenizer.from_pretrained("Tanrei/GPTSAN-japanese")
+ >>> masked_sentence = "武田信玄は、<|inputmask|>時代ファンならぜひ押さえ<|inputmask|>きたい名将の一人。"
+ >>> x_token = tokenizer("", prefix_text=masked_sentence, return_tensors="pt")
+ >>> trainer_utils.set_seed(30)
+ >>> input_ids = x_token.input_ids.to(device)
+ >>> token_type_ids = x_token.token_type_ids.to(device)
+ >>> out_lm_token = model.generate(input_ids, token_type_ids=token_type_ids, max_new_tokens=50)
+ >>> out_mlm_token = model(input_ids, token_type_ids=token_type_ids).logits.argmax(axis=-1)
+ >>> tokenizer.decode(out_mlm_token[0])
+ "武田信玄は、戦国時代ファンならぜひ押さえておきたい名将の一人。"
+
+ >>> tokenizer.decode(out_lm_token[0][input_ids.shape[1] :])
+ "武田氏の三代に渡った武田家のひとり\n甲斐市に住む、日本史上最大の戦国大名。..."
+ ```"""
+ SEG_TOKEN = self.config.separator_token_id
+ use_cache = use_cache or self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ model_return_dict = True
+ num_precontext = None
+ if input_ids is not None:
+ num_batch = input_ids.shape[0]
+ num_precontext = torch.zeros([num_batch]).int().to(input_ids.device)
+ where_separators = torch.where(input_ids == SEG_TOKEN)
+ num_precontext[where_separators[0]] += where_separators[1]
+ num_precontext = num_precontext.unsqueeze(1)
+
+ outputs = self.model(
+ input_ids,
+ attention_mask,
+ token_type_ids,
+ spout,
+ past_key_values,
+ head_mask,
+ use_cache,
+ inputs_embeds,
+ decoder_inputs_embeds,
+ output_attentions,
+ output_hidden_states,
+ model_return_dict,
+ output_router_logits,
+ num_precontext,
+ )
+
+ lm_logits = self.lm_head(outputs[0])
+ if lm_logits.shape[-1] == self.final_logits_bias.shape[-1]:
+ lm_logits = lm_logits + self.final_logits_bias
+
+ loss = None
+ z_loss = None
+ router_probs = None
+ aux_loss = None
+ if labels is not None:
+ # move labels to correct device to enable model parallelism
+ labels = labels.to(lm_logits.device)
+
+ loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
+
+ if output_router_logits:
+ # Compute the router loss (z_loss + auxiliary loss) for each router in the encoder and decoder
+ router_logits, expert_indexes = self._unpack_router_logits(outputs.router_probs)
+ z_loss = router_z_loss_func(router_logits)
+ router_probs = nn.Softmax(dim=-1)(router_logits)
+ aux_loss = load_balancing_loss_func(router_probs, expert_indexes)
+
+ loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [
+ loss,
+ lm_logits,
+ outputs.past_key_values,
+ outputs.hidden_states,
+ outputs.router_probs,
+ z_loss,
+ aux_loss,
+ ]
+ if v is not None
+ )
+
+ return MoECausalLMOutputWithPast(
+ loss=loss,
+ logits=lm_logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ router_logits=outputs.router_probs,
+ z_loss=z_loss,
+ aux_loss=aux_loss,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids: torch.LongTensor,
+ attention_mask: torch.FloatTensor,
+ token_type_ids: Optional[torch.FloatTensor] = None,
+ spout: Optional[Union[List, torch.FloatTensor]] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ **kwargs,
+ ):
+ if isinstance(spout, list):
+ spout = torch.tensor(spout).float()
+ if input_ids is not None:
+ spout = spout.to(input_ids.device)
+ if past_key_values is not None:
+ return {
+ "input_ids": input_ids[:, -1:] if input_ids is not None else None,
+ "attention_mask": attention_mask,
+ "token_type_ids": token_type_ids[:, -1:] if token_type_ids is not None else None,
+ "spout": spout,
+ "past_key_values": past_key_values,
+ }
+ return {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ "token_type_ids": token_type_ids,
+ "spout": spout,
+ "past_key_values": None,
+ }
+
+ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
+ return self._shift_right(labels)
+
+ def resize_token_embeddings(
+ self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True
+ ) -> nn.Embedding:
+ new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
+ self._resize_final_logits_bias(new_embeddings.weight.shape[0])
+ return new_embeddings
+
+ def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
+ old_num_tokens = self.final_logits_bias.shape[-1]
+ if new_num_tokens <= old_num_tokens:
+ new_bias = self.final_logits_bias[:, :new_num_tokens]
+ else:
+ extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
+ new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
+ self.register_buffer("final_logits_bias", new_bias)
+
+ def get_input_embeddings(self):
+ return self.model.get_input_embeddings()
+
+ def set_input_embeddings(self, new_embeddings):
+ self.model.set_input_embeddings(new_embeddings)
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def _unpack_router_logits(self, router_outputs):
+ total_router_logits = []
+ total_expert_indexes = []
+ for router_output in router_outputs:
+ if len(router_output[0].shape) > 1:
+ router_logits, expert_indexes = router_output
+ total_router_logits.append(router_logits)
+ total_expert_indexes.append(expert_indexes)
+ return torch.cat(total_router_logits, dim=1), torch.cat(total_expert_indexes, dim=1)
+
+
+__all__ = ["GPTSanJapaneseForConditionalGeneration", "GPTSanJapaneseModel", "GPTSanJapanesePreTrainedModel"]
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/gptsan_japanese/tokenization_gptsan_japanese.py b/docs/transformers/build/lib/transformers/models/deprecated/gptsan_japanese/tokenization_gptsan_japanese.py
new file mode 100644
index 0000000000000000000000000000000000000000..c93ea87278d767fb12b48554ffd6b7f611f7e034
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/gptsan_japanese/tokenization_gptsan_japanese.py
@@ -0,0 +1,518 @@
+# coding=utf-8
+# Copyright 2023 HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization classes for GPTSANJapanese."""
+
+import collections
+import json
+import os
+import re
+import sys
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+
+from ....tokenization_utils import PreTrainedTokenizer
+from ....tokenization_utils_base import (
+ BatchEncoding,
+ PreTokenizedInput,
+ PreTokenizedInputPair,
+ TextInput,
+ TextInputPair,
+ TruncationStrategy,
+)
+from ....utils import PaddingStrategy, logging
+
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "emoji_file": "emoji.json"}
+
+
+def load_vocab_and_emoji(vocab_file, emoji_file):
+ """Loads a vocabulary file and emoji file into a dictionary."""
+ with open(emoji_file, "r", encoding="utf-8") as f:
+ emoji = json.loads(f.read())
+
+ vocab = collections.OrderedDict()
+ raw_vocab = collections.OrderedDict()
+ ids_to_tokens = collections.OrderedDict()
+ with open(vocab_file, "r", encoding="utf-8") as f:
+ token = f.readlines()
+ token = [[t.rstrip("\n")] if (t == ",\n" or "," not in t) else t.rstrip("\n").split(",") for t in token]
+ for idx, b in enumerate(token):
+ ids_to_tokens[idx] = b
+ raw_vocab[",".join(b)] = idx
+ for wd in b:
+ vocab[wd] = idx
+
+ return vocab, raw_vocab, ids_to_tokens, emoji
+
+
+class GPTSanJapaneseTokenizer(PreTrainedTokenizer):
+ """
+ This tokenizer is based on GPTNeoXJapaneseTokenizer and has the following modifications
+ - Decoding byte0~byte255 tokens correctly
+ - Added bagofword token handling
+ - Return token_type_ids for Prefix-LM model
+ The bagofword token represents a repetition of the previous token and is converted to 3 consecutive tokens when
+ decoding In addition, the original Japanese special Sub-Word-Encoding has been released in this repository
+ (https://github.com/tanreinama/Japanese-BPEEncoder_V2). The token_type_ids is a mask indicating the prefix input
+ position of the Prefix-LM model. To specify a prefix position, specify a prefix input for prefix_text, or specify a
+ sentence of the prefix part and the part after it as a text pair of batch input.
+
+ Example:
+
+ ```python
+ >>> from transformers import GPTSanJapaneseTokenizer
+
+ >>> tokenizer = GPTSanJapaneseTokenizer.from_pretrained("Tanrei/GPTSAN-japanese")
+ >>> # You can confirm both 慶応 and 慶應 are encoded to 17750
+ >>> tokenizer("吾輩は猫である🐯。実は慶応(慶應)大学出身")["input_ids"]
+ [35993, 35998, 34347, 31459, 30647, 31448, 25, 30659, 35729, 35676, 32417, 30647, 17750, 35589, 17750, 35590, 321, 1281]
+
+ >>> # Both 慶応 and 慶應 are decoded to 慶応
+ >>> tokenizer.decode(tokenizer("吾輩は猫である🐯。実は慶応(慶應)大学出身")["input_ids"])
+ '吾輩は猫である🐯。実は慶応(慶応)大学出身'
+ ```
+
+ Example for Prefix-LM:
+
+ ```python
+ >>> from transformers import GPTSanJapaneseTokenizer
+
+ >>> tokenizer = GPTSanJapaneseTokenizer.from_pretrained("Tanrei/GPTSAN-japanese")
+ >>> tokenizer("実は慶応(慶應)大学出身", prefix_text="吾輩は猫である🐯。")["input_ids"]
+ [35993, 34347, 31459, 30647, 31448, 25, 30659, 35729, 35676, 35998, 32417, 30647, 17750, 35589, 17750, 35590, 321, 1281]
+
+ >>> # Mask for Prefix-LM inputs
+ >>> tokenizer("実は慶応(慶應)大学出身", prefix_text="吾輩は猫である🐯。")["token_type_ids"]
+ [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]
+ ```
+
+ Example for batch encode:
+
+ ```python
+ >>> from transformers import GPTSanJapaneseTokenizer
+
+ >>> tokenizer = GPTSanJapaneseTokenizer.from_pretrained("Tanrei/GPTSAN-japanese")
+ >>> tokenizer([["武田信玄", "は、"], ["織田信長", "の配下の、"]], padding=True)["input_ids"]
+ [[35993, 35998, 8640, 25948, 35993, 35998, 30647, 35675, 35999, 35999], [35993, 35998, 10382, 9868, 35993, 35998, 30646, 9459, 30646, 35675]]
+
+ >>> # Mask for Prefix-LM inputs
+ >>> tokenizer([["武田信玄", "は、"], ["織田信長", "の配下の、"]], padding=True)["token_type_ids"]
+ [[1, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
+
+ >>> # Mask for padding
+ >>> tokenizer([["武田信玄", "は、"], ["織田信長", "の配下の、"]], padding=True)["attention_mask"]
+ [[1, 1, 1, 1, 1, 1, 1, 1, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]
+ ```
+
+ Args:
+ vocab_file (`str`):
+ File containing the vocabulary.
+ emoji_file (`str`):
+ File containing the emoji.
+ unk_token (`str`, *optional*, defaults to `"<|nottoken|>"`):
+ The token used for unknown charactor
+ pad_token (`str`, *optional*, defaults to `"<|separator|>"`):
+ The token used for padding
+ bos_token (`str`, *optional*, defaults to `"<|startoftext|>"`):
+ The beginning of sequence token.
+ eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
+ The end of sequence token.
+ sep_token (`str`, *optional*, defaults to `"<|segmenter|>"`):
+ A special token to separate token to prefix part and general input part.
+ do_clean_text (`bool`, *optional*, defaults to `False`):
+ Whether or not to clean text for URL, EMAIL, TEL, Japanese DATE and Japanese PRICE.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
+
+ def __init__(
+ self,
+ vocab_file,
+ emoji_file,
+ unk_token="<|nottoken|>",
+ pad_token="<|separator|>",
+ bos_token="<|startoftext|>",
+ eos_token="<|endoftext|>",
+ sep_token="<|segmenter|>",
+ do_clean_text=False,
+ **kwargs,
+ ):
+ if not os.path.isfile(vocab_file):
+ raise ValueError(
+ f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained"
+ " model use `tokenizer = GPTSanJapaneseTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
+ )
+ if not os.path.isfile(emoji_file):
+ raise ValueError(
+ f"Can't find a emoji file at path '{emoji_file}'. To load the emoji information from a Google"
+ " pretrained model use `tokenizer = GPTSanJapaneseTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
+ )
+ self.do_clean_text = do_clean_text
+ self.vocab, self.raw_vocab, self.ids_to_tokens, self.emoji = load_vocab_and_emoji(vocab_file, emoji_file)
+ self.subword_tokenizer = SubWordJapaneseTokenizer(
+ vocab=self.vocab, ids_to_tokens=self.ids_to_tokens, emoji=self.emoji
+ )
+
+ super().__init__(
+ unk_token=unk_token,
+ pad_token=pad_token,
+ bos_token=bos_token,
+ eos_token=eos_token,
+ sep_token=sep_token,
+ do_clean_text=do_clean_text,
+ **kwargs,
+ )
+
+ @property
+ def vocab_size(self):
+ # self.vocab contains support for character fluctuation unique to Japanese, and has a large number of vocab
+ return len(self.raw_vocab)
+
+ def get_vocab(self):
+ return dict(self.raw_vocab, **self.added_tokens_encoder)
+
+ def _tokenize(self, text):
+ return self.subword_tokenizer.tokenize(text, clean=self.do_clean_text)
+
+ def _convert_token_to_id(self, token):
+ """Converts a token (str) in an id using the vocab."""
+ return self.vocab.get(token, self.vocab.get(self.unk_token))
+
+ def _convert_id_to_token(self, index):
+ """Converts an index (integer) in a token (str) using the vocab."""
+ return self.subword_tokenizer.convert_id_to_token(index)
+
+ def convert_tokens_to_string(self, tokens):
+ """Converts a sequence of tokens (string) in a single string."""
+ words = []
+ byte_tokens = []
+ for word in tokens:
+ if word[:6] == "<|byte" and word[-2:] == "|>":
+ byte_tokens.append(int(word[6:-2]))
+ else:
+ if len(byte_tokens) > 0:
+ words.append(bytearray(byte_tokens).decode("utf-8", errors="replace"))
+ byte_tokens = []
+ if word[:7] == "<|emoji" and word[-2:] == "|>":
+ words.append(self.emoji["emoji_inv"][word])
+ elif word == "":
+ words.append(" ")
+ elif word == "
":
+ words.append("\n")
+ elif word == "":
+ words.append("\t")
+ elif word == "":
+ words.append("▀")
+ elif word == "":
+ words.append("ǀ")
+ elif word == "":
+ words.append("‖")
+ elif word == "<|bagoftoken|>":
+ if len(words) > 0:
+ words.append(words[-1])
+ words.append(words[-1])
+ words.append(words[-1])
+ elif word.startswith("<|") and word.endswith("|>"):
+ words.append("")
+ else:
+ words.append(word)
+ if len(byte_tokens) > 0:
+ words.append(bytearray(byte_tokens).decode("utf-8", errors="replace"))
+ text = "".join(words)
+ return text
+
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
+ index = 0
+ if os.path.isdir(save_directory):
+ vocab_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
+ )
+ emoji_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["emoji_file"]
+ )
+ else:
+ vocab_file = (
+ (filename_prefix + "-" if filename_prefix else "") + save_directory + VOCAB_FILES_NAMES["vocab_file"]
+ )
+ emoji_file = (
+ (filename_prefix + "-" if filename_prefix else "") + save_directory + VOCAB_FILES_NAMES["emoji_file"]
+ )
+ with open(vocab_file, "w", encoding="utf-8") as writer:
+ for token_index, token in self.ids_to_tokens.items():
+ if index != token_index:
+ logger.warning(
+ f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
+ " Please check that the vocabulary is not corrupted!"
+ )
+ index = token_index
+ writer.write(",".join(token) + "\n")
+ index += 1
+ with open(emoji_file, "w", encoding="utf-8") as writer:
+ json.dump(self.emoji, writer)
+ return vocab_file, emoji_file
+
+ def create_token_type_ids_from_sequences(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ # docstyle-ignore
+ """
+ The tokenizer returns token_type_ids as separators between the Prefix part and the rest.
+ token_type_ids is 1 for the Prefix part and 0 for the rest of the token.
+
+ Example:
+ ```python
+ >>> from transformers import GPTSanJapaneseTokenizer
+
+ >>> tokenizer = GPTSanJapaneseTokenizer.from_pretrained("Tanrei/GPTSAN-japanese")
+ >>> x_token = tokenizer("アイウエ")
+ >>> # input_ids: | SOT | SEG | ア | イ | ウ | エ |
+ >>> # token_type_ids: | 1 | 0 | 0 | 0 | 0 | 0 |
+
+ >>> x_token = tokenizer("", prefix_text="アイウエ")
+ >>> # input_ids: | SOT | ア | イ | ウ | エ | SEG |
+ >>> # token_type_ids: | 1 | 1 | 1 | 1 | 1 | 0 |
+
+ >>> x_token = tokenizer("ウエ", prefix_text="アイ")
+ >>> # input_ids: | SOT | ア | イ | SEG | ウ | エ |
+ >>> # token_type_ids: | 1 | 1 | 1 | 0 | 0 | 0 |
+ ```"""
+ prefix_len = 0
+ if self.sep_token in self.vocab:
+ segid = self.vocab[self.sep_token]
+ if segid in token_ids_0:
+ prefix_len = token_ids_0.index(segid)
+ if token_ids_1 is None:
+ total_len = len(token_ids_0)
+ else:
+ total_len = len(token_ids_0 + token_ids_1)
+ return prefix_len * [1] + (total_len - prefix_len) * [0]
+
+ def prepare_for_tokenization(self, text, prefix_text=None, add_sep_token=None, **kwargs):
+ # GPTSAN inserts extra SEP tokens in Prefix-LM in addition to SOT for text generation.
+ # SOT at the beginning of the text, and SEP at the separator between the Prefix part and the rest.
+ if add_sep_token is None:
+ add_sep_token = self.sep_token not in text # If insert un-prefix position explicitly
+ prepared = self.bos_token if self.bos_token in self.vocab else ""
+ prepared += prefix_text if prefix_text is not None else ""
+ if add_sep_token:
+ prepared += self.sep_token if self.sep_token in self.vocab else ""
+ prepared += text
+ return (prepared, kwargs)
+
+ def _batch_encode_plus(
+ self,
+ batch_text_or_text_pairs: Union[
+ List[TextInput], List[TextInputPair], List[PreTokenizedInput], List[PreTokenizedInputPair]
+ ],
+ add_special_tokens: bool = True,
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+ truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ is_split_into_words: bool = False,
+ pad_to_multiple_of: Optional[int] = None,
+ return_tensors: Optional[str] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ **kwargs,
+ ) -> BatchEncoding:
+ # This tokenizer converts input text pairs into Prefix input and subsequent input
+ if isinstance(batch_text_or_text_pairs[0], tuple) or isinstance(tuple(batch_text_or_text_pairs[0]), list):
+ # As a single text with an explicit un-prefix position
+ batch_prefix_texts = []
+ for pref, txt in batch_text_or_text_pairs:
+ batch_prefix_texts.append(pref + self.sep_token + txt)
+ batch_text_or_text_pairs = batch_prefix_texts
+
+ return super()._batch_encode_plus(
+ batch_text_or_text_pairs,
+ add_special_tokens,
+ padding_strategy,
+ truncation_strategy,
+ max_length,
+ stride,
+ is_split_into_words,
+ pad_to_multiple_of,
+ return_tensors,
+ return_token_type_ids,
+ return_attention_mask,
+ return_overflowing_tokens,
+ return_special_tokens_mask,
+ return_offsets_mapping,
+ return_length,
+ verbose,
+ **kwargs,
+ )
+
+
+class SubWordJapaneseTokenizer:
+ """
+ This tokenizer is based on GPTNeoXJapaneseTokenizer and has the following modifications
+ - Decoding byte0~byte255 tokens correctly
+ - Added bagofword token handling
+
+ https://github.com/tanreinama/Japanese-BPEEncoder_V2 This tokenizer class is under MIT Lisence according to the
+ original repository.
+
+ MIT License
+
+ Copyright (c) 2020 tanreinama
+
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
+ documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
+ rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to
+ permit persons to whom the Software is furnished to do so, subject to the following conditions:
+
+ The above copyright notice and this permission notice shall be included in all copies or substantial portions of
+ the Software.
+
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
+ THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
+ TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ SOFTWARE.
+ """
+
+ def __init__(self, vocab, ids_to_tokens, emoji):
+ self.vocab = vocab # same as swe
+ self.ids_to_tokens = ids_to_tokens # same as bpe
+ self.emoji = emoji
+ self.maxlen = np.max([len(w) for w in self.vocab.keys()])
+ self.content_repatter1 = re.compile(r"(https?|ftp)(:\/\/[-_\.!~*\'()a-zA-Z0-9;\/?:\@&=\+$,%#]+)")
+ self.content_repatter2 = re.compile(r"[A-Za-z0-9\._+]*@[\-_0-9A-Za-z]+(\.[A-Za-z]+)*")
+ self.content_repatter3 = re.compile(r"[\(]{0,1}[0-9]{2,4}[\)\-\(]{0,1}[0-9]{2,4}[\)\-]{0,1}[0-9]{3,4}")
+ self.content_repatter4 = re.compile(
+ r"([12]\d{3}[/\-年])*(0?[1-9]|1[0-2])[/\-月]((0?[1-9]|[12][0-9]|3[01])日?)*(\d{1,2}|:|\d{1,2}時|\d{1,2}分|\(日\)|\(月\)|\(火\)|\(水\)|\(木\)|\(金\)|\(土\)|㈰|㈪|㈫|㈬|㈭|㈮|㈯)*"
+ )
+ self.content_repatter5 = re.compile(
+ r"(明治|大正|昭和|平成|令和|㍾|㍽|㍼|㍻|\u32ff)\d{1,2}年(0?[1-9]|1[0-2])月(0?[1-9]|[12][0-9]|3[01])日(\d{1,2}|:|\d{1,2}時|\d{1,2}分|\(日\)|\(月\)|\(火\)|\(水\)|\(木\)|\(金\)|\(土\)|㈰|㈪|㈫|㈬|㈭|㈮|㈯)*"
+ )
+ # The original version of this regex displays catastrophic backtracking behaviour. We avoid this using
+ # possessive quantifiers in Py >= 3.11. In versions below this, we avoid the vulnerability using a slightly
+ # different regex that should generally have the same behaviour in most non-pathological cases.
+ if sys.version_info >= (3, 11):
+ self.content_repatter6 = re.compile(
+ r"(?:\d,\d{3}|[\d億])*+"
+ r"(?:\d,\d{3}|[\d万])*+"
+ r"(?:\d,\d{3}|[\d千])*+"
+ r"(?:千円|万円|千万円|円|千ドル|万ドル|千万ドル|ドル|千ユーロ|万ユーロ|千万ユーロ|ユーロ)+"
+ r"(?:\(税込\)|\(税抜\)|\+tax)*"
+ )
+ else:
+ self.content_repatter6 = re.compile(
+ r"(?:\d,\d{3}|[\d億万千])*"
+ r"(?:千円|万円|千万円|円|千ドル|万ドル|千万ドル|ドル|千ユーロ|万ユーロ|千万ユーロ|ユーロ)+"
+ r"(?:\(税込\)|\(税抜\)|\+tax)*"
+ )
+ keisen = "─━│┃┄┅┆┇┈┉┊┋┌┍┎┏┐┑┒┓└┕┖┗┘┙┚┛├┝┞┟┠┡┢┣┤┥┦┧┨┩┪┫┬┭┮┯┰┱┲┳┴┵┶┷┸┹┺┻┼┽┾┿╀╁╂╃╄╅╆╇╈╉╊╋╌╍╎╏═║╒╓╔╕╖╗╘╙╚╛╜╝╞╟╠╡╢╣╤╥╦╧╨╩╪╫╬╭╮╯╰╱╲╳╴╵╶╷╸╹╺╻╼╽╾╿"
+ blocks = "▀▁▂▃▄▅▆▇█▉▊▋▌▍▎▏▐░▒▓▔▕▖▗▘▙▚▛▜▝▞▟"
+ self.content_trans1 = str.maketrans(dict.fromkeys(keisen + blocks, ""))
+
+ def __len__(self):
+ return len(self.ids_to_tokens)
+
+ def clean_text(self, content):
+ content = self.content_repatter1.sub("", content)
+ content = self.content_repatter2.sub("", content)
+ content = self.content_repatter3.sub("", content)
+ content = self.content_repatter4.sub("", content)
+ content = self.content_repatter5.sub("", content)
+ content = self.content_repatter6.sub("", content)
+ content = content.translate(self.content_trans1)
+ while "" in content:
+ content = content.replace("", "")
+ return content
+
+ def tokenize(self, text, clean=False):
+ text = text.replace(" ", "")
+ text = text.replace(" ", "")
+ text = text.replace("\r\n", "
")
+ text = text.replace("\n", "
")
+ text = text.replace("\r", "
")
+ text = text.replace("\t", "")
+ text = text.replace("—", "ー")
+ text = text.replace("−", "ー")
+ for k, v in self.emoji["emoji"].items():
+ if k in text:
+ text = text.replace(k, v)
+ if clean:
+ text = self.clean_text(text)
+
+ def check_simbol(x):
+ e = x.encode()
+ if len(x) == 1 and len(e) == 2:
+ c = (int(e[0]) << 8) + int(e[1])
+ if (
+ (c >= 0xC2A1 and c <= 0xC2BF)
+ or (c >= 0xC780 and c <= 0xC783)
+ or (c >= 0xCAB9 and c <= 0xCBBF)
+ or (c >= 0xCC80 and c <= 0xCDA2)
+ ):
+ return True
+ return False
+
+ def checku2e(x):
+ e = x.encode()
+ if len(x) == 1 and len(e) == 3:
+ c = (int(e[0]) << 16) + (int(e[1]) << 8) + int(e[2])
+ if c >= 0xE28080 and c <= 0xE2B07F:
+ return True
+ return False
+
+ pos = 0
+ result = []
+ while pos < len(text):
+ end = min(len(text), pos + self.maxlen + 1) if text[pos] == "<" else pos + 3
+ candidates = [] # (token_id, token, pos)
+ for e in range(end, pos, -1):
+ wd = text[pos:e]
+ if wd in self.vocab:
+ if wd[0] == "<" and len(wd) > 2:
+ candidates = [(self.vocab[wd], wd, e)]
+ break
+ else:
+ candidates.append((self.vocab[wd], wd, e))
+ if len(candidates) > 0:
+ # the smallest token_id is adopted
+ _, wd, e = sorted(candidates, key=lambda x: x[0])[0]
+ result.append(wd)
+ pos = e
+ else:
+ end = pos + 1
+ wd = text[pos:end]
+ if check_simbol(wd):
+ result.append("")
+ elif checku2e(wd):
+ result.append("")
+ else:
+ for i in wd.encode("utf-8"):
+ result.append("<|byte%d|>" % i)
+ pos = end
+ return result
+
+ def convert_id_to_token(self, index):
+ return self.ids_to_tokens[index][0]
+
+
+__all__ = ["GPTSanJapaneseTokenizer"]
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/graphormer/__init__.py b/docs/transformers/build/lib/transformers/models/deprecated/graphormer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a4b3eb1be2b4e69dcad1540abdd91f412de0ca2
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/graphormer/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2020 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ....utils import _LazyModule
+from ....utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_graphormer import *
+ from .modeling_graphormer import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/graphormer/algos_graphormer.pyx b/docs/transformers/build/lib/transformers/models/deprecated/graphormer/algos_graphormer.pyx
new file mode 100644
index 0000000000000000000000000000000000000000..a0fafbdee53b55efb9596036817b03be0d006992
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/graphormer/algos_graphormer.pyx
@@ -0,0 +1,107 @@
+# Copyright (c) Microsoft Corporation and HuggingFace
+# Licensed under the MIT License.
+
+import cython
+
+cimport numpy
+from cython.parallel cimport parallel, prange
+
+import numpy as np
+
+
+# Reduce this number if matrices are too big for large graphs
+UNREACHABLE_NODE_DISTANCE = 510
+
+def floyd_warshall(adjacency_matrix):
+ """
+ Applies the Floyd-Warshall algorithm to the adjacency matrix, to compute the
+ shortest paths distance between all nodes, up to UNREACHABLE_NODE_DISTANCE.
+ """
+ (nrows, ncols) = adjacency_matrix.shape
+ assert nrows == ncols
+ cdef unsigned int n = nrows
+
+ adj_mat_copy = adjacency_matrix.astype(np.int32, order='C', casting='safe', copy=True)
+ assert adj_mat_copy.flags['C_CONTIGUOUS']
+ cdef numpy.ndarray[numpy.int32_t, ndim=2, mode='c'] M = adj_mat_copy
+ cdef numpy.ndarray[numpy.int32_t, ndim=2, mode='c'] path = -1 * np.ones([n, n], dtype=np.int32)
+
+ cdef unsigned int i, j, k
+ cdef numpy.int32_t M_ij, M_ik, cost_ikkj
+ cdef numpy.int32_t* M_ptr = &M[0,0]
+ cdef numpy.int32_t* M_i_ptr
+ cdef numpy.int32_t* M_k_ptr
+
+ # set unreachable nodes distance to UNREACHABLE_NODE_DISTANCE
+ for i in range(n):
+ for j in range(n):
+ if i == j:
+ M[i][j] = 0
+ elif M[i][j] == 0:
+ M[i][j] = UNREACHABLE_NODE_DISTANCE
+
+ # floyed algo
+ for k in range(n):
+ M_k_ptr = M_ptr + n*k
+ for i in range(n):
+ M_i_ptr = M_ptr + n*i
+ M_ik = M_i_ptr[k]
+ for j in range(n):
+ cost_ikkj = M_ik + M_k_ptr[j]
+ M_ij = M_i_ptr[j]
+ if M_ij > cost_ikkj:
+ M_i_ptr[j] = cost_ikkj
+ path[i][j] = k
+
+ # set unreachable path to UNREACHABLE_NODE_DISTANCE
+ for i in range(n):
+ for j in range(n):
+ if M[i][j] >= UNREACHABLE_NODE_DISTANCE:
+ path[i][j] = UNREACHABLE_NODE_DISTANCE
+ M[i][j] = UNREACHABLE_NODE_DISTANCE
+
+ return M, path
+
+
+def get_all_edges(path, i, j):
+ """
+ Recursive function to compute all possible paths between two nodes from the graph adjacency matrix.
+ """
+ cdef int k = path[i][j]
+ if k == -1:
+ return []
+ else:
+ return get_all_edges(path, i, k) + [k] + get_all_edges(path, k, j)
+
+
+def gen_edge_input(max_dist, path, edge_feat):
+ """
+ Generates the full edge feature and adjacency matrix.
+ Shape: num_nodes * num_nodes * max_distance_between_nodes * num_edge_features
+ Dim 1 is the input node, dim 2 the output node of the edge, dim 3 the depth of the edge, dim 4 the feature
+ """
+ (nrows, ncols) = path.shape
+ assert nrows == ncols
+ cdef unsigned int n = nrows
+ cdef unsigned int max_dist_copy = max_dist
+
+ path_copy = path.astype(long, order='C', casting='safe', copy=True)
+ edge_feat_copy = edge_feat.astype(long, order='C', casting='safe', copy=True)
+ assert path_copy.flags['C_CONTIGUOUS']
+ assert edge_feat_copy.flags['C_CONTIGUOUS']
+
+ cdef numpy.ndarray[numpy.int32_t, ndim=4, mode='c'] edge_fea_all = -1 * np.ones([n, n, max_dist_copy, edge_feat.shape[-1]], dtype=np.int32)
+ cdef unsigned int i, j, k, num_path, cur
+
+ for i in range(n):
+ for j in range(n):
+ if i == j:
+ continue
+ if path_copy[i][j] == UNREACHABLE_NODE_DISTANCE:
+ continue
+ path = [i] + get_all_edges(path_copy, i, j) + [j]
+ num_path = len(path) - 1
+ for k in range(num_path):
+ edge_fea_all[i, j, k, :] = edge_feat_copy[path[k], path[k+1], :]
+
+ return edge_fea_all
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/graphormer/collating_graphormer.py b/docs/transformers/build/lib/transformers/models/deprecated/graphormer/collating_graphormer.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c2342913d63ffa120118574be4b1bd30af09157
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/graphormer/collating_graphormer.py
@@ -0,0 +1,134 @@
+# Copyright (c) Microsoft Corporation and HuggingFace
+# Licensed under the MIT License.
+
+from typing import Any, Dict, List, Mapping
+
+import numpy as np
+import torch
+
+from ....utils import is_cython_available, requires_backends
+
+
+if is_cython_available():
+ import pyximport
+
+ pyximport.install(setup_args={"include_dirs": np.get_include()})
+ from . import algos_graphormer # noqa E402
+
+
+def convert_to_single_emb(x, offset: int = 512):
+ feature_num = x.shape[1] if len(x.shape) > 1 else 1
+ feature_offset = 1 + np.arange(0, feature_num * offset, offset, dtype=np.int64)
+ x = x + feature_offset
+ return x
+
+
+def preprocess_item(item, keep_features=True):
+ requires_backends(preprocess_item, ["cython"])
+
+ if keep_features and "edge_attr" in item.keys(): # edge_attr
+ edge_attr = np.asarray(item["edge_attr"], dtype=np.int64)
+ else:
+ edge_attr = np.ones((len(item["edge_index"][0]), 1), dtype=np.int64) # same embedding for all
+
+ if keep_features and "node_feat" in item.keys(): # input_nodes
+ node_feature = np.asarray(item["node_feat"], dtype=np.int64)
+ else:
+ node_feature = np.ones((item["num_nodes"], 1), dtype=np.int64) # same embedding for all
+
+ edge_index = np.asarray(item["edge_index"], dtype=np.int64)
+
+ input_nodes = convert_to_single_emb(node_feature) + 1
+ num_nodes = item["num_nodes"]
+
+ if len(edge_attr.shape) == 1:
+ edge_attr = edge_attr[:, None]
+ attn_edge_type = np.zeros([num_nodes, num_nodes, edge_attr.shape[-1]], dtype=np.int64)
+ attn_edge_type[edge_index[0], edge_index[1]] = convert_to_single_emb(edge_attr) + 1
+
+ # node adj matrix [num_nodes, num_nodes] bool
+ adj = np.zeros([num_nodes, num_nodes], dtype=bool)
+ adj[edge_index[0], edge_index[1]] = True
+
+ shortest_path_result, path = algos_graphormer.floyd_warshall(adj)
+ max_dist = np.amax(shortest_path_result)
+
+ input_edges = algos_graphormer.gen_edge_input(max_dist, path, attn_edge_type)
+ attn_bias = np.zeros([num_nodes + 1, num_nodes + 1], dtype=np.single) # with graph token
+
+ # combine
+ item["input_nodes"] = input_nodes + 1 # we shift all indices by one for padding
+ item["attn_bias"] = attn_bias
+ item["attn_edge_type"] = attn_edge_type
+ item["spatial_pos"] = shortest_path_result.astype(np.int64) + 1 # we shift all indices by one for padding
+ item["in_degree"] = np.sum(adj, axis=1).reshape(-1) + 1 # we shift all indices by one for padding
+ item["out_degree"] = item["in_degree"] # for undirected graph
+ item["input_edges"] = input_edges + 1 # we shift all indices by one for padding
+ if "labels" not in item:
+ item["labels"] = item["y"]
+
+ return item
+
+
+class GraphormerDataCollator:
+ def __init__(self, spatial_pos_max=20, on_the_fly_processing=False):
+ if not is_cython_available():
+ raise ImportError("Graphormer preprocessing needs Cython (pyximport)")
+
+ self.spatial_pos_max = spatial_pos_max
+ self.on_the_fly_processing = on_the_fly_processing
+
+ def __call__(self, features: List[dict]) -> Dict[str, Any]:
+ if self.on_the_fly_processing:
+ features = [preprocess_item(i) for i in features]
+
+ if not isinstance(features[0], Mapping):
+ features = [vars(f) for f in features]
+ batch = {}
+
+ max_node_num = max(len(i["input_nodes"]) for i in features)
+ node_feat_size = len(features[0]["input_nodes"][0])
+ edge_feat_size = len(features[0]["attn_edge_type"][0][0])
+ max_dist = max(len(i["input_edges"][0][0]) for i in features)
+ edge_input_size = len(features[0]["input_edges"][0][0][0])
+ batch_size = len(features)
+
+ batch["attn_bias"] = torch.zeros(batch_size, max_node_num + 1, max_node_num + 1, dtype=torch.float)
+ batch["attn_edge_type"] = torch.zeros(batch_size, max_node_num, max_node_num, edge_feat_size, dtype=torch.long)
+ batch["spatial_pos"] = torch.zeros(batch_size, max_node_num, max_node_num, dtype=torch.long)
+ batch["in_degree"] = torch.zeros(batch_size, max_node_num, dtype=torch.long)
+ batch["input_nodes"] = torch.zeros(batch_size, max_node_num, node_feat_size, dtype=torch.long)
+ batch["input_edges"] = torch.zeros(
+ batch_size, max_node_num, max_node_num, max_dist, edge_input_size, dtype=torch.long
+ )
+
+ for ix, f in enumerate(features):
+ for k in ["attn_bias", "attn_edge_type", "spatial_pos", "in_degree", "input_nodes", "input_edges"]:
+ f[k] = torch.tensor(f[k])
+
+ if len(f["attn_bias"][1:, 1:][f["spatial_pos"] >= self.spatial_pos_max]) > 0:
+ f["attn_bias"][1:, 1:][f["spatial_pos"] >= self.spatial_pos_max] = float("-inf")
+
+ batch["attn_bias"][ix, : f["attn_bias"].shape[0], : f["attn_bias"].shape[1]] = f["attn_bias"]
+ batch["attn_edge_type"][ix, : f["attn_edge_type"].shape[0], : f["attn_edge_type"].shape[1], :] = f[
+ "attn_edge_type"
+ ]
+ batch["spatial_pos"][ix, : f["spatial_pos"].shape[0], : f["spatial_pos"].shape[1]] = f["spatial_pos"]
+ batch["in_degree"][ix, : f["in_degree"].shape[0]] = f["in_degree"]
+ batch["input_nodes"][ix, : f["input_nodes"].shape[0], :] = f["input_nodes"]
+ batch["input_edges"][
+ ix, : f["input_edges"].shape[0], : f["input_edges"].shape[1], : f["input_edges"].shape[2], :
+ ] = f["input_edges"]
+
+ batch["out_degree"] = batch["in_degree"]
+
+ sample = features[0]["labels"]
+ if len(sample) == 1: # one task
+ if isinstance(sample[0], float): # regression
+ batch["labels"] = torch.from_numpy(np.concatenate([i["labels"] for i in features]))
+ else: # binary classification
+ batch["labels"] = torch.from_numpy(np.concatenate([i["labels"] for i in features]))
+ else: # multi task classification, left to float to keep the NaNs
+ batch["labels"] = torch.from_numpy(np.stack([i["labels"] for i in features], axis=0))
+
+ return batch
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/graphormer/configuration_graphormer.py b/docs/transformers/build/lib/transformers/models/deprecated/graphormer/configuration_graphormer.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ecde152e4e522ab7ca5c7d1c5d2bc5b4fdda029
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/graphormer/configuration_graphormer.py
@@ -0,0 +1,220 @@
+# coding=utf-8
+# Copyright 2022 Microsoft, clefourrier and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Graphormer model configuration"""
+
+from typing import Optional
+
+from ....configuration_utils import PretrainedConfig
+from ....utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class GraphormerConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`~GraphormerModel`]. It is used to instantiate an
+ Graphormer model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the Graphormer
+ [graphormer-base-pcqm4mv1](https://huggingface.co/graphormer-base-pcqm4mv1) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ num_classes (`int`, *optional*, defaults to 1):
+ Number of target classes or labels, set to n for binary classification of n tasks.
+ num_atoms (`int`, *optional*, defaults to 512*9):
+ Number of node types in the graphs.
+ num_edges (`int`, *optional*, defaults to 512*3):
+ Number of edges types in the graph.
+ num_in_degree (`int`, *optional*, defaults to 512):
+ Number of in degrees types in the input graphs.
+ num_out_degree (`int`, *optional*, defaults to 512):
+ Number of out degrees types in the input graphs.
+ num_edge_dis (`int`, *optional*, defaults to 128):
+ Number of edge dis in the input graphs.
+ multi_hop_max_dist (`int`, *optional*, defaults to 20):
+ Maximum distance of multi hop edges between two nodes.
+ spatial_pos_max (`int`, *optional*, defaults to 1024):
+ Maximum distance between nodes in the graph attention bias matrices, used during preprocessing and
+ collation.
+ edge_type (`str`, *optional*, defaults to multihop):
+ Type of edge relation chosen.
+ max_nodes (`int`, *optional*, defaults to 512):
+ Maximum number of nodes which can be parsed for the input graphs.
+ share_input_output_embed (`bool`, *optional*, defaults to `False`):
+ Shares the embedding layer between encoder and decoder - careful, True is not implemented.
+ num_layers (`int`, *optional*, defaults to 12):
+ Number of layers.
+ embedding_dim (`int`, *optional*, defaults to 768):
+ Dimension of the embedding layer in encoder.
+ ffn_embedding_dim (`int`, *optional*, defaults to 768):
+ Dimension of the "intermediate" (often named feed-forward) layer in encoder.
+ num_attention_heads (`int`, *optional*, defaults to 32):
+ Number of attention heads in the encoder.
+ self_attention (`bool`, *optional*, defaults to `True`):
+ Model is self attentive (False not implemented).
+ activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
+ dropout (`float`, *optional*, defaults to 0.1):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_dropout (`float`, *optional*, defaults to 0.1):
+ The dropout probability for the attention weights.
+ activation_dropout (`float`, *optional*, defaults to 0.1):
+ The dropout probability for the activation of the linear transformer layer.
+ layerdrop (`float`, *optional*, defaults to 0.0):
+ The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
+ for more details.
+ bias (`bool`, *optional*, defaults to `True`):
+ Uses bias in the attention module - unsupported at the moment.
+ embed_scale(`float`, *optional*, defaults to None):
+ Scaling factor for the node embeddings.
+ num_trans_layers_to_freeze (`int`, *optional*, defaults to 0):
+ Number of transformer layers to freeze.
+ encoder_normalize_before (`bool`, *optional*, defaults to `False`):
+ Normalize features before encoding the graph.
+ pre_layernorm (`bool`, *optional*, defaults to `False`):
+ Apply layernorm before self attention and the feed forward network. Without this, post layernorm will be
+ used.
+ apply_graphormer_init (`bool`, *optional*, defaults to `False`):
+ Apply a custom graphormer initialisation to the model before training.
+ freeze_embeddings (`bool`, *optional*, defaults to `False`):
+ Freeze the embedding layer, or train it along the model.
+ encoder_normalize_before (`bool`, *optional*, defaults to `False`):
+ Apply the layer norm before each encoder block.
+ q_noise (`float`, *optional*, defaults to 0.0):
+ Amount of quantization noise (see "Training with Quantization Noise for Extreme Model Compression"). (For
+ more detail, see fairseq's documentation on quant_noise).
+ qn_block_size (`int`, *optional*, defaults to 8):
+ Size of the blocks for subsequent quantization with iPQ (see q_noise).
+ kdim (`int`, *optional*, defaults to None):
+ Dimension of the key in the attention, if different from the other values.
+ vdim (`int`, *optional*, defaults to None):
+ Dimension of the value in the attention, if different from the other values.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models).
+ traceable (`bool`, *optional*, defaults to `False`):
+ Changes return value of the encoder's inner_state to stacked tensors.
+
+ Example:
+ ```python
+ >>> from transformers import GraphormerForGraphClassification, GraphormerConfig
+
+ >>> # Initializing a Graphormer graphormer-base-pcqm4mv2 style configuration
+ >>> configuration = GraphormerConfig()
+
+ >>> # Initializing a model from the graphormer-base-pcqm4mv1 style configuration
+ >>> model = GraphormerForGraphClassification(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```
+ """
+
+ model_type = "graphormer"
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ num_classes: int = 1,
+ num_atoms: int = 512 * 9,
+ num_edges: int = 512 * 3,
+ num_in_degree: int = 512,
+ num_out_degree: int = 512,
+ num_spatial: int = 512,
+ num_edge_dis: int = 128,
+ multi_hop_max_dist: int = 5, # sometimes is 20
+ spatial_pos_max: int = 1024,
+ edge_type: str = "multi_hop",
+ max_nodes: int = 512,
+ share_input_output_embed: bool = False,
+ num_hidden_layers: int = 12,
+ embedding_dim: int = 768,
+ ffn_embedding_dim: int = 768,
+ num_attention_heads: int = 32,
+ dropout: float = 0.1,
+ attention_dropout: float = 0.1,
+ activation_dropout: float = 0.1,
+ layerdrop: float = 0.0,
+ encoder_normalize_before: bool = False,
+ pre_layernorm: bool = False,
+ apply_graphormer_init: bool = False,
+ activation_fn: str = "gelu",
+ embed_scale: Optional[float] = None,
+ freeze_embeddings: bool = False,
+ num_trans_layers_to_freeze: int = 0,
+ traceable: bool = False,
+ q_noise: float = 0.0,
+ qn_block_size: int = 8,
+ kdim: Optional[int] = None,
+ vdim: Optional[int] = None,
+ bias: bool = True,
+ self_attention: bool = True,
+ pad_token_id=0,
+ bos_token_id=1,
+ eos_token_id=2,
+ **kwargs,
+ ):
+ self.num_classes = num_classes
+ self.num_atoms = num_atoms
+ self.num_in_degree = num_in_degree
+ self.num_out_degree = num_out_degree
+ self.num_edges = num_edges
+ self.num_spatial = num_spatial
+ self.num_edge_dis = num_edge_dis
+ self.edge_type = edge_type
+ self.multi_hop_max_dist = multi_hop_max_dist
+ self.spatial_pos_max = spatial_pos_max
+ self.max_nodes = max_nodes
+ self.num_hidden_layers = num_hidden_layers
+ self.embedding_dim = embedding_dim
+ self.hidden_size = embedding_dim
+ self.ffn_embedding_dim = ffn_embedding_dim
+ self.num_attention_heads = num_attention_heads
+ self.dropout = dropout
+ self.attention_dropout = attention_dropout
+ self.activation_dropout = activation_dropout
+ self.layerdrop = layerdrop
+ self.encoder_normalize_before = encoder_normalize_before
+ self.pre_layernorm = pre_layernorm
+ self.apply_graphormer_init = apply_graphormer_init
+ self.activation_fn = activation_fn
+ self.embed_scale = embed_scale
+ self.freeze_embeddings = freeze_embeddings
+ self.num_trans_layers_to_freeze = num_trans_layers_to_freeze
+ self.share_input_output_embed = share_input_output_embed
+ self.traceable = traceable
+ self.q_noise = q_noise
+ self.qn_block_size = qn_block_size
+
+ # These parameters are here for future extensions
+ # atm, the model only supports self attention
+ self.kdim = kdim
+ self.vdim = vdim
+ self.self_attention = self_attention
+ self.bias = bias
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ **kwargs,
+ )
+
+
+__all__ = ["GraphormerConfig"]
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/graphormer/modeling_graphormer.py b/docs/transformers/build/lib/transformers/models/deprecated/graphormer/modeling_graphormer.py
new file mode 100644
index 0000000000000000000000000000000000000000..1253d1365e8528da0ec4543d39d9a8f66f2255f7
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/graphormer/modeling_graphormer.py
@@ -0,0 +1,911 @@
+# coding=utf-8
+# Copyright 2022 Microsoft, clefourrier The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Graphormer model."""
+
+import math
+from typing import Iterable, Iterator, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from ....activations import ACT2FN
+from ....modeling_outputs import (
+ BaseModelOutputWithNoAttention,
+ SequenceClassifierOutput,
+)
+from ....modeling_utils import PreTrainedModel
+from ....utils import logging
+from .configuration_graphormer import GraphormerConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "graphormer-base-pcqm4mv1"
+_CONFIG_FOR_DOC = "GraphormerConfig"
+
+
+def quant_noise(module: nn.Module, p: float, block_size: int):
+ """
+ From:
+ https://github.com/facebookresearch/fairseq/blob/dd0079bde7f678b0cd0715cbd0ae68d661b7226d/fairseq/modules/quant_noise.py
+
+ Wraps modules and applies quantization noise to the weights for subsequent quantization with Iterative Product
+ Quantization as described in "Training with Quantization Noise for Extreme Model Compression"
+
+ Args:
+ - module: nn.Module
+ - p: amount of Quantization Noise
+ - block_size: size of the blocks for subsequent quantization with iPQ
+
+ Remarks:
+ - Module weights must have the right sizes wrt the block size
+ - Only Linear, Embedding and Conv2d modules are supported for the moment
+ - For more detail on how to quantize by blocks with convolutional weights, see "And the Bit Goes Down:
+ Revisiting the Quantization of Neural Networks"
+ - We implement the simplest form of noise here as stated in the paper which consists in randomly dropping
+ blocks
+ """
+
+ # if no quantization noise, don't register hook
+ if p <= 0:
+ return module
+
+ # supported modules
+ if not isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d)):
+ raise NotImplementedError("Module unsupported for quant_noise.")
+
+ # test whether module.weight has the right sizes wrt block_size
+ is_conv = module.weight.ndim == 4
+
+ # 2D matrix
+ if not is_conv:
+ if module.weight.size(1) % block_size != 0:
+ raise AssertionError("Input features must be a multiple of block sizes")
+
+ # 4D matrix
+ else:
+ # 1x1 convolutions
+ if module.kernel_size == (1, 1):
+ if module.in_channels % block_size != 0:
+ raise AssertionError("Input channels must be a multiple of block sizes")
+ # regular convolutions
+ else:
+ k = module.kernel_size[0] * module.kernel_size[1]
+ if k % block_size != 0:
+ raise AssertionError("Kernel size must be a multiple of block size")
+
+ def _forward_pre_hook(mod, input):
+ # no noise for evaluation
+ if mod.training:
+ if not is_conv:
+ # gather weight and sizes
+ weight = mod.weight
+ in_features = weight.size(1)
+ out_features = weight.size(0)
+
+ # split weight matrix into blocks and randomly drop selected blocks
+ mask = torch.zeros(in_features // block_size * out_features, device=weight.device)
+ mask.bernoulli_(p)
+ mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
+
+ else:
+ # gather weight and sizes
+ weight = mod.weight
+ in_channels = mod.in_channels
+ out_channels = mod.out_channels
+
+ # split weight matrix into blocks and randomly drop selected blocks
+ if mod.kernel_size == (1, 1):
+ mask = torch.zeros(
+ int(in_channels // block_size * out_channels),
+ device=weight.device,
+ )
+ mask.bernoulli_(p)
+ mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
+ else:
+ mask = torch.zeros(weight.size(0), weight.size(1), device=weight.device)
+ mask.bernoulli_(p)
+ mask = mask.unsqueeze(2).unsqueeze(3).repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
+
+ # scale weights and apply mask
+ mask = mask.to(torch.bool) # x.bool() is not currently supported in TorchScript
+ s = 1 / (1 - p)
+ mod.weight.data = s * weight.masked_fill(mask, 0)
+
+ module.register_forward_pre_hook(_forward_pre_hook)
+ return module
+
+
+class LayerDropModuleList(nn.ModuleList):
+ """
+ From:
+ https://github.com/facebookresearch/fairseq/blob/dd0079bde7f678b0cd0715cbd0ae68d661b7226d/fairseq/modules/layer_drop.py
+ A LayerDrop implementation based on [`torch.nn.ModuleList`]. LayerDrop as described in
+ https://arxiv.org/abs/1909.11556.
+
+ We refresh the choice of which layers to drop every time we iterate over the LayerDropModuleList instance. During
+ evaluation we always iterate over all layers.
+
+ Usage:
+
+ ```python
+ layers = LayerDropList(p=0.5, modules=[layer1, layer2, layer3])
+ for layer in layers: # this might iterate over layers 1 and 3
+ x = layer(x)
+ for layer in layers: # this might iterate over all layers
+ x = layer(x)
+ for layer in layers: # this might not iterate over any layers
+ x = layer(x)
+ ```
+
+ Args:
+ p (float): probability of dropping out each layer
+ modules (iterable, optional): an iterable of modules to add
+ """
+
+ def __init__(self, p: float, modules: Optional[Iterable[nn.Module]] = None):
+ super().__init__(modules)
+ self.p = p
+
+ def __iter__(self) -> Iterator[nn.Module]:
+ dropout_probs = torch.empty(len(self)).uniform_()
+ for i, m in enumerate(super().__iter__()):
+ if not self.training or (dropout_probs[i] > self.p):
+ yield m
+
+
+class GraphormerGraphNodeFeature(nn.Module):
+ """
+ Compute node features for each node in the graph.
+ """
+
+ def __init__(self, config: GraphormerConfig):
+ super().__init__()
+ self.num_heads = config.num_attention_heads
+ self.num_atoms = config.num_atoms
+
+ self.atom_encoder = nn.Embedding(config.num_atoms + 1, config.hidden_size, padding_idx=config.pad_token_id)
+ self.in_degree_encoder = nn.Embedding(
+ config.num_in_degree, config.hidden_size, padding_idx=config.pad_token_id
+ )
+ self.out_degree_encoder = nn.Embedding(
+ config.num_out_degree, config.hidden_size, padding_idx=config.pad_token_id
+ )
+
+ self.graph_token = nn.Embedding(1, config.hidden_size)
+
+ def forward(
+ self,
+ input_nodes: torch.LongTensor,
+ in_degree: torch.LongTensor,
+ out_degree: torch.LongTensor,
+ ) -> torch.Tensor:
+ n_graph, n_node = input_nodes.size()[:2]
+
+ node_feature = ( # node feature + graph token
+ self.atom_encoder(input_nodes).sum(dim=-2) # [n_graph, n_node, n_hidden]
+ + self.in_degree_encoder(in_degree)
+ + self.out_degree_encoder(out_degree)
+ )
+
+ graph_token_feature = self.graph_token.weight.unsqueeze(0).repeat(n_graph, 1, 1)
+
+ graph_node_feature = torch.cat([graph_token_feature, node_feature], dim=1)
+
+ return graph_node_feature
+
+
+class GraphormerGraphAttnBias(nn.Module):
+ """
+ Compute attention bias for each head.
+ """
+
+ def __init__(self, config: GraphormerConfig):
+ super().__init__()
+ self.num_heads = config.num_attention_heads
+ self.multi_hop_max_dist = config.multi_hop_max_dist
+
+ # We do not change edge feature embedding learning, as edge embeddings are represented as a combination of the original features
+ # + shortest path
+ self.edge_encoder = nn.Embedding(config.num_edges + 1, config.num_attention_heads, padding_idx=0)
+
+ self.edge_type = config.edge_type
+ if self.edge_type == "multi_hop":
+ self.edge_dis_encoder = nn.Embedding(
+ config.num_edge_dis * config.num_attention_heads * config.num_attention_heads,
+ 1,
+ )
+
+ self.spatial_pos_encoder = nn.Embedding(config.num_spatial, config.num_attention_heads, padding_idx=0)
+
+ self.graph_token_virtual_distance = nn.Embedding(1, config.num_attention_heads)
+
+ def forward(
+ self,
+ input_nodes: torch.LongTensor,
+ attn_bias: torch.Tensor,
+ spatial_pos: torch.LongTensor,
+ input_edges: torch.LongTensor,
+ attn_edge_type: torch.LongTensor,
+ ) -> torch.Tensor:
+ n_graph, n_node = input_nodes.size()[:2]
+ graph_attn_bias = attn_bias.clone()
+ graph_attn_bias = graph_attn_bias.unsqueeze(1).repeat(
+ 1, self.num_heads, 1, 1
+ ) # [n_graph, n_head, n_node+1, n_node+1]
+
+ # spatial pos
+ # [n_graph, n_node, n_node, n_head] -> [n_graph, n_head, n_node, n_node]
+ spatial_pos_bias = self.spatial_pos_encoder(spatial_pos).permute(0, 3, 1, 2)
+ graph_attn_bias[:, :, 1:, 1:] = graph_attn_bias[:, :, 1:, 1:] + spatial_pos_bias
+
+ # reset spatial pos here
+ t = self.graph_token_virtual_distance.weight.view(1, self.num_heads, 1)
+ graph_attn_bias[:, :, 1:, 0] = graph_attn_bias[:, :, 1:, 0] + t
+ graph_attn_bias[:, :, 0, :] = graph_attn_bias[:, :, 0, :] + t
+
+ # edge feature
+ if self.edge_type == "multi_hop":
+ spatial_pos_ = spatial_pos.clone()
+
+ spatial_pos_[spatial_pos_ == 0] = 1 # set pad to 1
+ # set 1 to 1, input_nodes > 1 to input_nodes - 1
+ spatial_pos_ = torch.where(spatial_pos_ > 1, spatial_pos_ - 1, spatial_pos_)
+ if self.multi_hop_max_dist > 0:
+ spatial_pos_ = spatial_pos_.clamp(0, self.multi_hop_max_dist)
+ input_edges = input_edges[:, :, :, : self.multi_hop_max_dist, :]
+ # [n_graph, n_node, n_node, max_dist, n_head]
+
+ input_edges = self.edge_encoder(input_edges).mean(-2)
+ max_dist = input_edges.size(-2)
+ edge_input_flat = input_edges.permute(3, 0, 1, 2, 4).reshape(max_dist, -1, self.num_heads)
+ edge_input_flat = torch.bmm(
+ edge_input_flat,
+ self.edge_dis_encoder.weight.reshape(-1, self.num_heads, self.num_heads)[:max_dist, :, :],
+ )
+ input_edges = edge_input_flat.reshape(max_dist, n_graph, n_node, n_node, self.num_heads).permute(
+ 1, 2, 3, 0, 4
+ )
+ input_edges = (input_edges.sum(-2) / (spatial_pos_.float().unsqueeze(-1))).permute(0, 3, 1, 2)
+ else:
+ # [n_graph, n_node, n_node, n_head] -> [n_graph, n_head, n_node, n_node]
+ input_edges = self.edge_encoder(attn_edge_type).mean(-2).permute(0, 3, 1, 2)
+
+ graph_attn_bias[:, :, 1:, 1:] = graph_attn_bias[:, :, 1:, 1:] + input_edges
+ graph_attn_bias = graph_attn_bias + attn_bias.unsqueeze(1) # reset
+
+ return graph_attn_bias
+
+
+class GraphormerMultiheadAttention(nn.Module):
+ """Multi-headed attention.
+
+ See "Attention Is All You Need" for more details.
+ """
+
+ def __init__(self, config: GraphormerConfig):
+ super().__init__()
+ self.embedding_dim = config.embedding_dim
+ self.kdim = config.kdim if config.kdim is not None else config.embedding_dim
+ self.vdim = config.vdim if config.vdim is not None else config.embedding_dim
+ self.qkv_same_dim = self.kdim == config.embedding_dim and self.vdim == config.embedding_dim
+
+ self.num_heads = config.num_attention_heads
+ self.attention_dropout_module = torch.nn.Dropout(p=config.attention_dropout, inplace=False)
+
+ self.head_dim = config.embedding_dim // config.num_attention_heads
+ if not (self.head_dim * config.num_attention_heads == self.embedding_dim):
+ raise AssertionError("The embedding_dim must be divisible by num_heads.")
+ self.scaling = self.head_dim**-0.5
+
+ self.self_attention = True # config.self_attention
+ if not (self.self_attention):
+ raise NotImplementedError("The Graphormer model only supports self attention for now.")
+ if self.self_attention and not self.qkv_same_dim:
+ raise AssertionError("Self-attention requires query, key and value to be of the same size.")
+
+ self.k_proj = quant_noise(
+ nn.Linear(self.kdim, config.embedding_dim, bias=config.bias),
+ config.q_noise,
+ config.qn_block_size,
+ )
+ self.v_proj = quant_noise(
+ nn.Linear(self.vdim, config.embedding_dim, bias=config.bias),
+ config.q_noise,
+ config.qn_block_size,
+ )
+ self.q_proj = quant_noise(
+ nn.Linear(config.embedding_dim, config.embedding_dim, bias=config.bias),
+ config.q_noise,
+ config.qn_block_size,
+ )
+
+ self.out_proj = quant_noise(
+ nn.Linear(config.embedding_dim, config.embedding_dim, bias=config.bias),
+ config.q_noise,
+ config.qn_block_size,
+ )
+
+ self.onnx_trace = False
+
+ def reset_parameters(self):
+ if self.qkv_same_dim:
+ # Empirically observed the convergence to be much better with
+ # the scaled initialization
+ nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
+ nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
+ nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
+ else:
+ nn.init.xavier_uniform_(self.k_proj.weight)
+ nn.init.xavier_uniform_(self.v_proj.weight)
+ nn.init.xavier_uniform_(self.q_proj.weight)
+
+ nn.init.xavier_uniform_(self.out_proj.weight)
+ if self.out_proj.bias is not None:
+ nn.init.constant_(self.out_proj.bias, 0.0)
+
+ def forward(
+ self,
+ query: torch.LongTensor,
+ key: Optional[torch.Tensor],
+ value: Optional[torch.Tensor],
+ attn_bias: Optional[torch.Tensor],
+ key_padding_mask: Optional[torch.Tensor] = None,
+ need_weights: bool = True,
+ attn_mask: Optional[torch.Tensor] = None,
+ before_softmax: bool = False,
+ need_head_weights: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ """
+ Args:
+ key_padding_mask (Bytetorch.Tensor, optional): mask to exclude
+ keys that are pads, of shape `(batch, src_len)`, where padding elements are indicated by 1s.
+ need_weights (bool, optional): return the attention weights,
+ averaged over heads (default: False).
+ attn_mask (Bytetorch.Tensor, optional): typically used to
+ implement causal attention, where the mask prevents the attention from looking forward in time
+ (default: None).
+ before_softmax (bool, optional): return the raw attention
+ weights and values before the attention softmax.
+ need_head_weights (bool, optional): return the attention
+ weights for each head. Implies *need_weights*. Default: return the average attention weights over all
+ heads.
+ """
+ if need_head_weights:
+ need_weights = True
+
+ tgt_len, bsz, embedding_dim = query.size()
+ src_len = tgt_len
+ if not (embedding_dim == self.embedding_dim):
+ raise AssertionError(
+ f"The query embedding dimension {embedding_dim} is not equal to the expected embedding_dim"
+ f" {self.embedding_dim}."
+ )
+ if not (list(query.size()) == [tgt_len, bsz, embedding_dim]):
+ raise AssertionError("Query size incorrect in Graphormer, compared to model dimensions.")
+
+ if key is not None:
+ src_len, key_bsz, _ = key.size()
+ if not torch.jit.is_scripting():
+ if (key_bsz != bsz) or (value is None) or not (src_len, bsz == value.shape[:2]):
+ raise AssertionError(
+ "The batch shape does not match the key or value shapes provided to the attention."
+ )
+
+ q = self.q_proj(query)
+ k = self.k_proj(query)
+ v = self.v_proj(query)
+
+ q *= self.scaling
+
+ q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
+ if k is not None:
+ k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
+ if v is not None:
+ v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
+
+ if (k is None) or not (k.size(1) == src_len):
+ raise AssertionError("The shape of the key generated in the attention is incorrect")
+
+ # This is part of a workaround to get around fork/join parallelism
+ # not supporting Optional types.
+ if key_padding_mask is not None and key_padding_mask.dim() == 0:
+ key_padding_mask = None
+
+ if key_padding_mask is not None:
+ if key_padding_mask.size(0) != bsz or key_padding_mask.size(1) != src_len:
+ raise AssertionError(
+ "The shape of the generated padding mask for the key does not match expected dimensions."
+ )
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
+ attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
+
+ if list(attn_weights.size()) != [bsz * self.num_heads, tgt_len, src_len]:
+ raise AssertionError("The attention weights generated do not match the expected dimensions.")
+
+ if attn_bias is not None:
+ attn_weights += attn_bias.view(bsz * self.num_heads, tgt_len, src_len)
+
+ if attn_mask is not None:
+ attn_mask = attn_mask.unsqueeze(0)
+ attn_weights += attn_mask
+
+ if key_padding_mask is not None:
+ # don't attend to padding symbols
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights.masked_fill(
+ key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf")
+ )
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ if before_softmax:
+ return attn_weights, v
+
+ attn_weights_float = torch.nn.functional.softmax(attn_weights, dim=-1)
+ attn_weights = attn_weights_float.type_as(attn_weights)
+ attn_probs = self.attention_dropout_module(attn_weights)
+
+ if v is None:
+ raise AssertionError("No value generated")
+ attn = torch.bmm(attn_probs, v)
+ if list(attn.size()) != [bsz * self.num_heads, tgt_len, self.head_dim]:
+ raise AssertionError("The attention generated do not match the expected dimensions.")
+
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embedding_dim)
+ attn: torch.Tensor = self.out_proj(attn)
+
+ attn_weights = None
+ if need_weights:
+ attn_weights = attn_weights_float.contiguous().view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
+ if not need_head_weights:
+ # average attention weights over heads
+ attn_weights = attn_weights.mean(dim=0)
+
+ return attn, attn_weights
+
+ def apply_sparse_mask(self, attn_weights: torch.Tensor, tgt_len: int, src_len: int, bsz: int) -> torch.Tensor:
+ return attn_weights
+
+
+class GraphormerGraphEncoderLayer(nn.Module):
+ def __init__(self, config: GraphormerConfig) -> None:
+ super().__init__()
+
+ # Initialize parameters
+ self.embedding_dim = config.embedding_dim
+ self.num_attention_heads = config.num_attention_heads
+ self.q_noise = config.q_noise
+ self.qn_block_size = config.qn_block_size
+ self.pre_layernorm = config.pre_layernorm
+
+ self.dropout_module = torch.nn.Dropout(p=config.dropout, inplace=False)
+
+ self.activation_dropout_module = torch.nn.Dropout(p=config.activation_dropout, inplace=False)
+
+ # Initialize blocks
+ self.activation_fn = ACT2FN[config.activation_fn]
+ self.self_attn = GraphormerMultiheadAttention(config)
+
+ # layer norm associated with the self attention layer
+ self.self_attn_layer_norm = nn.LayerNorm(self.embedding_dim)
+
+ self.fc1 = self.build_fc(
+ self.embedding_dim,
+ config.ffn_embedding_dim,
+ q_noise=config.q_noise,
+ qn_block_size=config.qn_block_size,
+ )
+ self.fc2 = self.build_fc(
+ config.ffn_embedding_dim,
+ self.embedding_dim,
+ q_noise=config.q_noise,
+ qn_block_size=config.qn_block_size,
+ )
+
+ # layer norm associated with the position wise feed-forward NN
+ self.final_layer_norm = nn.LayerNorm(self.embedding_dim)
+
+ def build_fc(
+ self, input_dim: int, output_dim: int, q_noise: float, qn_block_size: int
+ ) -> Union[nn.Module, nn.Linear, nn.Embedding, nn.Conv2d]:
+ return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size)
+
+ def forward(
+ self,
+ input_nodes: torch.Tensor,
+ self_attn_bias: Optional[torch.Tensor] = None,
+ self_attn_mask: Optional[torch.Tensor] = None,
+ self_attn_padding_mask: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ """
+ nn.LayerNorm is applied either before or after the self-attention/ffn modules similar to the original
+ Transformer implementation.
+ """
+ residual = input_nodes
+ if self.pre_layernorm:
+ input_nodes = self.self_attn_layer_norm(input_nodes)
+
+ input_nodes, attn = self.self_attn(
+ query=input_nodes,
+ key=input_nodes,
+ value=input_nodes,
+ attn_bias=self_attn_bias,
+ key_padding_mask=self_attn_padding_mask,
+ need_weights=False,
+ attn_mask=self_attn_mask,
+ )
+ input_nodes = self.dropout_module(input_nodes)
+ input_nodes = residual + input_nodes
+ if not self.pre_layernorm:
+ input_nodes = self.self_attn_layer_norm(input_nodes)
+
+ residual = input_nodes
+ if self.pre_layernorm:
+ input_nodes = self.final_layer_norm(input_nodes)
+ input_nodes = self.activation_fn(self.fc1(input_nodes))
+ input_nodes = self.activation_dropout_module(input_nodes)
+ input_nodes = self.fc2(input_nodes)
+ input_nodes = self.dropout_module(input_nodes)
+ input_nodes = residual + input_nodes
+ if not self.pre_layernorm:
+ input_nodes = self.final_layer_norm(input_nodes)
+
+ return input_nodes, attn
+
+
+class GraphormerGraphEncoder(nn.Module):
+ def __init__(self, config: GraphormerConfig):
+ super().__init__()
+
+ self.dropout_module = torch.nn.Dropout(p=config.dropout, inplace=False)
+ self.layerdrop = config.layerdrop
+ self.embedding_dim = config.embedding_dim
+ self.apply_graphormer_init = config.apply_graphormer_init
+ self.traceable = config.traceable
+
+ self.graph_node_feature = GraphormerGraphNodeFeature(config)
+ self.graph_attn_bias = GraphormerGraphAttnBias(config)
+
+ self.embed_scale = config.embed_scale
+
+ if config.q_noise > 0:
+ self.quant_noise = quant_noise(
+ nn.Linear(self.embedding_dim, self.embedding_dim, bias=False),
+ config.q_noise,
+ config.qn_block_size,
+ )
+ else:
+ self.quant_noise = None
+
+ if config.encoder_normalize_before:
+ self.emb_layer_norm = nn.LayerNorm(self.embedding_dim)
+ else:
+ self.emb_layer_norm = None
+
+ if config.pre_layernorm:
+ self.final_layer_norm = nn.LayerNorm(self.embedding_dim)
+
+ if self.layerdrop > 0.0:
+ self.layers = LayerDropModuleList(p=self.layerdrop)
+ else:
+ self.layers = nn.ModuleList([])
+ self.layers.extend([GraphormerGraphEncoderLayer(config) for _ in range(config.num_hidden_layers)])
+
+ # Apply initialization of model params after building the model
+ if config.freeze_embeddings:
+ raise NotImplementedError("Freezing embeddings is not implemented yet.")
+
+ for layer in range(config.num_trans_layers_to_freeze):
+ m = self.layers[layer]
+ if m is not None:
+ for p in m.parameters():
+ p.requires_grad = False
+
+ def forward(
+ self,
+ input_nodes: torch.LongTensor,
+ input_edges: torch.LongTensor,
+ attn_bias: torch.Tensor,
+ in_degree: torch.LongTensor,
+ out_degree: torch.LongTensor,
+ spatial_pos: torch.LongTensor,
+ attn_edge_type: torch.LongTensor,
+ perturb=None,
+ last_state_only: bool = False,
+ token_embeddings: Optional[torch.Tensor] = None,
+ attn_mask: Optional[torch.Tensor] = None,
+ ) -> Tuple[Union[torch.Tensor, List[torch.LongTensor]], torch.Tensor]:
+ # compute padding mask. This is needed for multi-head attention
+ data_x = input_nodes
+ n_graph, n_node = data_x.size()[:2]
+ padding_mask = (data_x[:, :, 0]).eq(0)
+ padding_mask_cls = torch.zeros(n_graph, 1, device=padding_mask.device, dtype=padding_mask.dtype)
+ padding_mask = torch.cat((padding_mask_cls, padding_mask), dim=1)
+
+ attn_bias = self.graph_attn_bias(input_nodes, attn_bias, spatial_pos, input_edges, attn_edge_type)
+
+ if token_embeddings is not None:
+ input_nodes = token_embeddings
+ else:
+ input_nodes = self.graph_node_feature(input_nodes, in_degree, out_degree)
+
+ if perturb is not None:
+ input_nodes[:, 1:, :] += perturb
+
+ if self.embed_scale is not None:
+ input_nodes = input_nodes * self.embed_scale
+
+ if self.quant_noise is not None:
+ input_nodes = self.quant_noise(input_nodes)
+
+ if self.emb_layer_norm is not None:
+ input_nodes = self.emb_layer_norm(input_nodes)
+
+ input_nodes = self.dropout_module(input_nodes)
+
+ input_nodes = input_nodes.transpose(0, 1)
+
+ inner_states = []
+ if not last_state_only:
+ inner_states.append(input_nodes)
+
+ for layer in self.layers:
+ input_nodes, _ = layer(
+ input_nodes,
+ self_attn_padding_mask=padding_mask,
+ self_attn_mask=attn_mask,
+ self_attn_bias=attn_bias,
+ )
+ if not last_state_only:
+ inner_states.append(input_nodes)
+
+ graph_rep = input_nodes[0, :, :]
+
+ if last_state_only:
+ inner_states = [input_nodes]
+
+ if self.traceable:
+ return torch.stack(inner_states), graph_rep
+ else:
+ return inner_states, graph_rep
+
+
+class GraphormerDecoderHead(nn.Module):
+ def __init__(self, embedding_dim: int, num_classes: int):
+ super().__init__()
+ """num_classes should be 1 for regression, or the number of classes for classification"""
+ self.lm_output_learned_bias = nn.Parameter(torch.zeros(1))
+ self.classifier = nn.Linear(embedding_dim, num_classes, bias=False)
+ self.num_classes = num_classes
+
+ def forward(self, input_nodes: torch.Tensor, **unused) -> torch.Tensor:
+ input_nodes = self.classifier(input_nodes)
+ input_nodes = input_nodes + self.lm_output_learned_bias
+ return input_nodes
+
+
+class GraphormerPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = GraphormerConfig
+ base_model_prefix = "graphormer"
+ main_input_name_nodes = "input_nodes"
+ main_input_name_edges = "input_edges"
+
+ def normal_(self, data: torch.Tensor):
+ # with FSDP, module params will be on CUDA, so we cast them back to CPU
+ # so that the RNG is consistent with and without FSDP
+ data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
+
+ def init_graphormer_params(self, module: Union[nn.Linear, nn.Embedding, GraphormerMultiheadAttention]):
+ """
+ Initialize the weights specific to the Graphormer Model.
+ """
+ if isinstance(module, nn.Linear):
+ self.normal_(module.weight.data)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ if isinstance(module, nn.Embedding):
+ self.normal_(module.weight.data)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ if isinstance(module, GraphormerMultiheadAttention):
+ self.normal_(module.q_proj.weight.data)
+ self.normal_(module.k_proj.weight.data)
+ self.normal_(module.v_proj.weight.data)
+
+ def _init_weights(
+ self,
+ module: Union[
+ nn.Linear, nn.Conv2d, nn.Embedding, nn.LayerNorm, GraphormerMultiheadAttention, GraphormerGraphEncoder
+ ],
+ ):
+ """
+ Initialize the weights
+ """
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
+ # We might be missing part of the Linear init, dependant on the layer num
+ module.weight.data.normal_(mean=0.0, std=0.02)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=0.02)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, GraphormerMultiheadAttention):
+ module.q_proj.weight.data.normal_(mean=0.0, std=0.02)
+ module.k_proj.weight.data.normal_(mean=0.0, std=0.02)
+ module.v_proj.weight.data.normal_(mean=0.0, std=0.02)
+ module.reset_parameters()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, GraphormerGraphEncoder):
+ if module.apply_graphormer_init:
+ module.apply(self.init_graphormer_params)
+
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+
+class GraphormerModel(GraphormerPreTrainedModel):
+ """The Graphormer model is a graph-encoder model.
+
+ It goes from a graph to its representation. If you want to use the model for a downstream classification task, use
+ GraphormerForGraphClassification instead. For any other downstream task, feel free to add a new class, or combine
+ this model with a downstream model of your choice, following the example in GraphormerForGraphClassification.
+ """
+
+ def __init__(self, config: GraphormerConfig):
+ super().__init__(config)
+ self.max_nodes = config.max_nodes
+
+ self.graph_encoder = GraphormerGraphEncoder(config)
+
+ self.share_input_output_embed = config.share_input_output_embed
+ self.lm_output_learned_bias = None
+
+ # Remove head is set to true during fine-tuning
+ self.load_softmax = not getattr(config, "remove_head", False)
+
+ self.lm_head_transform_weight = nn.Linear(config.embedding_dim, config.embedding_dim)
+ self.activation_fn = ACT2FN[config.activation_fn]
+ self.layer_norm = nn.LayerNorm(config.embedding_dim)
+
+ self.post_init()
+
+ def reset_output_layer_parameters(self):
+ self.lm_output_learned_bias = nn.Parameter(torch.zeros(1))
+
+ def forward(
+ self,
+ input_nodes: torch.LongTensor,
+ input_edges: torch.LongTensor,
+ attn_bias: torch.Tensor,
+ in_degree: torch.LongTensor,
+ out_degree: torch.LongTensor,
+ spatial_pos: torch.LongTensor,
+ attn_edge_type: torch.LongTensor,
+ perturb: Optional[torch.FloatTensor] = None,
+ masked_tokens: None = None,
+ return_dict: Optional[bool] = None,
+ **unused,
+ ) -> Union[Tuple[torch.LongTensor], BaseModelOutputWithNoAttention]:
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ inner_states, graph_rep = self.graph_encoder(
+ input_nodes, input_edges, attn_bias, in_degree, out_degree, spatial_pos, attn_edge_type, perturb=perturb
+ )
+
+ # last inner state, then revert Batch and Graph len
+ input_nodes = inner_states[-1].transpose(0, 1)
+
+ # project masked tokens only
+ if masked_tokens is not None:
+ raise NotImplementedError
+
+ input_nodes = self.layer_norm(self.activation_fn(self.lm_head_transform_weight(input_nodes)))
+
+ # project back to size of vocabulary
+ if self.share_input_output_embed and hasattr(self.graph_encoder.embed_tokens, "weight"):
+ input_nodes = torch.nn.functional.linear(input_nodes, self.graph_encoder.embed_tokens.weight)
+
+ if not return_dict:
+ return tuple(x for x in [input_nodes, inner_states] if x is not None)
+ return BaseModelOutputWithNoAttention(last_hidden_state=input_nodes, hidden_states=inner_states)
+
+ def max_nodes(self):
+ """Maximum output length supported by the encoder."""
+ return self.max_nodes
+
+
+class GraphormerForGraphClassification(GraphormerPreTrainedModel):
+ """
+ This model can be used for graph-level classification or regression tasks.
+
+ It can be trained on
+ - regression (by setting config.num_classes to 1); there should be one float-type label per graph
+ - one task classification (by setting config.num_classes to the number of classes); there should be one integer
+ label per graph
+ - binary multi-task classification (by setting config.num_classes to the number of labels); there should be a list
+ of integer labels for each graph.
+ """
+
+ def __init__(self, config: GraphormerConfig):
+ super().__init__(config)
+ self.encoder = GraphormerModel(config)
+ self.embedding_dim = config.embedding_dim
+ self.num_classes = config.num_classes
+ self.classifier = GraphormerDecoderHead(self.embedding_dim, self.num_classes)
+ self.is_encoder_decoder = True
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def forward(
+ self,
+ input_nodes: torch.LongTensor,
+ input_edges: torch.LongTensor,
+ attn_bias: torch.Tensor,
+ in_degree: torch.LongTensor,
+ out_degree: torch.LongTensor,
+ spatial_pos: torch.LongTensor,
+ attn_edge_type: torch.LongTensor,
+ labels: Optional[torch.LongTensor] = None,
+ return_dict: Optional[bool] = None,
+ **unused,
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ encoder_outputs = self.encoder(
+ input_nodes,
+ input_edges,
+ attn_bias,
+ in_degree,
+ out_degree,
+ spatial_pos,
+ attn_edge_type,
+ return_dict=True,
+ )
+ outputs, hidden_states = encoder_outputs["last_hidden_state"], encoder_outputs["hidden_states"]
+
+ head_outputs = self.classifier(outputs)
+ logits = head_outputs[:, 0, :].contiguous()
+
+ loss = None
+ if labels is not None:
+ mask = ~torch.isnan(labels)
+
+ if self.num_classes == 1: # regression
+ loss_fct = MSELoss()
+ loss = loss_fct(logits[mask].squeeze(), labels[mask].squeeze().float())
+ elif self.num_classes > 1 and len(labels.shape) == 1: # One task classification
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits[mask].view(-1, self.num_classes), labels[mask].view(-1))
+ else: # Binary multi-task classification
+ loss_fct = BCEWithLogitsLoss(reduction="sum")
+ loss = loss_fct(logits[mask], labels[mask])
+
+ if not return_dict:
+ return tuple(x for x in [loss, logits, hidden_states] if x is not None)
+ return SequenceClassifierOutput(loss=loss, logits=logits, hidden_states=hidden_states, attentions=None)
+
+
+__all__ = ["GraphormerForGraphClassification", "GraphormerModel", "GraphormerPreTrainedModel"]
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/jukebox/__init__.py b/docs/transformers/build/lib/transformers/models/deprecated/jukebox/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..826bdbddc1f182de43a795ab0d78ad4009507c14
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/jukebox/__init__.py
@@ -0,0 +1,28 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ....utils import _LazyModule
+from ....utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_jukebox import *
+ from .modeling_jukebox import *
+ from .tokenization_jukebox import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/jukebox/configuration_jukebox.py b/docs/transformers/build/lib/transformers/models/deprecated/jukebox/configuration_jukebox.py
new file mode 100644
index 0000000000000000000000000000000000000000..d10cbc2d82cfbd310cff2b84b54de4d693cea1c7
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/jukebox/configuration_jukebox.py
@@ -0,0 +1,613 @@
+# coding=utf-8
+# Copyright 2022 The OpenAI Team Authors and HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Jukebox configuration"""
+
+import os
+from typing import List, Union
+
+from ....configuration_utils import PretrainedConfig
+from ....utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+_LARGE_ATTENTION = [
+ "block_attn",
+ "transpose_block_attn",
+ "prev_block_attn",
+ "block_attn",
+ "transpose_block_attn",
+ "prev_block_attn",
+ "block_attn",
+ "transpose_block_attn",
+ "prev_block_attn",
+ "block_attn",
+ "transpose_block_attn",
+ "prev_block_attn",
+ "block_attn",
+ "transpose_block_attn",
+ "prev_block_attn",
+ "block_attn",
+ "transpose_block_attn",
+ "prev_block_attn",
+ "cross_attention",
+ "block_attn",
+ "transpose_block_attn",
+ "prev_block_attn",
+ "block_attn",
+ "transpose_block_attn",
+ "prev_block_attn",
+ "block_attn",
+ "transpose_block_attn",
+ "prev_block_attn",
+ "cross_attention",
+ "block_attn",
+ "transpose_block_attn",
+ "prev_block_attn",
+ "block_attn",
+ "transpose_block_attn",
+ "prev_block_attn",
+ "block_attn",
+ "transpose_block_attn",
+ "prev_block_attn",
+ "cross_attention",
+ "block_attn",
+ "transpose_block_attn",
+ "prev_block_attn",
+ "block_attn",
+ "transpose_block_attn",
+ "prev_block_attn",
+ "block_attn",
+ "transpose_block_attn",
+ "prev_block_attn",
+ "cross_attention",
+ "block_attn",
+ "transpose_block_attn",
+ "prev_block_attn",
+ "block_attn",
+ "transpose_block_attn",
+ "prev_block_attn",
+ "block_attn",
+ "transpose_block_attn",
+ "prev_block_attn",
+ "cross_attention",
+ "block_attn",
+ "transpose_block_attn",
+ "prev_block_attn",
+ "block_attn",
+ "transpose_block_attn",
+ "prev_block_attn",
+ "block_attn",
+ "transpose_block_attn",
+ "prev_block_attn",
+ "cross_attention",
+ "block_attn",
+ "transpose_block_attn",
+ "prev_block_attn",
+ "block_attn",
+ "transpose_block_attn",
+ "prev_block_attn",
+ "block_attn",
+ "transpose_block_attn",
+ "prev_block_attn",
+ "cross_attention",
+]
+_RawColumnPreviousRowAttention = ["block_attn", "transpose_block_attn", "prev_block_attn"]
+_FullDenseAttention = ["dense_attention"]
+_PrimePrimeDenseAttention = ["prime_attn", "prime_attn", "dense_attn"]
+
+
+def full_dense_attention(layer):
+ return _FullDenseAttention[0]
+
+
+def raw_column_previous_row_attention(layer):
+ return _RawColumnPreviousRowAttention[layer % 3]
+
+
+def large_separated_enc_dec_w_lyrics(layer):
+ return _LARGE_ATTENTION[layer % 79]
+
+
+def enc_dec_with_lyrics(layer):
+ if layer % 16 == 15:
+ return _PrimePrimeDenseAttention[layer % 3]
+ return _RawColumnPreviousRowAttention[layer % 3]
+
+
+ATTENTION_PATTERNS = {
+ "full_dense_attention": full_dense_attention,
+ "raw_column_previous_row_attention": raw_column_previous_row_attention, # Alternate row, column and previous row attn
+ "large_separated_enc_dec_w_lyrics": large_separated_enc_dec_w_lyrics, # Used by large separated_enc_dec model with lyrics
+ "enc_dec_with_lyrics": enc_dec_with_lyrics, # Used by encoder_decoder model with lyrics
+}
+
+
+class JukeboxPriorConfig(PretrainedConfig):
+ """
+ This is the configuration class to store the configuration of a [`JukeboxPrior`]. It is used to instantiate a
+ `JukeboxPrior` according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the top level prior from the
+ [openai/jukebox-1b-lyrics](https://huggingface.co/openai/jukebox
+ -1b-lyrics) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+
+ Args:
+ act_fn (`str`, *optional*, defaults to `"quick_gelu"`):
+ Activation function.
+ alignment_head (`int`, *optional*, defaults to 2):
+ Head that is responsible of the alignment between lyrics and music. Only used to compute the lyric to audio
+ alignment
+ alignment_layer (`int`, *optional*, defaults to 68):
+ Index of the layer that is responsible of the alignment between lyrics and music. Only used to compute the
+ lyric to audio alignment
+ attention_multiplier (`float`, *optional*, defaults to 0.25):
+ Multiplier coefficient used to define the hidden dimension of the attention layers. 0.25 means that
+ 0.25*width of the model will be used.
+ attention_pattern (`str`, *optional*, defaults to `"enc_dec_with_lyrics"`):
+ Which attention pattern to use for the decoder/
+ attn_dropout (`int`, *optional*, defaults to 0):
+ Dropout probability for the post-attention layer dropout in the decoder.
+ attn_res_scale (`bool`, *optional*, defaults to `False`):
+ Whether or not to scale the residuals in the attention conditioner block.
+ blocks (`int`, *optional*, defaults to 64):
+ Number of blocks used in the `block_attn`. A sequence of length seq_len is factored as `[blocks, seq_len //
+ blocks]` in the `JukeboxAttention` layer.
+ conv_res_scale (`int`, *optional*):
+ Whether or not to scale the residuals in the conditioner block. Since the top level prior does not have a
+ conditioner, the default value is to None and should not be modified.
+ num_layers (`int`, *optional*, defaults to 72):
+ Number of layers of the transformer architecture.
+ emb_dropout (`int`, *optional*, defaults to 0):
+ Embedding dropout used in the lyric decoder.
+ encoder_config (`JukeboxPriorConfig`, *optional*) :
+ Configuration of the encoder which models the prior on the lyrics.
+ encoder_loss_fraction (`float`, *optional*, defaults to 0.4):
+ Multiplication factor used in front of the lyric encoder loss.
+ hidden_size (`int`, *optional*, defaults to 2048):
+ Hidden dimension of the attention layers.
+ init_scale (`float`, *optional*, defaults to 0.2):
+ Initialization scales for the prior modules.
+ is_encoder_decoder (`bool`, *optional*, defaults to `True`):
+ Whether or not the prior is an encoder-decoder model. In case it is not, and `nb_relevant_lyric_tokens` is
+ greater than 0, the `encoder` args should be specified for the lyric encoding.
+ mask (`bool`, *optional*, defaults to `False`):
+ Whether or not to mask the previous positions in the attention.
+ max_duration (`int`, *optional*, defaults to 600):
+ Maximum supported duration of the generated song in seconds.
+ max_nb_genres (`int`, *optional*, defaults to 1):
+ Maximum number of genres that can be used to condition the model.
+ merged_decoder (`bool`, *optional*, defaults to `True`):
+ Whether or not the decoder and the encoder inputs are merged. This is used for the separated
+ encoder-decoder architecture
+ metadata_conditioning (`bool`, *optional*, defaults to `True)`:
+ Whether or not to condition on the artist and genre metadata.
+ metadata_dims (`List[int]`, *optional*, defaults to `[604, 7898]`):
+ Number of genres and the number of artists that were used to train the embedding layers of the prior
+ models.
+ min_duration (`int`, *optional*, defaults to 0):
+ Minimum duration of the generated audio on which the model was trained.
+ mlp_multiplier (`float`, *optional*, defaults to 1.0):
+ Multiplier coefficient used to define the hidden dimension of the MLP layers. 0.25 means that 0.25*width of
+ the model will be used.
+ music_vocab_size (`int`, *optional*, defaults to 2048):
+ Number of different music tokens. Should be similar to the `JukeboxVQVAEConfig.nb_discrete_codes`.
+ n_ctx (`int`, *optional*, defaults to 6144):
+ Number of context tokens for each prior. The context tokens are the music tokens that are attended to when
+ generating music tokens.
+ n_heads (`int`, *optional*, defaults to 2):
+ Number of attention heads.
+ nb_relevant_lyric_tokens (`int`, *optional*, defaults to 384):
+ Number of lyric tokens that are used when sampling a single window of length `n_ctx`
+ res_conv_depth (`int`, *optional*, defaults to 3):
+ Depth of the `JukeboxDecoderConvBock` used to upsample the previously sampled audio in the
+ `JukeboxMusicTokenConditioner`.
+ res_conv_width (`int`, *optional*, defaults to 128):
+ Width of the `JukeboxDecoderConvBock` used to upsample the previously sampled audio in the
+ `JukeboxMusicTokenConditioner`.
+ res_convolution_multiplier (`int`, *optional*, defaults to 1):
+ Multiplier used to scale the `hidden_dim` of the `JukeboxResConv1DBlock`.
+ res_dilation_cycle (`int`, *optional*):
+ Dilation cycle used to define the `JukeboxMusicTokenConditioner`. Usually similar to the ones used in the
+ corresponding level of the VQVAE. The first prior does not use it as it is not conditioned on upper level
+ tokens.
+ res_dilation_growth_rate (`int`, *optional*, defaults to 1):
+ Dilation grow rate used between each convolutionnal block of the `JukeboxMusicTokenConditioner`
+ res_downs_t (`List[int]`, *optional*, defaults to `[3, 2, 2]`):
+ Downsampling rates used in the audio conditioning network
+ res_strides_t (`List[int]`, *optional*, defaults to `[2, 2, 2]`):
+ Striding used in the audio conditioning network
+ resid_dropout (`int`, *optional*, defaults to 0):
+ Residual dropout used in the attention pattern.
+ sampling_rate (`int`, *optional*, defaults to 44100):
+ Sampling rate used for training.
+ spread (`int`, *optional*):
+ Spread used in the `summary_spread_attention` pattern
+ timing_dims (`int`, *optional*, defaults to 64):
+ Dimension of the timing embedding.
+ zero_out (`bool`, *optional*, defaults to `False`):
+ Whether or not to zero out convolution weights when initializing.
+ """
+
+ model_type = "jukebox_prior"
+ attribute_map = {
+ "max_position_embeddings": "n_positions",
+ "num_attention_heads": "n_head",
+ }
+
+ def __init__(
+ self,
+ act_fn="quick_gelu",
+ level=0,
+ alignment_head=2,
+ alignment_layer=68,
+ attention_multiplier=0.25,
+ attention_pattern="enc_dec_with_lyrics",
+ attn_dropout=0,
+ attn_res_scale=False,
+ blocks=64,
+ conv_res_scale=None,
+ num_layers=72,
+ emb_dropout=0,
+ encoder_config=None,
+ encoder_loss_fraction=0.4,
+ hidden_size=2048,
+ init_scale=0.2,
+ is_encoder_decoder=True,
+ lyric_vocab_size=80,
+ mask=False,
+ max_duration=600,
+ max_nb_genres=1,
+ merged_decoder=True,
+ metadata_conditioning=True,
+ metadata_dims=[604, 7898],
+ min_duration=0,
+ mlp_multiplier=1.0,
+ music_vocab_size=2048,
+ n_ctx=6144,
+ n_heads=2,
+ nb_relevant_lyric_tokens=384,
+ res_conv_depth=3,
+ res_conv_width=128,
+ res_convolution_multiplier=1,
+ res_dilation_cycle=None,
+ res_dilation_growth_rate=1,
+ res_downs_t=[3, 2, 2],
+ res_strides_t=[2, 2, 2],
+ resid_dropout=0,
+ sampling_rate=44100,
+ spread=None,
+ timing_dims=64,
+ zero_out=False,
+ **kwargs,
+ ):
+ self.act_fn = act_fn
+ self.alignment_head = alignment_head
+ self.alignment_layer = alignment_layer
+ self.attention_multiplier = attention_multiplier
+ self.attention_pattern = attention_pattern
+ self.attn_dropout = attn_dropout
+ self.attn_res_scale = attn_res_scale
+ self.blocks = blocks
+ self.conv_res_scale = conv_res_scale
+ self.num_layers = num_layers
+ self.emb_dropout = emb_dropout
+ self.music_vocab_size = music_vocab_size
+ if encoder_config is not None:
+ self.encoder_config = JukeboxPriorConfig(**encoder_config)
+ else:
+ self.encoder_config = None
+ self.encoder_loss_fraction = encoder_loss_fraction
+ self.init_scale = init_scale
+ self.is_encoder_decoder = is_encoder_decoder
+ self.lyric_vocab_size = lyric_vocab_size
+ self.level = level
+ self.mask = mask
+ self.max_duration = max_duration
+ self.max_nb_genres = max_nb_genres
+ self.merged_decoder = merged_decoder
+ self.metadata_conditioning = metadata_conditioning
+ self.metadata_dims = metadata_dims
+ self.min_duration = min_duration
+ self.mlp_multiplier = mlp_multiplier
+ self.n_ctx = n_ctx
+ self.n_heads = n_heads
+ self.nb_relevant_lyric_tokens = nb_relevant_lyric_tokens
+ self.res_conv_depth = res_conv_depth
+ self.res_conv_width = res_conv_width
+ self.res_convolution_multiplier = res_convolution_multiplier
+ self.res_dilation_cycle = res_dilation_cycle
+ self.res_dilation_growth_rate = res_dilation_growth_rate
+ self.res_downs_t = res_downs_t
+ self.res_strides_t = res_strides_t
+ self.resid_dropout = resid_dropout
+ self.sampling_rate = sampling_rate
+ self.spread = spread
+ self.timing_dims = timing_dims
+ self.hidden_size = hidden_size
+ self.zero_out = zero_out
+
+ @classmethod
+ def from_pretrained(
+ cls, pretrained_model_name_or_path: Union[str, os.PathLike], level=0, **kwargs
+ ) -> "PretrainedConfig":
+ cls._set_token_in_kwargs(kwargs)
+
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
+
+ # get the prior config dict if we are loading from JukeboxConfig
+ if config_dict.get("model_type") == "jukebox":
+ config_dict = config_dict[f"prior_{level}"]
+
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
+ logger.warning(
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
+ )
+
+ return cls.from_dict(config_dict, **kwargs)
+
+
+class JukeboxVQVAEConfig(PretrainedConfig):
+ """
+ This is the configuration class to store the configuration of a [`JukeboxVQVAE`]. It is used to instantiate a
+ `JukeboxVQVAE` according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the VQVAE from
+ [openai/jukebox-1b-lyrics](https://huggingface.co/openai/jukebox-1b-lyrics) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ act_fn (`str`, *optional*, defaults to `"relu"`):
+ Activation function of the model.
+ nb_discrete_codes (`int`, *optional*, defaults to 2048):
+ Number of codes of the VQVAE.
+ commit (`float`, *optional*, defaults to 0.02):
+ Commit loss multiplier.
+ conv_input_shape (`int`, *optional*, defaults to 1):
+ Number of audio channels.
+ conv_res_scale (`bool`, *optional*, defaults to `False`):
+ Whether or not to scale the residuals of the `JukeboxResConv1DBlock`.
+ embed_dim (`int`, *optional*, defaults to 64):
+ Embedding dimension of the codebook vectors.
+ hop_fraction (`List[int]`, *optional*, defaults to `[0.125, 0.5, 0.5]`):
+ Fraction of non-intersecting window used when continuing the sampling process.
+ levels (`int`, *optional*, defaults to 3):
+ Number of hierarchical levels that used in the VQVAE.
+ lmu (`float`, *optional*, defaults to 0.99):
+ Used in the codebook update, exponential moving average coefficient. For more detail refer to Appendix A.1
+ of the original [VQVAE paper](https://arxiv.org/pdf/1711.00937v2.pdf)
+ multipliers (`List[int]`, *optional*, defaults to `[2, 1, 1]`):
+ Depth and width multipliers used for each level. Used on the `res_conv_width` and `res_conv_depth`
+ res_conv_depth (`int`, *optional*, defaults to 4):
+ Depth of the encoder and decoder block. If no `multipliers` are used, this is the same for each level.
+ res_conv_width (`int`, *optional*, defaults to 32):
+ Width of the encoder and decoder block. If no `multipliers` are used, this is the same for each level.
+ res_convolution_multiplier (`int`, *optional*, defaults to 1):
+ Scaling factor of the hidden dimension used in the `JukeboxResConv1DBlock`.
+ res_dilation_cycle (`int`, *optional*):
+ Dilation cycle value used in the `JukeboxResnet`. If an int is used, each new Conv1 block will have a depth
+ reduced by a power of `res_dilation_cycle`.
+ res_dilation_growth_rate (`int`, *optional*, defaults to 3):
+ Resnet dilation growth rate used in the VQVAE (dilation_growth_rate ** depth)
+ res_downs_t (`List[int]`, *optional*, defaults to `[3, 2, 2]`):
+ Downsampling rate for each level of the hierarchical VQ-VAE.
+ res_strides_t (`List[int]`, *optional*, defaults to `[2, 2, 2]`):
+ Stride used for each level of the hierarchical VQ-VAE.
+ sample_length (`int`, *optional*, defaults to 1058304):
+ Provides the max input shape of the VQVAE. Is used to compute the input shape of each level.
+ init_scale (`float`, *optional*, defaults to 0.2):
+ Initialization scale.
+ zero_out (`bool`, *optional*, defaults to `False`):
+ Whether or not to zero out convolution weights when initializing.
+ """
+
+ model_type = "jukebox_vqvae"
+
+ def __init__(
+ self,
+ act_fn="relu",
+ nb_discrete_codes=2048,
+ commit=0.02,
+ conv_input_shape=1,
+ conv_res_scale=False,
+ embed_dim=64,
+ hop_fraction=[0.125, 0.5, 0.5],
+ levels=3,
+ lmu=0.99,
+ multipliers=[2, 1, 1],
+ res_conv_depth=4,
+ res_conv_width=32,
+ res_convolution_multiplier=1,
+ res_dilation_cycle=None,
+ res_dilation_growth_rate=3,
+ res_downs_t=[3, 2, 2],
+ res_strides_t=[2, 2, 2],
+ sample_length=1058304,
+ init_scale=0.2,
+ zero_out=False,
+ **kwargs,
+ ):
+ self.hop_fraction = hop_fraction
+ self.conv_input_shape = conv_input_shape
+ self.sample_length = sample_length
+
+ # VQVAE parameters (all used)
+ self.levels = levels
+ self.embed_dim = embed_dim
+ self.nb_discrete_codes = nb_discrete_codes
+ self.res_conv_width = res_conv_width
+ self.res_conv_depth = res_conv_depth
+ self.res_convolution_multiplier = res_convolution_multiplier
+ self.res_dilation_growth_rate = res_dilation_growth_rate
+ self.res_dilation_cycle = res_dilation_cycle
+ self.multipliers = multipliers
+ self.res_downs_t = res_downs_t
+ self.res_strides_t = res_strides_t
+ self.lmu = lmu
+ self.commit = commit
+ self.conv_res_scale = conv_res_scale
+ self.act_fn = act_fn
+ self.init_scale = init_scale
+ self.zero_out = zero_out
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
+ cls._set_token_in_kwargs(kwargs)
+
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
+
+ # get the text config dict if we are loading from CLIPConfig
+ if config_dict.get("model_type") == "jukebox":
+ config_dict = config_dict["vqvae_config"]
+
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
+ logger.warning(
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
+ )
+
+ return cls.from_dict(config_dict, **kwargs)
+
+
+class JukeboxConfig(PretrainedConfig):
+ """
+ This is the configuration class to store the configuration of a [`JukeboxModel`].
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information. Instantiating a configuration with the defaults will
+ yield a similar configuration to that of
+ [openai/jukebox-1b-lyrics](https://huggingface.co/openai/jukebox-1b-lyrics) architecture.
+
+
+ The downsampling and stride are used to determine downsampling of the input sequence. For example, downsampling =
+ (5,3), and strides = (2, 2) will downsample the audio by 2^5 = 32 to get the first level of codes, and 2**8 = 256
+ to get the second level codes. This is mostly true for training the top level prior and the upsamplers.
+
+ Args:
+ vqvae_config (`JukeboxVQVAEConfig`, *optional*):
+ Configuration for the `JukeboxVQVAE` model.
+ prior_config_list (`List[JukeboxPriorConfig]`, *optional*):
+ List of the configs for each of the `JukeboxPrior` of the model. The original architecture uses 3 priors.
+ nb_priors (`int`, *optional*, defaults to 3):
+ Number of prior models that will sequentially sample tokens. Each prior is conditional auto regressive
+ (decoder) model, apart from the top prior, which can include a lyric encoder. The available models were
+ trained using a top prior and 2 upsampler priors.
+ sampling_rate (`int`, *optional*, defaults to 44100):
+ Sampling rate of the raw audio.
+ timing_dims (`int`, *optional*, defaults to 64):
+ Dimensions of the JukeboxRangeEmbedding layer which is equivalent to traditional positional embedding
+ layer. The timing embedding layer converts the absolute and relative position in the currently sampled
+ audio to a tensor of length `timing_dims` that will be added to the music tokens.
+ min_duration (`int`, *optional*, defaults to 0):
+ Minimum duration of the audios to generate
+ max_duration (`float`, *optional*, defaults to 600.0):
+ Maximum duration of the audios to generate
+ max_nb_genres (`int`, *optional*, defaults to 5):
+ Maximum number of genres that can be used to condition a single sample.
+ metadata_conditioning (`bool`, *optional*, defaults to `True`):
+ Whether or not to use metadata conditioning, corresponding to the artist, the genre and the min/maximum
+ duration.
+
+ Example:
+
+ ```python
+ >>> from transformers import JukeboxModel, JukeboxConfig
+
+ >>> # Initializing a Jukebox configuration
+ >>> configuration = JukeboxConfig()
+
+ >>> # Initializing a model from the configuration
+ >>> model = JukeboxModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```
+ """
+
+ model_type = "jukebox"
+
+ def __init__(
+ self,
+ vqvae_config=None,
+ prior_config_list=None,
+ nb_priors=3,
+ sampling_rate=44100,
+ timing_dims=64,
+ min_duration=0,
+ max_duration=600.0,
+ max_nb_genres=5,
+ metadata_conditioning=True,
+ **kwargs,
+ ):
+ if vqvae_config is None:
+ vqvae_config = {}
+ logger.info("vqvae_config is None. initializing the JukeboxVQVAE with default values.")
+
+ self.vqvae_config = JukeboxVQVAEConfig(**vqvae_config)
+ if prior_config_list is not None:
+ self.prior_configs = [JukeboxPriorConfig(**prior_config) for prior_config in prior_config_list]
+ else:
+ self.prior_configs = []
+ for prior_idx in range(nb_priors):
+ prior_config = kwargs.pop(f"prior_{prior_idx}", None)
+ if prior_config is None:
+ prior_config = {}
+ logger.info(
+ f"prior_{prior_idx}'s config is None. Initializing the JukeboxPriorConfig list with default"
+ " values."
+ )
+ self.prior_configs.append(JukeboxPriorConfig(**prior_config))
+
+ self.hop_fraction = self.vqvae_config.hop_fraction
+
+ self.nb_priors = nb_priors
+
+ # Metadata conditioning
+ self.max_nb_genres = max_nb_genres
+ self.sampling_rate = sampling_rate
+ self.timing_dims = timing_dims
+ self.min_duration = min_duration
+ self.max_duration = max_duration
+ self.metadata_conditioning = metadata_conditioning
+
+ super().__init__(**kwargs)
+
+ @classmethod
+ def from_configs(cls, prior_configs: List[JukeboxPriorConfig], vqvae_config: JukeboxVQVAEConfig, **kwargs):
+ r"""
+ Instantiate a [`JukeboxConfig`] (or a derived class) from clip text model configuration and clip vision model
+ configuration.
+
+ Returns:
+ [`JukeboxConfig`]: An instance of a configuration object
+ """
+ prior_config_list = [config.to_dict() for config in prior_configs]
+ return cls(prior_config_list=prior_config_list, vqvae_config_dict=vqvae_config.to_dict(), **kwargs)
+
+ def to_dict(self):
+ # Override the default to_dict to apply to_dict to the list of prior configs.
+ result = super().to_dict()
+ result["prior_config_list"] = [config.to_dict() for config in result.pop("prior_configs")]
+ return result
+
+
+__all__ = ["JukeboxConfig", "JukeboxPriorConfig", "JukeboxVQVAEConfig"]
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/jukebox/convert_jukebox.py b/docs/transformers/build/lib/transformers/models/deprecated/jukebox/convert_jukebox.py
new file mode 100644
index 0000000000000000000000000000000000000000..aac3b2efe733bd9f0c4eefb2d5442e15427c9347
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/jukebox/convert_jukebox.py
@@ -0,0 +1,279 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Convert Jukebox checkpoints"""
+
+import argparse
+import json
+import os
+from pathlib import Path
+
+import requests
+import torch
+
+from transformers import JukeboxConfig, JukeboxModel
+from transformers.utils import logging
+
+
+logging.set_verbosity_info()
+logger = logging.get_logger(__name__)
+
+
+PREFIX = "https://openaipublic.azureedge.net/jukebox/models/"
+MODEL_MAPPING = {
+ "jukebox-1b-lyrics": [
+ "5b/vqvae.pth.tar",
+ "5b/prior_level_0.pth.tar",
+ "5b/prior_level_1.pth.tar",
+ "1b_lyrics/prior_level_2.pth.tar",
+ ],
+ "jukebox-5b-lyrics": [
+ "5b/vqvae.pth.tar",
+ "5b/prior_level_0.pth.tar",
+ "5b/prior_level_1.pth.tar",
+ "5b_lyrics/prior_level_2.pth.tar",
+ ],
+}
+
+
+def replace_key(key):
+ if key.endswith(".model.1.bias") and len(key.split(".")) > 10:
+ key = key.replace(".model.1.bias", ".conv1d_1.bias")
+ elif key.endswith(".model.1.weight") and len(key.split(".")) > 10:
+ key = key.replace(".model.1.weight", ".conv1d_1.weight")
+ elif key.endswith(".model.3.bias") and len(key.split(".")) > 10:
+ key = key.replace(".model.3.bias", ".conv1d_2.bias")
+ elif key.endswith(".model.3.weight") and len(key.split(".")) > 10:
+ key = key.replace(".model.3.weight", ".conv1d_2.weight")
+
+ if "conditioner_blocks.0." in key:
+ key = key.replace("conditioner_blocks.0", "conditioner_blocks")
+
+ if "prime_prior" in key:
+ key = key.replace("prime_prior", "encoder")
+
+ if ".emb." in key and "total" not in key and "absolute" not in key and "relative" not in key:
+ key = key.replace(".emb.", ".")
+
+ if key.endswith("k"): # replace vqvae.X.k with vqvae.X.codebook
+ return key.replace(".k", ".codebook")
+ if "y_emb." in key:
+ return key.replace("y_emb.", "metadata_embedding.")
+
+ if "x_emb.emb." in key:
+ key = key.replace("0.x_emb.emb", "embed_tokens")
+
+ if "prime_state_ln" in key:
+ return key.replace("prime_state_ln", "encoder.final_layer_norm")
+ if ".ln" in key:
+ return key.replace(".ln", ".layer_norm")
+ if "_ln" in key:
+ return key.replace("_ln", "_layer_norm")
+
+ if "prime_state_proj" in key:
+ return key.replace("prime_state_proj", "encoder.proj_in")
+ if "prime_x_out" in key:
+ return key.replace("prime_x_out", "encoder.lm_head")
+ if "prior.x_out" in key:
+ return key.replace("x_out", "fc_proj_out")
+ if "x_emb" in key:
+ return key.replace("x_emb", "embed_tokens")
+
+ return key
+
+
+def fix_jukebox_keys(state_dict, model_state_dict, key_prefix, mapping):
+ new_dict = {}
+ import re
+
+ re_encoder_block_conv_in = re.compile(r"encoders.(\d*).level_blocks.(\d*).model.(\d*).(\d).(bias|weight)")
+ re_encoder_block_resnet = re.compile(
+ r"encoders.(\d*).level_blocks.(\d*).model.(\d*).(\d).model.(\d*).model.(\d*).(bias|weight)"
+ )
+ re_encoder_block_proj_out = re.compile(r"encoders.(\d*).level_blocks.(\d*).model.(\d*).(bias|weight)")
+
+ re_decoder_block_conv_out = re.compile(r"decoders.(\d*).level_blocks.(\d*).model.(\d*).(\d).(bias|weight)")
+ re_decoder_block_resnet = re.compile(
+ r"decoders.(\d*).level_blocks.(\d*).model.(\d*).(\d).model.(\d*).model.(\d*).(bias|weight)"
+ )
+ re_decoder_block_proj_in = re.compile(r"decoders.(\d*).level_blocks.(\d*).model.(\d*).(bias|weight)")
+
+ re_prior_cond_conv_out = re.compile(r"conditioner_blocks.(\d*).cond.model.(\d*).(\d).(bias|weight)")
+ re_prior_cond_resnet = re.compile(
+ r"conditioner_blocks.(\d*).cond.model.(\d*).(\d).model.(\d*).model.(\d*).(bias|weight)"
+ )
+ re_prior_cond_proj_in = re.compile(r"conditioner_blocks.(\d*).cond.model.(\d*).(bias|weight)")
+
+ for original_key, value in state_dict.items():
+ # rename vqvae.encoder keys
+ if re_encoder_block_conv_in.fullmatch(original_key):
+ regex_match = re_encoder_block_conv_in.match(original_key)
+ groups = regex_match.groups()
+ block_index = int(groups[2]) * 2 + int(groups[3])
+ re_new_key = f"encoders.{groups[0]}.level_blocks.{groups[1]}.downsample_block.{block_index}.{groups[-1]}"
+ key = re_encoder_block_conv_in.sub(re_new_key, original_key)
+
+ elif re_encoder_block_resnet.fullmatch(original_key):
+ regex_match = re_encoder_block_resnet.match(original_key)
+ groups = regex_match.groups()
+ block_index = int(groups[2]) * 2 + int(groups[3])
+ conv_index = {"1": 1, "3": 2}[groups[-2]]
+ prefix = f"encoders.{groups[0]}.level_blocks.{groups[1]}.downsample_block.{block_index}."
+ resnet_block = f"resnet_block.{groups[-3]}.conv1d_{conv_index}.{groups[-1]}"
+ re_new_key = prefix + resnet_block
+ key = re_encoder_block_resnet.sub(re_new_key, original_key)
+
+ elif re_encoder_block_proj_out.fullmatch(original_key):
+ regex_match = re_encoder_block_proj_out.match(original_key)
+ groups = regex_match.groups()
+ re_new_key = f"encoders.{groups[0]}.level_blocks.{groups[1]}.proj_out.{groups[-1]}"
+ key = re_encoder_block_proj_out.sub(re_new_key, original_key)
+
+ # rename vqvae.decoder keys
+ elif re_decoder_block_conv_out.fullmatch(original_key):
+ regex_match = re_decoder_block_conv_out.match(original_key)
+ groups = regex_match.groups()
+ block_index = int(groups[2]) * 2 + int(groups[3]) - 2
+ re_new_key = f"decoders.{groups[0]}.level_blocks.{groups[1]}.upsample_block.{block_index}.{groups[-1]}"
+ key = re_decoder_block_conv_out.sub(re_new_key, original_key)
+
+ elif re_decoder_block_resnet.fullmatch(original_key):
+ regex_match = re_decoder_block_resnet.match(original_key)
+ groups = regex_match.groups()
+ block_index = int(groups[2]) * 2 + int(groups[3]) - 2
+ conv_index = {"1": 1, "3": 2}[groups[-2]]
+ prefix = f"decoders.{groups[0]}.level_blocks.{groups[1]}.upsample_block.{block_index}."
+ resnet_block = f"resnet_block.{groups[-3]}.conv1d_{conv_index}.{groups[-1]}"
+ re_new_key = prefix + resnet_block
+ key = re_decoder_block_resnet.sub(re_new_key, original_key)
+
+ elif re_decoder_block_proj_in.fullmatch(original_key):
+ regex_match = re_decoder_block_proj_in.match(original_key)
+ groups = regex_match.groups()
+ re_new_key = f"decoders.{groups[0]}.level_blocks.{groups[1]}.proj_in.{groups[-1]}"
+ key = re_decoder_block_proj_in.sub(re_new_key, original_key)
+
+ # rename prior cond.model to upsampler.upsample_block and resnet
+ elif re_prior_cond_conv_out.fullmatch(original_key):
+ regex_match = re_prior_cond_conv_out.match(original_key)
+ groups = regex_match.groups()
+ block_index = int(groups[1]) * 2 + int(groups[2]) - 2
+ re_new_key = f"conditioner_blocks.upsampler.upsample_block.{block_index}.{groups[-1]}"
+ key = re_prior_cond_conv_out.sub(re_new_key, original_key)
+
+ elif re_prior_cond_resnet.fullmatch(original_key):
+ regex_match = re_prior_cond_resnet.match(original_key)
+ groups = regex_match.groups()
+ block_index = int(groups[1]) * 2 + int(groups[2]) - 2
+ conv_index = {"1": 1, "3": 2}[groups[-2]]
+ prefix = f"conditioner_blocks.upsampler.upsample_block.{block_index}."
+ resnet_block = f"resnet_block.{groups[-3]}.conv1d_{conv_index}.{groups[-1]}"
+ re_new_key = prefix + resnet_block
+ key = re_prior_cond_resnet.sub(re_new_key, original_key)
+
+ elif re_prior_cond_proj_in.fullmatch(original_key):
+ regex_match = re_prior_cond_proj_in.match(original_key)
+ groups = regex_match.groups()
+ re_new_key = f"conditioner_blocks.upsampler.proj_in.{groups[-1]}"
+ key = re_prior_cond_proj_in.sub(re_new_key, original_key)
+
+ # keep original key
+ else:
+ key = original_key
+
+ key = replace_key(key)
+
+ if f"{key_prefix}.{key}" not in model_state_dict or key is None:
+ print(f"failed converting {original_key} to {key}, does not match")
+
+ # handle missmatched shape
+ elif value.shape != model_state_dict[f"{key_prefix}.{key}"].shape:
+ val = model_state_dict[f"{key_prefix}.{key}"]
+ print(f"{original_key}-> {key} : \nshape {val.shape} and {value.shape}, do not match")
+ key = original_key
+
+ mapping[key] = original_key
+ new_dict[key] = value
+
+ return new_dict
+
+
+@torch.no_grad()
+def convert_openai_checkpoint(model_name=None, pytorch_dump_folder_path=None):
+ """
+ Copy/paste/tweak model's weights to our Jukebox structure.
+ """
+ for file in MODEL_MAPPING[model_name]:
+ if not os.path.isfile(f"{pytorch_dump_folder_path}/{file.split('/')[-1]}"):
+ r = requests.get(f"{PREFIX}{file}", allow_redirects=True)
+ os.makedirs(f"{pytorch_dump_folder_path}/", exist_ok=True)
+ open(f"{pytorch_dump_folder_path}/{file.split('/')[-1]}", "wb").write(r.content)
+
+ model_to_convert = MODEL_MAPPING[model_name.split("/")[-1]]
+
+ config = JukeboxConfig.from_pretrained(model_name)
+ model = JukeboxModel(config)
+
+ weight_dict = []
+ mapping = {}
+ for i, dict_name in enumerate(model_to_convert):
+ old_dic = torch.load(f"{pytorch_dump_folder_path}/{dict_name.split('/')[-1]}", weights_only=True)["model"]
+
+ new_dic = {}
+ for k in old_dic.keys():
+ if k.endswith(".b"):
+ new_dic[k.replace("b", "bias")] = old_dic[k]
+ elif k.endswith(".w"):
+ new_dic[k.replace("w", "weight")] = old_dic[k]
+ elif "level_2" not in dict_name and "cond.model." in k:
+ new_dic[k.replace(".blocks.", ".model.")] = old_dic[k]
+ else:
+ new_dic[k] = old_dic[k]
+
+ key_prefix = "vqvae" if i == 0 else f"priors.{3 - i}"
+ new_dic = fix_jukebox_keys(new_dic, model.state_dict(), key_prefix, mapping)
+ weight_dict.append(new_dic)
+
+ vqvae_state_dict = weight_dict.pop(0)
+ model.vqvae.load_state_dict(vqvae_state_dict)
+ for i in range(len(weight_dict)):
+ model.priors[i].load_state_dict(weight_dict[2 - i])
+
+ Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
+ with open(f"{pytorch_dump_folder_path}/mapping.json", "w") as txtfile:
+ json.dump(mapping, txtfile)
+
+ print(f"Saving model {model_name} to {pytorch_dump_folder_path}")
+ model.save_pretrained(pytorch_dump_folder_path)
+
+ return weight_dict
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ # Required parameters
+ parser.add_argument(
+ "--model_name",
+ default="jukebox-5b-lyrics",
+ type=str,
+ help="Name of the model you'd like to convert.",
+ )
+ parser.add_argument(
+ "--pytorch_dump_folder_path",
+ default="jukebox-5b-lyrics-converted",
+ type=str,
+ help="Path to the output PyTorch model directory.",
+ )
+ args = parser.parse_args()
+ convert_openai_checkpoint(args.model_name, args.pytorch_dump_folder_path)
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/jukebox/modeling_jukebox.py b/docs/transformers/build/lib/transformers/models/deprecated/jukebox/modeling_jukebox.py
new file mode 100644
index 0000000000000000000000000000000000000000..566148ceda36b3581c653c2c29ec7265d2582135
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/jukebox/modeling_jukebox.py
@@ -0,0 +1,2670 @@
+# coding=utf-8
+# Copyright 2022 The OpenAI Team Authors and HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Jukebox model."""
+
+import math
+import os
+from typing import List, Optional, Tuple
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn
+from torch.nn import LayerNorm as FusedLayerNorm
+
+from ....activations import ACT2FN
+from ....modeling_utils import PreTrainedModel
+from ....utils import add_start_docstrings, logging
+from ....utils.logging import tqdm
+from .configuration_jukebox import ATTENTION_PATTERNS, JukeboxConfig, JukeboxPriorConfig, JukeboxVQVAEConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+def filter_logits(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")):
+ """
+ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
+
+ Args:
+ logits (`torch.Tensor`):
+ logits distribution shape (vocabulary size)
+ top_k (`int`, *optional*, defaults to 0):
+ When `top_k >0` keep only top key tokens with highest probability (top-k filtering).
+ top_p (`int`, *optional*, defaults to 0):
+ When `top_p>0.0` keep the top tokens with cumulative probability >= `top_p` (nucleus filtering).
+ """
+ logits = logits.clone()
+ top_k = min(top_k, logits.size(-1)) # Safety check
+
+ if top_k > 0:
+ # Remove all tokens with a probability less than the last token of the top-k
+ indices_to_remove = logits < torch.topk(logits, top_k, dim=-1)[0][..., -1:]
+ logits[indices_to_remove] = filter_value
+
+ if top_p > 0.0:
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
+
+ # Remove tokens with cumulative probability above the threshold
+ sorted_indices_to_remove = cumulative_probs > top_p
+ # Shift the indices to the right to keep also the first token above the threshold
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
+ sorted_indices_to_remove[..., 0] = 0
+
+ # indices_to_remove = sorted_indices[sorted_indices_to_remove]
+ indices_to_remove = torch.zeros_like(logits, dtype=torch.bool).scatter_(
+ dim=-1, index=sorted_indices, src=sorted_indices_to_remove
+ )
+ logits[indices_to_remove] = filter_value
+ return logits
+
+
+def get_relevant_lyric_tokens(full_tokens, max_n_lyric_tokens, total_length, offset, duration):
+ """
+ Extract only the relevant tokens based on the character position. A total of `max_n_lyric_tokens` tokens will be
+ returned. If the provided token sequence is smaller, it will be padded, otherwise, only characters ranging from the
+ midpoint - `max_n_lyric_tokens//2` to the midpoint + `max_n_lyric_tokens//2` will be returned. This *focuses* on
+ the most relevant tokens (in time) for the sequence.
+
+ Args:
+ full_tokens (`List[int]`):
+ List containing the token ids of the entire lyrics.
+ total_length (`int`):
+ Total expected length of the music (not all of it is generated, see duration), in samples.
+ offset (`int`):
+ Starting sample in the music. If the offset is greater than 0, the lyrics will be shifted take that into
+ account
+ duration (`int`):
+ Expected duration of the generated music, in samples. The duration has to be smaller than the total length,
+ which represent the overall length of the signal,
+ """
+ full_tokens = full_tokens[0]
+ if len(full_tokens) < max_n_lyric_tokens:
+ tokens = torch.cat(
+ [torch.zeros(max_n_lyric_tokens - len(full_tokens), dtype=torch.long).to(full_tokens.device), full_tokens]
+ )
+ indices = [-1] * (max_n_lyric_tokens - len(full_tokens)) + list(range(0, len(full_tokens)))
+ else:
+ midpoint = int(len(full_tokens) * (offset + duration / 2.0) / total_length)
+ midpoint = min(max(midpoint, max_n_lyric_tokens // 2), len(full_tokens) - max_n_lyric_tokens // 2)
+ tokens = full_tokens[midpoint - max_n_lyric_tokens // 2 : midpoint + max_n_lyric_tokens // 2]
+ indices = list(range(midpoint - max_n_lyric_tokens // 2, midpoint + max_n_lyric_tokens // 2))
+ return tokens.unsqueeze(dim=0), indices
+
+
+# Break total_length into hops/windows of size n_ctx separated by hop_length
+def get_starts(total_length, n_ctx, hop_length):
+ starts = []
+ for start in range(0, total_length - n_ctx + hop_length, hop_length):
+ if start + n_ctx >= total_length:
+ # Last hop could be smaller, we make it n_ctx to maximise context
+ start = total_length - n_ctx
+ starts.append(start)
+ return starts
+
+
+def get_alignment(music_tokens, labels, prior, config):
+ level = prior.levels - 1 # Top level used
+ n_ctx = prior.n_ctx
+ tokens = music_tokens[level]
+ batch_size, total_length = tokens.shape[0], tokens.shape[1]
+ if total_length < n_ctx:
+ padding_length = n_ctx - total_length
+ tokens = torch.cat(
+ [tokens, torch.zeros(batch_size, n_ctx - total_length, dtype=tokens.dtype, device=tokens.device)], dim=1
+ )
+ total_length = tokens.shape[1]
+ else:
+ padding_length = 0
+
+ hop_length = int(config.hop_fraction[-level - 1] * prior.n_ctx)
+ alignment_head, alignment_layer = config.prior_alignment_head[0], config.prior_alignment_layer[0]
+ attn_layers = {alignment_layer}
+ alignment_hops = {}
+ indices_hops = {}
+ for start in tqdm(get_starts(total_length, n_ctx, hop_length), desc="Computing lyric to music alignment "):
+ end = start + n_ctx
+ # set metadata offset, sample_length and lyrics tokens
+ metadata, indices_hop = prior.get_metadata(labels, start, config.sample_length, get_indices=True, offset=0)
+ tokens_bs = torch.chunk(tokens, batch_size, dim=0)
+ metadata_bs = torch.chunk(metadata, batch_size, dim=0)
+ w_hops = []
+ for tokens_i, metadata_i in zip(tokens_bs, metadata_bs):
+ w_hop = prior.forward_tokens(tokens_i[:, start:end], [], metadata_i, get_attn_weights=attn_layers)
+ w_hops.append(w_hop[0][:, alignment_head])
+ del w_hop
+ weights = torch.cat(w_hops, dim=0)
+ del w_hops
+ alignment_hop = weights.to(device="cpu", dtype=torch.float).numpy()
+ del weights
+
+ # alignment_hop has shape (bs, n_ctx, nb_relevant_lyric_tokens)
+ # indices_hop is a list of len=bs, each entry of len hps.nb_relevant_lyric_tokens
+ indices_hops[start] = indices_hop
+ alignment_hops[start] = alignment_hop
+
+ # Combine attn for each hop into attn for full range
+ # Use indices to place them into correct place for corresponding source tokens
+ alignments = []
+ for item in range(batch_size):
+ # Note each item has different length lyrics
+ full_tokens = labels[0, 3:]
+ alignment = np.zeros((total_length, len(full_tokens) + 1))
+ for start in reversed(get_starts(total_length, n_ctx, hop_length)):
+ end = start + n_ctx
+ alignment_hop = alignment_hops[start][item]
+ indices = indices_hops[start][item]
+ alignment[start:end, indices] = alignment_hop
+ alignment = alignment[: total_length - padding_length, :-1] # remove token padding, and last lyric index
+ alignments.append(alignment)
+ return alignments
+
+
+def save_temp_audio(fname, lvl, metas, aud):
+ aud = torch.clamp(aud, -1, 1).cpu().numpy()
+ for i in list(range(aud.shape[0])):
+ if metas is not None:
+ artists, genres, lyrics = list(metas)[i].values()
+ path = f"{fname}/lvl_{lvl}-{artists}-{genres}-{lyrics[:5]}-{i}"
+ np.save(path, aud[i])
+ else:
+ np.save(f"{fname}/lvl_{lvl}-sample-{i}", aud[i])
+
+
+def get_mask(mask, query_length, key_value_length, blocks, spread, device, sample, sample_t):
+ # returns a mask of shape 1 x 1 x query_length x key_value_length or None if masking is not needed.
+ if mask is None or query_length == 1:
+ return None
+ offset = sample_t - query_length if sample else max(key_value_length - query_length, 0)
+ if mask == "autoregressive":
+ # Masked dense
+ mask = torch.ones(query_length, key_value_length, device=device).tril(offset)
+ elif mask == "summary":
+ # Masked summary
+ mask = torch.ones(query_length, query_length, device=device).tril()
+ mask = torch.ones(query_length, query_length, device=device).tril()
+ mask = mask.view(query_length, blocks, query_length // blocks)[:, :-1, -key_value_length // blocks :]
+ mask = (
+ torch.nn.functional.pad(
+ mask,
+ (0, 0, 1, 0),
+ value=1,
+ )
+ .contiguous()
+ .view(query_length, key_value_length)
+ )
+ elif mask == "prime":
+ mask = torch.ones(query_length, key_value_length, device=device).tril(offset)
+ return mask.view(1, 1, query_length, key_value_length)
+
+
+class JukeboxConv1D(nn.Module):
+ def __init__(self, input_width, output_width):
+ super().__init__()
+ self.input_width = input_width
+ self.output_width = output_width
+ weight = torch.empty(input_width, output_width)
+ bias = torch.zeros(output_width)
+ self.weight = nn.Parameter(weight)
+ self.bias = nn.Parameter(bias)
+
+ def forward(self, hidden_states):
+ size_out = (*hidden_states.size()[:-1], self.output_width)
+ hidden_states = torch.addmm(
+ self.bias.type_as(hidden_states),
+ hidden_states.view(-1, hidden_states.size(-1)),
+ self.weight.type_as(hidden_states),
+ )
+ hidden_states = hidden_states.view(*size_out)
+ return hidden_states
+
+
+class JukeboxResConv1DBlock(nn.Module):
+ def __init__(self, config, conv_width, depth=1, res_scale=1.0):
+ super().__init__()
+ hidden_dim = config.res_convolution_multiplier * conv_width
+ dilation = config.res_dilation_growth_rate**depth
+ padding = dilation
+
+ self.res_scale = res_scale
+ self.activation = nn.ReLU()
+ self.conv1d_1 = nn.Conv1d(conv_width, hidden_dim, 3, 1, padding, dilation)
+ self.conv1d_2 = nn.Conv1d(hidden_dim, conv_width, 1, 1, 0)
+
+ def forward(self, hidden_states):
+ residuals = hidden_states
+ hidden_states = self.activation(hidden_states)
+ hidden_states = self.conv1d_1(hidden_states)
+ hidden_states = self.activation(hidden_states)
+ hidden_states = self.conv1d_2(hidden_states)
+ return residuals + self.res_scale * hidden_states
+
+
+class JukeboxResnet1D(nn.Module):
+ def __init__(self, config, conv_width, n_depth, reverse_dilation=False):
+ super().__init__()
+ self.dilation_cycle = config.res_dilation_cycle
+ res_scale = 1.0 if not config.conv_res_scale else 1.0 / math.sqrt(n_depth)
+
+ blocks = []
+ for depth in range(n_depth):
+ block_depth = depth if self.dilation_cycle is None else depth % self.dilation_cycle
+ blocks.append(JukeboxResConv1DBlock(config, conv_width, block_depth, res_scale))
+
+ if reverse_dilation:
+ blocks = blocks[::-1]
+ self.resnet_block = nn.ModuleList(blocks)
+
+ def forward(self, hidden_states):
+ for block in self.resnet_block:
+ hidden_states = block(hidden_states)
+ return hidden_states
+
+
+class JukeboxEncoderConvBlock(nn.Module):
+ def __init__(self, config, embed_dim, hidden_dim, depth, down_t, stride_t):
+ super().__init__()
+ blocks = []
+ filter_t = stride_t * 2
+ pad_t = stride_t // 2
+ if down_t > 0:
+ for i in range(down_t):
+ blocks.append(nn.Conv1d(embed_dim if i == 0 else hidden_dim, hidden_dim, filter_t, stride_t, pad_t))
+ blocks.append(JukeboxResnet1D(config, hidden_dim, depth))
+ self.proj_out = nn.Conv1d(hidden_dim, config.embed_dim, 3, 1, 1)
+ self.downsample_block = nn.ModuleList(blocks)
+
+ def forward(self, hidden_states):
+ for block in self.downsample_block:
+ hidden_states = block(hidden_states)
+ hidden_states = self.proj_out(hidden_states)
+ return hidden_states
+
+
+class JukeboxEncoder(nn.Module):
+ def __init__(self, config, width, depth, levels, downs_t, strides_t):
+ super().__init__()
+ self.levels = levels
+ self.level_blocks = nn.ModuleList()
+
+ iterator = zip(list(range(self.levels)), downs_t, strides_t)
+ for i, down_t, stride_t in iterator:
+ self.level_blocks.append(
+ JukeboxEncoderConvBlock(
+ config, config.conv_input_shape if i == 0 else config.embed_dim, width, depth, down_t, stride_t
+ )
+ )
+
+ def forward(self, hidden_states):
+ all_hidden_states = []
+
+ # 64, 32, ...
+ for level in range(self.levels):
+ level_block = self.level_blocks[level]
+ hidden_states = level_block(hidden_states)
+ all_hidden_states.append(hidden_states)
+
+ return all_hidden_states
+
+
+class JukeboxDecoderConvBock(nn.Module):
+ def __init__(self, config, embed_dim, hidden_dim, depth, down_t, stride_t, reverse_dilation=True):
+ self.embed_dim = embed_dim
+ self.hidden_dim = hidden_dim
+ super().__init__()
+ blocks = []
+ if down_t > 0:
+ filter_t = stride_t * 2
+ pad_t = stride_t // 2
+ self.proj_in = nn.Conv1d(embed_dim, hidden_dim, 3, 1, 1)
+ for i in range(down_t):
+ blocks.append(JukeboxResnet1D(config, hidden_dim, depth, reverse_dilation))
+ blocks.append(
+ nn.ConvTranspose1d(
+ hidden_dim, hidden_dim if i < down_t - 1 else embed_dim, filter_t, stride_t, pad_t
+ )
+ )
+ self.upsample_block = nn.ModuleList(blocks)
+
+ def forward(self, hidden_states):
+ hidden_states = self.proj_in(hidden_states)
+ for block in self.upsample_block:
+ hidden_states = block(hidden_states)
+ return hidden_states
+
+
+class JukeboxDecoder(nn.Module):
+ def __init__(self, config, hidden_dim, depth, levels, downs_t, strides_t):
+ super().__init__()
+ self.levels = levels
+ self.level_blocks = nn.ModuleList()
+ for level, down_t, stride_t in zip(list(range(self.levels)), downs_t, strides_t):
+ self.level_blocks.append(
+ JukeboxDecoderConvBock(config, config.embed_dim, hidden_dim, depth, down_t, stride_t)
+ )
+
+ self.out = nn.Conv1d(config.embed_dim, config.conv_input_shape, 3, 1, 1)
+
+ def forward(self, hidden_states, all_levels=True):
+ hidden_state = hidden_states[-1]
+
+ # 32, 64 ...
+ for level in reversed(range(self.levels)):
+ level_block = self.level_blocks[level]
+ hidden_state = level_block(hidden_state)
+
+ if level != 0 and all_levels:
+ hidden_state = hidden_state + hidden_states[level - 1]
+
+ hidden_state = self.out(hidden_state)
+ return hidden_state
+
+
+class JukeboxBottleneckBlock(nn.Module):
+ def __init__(self, config: JukeboxVQVAEConfig):
+ super().__init__()
+ self.nb_discrete_codes = config.nb_discrete_codes
+ self.codebook_width = config.embed_dim
+ self.mu = config.lmu
+ self.threshold = 1.0
+ self.init = False
+ self.codebook_sum = None
+ self.codebook_elem = None
+ self.register_buffer("codebook", torch.zeros(self.nb_discrete_codes, self.codebook_width))
+
+ def _tile(self, hidden_states):
+ dim, embed_width = hidden_states.shape
+ if dim < self.nb_discrete_codes:
+ n_repeats = (self.nb_discrete_codes + dim - 1) // dim
+ std = 0.01 / np.sqrt(embed_width)
+ hidden_states = hidden_states.repeat(n_repeats, 1)
+ hidden_states = hidden_states + torch.randn_like(hidden_states) * std
+ return hidden_states
+
+ def init_codebook(self, hidden_states):
+ nb_discrete_codes = self.nb_discrete_codes
+ self.init = True
+ codes = self._tile(hidden_states)
+ self.codebook = codes[torch.randperm(codes.shape[0])][:nb_discrete_codes]
+ self.codebook_sum = self.codebook
+ self.codebook_elem = torch.ones(nb_discrete_codes, device=self.codebook.device)
+
+ def update_codebook(self, hidden_states, latent_states):
+ mu, codebook_width, nb_discrete_codes = self.mu, self.codebook_width, self.nb_discrete_codes
+ with torch.no_grad():
+ # Calculate new centres
+ # nb_discrete_codes, batch_size * seq_length
+ latent_states_onehot = torch.zeros(nb_discrete_codes, hidden_states.shape[0], device=hidden_states.device)
+ latent_states_onehot.scatter_(0, latent_states.view(1, hidden_states.shape[0]), 1)
+
+ _codebook_sum = torch.matmul(latent_states_onehot, hidden_states)
+ _codebook_elem = latent_states_onehot.sum(dim=-1) # nb_discrete_codes
+ codes = self._tile(hidden_states)
+ _random_codebook = codes[torch.randperm(codes.shape[0])][:nb_discrete_codes]
+
+ # Update centres
+ old_codebook = self.codebook
+ self.codebook_sum = mu * self.codebook_sum + (1.0 - mu) * _codebook_sum
+ self.codebook_elem = mu * self.codebook_elem + (1.0 - mu) * _codebook_elem # nb_discrete_codes
+ usage = (self.codebook_elem.view(nb_discrete_codes, 1) >= self.threshold).float()
+
+ norm_code = self.codebook_sum.view(nb_discrete_codes, codebook_width) / self.codebook_elem.view(
+ nb_discrete_codes, 1
+ )
+ self.codebook = usage * (norm_code) + (1 - usage) * _random_codebook
+ _codebook_prob = _codebook_elem / torch.sum(_codebook_elem) # prob of each bin
+ entropy = -torch.sum(_codebook_prob * torch.log(_codebook_prob + 1e-8)) # entropy ie how diverse
+ used_curr = (_codebook_elem >= self.threshold).sum()
+ usage = torch.sum(usage)
+ dk = torch.linalg.norm(self.codebook - old_codebook) / np.sqrt(np.prod(old_codebook.shape))
+ return {"entropy": entropy, "used_curr": used_curr, "usage": usage, "dk": dk}
+
+ def preprocess(self, hidden_states):
+ hidden_states = hidden_states.permute(0, 2, 1).contiguous()
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
+
+ if hidden_states.shape[-1] == self.codebook_width:
+ prenorm = torch.linalg.norm(hidden_states - torch.mean(hidden_states)) / np.sqrt(
+ np.prod(hidden_states.shape)
+ )
+ elif hidden_states.shape[-1] == 2 * self.codebook_width:
+ x1, x2 = hidden_states[..., : self.codebook_width], hidden_states[..., self.codebook_width :]
+ prenorm = (torch.linalg.norm(x1 - torch.mean(x1)) / np.sqrt(np.prod(x1.shape))) + (
+ torch.linalg.norm(x2 - torch.mean(x2)) / np.sqrt(np.prod(x2.shape))
+ )
+
+ # Normalise
+ hidden_states = x1 + x2
+
+ return hidden_states, prenorm
+
+ def postprocess(self, latent_states, dequantised_states, x_shape):
+ batch_size, time = x_shape
+ dequantised_states = dequantised_states.view(batch_size, time, -1).permute(0, 2, 1).contiguous()
+ latent_states = latent_states.view(batch_size, time)
+ return latent_states, dequantised_states
+
+ def quantise(self, latent_states):
+ # Calculate latent code latent_states
+ codebook_weights = self.codebook.t()
+ distance = (
+ torch.sum(latent_states**2, dim=-1, keepdim=True)
+ - 2 * torch.matmul(latent_states, codebook_weights)
+ + torch.sum(codebook_weights**2, dim=0, keepdim=True)
+ ) # (batch_size * latent_states , codebook_weights)
+ min_distance, music_tokens = torch.min(distance, dim=-1)
+ fit = torch.mean(min_distance)
+ return music_tokens, fit
+
+ def dequantise(self, music_tokens):
+ dequantised_states = F.embedding(music_tokens, self.codebook)
+ return dequantised_states
+
+ def encode(self, latent_states):
+ samples, _, seq_len = latent_states.shape
+
+ # Preprocess.
+ latent_states, _ = self.preprocess(latent_states)
+
+ # Quantise
+ music_tokens, _ = self.quantise(latent_states)
+
+ # Postprocess.
+ music_tokens = music_tokens.view(samples, seq_len)
+ return music_tokens
+
+ def decode(self, music_tokens):
+ samples, seq_len = music_tokens.shape
+
+ # Dequantise
+ dequantised_states = self.dequantise(music_tokens)
+
+ # Postprocess
+ dequantised_states = (
+ dequantised_states.view(samples, seq_len, self.codebook_width).permute(0, 2, 1).contiguous()
+ )
+ return dequantised_states
+
+ def forward(self, hidden_states, update_codebook=True):
+ samples, _, seq_len = hidden_states.shape
+
+ # Preprocess
+ hidden_states, prenorm = self.preprocess(hidden_states)
+
+ # Init codebook if not inited
+ if update_codebook and not self.init:
+ self.init_codebook(hidden_states)
+
+ # Quantise and dequantise through bottleneck
+ music_tokens, fit = self.quantise(hidden_states)
+ dequantised_states = self.dequantise(music_tokens)
+
+ # Update embeddings
+ if update_codebook:
+ update_metrics = self.update_codebook(hidden_states, music_tokens)
+ else:
+ update_metrics = {}
+
+ # Loss
+ commit_loss = torch.linalg.norm(dequantised_states.detach() - hidden_states) ** 2 / np.prod(
+ hidden_states.shape
+ )
+
+ # Passthrough
+ dequantised_states = hidden_states + (dequantised_states - hidden_states).detach()
+
+ # Postprocess
+ music_tokens, dequantised_states = self.postprocess(music_tokens, dequantised_states, (samples, seq_len))
+ return music_tokens, dequantised_states, commit_loss, dict(fit=fit, pn=prenorm, **update_metrics)
+
+
+class JukeboxBottleneck(nn.Module):
+ def __init__(self, config, levels):
+ super().__init__()
+ self.levels = levels
+ self.level_blocks = nn.ModuleList()
+ for level in range(self.levels):
+ self.level_blocks.append(JukeboxBottleneckBlock(config))
+
+ def encode(self, raw_audio):
+ music_tokens = [
+ level_block.encode(hidden_states) for (level_block, hidden_states) in zip(self.level_blocks, raw_audio)
+ ]
+ return music_tokens
+
+ def decode(self, music_tokens, start_level=0, end_level=None):
+ if end_level is None:
+ end_level = self.levels
+ quantised_audio = [
+ level_block.decode(z) for (level_block, z) in zip(self.level_blocks[start_level:end_level], music_tokens)
+ ]
+ return quantised_audio
+
+ def forward(self, input_audio):
+ music_tokens, quantised_states, commit_losses, metrics = [], [], [], []
+ for level in range(self.levels):
+ level_block = self.level_blocks[-level - 1]
+ hidden_states = input_audio[level]
+ sampled_tokens, quantised_state, commit_loss, metric = level_block(
+ hidden_states, update_codebook=self.training
+ )
+ music_tokens.append(sampled_tokens)
+ if not self.training:
+ # Be extra paranoid and make sure the encoder weights can't
+ # change from straight-through estimator
+ quantised_state = quantised_state.detach()
+ quantised_states.append(quantised_state)
+ commit_losses.append(commit_loss)
+ if self.training:
+ metrics.append(metric)
+ return music_tokens, quantised_states, commit_losses, metrics
+
+
+JUKEBOX_START_DOCSTRING = r"""
+
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config (`JukeboxConfig`): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+@add_start_docstrings(
+ """The Hierarchical VQ-VAE model used in Jukebox. This model follows the Hierarchical VQVAE paper from [Will Williams, Sam
+Ringer, Tom Ash, John Hughes, David MacLeod, Jamie Dougherty](https://arxiv.org/abs/2002.08111).
+
+ """,
+ JUKEBOX_START_DOCSTRING,
+)
+class JukeboxVQVAE(PreTrainedModel):
+ config_class = JukeboxVQVAEConfig
+ base_model_prefix = "vqvae"
+
+ def _init_weights(self, module):
+ if isinstance(module, nn.Embedding): # embed_tokens
+ module.weight.data.normal_(mean=0.0, std=0.02 * self.config.init_scale)
+ elif isinstance(module, JukeboxConv1D):
+ if self.config.zero_out:
+ module.weight.data.zero_()
+ else:
+ module.weight.data.normal_(mean=0.0, std=0.02 * self.config.init_scale)
+ elif isinstance(module, JukeboxResConv1DBlock) and self.config.zero_out:
+ module.conv1d_2.weight.data.zero_()
+ module.conv1d_2.bias.data.zero_()
+ if isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ if isinstance(module, nn.Linear) and module.bias is not None:
+ module.bias.data.zero_()
+
+ def __init__(self, config: JukeboxVQVAEConfig):
+ super().__init__(config)
+ downs_t = config.res_downs_t
+ strides_t = config.res_strides_t
+ if not config.sample_length:
+ downsamples = [stride**down for stride, down in zip(strides_t, downs_t)]
+ top_raw_to_tokens = np.prod(downsamples)
+ config.sample_length = (
+ config.sample_length_in_seconds * config.sampling_rate // top_raw_to_tokens
+ ) * top_raw_to_tokens
+ config.sample_length = config.sample_length.astype(int)
+
+ self.nb_discrete_codes = config.nb_discrete_codes
+ self.commit = config.commit
+ self.sample_length = config.sample_length
+
+ self.downsamples = [stride**down for stride, down in zip(strides_t, downs_t)]
+ self.hop_lengths = np.cumprod(self.downsamples)
+ self.levels = levels = config.levels
+ self.music_tokens_shapes = [
+ (int(self.sample_length // self.hop_lengths[-level - 1])) for level in range(levels)
+ ]
+
+ self.multipliers = config.multipliers if config.multipliers is not None else [1] * levels
+
+ self.encoders = nn.ModuleList()
+ self.decoders = nn.ModuleList()
+ for level in range(levels):
+ width = config.res_conv_width * self.multipliers[level]
+ depth = config.res_conv_depth * self.multipliers[level]
+ self.encoders.append(
+ JukeboxEncoder(config, width, depth, level + 1, downs_t[: level + 1], strides_t[: level + 1])
+ )
+ self.decoders.append(
+ JukeboxDecoder(config, width, depth, level + 1, downs_t[: level + 1], strides_t[: level + 1])
+ )
+
+ self.bottleneck = JukeboxBottleneck(config, levels)
+
+ def _decode(self, music_tokens, start_level=0, end_level=None):
+ # Decode
+ if end_level is None:
+ end_level = self.levels
+ latent_states = self.bottleneck.decode(music_tokens, start_level=start_level, end_level=end_level)
+ # Use only lowest level
+ decoder, dequantised_state = self.decoders[start_level], latent_states[0:1]
+ dequantised_state = decoder(dequantised_state, all_levels=False)
+ dequantised_state = dequantised_state.permute(0, 2, 1)
+ return dequantised_state
+
+ def decode(self, music_tokens, start_level=0, end_level=None, bs_chunks=1) -> torch.Tensor:
+ """
+ Transforms the input `music_tokens` to their `raw_audio` representation.
+
+ Args:
+ music_tokens (`torch.LongTensor`):
+ Tensor of music tokens which will be decoded to raw audio by using the codebook. Each music token
+ should be an index to a corresponding `code` vector in the codebook.
+ start_level (`int`, *optional*):
+ Level at which the decoding process will start. Default to 0.
+ end_level (`int`, *optional*):
+ Level at which the decoding process will start. Default to None.
+ bs_chunks (int, *optional*):
+ Number of chunks to process at the same time.
+ """
+ token_chunks = [torch.chunk(token, bs_chunks, dim=0) for token in music_tokens]
+ dequantised_states = []
+ for i in range(bs_chunks):
+ music_tokens_i = [chunks[i] for chunks in token_chunks]
+ dequantised_state = self._decode(music_tokens_i, start_level=start_level, end_level=end_level)
+ dequantised_states.append(dequantised_state)
+ return torch.cat(dequantised_states, dim=0)
+
+ def _encode(self, raw_audio, start_level=0, end_level=None):
+ # Encode
+ if end_level is None:
+ end_level = self.levels
+ input_audio = raw_audio.permute(0, 2, 1).float()
+ latent_states = []
+ for level in range(self.levels):
+ encoder = self.encoders[level]
+ latent_state = encoder(input_audio)
+ latent_states.append(latent_state[-1])
+ music_tokens = self.bottleneck.encode(latent_states)
+ return music_tokens[start_level:end_level]
+
+ def encode(self, input_audio, start_level=0, end_level=None, bs_chunks=1):
+ """
+ Transforms the `input_audio` to a discrete representation made out of `music_tokens`.
+
+ Args:
+ input_audio (`torch.Tensor`):
+ Raw audio which will be encoded to its discrete representation using the codebook. The closest `code`
+ form the codebook will be computed for each sequence of samples.
+ start_level (`int`, *optional*, defaults to 0):
+ Level at which the encoding process will start. Default to 0.
+ end_level (`int`, *optional*):
+ Level at which the encoding process will start. Default to None.
+ bs_chunks (int, *optional*, defaults to 1):
+ Number of chunks of raw audio to process at the same time.
+ """
+ audio_chunks = torch.chunk(input_audio, bs_chunks, dim=0)
+ music_tokens_list = []
+ for chunk_i in audio_chunks:
+ music_tokens_i = self._encode(chunk_i, start_level=start_level, end_level=end_level)
+ music_tokens_list.append(music_tokens_i)
+ music_tokens = [torch.cat(music_tokens_level, dim=0) for music_tokens_level in zip(*music_tokens_list)]
+ return music_tokens
+
+ def sample(self, n_samples):
+ music_tokens = [
+ torch.randint(0, self.nb_discrete_codes, size=(n_samples, *music_tokens_shape), device="cpu")
+ for music_tokens_shape in self.music_tokens_shapes
+ ]
+ return self.decode(music_tokens)
+
+ def forward(self, raw_audio: torch.FloatTensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Forward pass of the VQ-VAE, encodes the `raw_audio` to latent states, which are then decoded for each level.
+ The commit loss, which ensure that the encoder's computed embeddings are close to the codebook vectors, is
+ computed.
+
+ Args:
+ raw_audio (`torch.FloatTensor`):
+ Audio input which will be encoded and decoded.
+
+ Returns:
+ `Tuple[torch.Tensor, torch.Tensor]`
+
+
+ Example:
+ ```python
+ >>> from transformers import JukeboxVQVAE, set_seed
+ >>> import torch
+
+ >>> model = JukeboxVQVAE.from_pretrained("openai/jukebox-1b-lyrics").eval()
+ >>> set_seed(0)
+ >>> zs = [torch.randint(100, (4, 1))]
+ >>> model.decode(zs).shape
+ torch.Size([4, 8, 1])
+ ```
+ """
+
+ # Encode/Decode
+ input_audio = raw_audio.permute(0, 2, 1).float()
+ latent_states = []
+ for level in range(self.levels):
+ encoder = self.encoders[level]
+ latent_state = encoder(input_audio)
+ latent_states.append(latent_state[-1])
+
+ _, music_tokens, commit_losses, _ = self.bottleneck(latent_states)
+ dequantised_states = []
+ for level in range(self.levels):
+ decoder = self.decoders[level]
+ dequantised_state = decoder(music_tokens[level : level + 1], all_levels=False)
+ dequantised_states.append(dequantised_state.permute(0, 2, 1))
+
+ commit_loss = sum(commit_losses)
+ loss = self.commit * commit_loss
+
+ return dequantised_states, loss
+
+
+class JukeboxMLP(nn.Module):
+ def __init__(self, config):
+ # a single channel is always used in original code
+ super().__init__()
+ embed_dim = config.hidden_size
+ hidden_dim = int(config.mlp_multiplier * embed_dim)
+
+ self.c_fc = JukeboxConv1D(embed_dim, hidden_dim)
+ self.c_proj = JukeboxConv1D(hidden_dim, embed_dim)
+ self.act = ACT2FN[config.act_fn]
+ self.dropout = nn.Dropout(config.resid_dropout)
+
+ def forward(self, hidden_states):
+ hidden_states = self.c_fc(hidden_states)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.c_proj(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ return hidden_states
+
+
+class JukeboxLayerNorm(FusedLayerNorm):
+ def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
+ super().__init__(normalized_shape, eps=eps, elementwise_affine=elementwise_affine)
+ self.width = np.prod(normalized_shape)
+ self.max_numel = 65535 * self.width
+
+ def forward(self, input):
+ if input.numel() > self.max_numel:
+ return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps).type_as(input)
+ else:
+ return super().forward(input).type_as(input)
+
+
+class JukeboxAttention(nn.Module):
+ def __init__(self, config, n_ctx, attn_func="dense_attn"):
+ super().__init__()
+ self.embed_dim = config.hidden_size
+ self.n_heads = config.n_heads
+ self.dropout = config.attn_dropout
+ hidden_dim = int(config.attention_multiplier * self.embed_dim)
+
+ self.head_dim = hidden_dim // config.n_heads
+ self.n_ctx = n_ctx
+ self.hidden_dim = hidden_dim
+ self.scale = self.head_dim**-0.25
+ self.mask = config.mask
+
+ if attn_func == "cross_attention":
+ self.c_attn = JukeboxConv1D(self.embed_dim, hidden_dim)
+ self.c_enc_kv = JukeboxConv1D(self.embed_dim, hidden_dim * 2)
+ else:
+ self.c_attn = JukeboxConv1D(self.embed_dim, hidden_dim * 3)
+
+ self.c_proj = JukeboxConv1D(hidden_dim, self.embed_dim)
+ self.attn_dropout = nn.Dropout(config.attn_dropout)
+ self.resid_dropout = nn.Dropout(config.resid_dropout)
+
+ # Sequence of length seq_len is factored as [blocks, seq_len // blocks]
+ self.attn_func = attn_func
+ if attn_func == "cross_attention":
+ self.qkv = self.decode_qkv
+ elif attn_func == "prime_attn":
+ self.qkv = self.prime_qkv
+ else:
+ self.qkv = self.factored_qkv
+
+ ATTENTION_MAP = {
+ "dense_attn": (self.dense_attn, "autoregressive"),
+ "block_attn": (self.block_attn, "autoregressive"),
+ "transpose_block_attn": (self.transpose_block_attn, "autoregressive"),
+ "prev_block_attn": (self.prev_block_attn, None),
+ "summary_attn": (self.summary_attn, "summary"),
+ "summary_spread_attn": (self.summary_spread_attn, "summary"),
+ "cross_attention": (self.dense_attn, None),
+ "prime_attn": (self.prime_attn, "prime"),
+ }
+ self.attn, self.attn_mask = ATTENTION_MAP[attn_func]
+
+ self.blocks = config.blocks
+ self.spread = config.spread
+ if self.blocks is not None:
+ self.block_ctx = self.n_ctx // self.blocks
+
+ self.sample_t = 0
+ self.cache = {}
+ self.encoder_len = config.nb_relevant_lyric_tokens # length of the encoder input ids
+ self.record_attn = False
+
+ def _attn(self, query_states, key_states, value_states, sample):
+ scale = self.scale
+ if self.training:
+ attention_weight = torch.matmul(query_states * scale, key_states * scale)
+ else:
+ attention_weight = torch.matmul(query_states, key_states)
+ attention_weight.mul_(scale * scale)
+ attn_weight_type = attention_weight.dtype
+ attention_weight = attention_weight.float()
+ if self.mask:
+ # Generate appropriate mask to mask out all positions before current
+ # Might take up lot of memory for dense, so can cache it
+ mask = get_mask(
+ self.attn_mask,
+ query_states.size(-2),
+ key_states.size(-1),
+ self.blocks,
+ self.spread,
+ attention_weight.device,
+ sample,
+ self.sample_t,
+ )
+ if mask is not None:
+ attention_weight = attention_weight * mask + -1e9 * (1 - mask)
+ attention_prob = F.softmax(attention_weight, dim=-1).type(attn_weight_type)
+ if self.record_attn:
+ self.attention_prob = attention_prob
+ if self.attn_func == "prime_attn":
+ # only keep music queries and lyrics keys/values
+ self.attention_prob = self.attention_prob[:, :, self.encoder_len :, : self.encoder_len]
+ attention_prob = self.attn_dropout(attention_prob)
+ context_states = torch.matmul(attention_prob, value_states)
+ return context_states
+
+ def merge_heads(self, hidden_states):
+ hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
+ new_hidden_states_shape = (*hidden_states.size()[:-2], hidden_states.size(-2) * hidden_states.size(-1))
+ return hidden_states.view(*new_hidden_states_shape) # in Tensorflow implem: fct merge_states
+
+ def split_heads(self, hidden_states, is_key=False):
+ new_hidden_states_shape = (
+ *hidden_states.size()[:-1],
+ self.n_heads,
+ hidden_states.size(-1) // self.n_heads,
+ )
+ hidden_states = hidden_states.view(*new_hidden_states_shape) # in Tensorflow implem: fct split_states
+ if is_key:
+ return hidden_states.permute(0, 2, 3, 1)
+ else:
+ return hidden_states.permute(0, 2, 1, 3)
+
+ def dense_attn(self, query, key, value, sample):
+ query = self.split_heads(query)
+ key = self.split_heads(key, is_key=True)
+ value = self.split_heads(value)
+ context_states = self._attn(query, key, value, sample)
+ context_states = self.merge_heads(context_states)
+ return context_states
+
+ def block_attn(self, query, key, value, sample):
+ block_ctx = self.block_ctx
+ batch_size, seq_len, embed_dim = value.shape # For sample, query_len= 1, key_len = value_len = sample_t
+ if sample:
+ return self.dense_attn(query, key, value, sample).view(batch_size, 1, embed_dim)
+ else:
+ query_length = query.shape[1]
+ query = query.view(batch_size * query_length // block_ctx, block_ctx, embed_dim)
+ if query_length < seq_len:
+ seq_len = query_length
+ key = key[:, -seq_len:].contiguous()
+ value = value[:, -seq_len:].contiguous()
+ key = key.view(batch_size * seq_len // block_ctx, block_ctx, embed_dim)
+ value = value.view(batch_size * seq_len // block_ctx, block_ctx, embed_dim)
+ return self.dense_attn(query, key, value, sample).view(batch_size, seq_len, embed_dim)
+
+ def transpose_block_attn(self, query, key, value, sample):
+ block_ctx = self.block_ctx
+ batch_size, seq_len, embed_dim = value.shape # For sample, query_len= 1, key_len = value_len = sample_t
+ if sample:
+ block_len = (seq_len - 1) % block_ctx
+ key = key[:, block_len::block_ctx, :]
+ value = value[:, block_len::block_ctx, :]
+ return self.dense_attn(query, key, value, sample).view(batch_size, 1, embed_dim)
+ else:
+ query_length = query.shape[1]
+ query = query.view(batch_size, query_length // block_ctx, block_ctx, embed_dim)
+ query = query.transpose(1, 2).contiguous()
+ query = query.view(batch_size * block_ctx, query_length // block_ctx, embed_dim)
+
+ key = key.view(batch_size, seq_len // block_ctx, block_ctx, embed_dim)
+ key = key.transpose(1, 2).contiguous()
+ key = key.view(batch_size * block_ctx, seq_len // block_ctx, embed_dim)
+
+ value = value.view(batch_size, seq_len // block_ctx, block_ctx, embed_dim)
+ value = value.transpose(1, 2).contiguous()
+ value = value.view(batch_size * block_ctx, seq_len // block_ctx, embed_dim)
+
+ block_attn = self.dense_attn(query, key, value, sample)
+ block_attn = block_attn.view(batch_size, block_ctx, query_length // block_ctx, embed_dim)
+ block_attn = block_attn.transpose(1, 2).contiguous()
+ block_attn = block_attn.view(batch_size, query_length, embed_dim)
+
+ return block_attn
+
+ def prev_block_attn(self, query, key, value, sample):
+ block_ctx = self.block_ctx
+ batch_size, seq_len, embed_dim = value.shape # For sample, query_len= 1, key_len = value_len = sample_t
+ if sample:
+ block = (seq_len - 1) // block_ctx
+ prev_l = (block - 1) * block_ctx
+ if block > 0:
+ key = key[:, prev_l : prev_l + block_ctx, :]
+ value = value[:, prev_l : prev_l + block_ctx, :]
+ else:
+ key = torch.zeros(batch_size, block_ctx, embed_dim, device=query.device, dtype=query.dtype)
+ value = torch.zeros(batch_size, block_ctx, embed_dim, device=query.device, dtype=query.dtype)
+ return self.dense_attn(query, key, value, sample).view(batch_size, 1, embed_dim)
+ else:
+ query_length = query.shape[1]
+ query = query.view(batch_size * query_length // block_ctx, block_ctx, embed_dim)
+
+ key = key.view(batch_size, seq_len // block_ctx, block_ctx, embed_dim)[:, :-1, :, :]
+ key = torch.nn.functional.pad(key, (0, 0, 0, 0, 1, 0))
+ key = key.view(batch_size * seq_len // block_ctx, block_ctx, embed_dim)
+
+ value = value.view(batch_size, seq_len // block_ctx, block_ctx, embed_dim)[:, :-1, :, :]
+ value = torch.nn.functional.pad(value, (0, 0, 0, 0, 1, 0))
+ value = value.view(batch_size * seq_len // block_ctx, block_ctx, embed_dim)
+
+ if query_length < seq_len:
+ nb_query_blocks = query_length // block_ctx
+ nb_key_blocks = seq_len // block_ctx
+ seq_len = query_length
+ key = key.view(batch_size, nb_key_blocks, block_ctx, embed_dim)[:, -nb_query_blocks:]
+ key = key.contiguous().view(batch_size * nb_query_blocks, block_ctx, embed_dim)
+
+ value = value.view(batch_size, nb_key_blocks, block_ctx, embed_dim)[:, -nb_query_blocks:]
+ value = value.contiguous().view(batch_size * nb_query_blocks, block_ctx, embed_dim)
+
+ return self.dense_attn(query, key, value, sample).view(batch_size, seq_len, embed_dim)
+
+ def summary_attn(self, query, key, value, sample):
+ blocks = self.blocks
+ block_ctx = self.block_ctx
+ batch_size, seq_len, embed_dim = value.shape # For sample, query_len= 1, key_len = value_len = sample_t
+ if sample:
+ key = key[:, block_ctx - 1 : blocks * block_ctx - 1 : block_ctx, :]
+ key = torch.nn.functional.pad(key, (0, 0, 1, 0))
+
+ value = value[:, block_ctx - 1 : blocks * block_ctx - 1 : block_ctx, :]
+ value = torch.nn.functional.pad(value, (0, 0, 1, 0))
+ return self.dense_attn(query, key, value, sample).view(batch_size, 1, embed_dim)
+ else:
+ key = key.view(batch_size, blocks, seq_len // blocks, embed_dim)[:, :-1, -1, :]
+ key = torch.nn.functional.pad(key, (0, 0, 1, 0)) # batch_size, blocks, embed_dim
+
+ value = value.view(batch_size, blocks, seq_len // blocks, embed_dim)[:, :-1, -1, :]
+ value = torch.nn.functional.pad(value, (0, 0, 1, 0)) # batch_size, blocks, embed_dim
+ return self.dense_attn(query, key, value, sample).view(batch_size, seq_len, embed_dim)
+
+ def summary_spread_attn(self, query, key, value, sample):
+ blocks = self.blocks
+ spread = self.spread
+
+ batch_size, seq_len, embed_dim = value.shape # For sample, query_len= 1, key_len = value_len = sample_t
+ if sample:
+ raise NotImplementedError
+ else:
+ key = key.view(batch_size, blocks, seq_len // blocks, embed_dim)[:, :-1, -spread:, :]
+ key = torch.nn.functional.pad(key, (0, 0, 0, 0, 1, 0)).contiguous()
+ key = key.view(batch_size, blocks * spread, embed_dim)
+
+ value = value.view(batch_size, blocks, seq_len // blocks, embed_dim)[:, :-1, -spread:, :]
+ value = torch.nn.functional.pad(value, (0, 0, 0, 0, 1, 0)).contiguous()
+ value = value.view(batch_size, blocks * spread, embed_dim)
+
+ return self.dense_attn(query, key, value, sample).view(batch_size, seq_len, embed_dim)
+
+ def prime_attn(self, query, key, value, sample):
+ encoder_len = self._encoder_len
+ key = key[:, :encoder_len]
+ value = value[:, :encoder_len]
+ return self.dense_attn(query, key, value, sample)
+
+ def factored_qkv(self, hidden_states, last_encoder_hidden_states=None, sample=False):
+ curr_ctx = hidden_states.shape[1]
+ if last_encoder_hidden_states is not None:
+ raise TypeError("last_encoder_hidden_states should be None")
+
+ query, key, value = hidden_states.chunk(3, dim=2)
+ if sample:
+ self.sample_t += curr_ctx
+ key, value = self._append_cache(key, value)
+ l_cache = self._suff_cache_len()
+ if self._cache_len() > l_cache:
+ self._slice_cache(-l_cache)
+ if curr_ctx > 1:
+ if self.attn_func != "dense_attn":
+ query = self._pad_to_block_ctx(query, query=True)
+ key = self._pad_to_block_ctx(key)
+ value = self._pad_to_block_ctx(value)
+ sample = False
+ else:
+ key = self.cache["key"]
+ value = self.cache["value"]
+ return query, key, value, sample
+
+ def prime_qkv(self, hidden_states, last_encoder_hidden_states=None, sample=False):
+ curr_ctx = hidden_states.shape[1]
+ if last_encoder_hidden_states is not None:
+ raise TypeError("last_encoder_hidden_states should be None")
+ query, key, value = hidden_states.chunk(3, dim=2)
+ if sample:
+ if self._cache_len() < self._encoder_len:
+ self._append_cache(key, value)
+ if self._cache_len() > self._encoder_len:
+ self._slice_cache(0, self._encoder_len)
+ key, value = self.cache["key"], self.cache["value"]
+ self.sample_t += curr_ctx
+ return query, key, value, sample
+
+ def decode_qkv(self, hidden_states, last_encoder_hidden_states=None, sample=False):
+ curr_ctx = hidden_states.shape[1]
+ query = hidden_states
+ if sample:
+ if self.sample_t == 0:
+ self.cache["key"], self.cache["value"] = self.c_enc_kv(
+ last_encoder_hidden_states.type_as(hidden_states)
+ ).chunk(2, dim=2)
+ key, value = self.cache["key"], self.cache["value"]
+ self.sample_t += curr_ctx
+ else:
+ key, value = self.c_enc_kv(last_encoder_hidden_states.type_as(hidden_states)).chunk(2, dim=2)
+ return query, key, value, sample
+
+ def forward(self, hidden_states, last_encoder_hidden_states=None, sample=False):
+ curr_ctx = hidden_states.shape[1]
+ hidden_states = self.c_attn(hidden_states)
+ query, key, value, sample = self.qkv(
+ hidden_states, last_encoder_hidden_states=last_encoder_hidden_states, sample=sample
+ )
+ attention_scores = self.attn(query, key, value, sample)
+ if attention_scores.shape[1] != curr_ctx:
+ offset = self._offset(curr_ctx)
+ attention_scores = attention_scores[:, offset : offset + curr_ctx, :].contiguous()
+ attention_scores = self.c_proj(attention_scores)
+ return self.resid_dropout(attention_scores)
+
+ @property
+ def _encoder_len(self):
+ encoder_len = self.encoder_len
+ encoder_blocks = (encoder_len // self.blocks) + 1
+ return encoder_blocks * self.blocks
+
+ def _offset(self, curr_ctx):
+ if self.attn_func == "dense_attn":
+ return 0
+ return (self.sample_t - curr_ctx) % self.block_ctx
+
+ def _pad_to_block_ctx(self, hidden_states, query=False):
+ seq_len = hidden_states.shape[1]
+ offset = self._offset(seq_len) if query else 0
+ n_blocks = (seq_len + offset + self.block_ctx - 1) // self.block_ctx
+ pad = n_blocks * self.block_ctx - seq_len - offset
+ if pad == 0 and offset == 0:
+ return hidden_states
+ else:
+ return F.pad(hidden_states, (0, 0, offset, pad))
+
+ def _cache_len(self):
+ return 0 if "key" not in self.cache else self.cache["key"].shape[1]
+
+ def _suff_cache_len(self):
+ """
+ Precondition:
+ key and value are appended with the current context and self.sample_t reflects the 1-indexed sample
+ location in the context.
+ """
+ previous_block_length = (self.sample_t - 1) % self.block_ctx + 1 + self.block_ctx
+ REQUIRED_CACHE_LEN = {
+ "dense_attn": self.sample_t,
+ "block_attn": (self.sample_t - 1) % self.block_ctx + 1,
+ "transpose_block_attn": self.sample_t,
+ "prev_block_attn": self.sample_t if self.sample_t <= self.block_ctx else previous_block_length,
+ "cross_attn": self.encoder_len,
+ "prime_attn": min(self.sample_t, self._encoder_len),
+ }
+
+ return REQUIRED_CACHE_LEN[self.attn_func]
+
+ def _slice_cache(self, start, end=None):
+ self.cache["key"] = self.cache["key"][:, start:end]
+ self.cache["value"] = self.cache["value"][:, start:end]
+
+ def _append_cache(self, key, value):
+ if "key" not in self.cache:
+ self.cache["key"] = key
+ self.cache["value"] = value
+ else:
+ old_key, old_value = key, value
+ key = torch.cat([self.cache["key"], old_key], dim=1)
+ value = torch.cat([self.cache["value"], old_value], dim=1)
+ del self.cache["key"]
+ del self.cache["value"]
+ del old_key
+ del old_value
+ self.cache["key"] = key
+ self.cache["value"] = value
+ return self.cache["key"], self.cache["value"]
+
+ def del_cache(self):
+ self.sample_t = 0
+ if "key" in self.cache:
+ del self.cache["key"]
+ if "value" in self.cache:
+ del self.cache["value"]
+ self.cache = {}
+
+
+class JukeboxBlock(nn.Module):
+ def __init__(self, config, n_ctx, attn_func="dense_attn"):
+ super().__init__()
+ self.width = config.hidden_size
+ self.attn = JukeboxAttention(config, n_ctx, attn_func=attn_func)
+
+ self.layer_norm_0 = JukeboxLayerNorm(config.hidden_size)
+ self.mlp = JukeboxMLP(config)
+ self.layer_norm_1 = JukeboxLayerNorm(config.hidden_size)
+ self.res_scale = 1.0 / config.num_layers if config.attn_res_scale else 1.0
+ self.attn_func = attn_func
+
+ def forward(self, hidden_states, last_encoder_hidden_states, sample=False):
+ residuals = hidden_states
+ hidden_states = self.layer_norm_0(hidden_states)
+ hidden_states = self.attn(hidden_states, last_encoder_hidden_states, sample)
+
+ output_states = self.layer_norm_1(residuals + hidden_states)
+ output_states = self.mlp(output_states)
+ if self.res_scale == 1.0:
+ output = residuals + hidden_states + output_states
+ else:
+ output = residuals + self.res_scale * (hidden_states + output_states)
+ return output
+
+
+class JukeboxLayerStack(nn.Module):
+ def __init__(self, config, n_ctx):
+ super().__init__()
+ self.n_ctx = n_ctx
+ self.width = config.hidden_size
+ self.num_layers = config.num_layers
+ self.blocks = config.blocks
+ self.attention_pattern = config.attention_pattern
+ if self.blocks is not None:
+ self.block_ctx = n_ctx // self.blocks
+ self.encoder_len = config.nb_relevant_lyric_tokens
+ self.n_heads = config.n_heads
+
+ # Orders of attn_func
+ attention_pattern = ATTENTION_PATTERNS[self.attention_pattern]
+ self._attn_mods = nn.ModuleList()
+ for depth in range(self.num_layers):
+ self._attn_mods.append(JukeboxBlock(config, n_ctx, attn_func=attention_pattern(depth)))
+
+ self.saved_attn_weights = []
+
+ def set_record_attn(self, record_attn):
+ """
+ Makes forward prop dump self-attention softmaxes to self.saved_attn_weights.
+
+ Args:
+ record_attn (`Union[bool,set]`):
+ Either a set of layer indices indicating which layers to store, or a boolean value indicating Whether
+ to dump all.
+ """
+
+ def _should_record_attn(layer_idx):
+ if isinstance(record_attn, bool):
+ return record_attn
+ return layer_idx in record_attn
+
+ for i, layer in enumerate(self._attn_mods):
+ layer.attn.record_attn = _should_record_attn(i)
+
+ if not record_attn:
+ self.saved_attn_weights = []
+
+ def forward(self, hidden_states, last_encoder_hidden_states=None, sample=False):
+ # Blocks
+ for i, attn_layer in enumerate(self._attn_mods):
+ if attn_layer.attn_func == "cross_attention": # attend to the lyrics
+ hidden_states = attn_layer(
+ hidden_states, last_encoder_hidden_states=last_encoder_hidden_states, sample=sample
+ )
+ else:
+ hidden_states = attn_layer(hidden_states, last_encoder_hidden_states=None, sample=sample)
+ if attn_layer.attn.record_attn:
+ self.saved_attn_weights.append(attn_layer.attn.c_attn.weight)
+ return hidden_states
+
+ def del_cache(self):
+ for attn_layer in self._attn_mods:
+ attn_layer.attn.del_cache()
+
+
+class JukeboxPositionalEmbedding(nn.Module):
+ def __init__(self, embed_dim, width):
+ super().__init__()
+ self.pos_emb = nn.Parameter(torch.empty((embed_dim, width)))
+
+ def forward(self):
+ pos_emb = self.pos_emb
+ return pos_emb
+
+
+class JukeboxConditionalAutoregressive(nn.Module):
+ def __init__(
+ self,
+ config,
+ n_ctx=None,
+ embed_dim=None,
+ audio_conditioning=False,
+ metadata_conditioning=False,
+ is_encoder=False,
+ ):
+ """
+ Autoregressive model on either lyric tokens or music tokens, or both. The attention pattern should be properly
+ set fro each configuration.
+
+ Args:
+ config (`JukeboxPriorConfig`):
+ Model configuration class with all the parameters of the model. Initializing with a config file does
+ not load the weights associated with the model, only the configuration. Check out the
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+ n_ctx (`int`, *optional*):
+ Number of tokens or lyrics tokens provided in a single pass.
+ embed_dim (`int`, *optional*):
+ Either equals to the dimension of the codebook, or the sum of n_vocab (lyrics) and codebook dimension,
+ if the model combines lyrics and music tokens, or simply n_vocab if the model is a seperate encoder
+ audio_conditioning (`bool`, *optional*, defaults to `False`):
+ Whether or not the prior supports conditionning on audio.
+ metadata_conditioning (`bool`, *optional*, defaults to `False`):
+ Whether or not the prior supports conditionning on artitst, genres, lyrics and timing.
+ is_encoder (`bool`, *optional*, defaults to `False`):
+ Whether the model is an encoder only model.
+ """
+
+ super().__init__()
+ self.width = config.hidden_size
+ self.num_layers = config.num_layers
+ self.n_ctx = n_ctx if n_ctx is not None else config.n_ctx
+ self.embed_dim = embed_dim if embed_dim is not None else config.music_vocab_size
+ self.embed_tokens = nn.Embedding(self.embed_dim, config.hidden_size)
+ self.embed_tokens_dropout = nn.Dropout(config.emb_dropout)
+ self.metadata_conditioning = metadata_conditioning
+ self.audio_conditioning = audio_conditioning
+ if not metadata_conditioning:
+ self.start_token = nn.Parameter(torch.empty((1, config.hidden_size)))
+ self.pos_emb = JukeboxPositionalEmbedding(self.n_ctx, config.hidden_size)
+ self.pos_emb_dropout = nn.Dropout(config.emb_dropout)
+
+ self.transformer = JukeboxLayerStack(config, n_ctx=self.n_ctx)
+ self.is_encoder = is_encoder
+ self.encoder_len = config.nb_relevant_lyric_tokens
+
+ if config.merged_decoder:
+ # Merged piped model uses this setup
+ self.add_cond_after_transformer = False
+ self.share_embed_tokens_fc_proj_out = False
+ else:
+ self.add_cond_after_transformer = True
+ self.share_embed_tokens_fc_proj_out = True
+
+ if not is_encoder:
+ self.fc_proj_out = nn.Linear(config.hidden_size, self.embed_dim, bias=False)
+ if self.share_embed_tokens_fc_proj_out:
+ self.fc_proj_out.weight = self.embed_tokens.weight
+ self.loss = torch.nn.CrossEntropyLoss()
+
+ def forward(
+ self,
+ tokens,
+ audio_conditioning=None,
+ metadata_conditioning=None,
+ last_encoder_hidden_states=None,
+ get_preds=False,
+ get_acts=False,
+ get_sep_loss=False,
+ ):
+ """
+ Args:
+ tokens (`torch.tensor`):
+ Can represent music tokens, lyrics tokens or both, depending on the configuration.
+ """
+ # Preprocess.
+ batch_size = tokens.shape[0]
+ with torch.no_grad():
+ tokens = tokens.view(batch_size, -1).long()
+
+ if not self.audio_conditioning:
+ audio_conditioning = torch.zeros(
+ (batch_size, 1, self.width),
+ device=tokens.device,
+ dtype=self.transformer._attn_mods[0].mlp.c_fc.weight.dtype,
+ )
+
+ target = tokens # Target
+ hidden_states = self.embed_tokens(tokens)
+ # Shift by 1, and fill in start token
+ hidden_states = torch.cat((hidden_states[:, -1:], hidden_states[:, :-1]), dim=1)
+ if self.metadata_conditioning:
+ hidden_states[:, 0] = metadata_conditioning.view(batch_size, self.width)
+ else:
+ hidden_states[:, 0] = self.start_token
+
+ hidden_states = (
+ self.embed_tokens_dropout(hidden_states) + self.pos_emb_dropout(self.pos_emb()) + audio_conditioning
+ ) # Pos emb and dropout
+
+ hidden_states = self.transformer(
+ hidden_states, last_encoder_hidden_states=last_encoder_hidden_states
+ ) # Transformer
+ if self.add_cond_after_transformer: # Piped doesnt add x_cond
+ hidden_states = hidden_states + audio_conditioning
+
+ activations = hidden_states
+ if self.is_encoder:
+ return hidden_states
+
+ hidden_states = self.fc_proj_out(hidden_states) # Predictions
+ loss_fn = nn.CrossEntropyLoss()
+ if get_sep_loss:
+ lyric_hidden_states = hidden_states[:, : self.encoder_len].reshape(-1, self.embed_dim)
+ token_hidden_states = hidden_states[:, self.encoder_len :].reshape(-1, self.embed_dim)
+
+ lyric_loss = loss_fn(lyric_hidden_states, target[:, : self.encoder_len].reshape(-1)) / np.log(2.0)
+ music_token_loss = loss_fn(token_hidden_states, target[:, self.encoder_len :].reshape(-1)) / np.log(2.0)
+
+ loss = (lyric_loss, music_token_loss) # Note order! Lyric is first
+ else:
+ loss = loss_fn(hidden_states.view(-1, self.embed_dim), target.view(-1)) / np.log(2.0) # Loss
+
+ if get_preds:
+ return loss, hidden_states
+ elif get_acts:
+ return loss, activations
+ else:
+ return loss, None
+
+ def get_emb(self, sample_t, n_samples, tokens, audio_conditioning, metadata_conditioning):
+ if sample_t == 0:
+ hidden_states = torch.empty(n_samples, 1, self.width, dtype=self.embed_tokens.weight.dtype).to(
+ self.embed_tokens.weight.device
+ )
+ if self.metadata_conditioning:
+ hidden_states[:, 0] = metadata_conditioning.view(n_samples, self.width)
+ else:
+ hidden_states[:, 0] = self.start_token
+ else:
+ hidden_states = self.embed_tokens(tokens)
+ if audio_conditioning.shape == (n_samples, self.n_ctx, self.width):
+ cond = audio_conditioning[:, sample_t : sample_t + 1, :]
+ else:
+ cond = audio_conditioning
+ # Pos emb, dropout is identity at eval time
+ hidden_states = hidden_states + self.pos_emb()[sample_t : sample_t + 1] + cond
+ return hidden_states, cond
+
+ def sample(
+ self,
+ n_samples,
+ audio_conditioning=None,
+ metadata_conditioning=None,
+ last_encoder_hidden_states=None,
+ temp=1.0,
+ top_k=0,
+ top_p=0.0,
+ get_preds=False,
+ sample_tokens=None,
+ ):
+ if sample_tokens is None:
+ sample_tokens = self.n_ctx
+
+ if not self.audio_conditioning:
+ audio_conditioning = torch.zeros(
+ (n_samples, 1, self.width), dtype=self.transformer._attn_mods[0].mlp.c_fc.weight.dtype
+ ).to(self.fc_proj_out.device)
+
+ with torch.no_grad():
+ sampled_tokens = []
+ tokens = None
+ if get_preds:
+ preds = []
+
+ iter = tqdm(range(0, sample_tokens), leave=False)
+ for sample_t in iter:
+ iter.set_description(f"Ancestral sampling {sample_tokens} music tokens", refresh=True)
+ hidden_states, cond = self.get_emb(
+ sample_t, n_samples, tokens, audio_conditioning, metadata_conditioning
+ )
+
+ hidden_states = self.transformer(
+ hidden_states, last_encoder_hidden_states=last_encoder_hidden_states, sample=True
+ )
+ if self.add_cond_after_transformer:
+ hidden_states = hidden_states + cond
+ hidden_states = self.fc_proj_out(hidden_states) # Predictions
+ if get_preds:
+ preds.append(hidden_states.clone())
+ # Adjust logits
+ hidden_states = hidden_states / temp
+ hidden_states = filter_logits(hidden_states, top_k=top_k, top_p=top_p)
+ # Sample and replace hidden_states
+ tokens = torch.distributions.Categorical(logits=hidden_states).sample()
+ sampled_tokens.append(tokens.clone())
+
+ del tokens
+ self.transformer.del_cache()
+
+ tokens = torch.cat(sampled_tokens, dim=1)
+ if get_preds:
+ preds = torch.cat(preds, dim=1)
+ if get_preds:
+ return tokens, preds
+ else:
+ return tokens
+
+ def split_chunks(self, length, chunk_size):
+ n_passes = (length + chunk_size - 1) // chunk_size
+ chunk_sizes = [*[chunk_size] * (n_passes - 1), (length - 1) % chunk_size + 1]
+ return chunk_sizes
+
+ def primed_sample(
+ self,
+ n_samples,
+ lyric_and_music_tokens,
+ audio_conditioning=None,
+ metadata_conditioning=None,
+ last_encoder_hidden_states=None,
+ temp=1.0,
+ top_k=0,
+ top_p=0.0,
+ get_preds=False,
+ chunk_size=None,
+ sample_tokens=None,
+ ):
+ if sample_tokens is None:
+ sample_tokens = self.n_ctx
+ # Preprocess.
+ batch_size = lyric_and_music_tokens.shape[0]
+ with torch.no_grad():
+ lyric_and_music_tokens = lyric_and_music_tokens.view(batch_size, -1).long()
+
+ sampled_audio = torch.split(lyric_and_music_tokens, 1, dim=1)
+ sampled_audio = list(sampled_audio)
+
+ if not self.audio_conditioning:
+ audio_conditioning = torch.zeros(
+ (n_samples, 1, self.width), dtype=self.transformer._attn_mods[0].mlp.c_fc.weight.dtype
+ ).to(lyric_and_music_tokens.device)
+
+ with torch.no_grad():
+ if get_preds:
+ preds = []
+
+ # Fill up key/value cache for past context by runing forward pass.
+ # We do so in chunks instead of doing the whole past in one forward pass to reduce max memory usage.
+ if chunk_size is None:
+ chunk_size = len(sampled_audio)
+ chunk_sizes = self.split_chunks(len(sampled_audio), chunk_size)
+ x_primes = []
+ start = 0
+ token = None
+
+ for current_chunk_size in tqdm(chunk_sizes, desc="Preparing past key value", leave=False):
+ sampled_audio_prime, conds_prime = [], []
+ for sample_t in range(start, start + current_chunk_size):
+ x_prime, cond_prime = self.get_emb(
+ sample_t, n_samples, token, audio_conditioning, metadata_conditioning
+ )
+ token = sampled_audio[sample_t]
+ sampled_audio_prime.append(x_prime)
+ conds_prime.append(cond_prime)
+ start = start + current_chunk_size
+ x_prime, cond_prime = torch.cat(sampled_audio_prime, dim=1), torch.cat(conds_prime, dim=1)
+ del sampled_audio_prime
+ del conds_prime
+ if not get_preds:
+ del cond_prime
+ x_prime = self.transformer(x_prime, last_encoder_hidden_states=last_encoder_hidden_states, sample=True)
+
+ if get_preds:
+ if self.add_cond_after_transformer:
+ x_prime = x_prime + cond_prime
+ del cond_prime
+ x_primes.append(x_prime)
+ else:
+ del x_prime
+
+ if get_preds:
+ x_prime = torch.cat(x_primes, dim=1)
+ x_prime = self.fc_proj_out(x_prime) # Predictions
+ preds.append(x_prime)
+
+ # the input of the encoder and decoder can be merged into (lyrics, music tokens)
+ input_tokens = sampled_audio[-1]
+
+ itererator = tqdm(
+ range(len(sampled_audio), sample_tokens),
+ desc=f"Sampling {len(range(len(sampled_audio), sample_tokens))} music tokens",
+ leave=False,
+ )
+ for sample_t in itererator:
+ hidden_states, cond = self.get_emb(
+ sample_t, n_samples, input_tokens, audio_conditioning, metadata_conditioning
+ )
+
+ hidden_states = self.transformer(
+ hidden_states, last_encoder_hidden_states=last_encoder_hidden_states, sample=True
+ )
+ if self.add_cond_after_transformer:
+ hidden_states = hidden_states + cond
+ hidden_states = self.fc_proj_out(hidden_states) # Predictions
+ if get_preds:
+ preds.append(hidden_states)
+ # Adjust logits
+ hidden_states = hidden_states / temp
+ hidden_states = filter_logits(hidden_states, top_k=top_k, top_p=top_p)
+ # only music tokens are sampled
+ music_tokens = torch.distributions.Categorical(logits=hidden_states).sample()
+ sampled_audio.append(music_tokens.clone())
+ input_tokens = music_tokens
+
+ del input_tokens, music_tokens
+ self.transformer.del_cache()
+
+ music_tokens = torch.cat(sampled_audio, dim=1)
+ if get_preds:
+ preds = torch.cat(preds, dim=1)
+ if get_preds:
+ return music_tokens, preds
+ else:
+ return music_tokens
+
+
+class JukeboxMusicTokenConditioner(nn.Module):
+ """
+ The `JukeboxMusicTokenConditioner` takes music tokens as an input (coresponding to the codes of the VQVAE's
+ codebook) and upsamples it using a single layer of decoder convolution block (the same is used in the VQVAE).
+ """
+
+ def __init__(self, config, level):
+ super().__init__()
+ self.embed_tokens = nn.Embedding(config.music_vocab_size, config.hidden_size)
+ config.embed_dim = config.music_vocab_size # setting correct argument for the `JukeboxDecoder`
+
+ self.upsampler = JukeboxDecoderConvBock(
+ config,
+ config.hidden_size,
+ config.res_conv_width,
+ config.res_conv_depth,
+ config.res_downs_t[level],
+ config.res_strides_t[level],
+ reverse_dilation=False,
+ )
+ self.layer_norm = JukeboxLayerNorm(config.hidden_size)
+
+ def forward(self, music_tokens, raw_audio_conditionning=None):
+ """
+ Args:
+ music_tokens (`torch.LongTensor`):
+ Music tokens form the uper level in range(nb_discrete_codes)
+ raw_audio_conditionning (`torch.LongTensor`, *optional*):
+ Audio used when primed sampling, raw audio information that conditions the generation
+ """
+ if raw_audio_conditionning is None:
+ raw_audio_conditionning = 0.0
+ # Embed music_tokens
+ music_tokens = music_tokens.long()
+ hidden_states = self.embed_tokens(music_tokens)
+ hidden_states = hidden_states + raw_audio_conditionning
+
+ # Run conditioner
+ hidden_states = hidden_states.permute(0, 2, 1)
+ hidden_states = self.upsampler(hidden_states)
+ hidden_states = hidden_states.permute(0, 2, 1)
+ hidden_states = self.layer_norm(hidden_states)
+ return hidden_states
+
+
+class JukeboxRangeEmbedding(nn.Module):
+ """
+ The `JukeboxRangeEmbedding` interpolate the given [pos_start, pos_end] to obtain an equivalent of time positional
+ embedding of length `n_ctx`.
+
+ Binning process : For each pos in position tensor, find its bin [start,end) mapped to [0,1,...,bins-1] [start,end)
+ -> [0,1) -> [0, bins) -> floor -> [0,...,bins-1] NOTE: Open ended interval on right, so start <= pos < end, not <=
+ end
+ """
+
+ def __init__(self, n_time, embed_dim, range, out_width, clamp=False):
+ super().__init__()
+ self.n_time = n_time
+ self.embed_dim = embed_dim
+ self.emb = nn.Embedding(embed_dim, out_width)
+ self.pos_min, self.pos_max = range
+ self.clamp = clamp
+
+ def forward(self, pos_start, pos_end=None):
+ # Check if [pos_start,pos_end] in [pos_min, pos_max)
+ if not len(pos_start.shape) == 2:
+ raise TypeError(f"Expected shape with 2 dims, got {pos_start.shape}")
+ if not (self.pos_min <= pos_start).all() and (pos_start < self.pos_max).all():
+ raise TypeError(f"Range is [{self.pos_min},{self.pos_max}), got {pos_start}")
+
+ pos_start = pos_start.float()
+ if pos_end is not None:
+ if self.clamp:
+ pos_end = pos_end.clamp(self.pos_min, self.pos_max)
+
+ pos_end = pos_end.float()
+ # Interpolate so that [pos_start, ..., pos_end] <-> position tensor of length n_ctx
+ n_time = self.n_time
+ if n_time != 1:
+ interpolation = (
+ torch.arange(0, n_time, dtype=torch.float, device=pos_start.device).view(1, n_time) / n_time
+ )
+ position = pos_start + (pos_end - pos_start) * interpolation
+ else:
+ position = pos_start
+
+ # Bin each value to bins_
+ # [0,1) -> [0,1..,embed_dim) -> [0,1...,embed_dim-1
+ normalised_position = (position - self.pos_min) / (self.pos_max - self.pos_min)
+ bins_ = (self.embed_dim * normalised_position).floor().long().detach()
+ return self.emb(bins_)
+
+
+class JukeboxLabelConditioner(nn.Module):
+ def __init__(self, config, include_time_signal):
+ super().__init__()
+
+ embed_dim = config.hidden_size
+ timing_dims = config.timing_dims
+ sampling_rate = config.sampling_rate
+ nb_genres, nb_artists = config.metadata_dims
+ music_tokens_shape = config.n_ctx
+
+ self.max_nb_genres = config.max_nb_genres
+ self.bow_genre_emb = nn.Embedding(nb_genres, embed_dim)
+ self.artist_emb = nn.Embedding(nb_artists, embed_dim)
+ self.include_time_signal = include_time_signal
+ if self.include_time_signal:
+ total_length_range = (config.min_duration * sampling_rate, config.max_duration * sampling_rate)
+ absolute_pos_range = (0.0, config.max_duration * sampling_rate)
+ relative_pos_range = (0.0, 1.0)
+ self.total_length_emb = JukeboxRangeEmbedding(1, timing_dims, total_length_range, embed_dim)
+ self.absolute_pos_emb = JukeboxRangeEmbedding(
+ music_tokens_shape, timing_dims, absolute_pos_range, embed_dim
+ )
+ self.relative_pos_emb = JukeboxRangeEmbedding(
+ music_tokens_shape, timing_dims, relative_pos_range, embed_dim, clamp=True
+ )
+
+ def forward(self, metadata):
+ total_length = metadata[:, 0:1]
+ offset = metadata[:, 1:2]
+ length = metadata[:, 2:3]
+ artist = metadata[:, 3:4]
+ genre = metadata[:, 4:]
+
+ # Start embedding of length 1
+ artist_emb = self.artist_emb(artist)
+ # Empty genre slots are denoted by -1. We mask these out.
+ mask = (genre >= 0).float().unsqueeze(2)
+ genre_emb = (self.bow_genre_emb(genre.clamp(0)) * mask).sum(dim=1, keepdim=True)
+ start_emb = genre_emb + artist_emb
+
+ # Pos embedding of length n_ctx
+ if self.include_time_signal:
+ start, end = offset, offset + length
+ total_length = total_length.float()
+ start = start.float()
+ end = end.float()
+ pos_emb = (
+ self.total_length_emb(total_length)
+ + self.absolute_pos_emb(start, end)
+ + self.relative_pos_emb(start / total_length, end / total_length)
+ )
+ else:
+ pos_emb = None
+ return start_emb, pos_emb
+
+
+class JukeboxPrior(PreTrainedModel):
+ """
+ The JukeboxPrior class, which is a wrapper around the various conditioning and the transformer. JukeboxPrior can be
+ seen as language models trained on music. They model the next `music token` prediction task. If a (lyric) `encoderù
+ is defined, it also models the `next character` prediction on the lyrics. Can be conditionned on timing, artist,
+ genre, lyrics and codes from lower-levels Priors.
+
+ Args:
+ config (`JukeboxPriorConfig`):
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
+ load the weights associated with the model, only the configuration. Check out the
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+ level (`int`, *optional*):
+ Current level of the Prior. Should be in range `[0,nb_priors]`.
+ nb_priors (`int`, *optional*, defaults to 3):
+ Total number of priors.
+ vqvae_encoder (`Callable`, *optional*):
+ Encoding method of the VQVAE encoder used in the forward pass of the model. Passing functions instead of
+ the vqvae module to avoid getting the parameters.
+ vqvae_decoder (`Callable`, *optional*):
+ Decoding method of the VQVAE decoder used in the forward pass of the model. Passing functions instead of
+ the vqvae module to avoid getting the parameters.
+ """
+
+ config_class = JukeboxPriorConfig
+
+ def _init_weights(self, module):
+ init_scale = self.config.init_scale
+
+ if isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=0.02 * init_scale)
+ elif isinstance(module, JukeboxConv1D):
+ if self.config.zero_out:
+ module.weight.data.zero_()
+ else:
+ module.weight.data.normal_(mean=0.0, std=0.02 * init_scale)
+ elif isinstance(module, JukeboxPositionalEmbedding):
+ module.pos_emb.data.normal_(mean=0.0, std=0.01 * init_scale)
+ elif isinstance(module, JukeboxRangeEmbedding):
+ module.emb.weight.data.normal_(mean=0.0, std=0.01 * init_scale)
+ elif isinstance(module, JukeboxConditionalAutoregressive) and hasattr(module, "lm_head"):
+ module.lm_head.weight.data.normal_(mean=0.0, std=0.02 * init_scale)
+ elif isinstance(module, JukeboxConditionalAutoregressive) and hasattr(module, "start_token"):
+ module.start_token.data.normal_(mean=0.0, std=0.01 * init_scale)
+ elif isinstance(module, JukeboxResConv1DBlock) and self.config.zero_out:
+ module.conv1d_2.weigth.data.zero_()
+ module.conv1d_2.bias.data.zero_()
+ if isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ if isinstance(module, nn.Linear) and module.bias is not None:
+ module.bias.data.zero_()
+
+ def __init__(self, config: JukeboxPriorConfig, level=None, nb_priors=3, vqvae_encoder=None, vqvae_decoder=None):
+ super().__init__(config)
+ # Passing functions instead of the vqvae module to avoid getting params, only used in the
+ # forward loop
+ self.vqvae_encoder = vqvae_encoder
+ self.vqvae_decoder = vqvae_decoder
+
+ self.levels = nb_priors
+ self.level = level if level is not None else config.level
+
+ self.base_model_prefix = f"priors.{self.level}"
+
+ self.n_ctx = config.n_ctx
+
+ self.lyric_conditioning = config.nb_relevant_lyric_tokens > 0
+ self.nb_relevant_lyric_tokens = config.nb_relevant_lyric_tokens
+ self.encoder_loss_fraction = config.encoder_loss_fraction
+
+ # Audio conditioning : conditioning on music tokens (either from audio or from previous levels or both)
+ self.audio_conditioning = self.level != 0
+ self.cond_level = self.level - 1
+ if self.audio_conditioning:
+ self.conditioner_blocks = JukeboxMusicTokenConditioner(config, self.level)
+
+ # metadata conditioning : contioning on timing, genres, and artist
+ self.metadata_conditioning = config.metadata_conditioning
+ if self.metadata_conditioning:
+ self.metadata_embedding = JukeboxLabelConditioner(config, include_time_signal=not self.audio_conditioning)
+
+ # define encoder-decoder or encoder and decoder
+ self.is_encoder_decoder = config.is_encoder_decoder
+ if config.is_encoder_decoder:
+ # encoder-decoder transformer
+ self.input_shapes = [config.nb_relevant_lyric_tokens, config.n_ctx]
+ self.embed_dim_shift = [0, config.lyric_vocab_size]
+ self.width = config.hidden_size
+
+ self.nb_relevant_lyric_tokens = config.nb_relevant_lyric_tokens
+
+ self.prior = JukeboxConditionalAutoregressive(
+ config,
+ n_ctx=config.nb_relevant_lyric_tokens + config.n_ctx,
+ embed_dim=config.lyric_vocab_size + config.music_vocab_size,
+ audio_conditioning=(self.audio_conditioning or self.metadata_conditioning),
+ metadata_conditioning=True,
+ )
+
+ else:
+ # Separate encoder-decoder transformer
+ encoder_config = config.encoder_config
+
+ if self.nb_relevant_lyric_tokens != 0 and self.lyric_conditioning:
+ self.lyric_acts_width = encoder_config.hidden_size
+ self.encoder_width = config.hidden_size
+ self.encoder_dim = config.lyric_vocab_size
+ self.encoder = JukeboxConditionalAutoregressive(
+ encoder_config,
+ n_ctx=self.nb_relevant_lyric_tokens,
+ embed_dim=self.encoder_dim,
+ audio_conditioning=False,
+ metadata_conditioning=False,
+ is_encoder=True,
+ )
+ self.encoder.proj_in = JukeboxConv1D(encoder_config.hidden_size, config.hidden_size)
+ self.encoder.final_layer_norm = JukeboxLayerNorm(config.hidden_size)
+ self.encoder.lm_head = nn.Linear(config.hidden_size, config.lyric_vocab_size, bias=False)
+ else:
+ self.nb_relevant_lyric_tokens = 0
+
+ # decoder model on the tokens
+ self.prior = JukeboxConditionalAutoregressive(
+ config,
+ audio_conditioning=(self.audio_conditioning or self.metadata_conditioning),
+ metadata_conditioning=self.metadata_conditioning,
+ )
+
+ self.next_token_prediction_loss_dims = config.n_ctx
+ self.total_loss_dims = self.nb_relevant_lyric_tokens + self.next_token_prediction_loss_dims
+
+ self.downsamples = [stride**down for stride, down in zip(config.res_strides_t, config.res_downs_t)]
+ self.cond_downsample = self.downsamples[self.level] if self.level != 0 else None
+ self.raw_to_tokens = np.prod(self.downsamples[: nb_priors - self.level])
+ self.sample_length = self.n_ctx * self.raw_to_tokens
+
+ logger.info(
+ f"Level:{self.level}, Cond downsample:{self.cond_downsample}, Raw to tokens:{self.raw_to_tokens}, Sample"
+ f" length:{self.sample_length}"
+ )
+
+ def get_metadata(self, labels, start, total_length, offset, get_indices=False):
+ metadata = labels.clone()
+ metadata[:, 0] = total_length
+ # Set sample_length to match this level
+ metadata[:, 2] = int(self.sample_length)
+
+ # Set offset
+ metadata[:, 1:2] = int(offset * self.raw_to_tokens) + int(start * self.raw_to_tokens)
+ # here since metadata has the full token_list, we just need to selected the ones that are relevant
+
+ # Set lyric tokens
+ metadata, indices = self.set_metadata_lyric_tokens(metadata)
+ if get_indices:
+ return metadata, indices
+ else:
+ return metadata
+
+ def set_metadata_lyric_tokens(self, labels):
+ """
+ Processes the full labels to only retrieve the relevant lyric tokens and keep the metadata conditioning tokens.
+ """
+ if self.nb_relevant_lyric_tokens > 0:
+ tokens_list = torch.zeros(
+ (labels.shape[0], self.nb_relevant_lyric_tokens), dtype=torch.long, device=labels.device
+ )
+ indices_list = [] # whats the index of each current character in original array
+ for idx in range(labels.shape[0]):
+ full_tokens = labels.clone()[:, 4 + self.metadata_embedding.max_nb_genres :]
+ total_length, offset, duration = labels[idx, 0], labels[idx, 1], labels[idx, 2]
+ tokens, indices = get_relevant_lyric_tokens(
+ full_tokens, self.nb_relevant_lyric_tokens, total_length, offset, duration
+ )
+ tokens_list[idx, :] = tokens
+ indices_list.append(indices)
+
+ return (
+ torch.cat((labels[:, : 4 + self.metadata_embedding.max_nb_genres], tokens_list), dim=-1),
+ indices_list,
+ )
+ else:
+ return labels, None
+
+ def get_music_tokens_conds(self, music_tokens, start, end):
+ """
+ Extracts current level's conditioning music tokens.
+ """
+ if self.level != 0:
+ music_tokens_cond = music_tokens[self.level - 1]
+ music_tokens = music_tokens_cond[:, start // self.cond_downsample : end // self.cond_downsample]
+ missing_cond_len = self.n_ctx // self.cond_downsample - music_tokens_cond[-1].shape[-1]
+ if missing_cond_len > 0:
+ init_cond = torch.zeros(1, missing_cond_len).to(music_tokens_cond.device)
+ music_tokens_cond = torch.cat((music_tokens_cond, init_cond), dim=-1).long()
+ music_tokens_conds = [music_tokens_cond]
+ else:
+ music_tokens_conds = None
+ return music_tokens_conds
+
+ def prior_preprocess(self, tokens, conds):
+ """
+ Shifts the input tokens to account for the dictionary merge. The embed_dim_shift give by how much the music
+ tokens should be shifted by. It is equal to `lyric_vocab_size`.
+ """
+ batch_size = tokens[0].shape[0]
+ for i in range(len(tokens)):
+ tokens[i] = (tokens[i] + int(self.embed_dim_shift[i])).view(batch_size, -1)
+
+ for i in range(len(conds)):
+ if conds[i] is None:
+ conds[i] = torch.zeros(
+ (batch_size, self.input_shapes[i], self.width), dtype=tokens[0].dtype, device=tokens[0].device
+ )
+
+ return torch.cat(tokens, dim=1), torch.cat(conds, dim=1)
+
+ def prior_postprocess(self, tokens):
+ """
+ Shifts back the input tokens if the model uses an encoder decoder architecture. As the embedding layer is
+ shared, `prior_embed_dim_shift` shifts the music token ids by `lyric_vocab_size`. Only returns the music
+ tokens.
+ """
+ batch_size = tokens.shape[0]
+ dims = (self.input_shapes[0], tokens.shape[1] - self.input_shapes[0])
+ tokens = list(torch.split(tokens, dims, dim=1))
+
+ # Some of the input tokens might be shifted to take into account the voccabulary fusion
+ for i in range(len(tokens)):
+ bins_shift = int(self.embed_dim_shift[i])
+ tokens[i] = (tokens[i] - bins_shift).view(batch_size, -1)
+ tokens[i] = torch.clamp(tokens[i], min=0)
+ # If not masking loss, model may have generated lyric/midi tokens which are now shifted <0 by bin_shift
+ return tokens[-1]
+
+ def embed_tokens(self, music_tokens_conds):
+ """
+ Embeds the upper level music tokens and upsamples them to provide as audio conditioning.
+ """
+ music_tokens_conds = music_tokens_conds[: self.cond_level + 1]
+ audio_conditioning = None
+ for music_tokens_cond, conditioner_block in reversed(list(zip(music_tokens_conds, [self.conditioner_blocks]))):
+ audio_conditioning = conditioner_block(music_tokens_cond, audio_conditioning)
+ return audio_conditioning
+
+ def encode(self, hidden_states, start_level=None, end_level=None, bs_chunks=1):
+ """
+ Encodes the hidden states (raw audio) using the VQVAE's encoder. Returns latent_states.
+ """
+ if start_level is None:
+ start_level = self.level
+ if end_level is None:
+ end_level = self.levels
+ # Get latents
+ with torch.no_grad():
+ latent_states = self.vqvae_encoder(
+ hidden_states, start_level=start_level, end_level=end_level, bs_chunks=bs_chunks
+ )
+ return latent_states
+
+ def decode(self, music_tokens, start_level=None, end_level=None, bs_chunks=1):
+ """
+ Usamples the sequence of codebook vectors to a raw audio.
+ """
+ if start_level is None:
+ start_level = self.level
+ if end_level is None:
+ end_level = self.levels
+ with torch.no_grad():
+ output = self.vqvae_decoder(
+ music_tokens, start_level=start_level, end_level=end_level, bs_chunks=bs_chunks
+ )
+ return output
+
+ def get_cond(self, music_tokens_conds, metadata):
+ """
+ Converts the input tokens to input_embeddings. Splits the lyrics form the rest of the metadata. Lyric tokens
+ can be None.
+ """
+ if metadata is not None:
+ n_labels = metadata.shape[1] - self.nb_relevant_lyric_tokens
+ metadata, lyric_tokens = metadata[:, :n_labels], metadata[:, n_labels:]
+ else:
+ metadata, lyric_tokens = None, None
+ metadata_conditioning, metadata_pos = (
+ self.metadata_embedding(metadata) if self.metadata_conditioning else (None, None)
+ )
+ audio_conditioning = self.embed_tokens(music_tokens_conds) if self.audio_conditioning else metadata_pos
+ return audio_conditioning, metadata_conditioning, lyric_tokens
+
+ def sample(
+ self,
+ n_samples,
+ music_tokens=None,
+ music_tokens_conds=None,
+ metadata=None,
+ temp=1.0,
+ top_k=0,
+ top_p=0.0,
+ chunk_size=None,
+ sample_tokens=None,
+ ):
+ """
+ Ancestral/Prime sampling a window of tokens using the provided conditioning and metadatas.
+
+ Args:
+ n_samples (`int`):
+ Number of samples to generate.
+ music_tokens (`List[torch.LongTensor]`, *optional*):
+ Previously gemerated tokens at the current level. Used as context for the generation.
+ music_tokens_conds (`List[torch.FloatTensor]`, *optional*):
+ Upper-level music tokens generated by the previous prior model. Is `None` if the generation is not
+ conditionned on the upper-level tokens.
+ metadata (`List[torch.LongTensor]`, *optional*):
+ List containing the metatdata tensor with the artist, genre and the lyric tokens.
+ temp (`float`, *optional*, defaults to 1.0):
+ Sampling temperature.
+ top_k (`int`, *optional*, defaults to 0):
+ Top k probabilities used for filtering.
+ top_p (`float`, *optional*, defaults to 0.0):
+ Top p probabilities used for filtering.
+ chunk_size (`int`, *optional*):
+ Size of the chunks used to prepare the cache of the transformer.
+ sample_tokens (`int`, *optional*):
+ Number of tokens to sample.
+
+ """
+ no_past_context = music_tokens is None or music_tokens.shape[1] == 0
+ name = {True: "Ancestral", False: "Primed"}[no_past_context]
+ logger.info(f"{name} sampling {n_samples} samples with temp={temp}, top_k={top_k}, top_p={top_p}")
+
+ with torch.no_grad():
+ # Currently audio_conditioning only uses immediately above layer
+ audio_conditioning, metadata_conditioning, lyric_tokens = self.get_cond(music_tokens_conds, metadata)
+ if self.is_encoder_decoder:
+ if no_past_context: # the prime_sample function will be used with music_tokens set to None
+ lyric_and_music_tokens, audio_conditioning = self.prior_preprocess(
+ [lyric_tokens], [None, audio_conditioning]
+ )
+ else:
+ lyric_and_music_tokens, audio_conditioning = self.prior_preprocess(
+ [lyric_tokens, music_tokens], [None, audio_conditioning]
+ )
+ if sample_tokens is not None:
+ sample_tokens += self.nb_relevant_lyric_tokens
+ music_tokens = self.prior.primed_sample(
+ n_samples,
+ lyric_and_music_tokens,
+ audio_conditioning,
+ metadata_conditioning,
+ temp=temp,
+ top_k=top_k,
+ top_p=top_p,
+ chunk_size=chunk_size,
+ sample_tokens=sample_tokens,
+ )
+ music_tokens = self.prior_postprocess(music_tokens)
+ else:
+ last_encoder_hidden_states = self.get_encoder_states(lyric_tokens, sample=True)
+ if no_past_context:
+ music_tokens = self.prior.sample(
+ n_samples,
+ audio_conditioning,
+ metadata_conditioning,
+ last_encoder_hidden_states,
+ temp=temp,
+ top_k=top_k,
+ top_p=top_p,
+ sample_tokens=sample_tokens,
+ )
+ else:
+ music_tokens = self.prior.primed_sample(
+ n_samples,
+ music_tokens,
+ audio_conditioning,
+ metadata_conditioning,
+ last_encoder_hidden_states,
+ temp=temp,
+ top_k=top_k,
+ top_p=top_p,
+ chunk_size=chunk_size,
+ sample_tokens=sample_tokens,
+ )
+ return music_tokens
+
+ def get_encoder_states(self, lyric_tokens, sample=False):
+ """
+ Retrieve the last hidden_states of the lyric encoder that will be attended to by the decoder. Forwards through
+ the lyric encoder.
+ """
+ if self.nb_relevant_lyric_tokens != 0 and self.lyric_conditioning:
+ if sample:
+ self.encoder = self.encoder.to(lyric_tokens.device)
+ lyric_acts = self.encoder(lyric_tokens, None, None, None)
+ lyric_acts = self.encoder.proj_in(lyric_acts)
+ last_encoder_hidden_states = self.encoder.final_layer_norm(lyric_acts)
+ else:
+ last_encoder_hidden_states = None
+ return last_encoder_hidden_states
+
+ def get_encoder_loss(self, last_encoder_hidden_states, target_lyrics):
+ """
+ Computes the loss for the lyric encoder: next lyric token prediction.
+ """
+ if self.lyric_conditioning:
+ last_encoder_hidden_states = self.encoder.lm_head(last_encoder_hidden_states)
+ encoder_loss = nn.functional.cross_entropy(
+ last_encoder_hidden_states.view(-1, self.encoder_dim), target_lyrics.view(-1)
+ ) / np.log(2.0)
+ else:
+ encoder_loss = torch.tensor(0.0, device=last_encoder_hidden_states.device)
+ return encoder_loss
+
+ def forward_tokens(
+ self, music_tokens, music_tokens_conds=[], metadata=None, get_preds=False, get_attn_weights=False
+ ):
+ """
+ Applies a forward pass using the conditioning tokens. Different from the classic forward as it does not use the
+ vqvae's encoding layers.
+ """
+ if get_attn_weights:
+ self.prior.transformer.set_record_attn(get_attn_weights)
+ audio_conditioning, metadata_conditioning, lyric_tokens = self.get_cond(music_tokens_conds, metadata)
+
+ if self.is_encoder_decoder: # the preprocess returns the full tokens (Lyrics and Music tokens), shifted
+ tokens, audio_conditioning = self.prior_preprocess(
+ [lyric_tokens, music_tokens], [None, audio_conditioning]
+ )
+ (encoder_loss, next_token_prediction_loss), preds = self.prior(
+ tokens, audio_conditioning, metadata_conditioning, get_sep_loss=True, get_preds=get_preds
+ )
+ else:
+ last_encoder_hidden_states = self.get_encoder_states(lyric_tokens)
+ encoder_loss = self.get_encoder_loss(last_encoder_hidden_states, lyric_tokens)
+ next_token_prediction_loss, preds = self.prior(
+ music_tokens,
+ audio_conditioning,
+ metadata_conditioning,
+ last_encoder_hidden_states,
+ get_preds=get_preds,
+ )
+ loss = self.encoder_loss_fraction * encoder_loss * self.nb_relevant_lyric_tokens / self.total_loss_dims
+ loss += next_token_prediction_loss * self.next_token_prediction_loss_dims / self.total_loss_dims
+
+ metrics = {
+ "bpd": next_token_prediction_loss.detach().clone(),
+ "encoder_loss": encoder_loss.detach().clone(),
+ "next_token_prediction_loss": next_token_prediction_loss.detach().clone(),
+ }
+ if get_preds:
+ metrics["preds"] = preds.detach().clone()
+ if get_attn_weights:
+ saved_attn_weights = self.prior.transformer.saved_attn_weights
+ self.prior.transformer.set_record_attn(False)
+ return saved_attn_weights
+ else:
+ return loss, metrics
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ metadata: Optional[List[torch.LongTensor]],
+ decode: Optional[bool] = False,
+ get_preds: Optional[bool] = False,
+ ) -> List[torch.Tensor]:
+ """
+ Encode the hidden states using the `vqvae` encoder, and then predicts the next token in the `forward_tokens`
+ function. The loss is the sum of the `encoder` loss and the `decoder` loss.
+
+ Args:
+ hidden_states (`torch.Tensor`):
+ Hidden states which should be raw audio
+ metadata (`List[torch.LongTensor]`, *optional*):
+ List containing the metadata conditioning tensorwith the lyric and the metadata tokens.
+ decode (`bool`, *optional*, defaults to `False`):
+ Whether or not to decode the encoded to tokens.
+ get_preds (`bool`, *optional*, defaults to `False`):
+ Whether or not to return the actual predicitons of the model.
+ """
+ batch_size = hidden_states.shape[0]
+ music_tokens, *music_tokens_conds = self.encode(hidden_states, bs_chunks=batch_size)
+ loss, metrics = self.forward_tokens(
+ music_tokens=music_tokens,
+ music_tokens_conds=music_tokens_conds,
+ metadata=metadata,
+ get_preds=get_preds,
+ )
+ if decode:
+ dequantised_states = self.decode([music_tokens, *music_tokens_conds])
+ else:
+ dequantised_states = None
+ return dequantised_states, loss, metrics
+
+
+class JukeboxPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = JukeboxConfig
+ base_model_prefix = "jukebox"
+ supports_gradient_checkpointing = False
+
+ def _init_weights(self, module):
+ if isinstance(module, JukeboxPrior) or isinstance(module, JukeboxVQVAE):
+ module.apply(module._init_weights)
+
+ def __init__(self, *inputs, **kwargs):
+ super().__init__(*inputs, **kwargs)
+
+
+JUKEBOX_SAMPLING_INPUT_DOCSTRING = r"""
+ labels (`List[torch.LongTensor]` of length `n_sample`, and shape `(self.levels, self.config.max_nb_genre + lyric_sequence_length)` :
+ List of metadata such as `artist_id`, `genre_id` and the full list of lyric tokens which are used to
+ condition the generation.
+ sampling_kwargs (`Dict[Any]`):
+ Various additional sampling arguments that are used by the `_sample` function. A detail list of the
+ arguments can bee seen in the [`_sample`] function documentation.
+"""
+
+
+@add_start_docstrings(
+ """The bare JUKEBOX Model used for music generation. 4 sampling techniques are supported : `primed_sample`, `upsample`,
+ `continue_sample` and `ancestral_sample`. It does not have a `forward` method as the training is not end to end. If
+ you want to fine-tune the model, it is recommended to use the `JukeboxPrior` class and train each prior
+ individually.
+ """,
+ JUKEBOX_START_DOCSTRING,
+)
+class JukeboxModel(JukeboxPreTrainedModel):
+ _no_split_modules = ["JukeboxBlock"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ vqvae_config = config.vqvae_config
+ self.vqvae = JukeboxVQVAE(vqvae_config)
+ self.set_shared_params(config)
+ self.priors = nn.ModuleList(
+ [JukeboxPrior(config.prior_configs[level], level) for level in range(config.nb_priors)]
+ )
+
+ def set_shared_params(self, model_config):
+ """
+ Initialises the parameters that are shared. This has to be done here because the list of `JukeboxPriorConfig`
+ is nest, and is thus unreachable in the `from_dict` function
+ """
+ for config in model_config.prior_configs:
+ config.sampling_rate = model_config.sampling_rate
+ config.timing_dims = model_config.timing_dims
+ config.min_duration = model_config.min_duration
+ config.max_duration = model_config.max_duration
+ config.max_nb_genres = model_config.max_nb_genres
+ config.metadata_conditioning = model_config.metadata_conditioning
+
+ def decode(self, music_tokens, start_level=0, end_level=None, bs_chunks=1):
+ return self.vqvae.decode(music_tokens, start_level, end_level, bs_chunks)
+
+ def encode(self, input_audio, start_level=0, end_level=None, bs_chunks=1):
+ return self.vqvae.encode(input_audio, start_level, end_level, bs_chunks)
+
+ def split_batch(self, obj, n_samples, split_size):
+ n_passes = (n_samples + split_size - 1) // split_size
+ if isinstance(obj, torch.Tensor):
+ return torch.split(obj, split_size, dim=0)
+ elif isinstance(obj, list):
+ return list(zip(*[torch.split(item, split_size, dim=0) for item in obj]))
+ elif obj is None:
+ return [None] * n_passes
+ else:
+ raise TypeError("Unknown input type")
+
+ # Sample a partial window of length= self.priors[level].n_ctx:
+ iterator = get_starts(total_length, self.priors[level].n_ctx, hop_length)
+ for start in iterator:
+ music_tokens = self.sample_single_window(
+ music_tokens, labels, offset, sampling_kwargs, level, start, max_batch_size
+ )
+
+ else:
+ music_tokens = self.sample_partial_window(
+ music_tokens, labels, offset, sampling_kwargs, level, total_length, max_batch_size
+ )
+ return music_tokens
+
+ @torch.no_grad()
+ def _sample(
+ self,
+ music_tokens,
+ labels,
+ sample_levels,
+ metas=None,
+ chunk_size=32,
+ sampling_temperature=0.98,
+ lower_batch_size=16,
+ max_batch_size=16,
+ sample_length_in_seconds=24,
+ compute_alignments=False,
+ sample_tokens=None,
+ offset=0,
+ save_results=True,
+ sample_length=None,
+ ) -> List[torch.LongTensor]:
+ """
+ Core sampling function used to generate music tokens. Iterates over the provided list of levels, while saving
+ the generated raw audio at each step.
+
+ Args:
+ music_tokens (`List[torch.LongTensor]`):
+ A sequence of music tokens of length `self.levels` which will be used as context to continue the
+ sampling process. Should have `self.levels` tensors, each corresponding to the generation at a certain
+ level.
+ labels (`List[torch.LongTensor]`):
+ List of length `n_sample`, and shape `(self.levels, 4 + self.config.max_nb_genre +
+ lyric_sequence_length)` metadata such as `artist_id`, `genre_id` and the full list of lyric tokens
+ which are used to condition the generation.
+ sample_levels (`List[int]`):
+ List of the desired levels at which the sampling will be done. A level is equivalent to the index of
+ the prior in the list of priors
+ metas (`List[Any]`, *optional*):
+ Metadatas used to generate the `labels`
+ chunk_size (`int`, *optional*, defaults to 32):
+ Size of a chunk of audio, used to fill up the memory in chuncks to prevent OOM erros. Bigger chunks
+ means faster memory filling but more consumption.
+ sampling_temperature (`float`, *optional*, defaults to 0.98):
+ Temperature used to ajust the randomness of the sampling.
+ lower_batch_size (`int`, *optional*, defaults to 16):
+ Maximum batch size for the lower level priors
+ max_batch_size (`int`, *optional*, defaults to 16):
+ Maximum batch size for the top level priors
+ sample_length_in_seconds (`int`, *optional*, defaults to 24):
+ Desired length of the generation in seconds
+ compute_alignments (`bool`, *optional*, defaults to `False`):
+ Whether or not to compute the alignment between the lyrics and the audio using the top_prior
+ sample_tokens (`int`, *optional*):
+ Precise number of tokens that should be sampled at each level. This is mostly useful for running dummy
+ experiments
+ offset (`int`, *optional*, defaults to 0):
+ Audio offset used as conditioning, corresponds to the starting sample in the music. If the offset is
+ greater than 0, the lyrics will be shifted take that intoaccount
+ save_results (`bool`, *optional*, defaults to `True`):
+ Whether or not to save the intermediate results. If `True`, will generate a folder named with the start
+ time.
+ sample_length (`int`, *optional*):
+ Desired length of the generation in samples.
+
+ Returns: torch.Tensor
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, JukeboxModel, set_seed
+ >>> import torch
+
+ >>> metas = dict(artist="Zac Brown Band", genres="Country", lyrics="I met a traveller from an antique land")
+ >>> tokenizer = AutoTokenizer.from_pretrained("openai/jukebox-1b-lyrics")
+ >>> model = JukeboxModel.from_pretrained("openai/jukebox-1b-lyrics", min_duration=0).eval()
+
+ >>> labels = tokenizer(**metas)["input_ids"]
+ >>> set_seed(0)
+ >>> zs = [torch.zeros(1, 0, dtype=torch.long) for _ in range(3)]
+ >>> zs = model._sample(zs, labels, [0], sample_length=40 * model.priors[0].raw_to_tokens, save_results=False)
+ >>> zs[0]
+ tensor([[1853, 1369, 1150, 1869, 1379, 1789, 519, 710, 1306, 1100, 1229, 519,
+ 353, 1306, 1379, 1053, 519, 653, 1631, 1467, 1229, 1229, 10, 1647,
+ 1254, 1229, 1306, 1528, 1789, 216, 1631, 1434, 653, 475, 1150, 1528,
+ 1804, 541, 1804, 1434]])
+ ```
+ """
+
+ top_prior = self.priors[0]
+ if sample_length is not None:
+ total_length = sample_length
+ else:
+ total_length = (
+ int(sample_length_in_seconds * self.config.sampling_rate) // top_prior.raw_to_tokens
+ ) * top_prior.raw_to_tokens
+
+ if sample_levels is None:
+ sample_levels = range(len(self.priors))
+
+ # total length of the signal, might be bit different from the actual generated length
+ self.total_length = total_length
+ for level in sample_levels:
+ sampling_kwargs = {
+ "temp": 0.99 if level == len(self.priors) - 1 else sampling_temperature,
+ "chunk_size": chunk_size,
+ "sample_tokens": sample_tokens,
+ }
+ # Set correct total_length, hop_length, labels and sampling_kwargs for level
+
+ total_token_to_sample = total_length // self.priors[level].raw_to_tokens
+ hop_length = int(self.config.hop_fraction[level] * self.priors[level].n_ctx)
+ max_batch_size = lower_batch_size if level != sample_levels else max_batch_size
+ music_tokens = self.sample_level(
+ music_tokens,
+ labels[level],
+ offset,
+ sampling_kwargs,
+ level,
+ total_token_to_sample,
+ hop_length,
+ max_batch_size,
+ )
+
+ if save_results:
+ self.vqvae.to(music_tokens[level].device)
+ # Decode sample
+ with torch.no_grad():
+ start_level = len(self.priors) - level - 1 # vqvae levels are reversed
+ raw_audio = self.vqvae.decode(
+ music_tokens[: level + 1], start_level=start_level, bs_chunks=music_tokens[level].shape[0]
+ )
+ logdir = f"jukebox/level_{level}"
+ if not os.path.exists(logdir):
+ os.makedirs(logdir)
+ save_temp_audio(logdir, level, metas=metas, aud=raw_audio.float())
+ if compute_alignments and self.priors[0] is not None and self.priors[0].nb_relevant_lyric_tokens > 0:
+ with torch.no_grad():
+ alignments = get_alignment(music_tokens, labels[0], self.priors[0], self.config)
+ torch.save({"alignments": alignments}, f"{logdir}/lyric_alignments.pt")
+
+ return music_tokens
+
+ @add_start_docstrings(
+ """
+ Generates music tokens based on the provided `labels. Will start at the desired prior level and automatically
+ upsample the sequence. If you want to create the audio, you should call `model.decode(tokens)`, which will use
+ the VQ-VAE decoder to convert the music tokens to raw audio.
+
+ Args:
+ labels (`List[torch.LongTensor]`) :
+ List of length `n_sample`, and shape `(self.levels, 4 + self.config.max_nb_genre +
+ lyric_sequence_length)` metadata such as `artist_id`, `genre_id` and the full list of lyric tokens
+ which are used to condition the generation.
+ n_samples (`int`, *optional*, default to 1) :
+ Number of samples to be generated in parallel.
+ """,
+ )
+ def ancestral_sample(self, labels, n_samples=1, **sampling_kwargs) -> List[torch.LongTensor]:
+ """
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, JukeboxModel, set_seed
+
+ >>> model = JukeboxModel.from_pretrained("openai/jukebox-1b-lyrics", min_duration=0).eval()
+ >>> tokenizer = AutoTokenizer.from_pretrained("openai/jukebox-1b-lyrics")
+
+ >>> lyrics = "Hey, are you awake? Can you talk to me?"
+ >>> artist = "Zac Brown Band"
+ >>> genre = "Country"
+ >>> metas = tokenizer(artist=artist, genres=genre, lyrics=lyrics)
+ >>> set_seed(0)
+ >>> music_tokens = model.ancestral_sample(metas.input_ids, sample_length=400)
+
+ >>> with torch.no_grad():
+ ... model.decode(music_tokens)[:, :10].squeeze(-1)
+ tensor([[-0.0219, -0.0679, -0.1050, -0.1203, -0.1271, -0.0936, -0.0396, -0.0405,
+ -0.0818, -0.0697]])
+ ```
+ """
+
+ sample_levels = sampling_kwargs.pop("sample_levels", list(range(len(self.priors))))
+ music_tokens = [
+ torch.zeros(n_samples, 0, dtype=torch.long, device=labels[0].device) for _ in range(len(self.priors))
+ ]
+ music_tokens = self._sample(music_tokens, labels, sample_levels, **sampling_kwargs)
+ return music_tokens
+
+ @add_start_docstrings(
+ """Generates a continuation of the previously generated tokens.
+
+ Args:
+ music_tokens (`List[torch.LongTensor]` of length `self.levels` ) :
+ A sequence of music tokens which will be used as context to continue the sampling process. Should have
+ `self.levels` tensors, each corresponding to the generation at a certain level.
+ """,
+ JUKEBOX_SAMPLING_INPUT_DOCSTRING,
+ )
+ def continue_sample(self, music_tokens, labels, **sampling_kwargs) -> List[torch.LongTensor]:
+ sample_levels = sampling_kwargs.pop("sample_levels", list(range(len(self.priors))))
+ music_tokens = self._sample(music_tokens, labels, sample_levels, **sampling_kwargs)
+ return music_tokens
+
+ @add_start_docstrings(
+ """Upsamples a sequence of music tokens using the prior at level `level`.
+
+ Args:
+ music_tokens (`List[torch.LongTensor]` of length `self.levels` ) :
+ A sequence of music tokens which will be used as context to continue the sampling process. Should have
+ `self.levels` tensors, each corresponding to the generation at a certain level.
+ """,
+ JUKEBOX_SAMPLING_INPUT_DOCSTRING,
+ )
+ def upsample(self, music_tokens, labels, **sampling_kwargs) -> List[torch.LongTensor]:
+ sample_levels = sampling_kwargs.pop("sample_levels", list(range(len(self.priors) - 1)))
+ music_tokens = self._sample(music_tokens, labels, sample_levels, **sampling_kwargs)
+ return music_tokens
+
+ @add_start_docstrings(
+ """Generate a raw audio conditioned on the provided `raw_audio` which is used as conditioning at each of the
+ generation levels. The audio is encoded to music tokens using the 3 levels of the VQ-VAE. These tokens are
+ used: as conditioning for each level, which means that no ancestral sampling is required.
+
+ Args:
+ raw_audio (`List[torch.Tensor]` of length `n_samples` ) :
+ A list of raw audio that will be used as conditioning information for each samples that will be
+ generated.
+ """,
+ JUKEBOX_SAMPLING_INPUT_DOCSTRING,
+ )
+ def primed_sample(self, raw_audio, labels, **sampling_kwargs) -> List[torch.LongTensor]:
+ sample_levels = sampling_kwargs.pop("sample_levels", list(range(len(self.priors))))
+ self.vqvae.to(raw_audio.device).float()
+ with torch.no_grad():
+ music_tokens = self.vqvae.encode(
+ raw_audio, start_level=0, end_level=len(self.priors), bs_chunks=raw_audio.shape[0]
+ )
+ music_tokens = self._sample(music_tokens, labels, sample_levels, **sampling_kwargs)
+ return music_tokens
+
+
+__all__ = ["JukeboxModel", "JukeboxPreTrainedModel", "JukeboxVQVAE", "JukeboxPrior"]
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/jukebox/tokenization_jukebox.py b/docs/transformers/build/lib/transformers/models/deprecated/jukebox/tokenization_jukebox.py
new file mode 100644
index 0000000000000000000000000000000000000000..e08ab179a807d9af7050e92838ced80c7d0d1742
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/jukebox/tokenization_jukebox.py
@@ -0,0 +1,407 @@
+# coding=utf-8
+# Copyright 2022 The Open AI Team Authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization classes for OpenAI Jukebox."""
+
+import json
+import os
+import re
+import unicodedata
+from json.encoder import INFINITY
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import regex
+
+from ....tokenization_utils import AddedToken, PreTrainedTokenizer
+from ....tokenization_utils_base import BatchEncoding
+from ....utils import TensorType, is_flax_available, is_tf_available, is_torch_available, logging
+from ....utils.generic import _is_jax, _is_numpy
+
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {
+ "artists_file": "artists.json",
+ "lyrics_file": "lyrics.json",
+ "genres_file": "genres.json",
+}
+
+
+class JukeboxTokenizer(PreTrainedTokenizer):
+ """
+ Constructs a Jukebox tokenizer. Jukebox can be conditioned on 3 different inputs :
+ - Artists, unique ids are associated to each artist from the provided dictionary.
+ - Genres, unique ids are associated to each genre from the provided dictionary.
+ - Lyrics, character based tokenization. Must be initialized with the list of characters that are inside the
+ vocabulary.
+
+ This tokenizer does not require training. It should be able to process a different number of inputs:
+ as the conditioning of the model can be done on the three different queries. If None is provided, defaults values will be used.:
+
+ Depending on the number of genres on which the model should be conditioned (`n_genres`).
+ ```python
+ >>> from transformers import JukeboxTokenizer
+
+ >>> tokenizer = JukeboxTokenizer.from_pretrained("openai/jukebox-1b-lyrics")
+ >>> tokenizer("Alan Jackson", "Country Rock", "old town road")["input_ids"]
+ [tensor([[ 0, 0, 0, 6785, 546, 41, 38, 30, 76, 46, 41, 49,
+ 40, 76, 44, 41, 27, 30]]), tensor([[ 0, 0, 0, 145, 0]]), tensor([[ 0, 0, 0, 145, 0]])]
+ ```
+
+ You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer or when you
+ call it on some text, but since the model was not pretrained this way, it might yield a decrease in performance.
+
+
+
+ If nothing is provided, the genres and the artist will either be selected randomly or set to None
+
+
+
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to:
+ this superclass for more information regarding those methods.
+
+ However the code does not allow that and only supports composing from various genres.
+
+ Args:
+ artists_file (`str`):
+ Path to the vocabulary file which contains a mapping between artists and ids. The default file supports
+ both "v2" and "v3"
+ genres_file (`str`):
+ Path to the vocabulary file which contain a mapping between genres and ids.
+ lyrics_file (`str`):
+ Path to the vocabulary file which contains the accepted characters for the lyrics tokenization.
+ version (`List[str]`, `optional`, default to `["v3", "v2", "v2"]`) :
+ List of the tokenizer versions. The `5b-lyrics`'s top level prior model was trained using `v3` instead of
+ `v2`.
+ n_genres (`int`, `optional`, defaults to 1):
+ Maximum number of genres to use for composition.
+ max_n_lyric_tokens (`int`, `optional`, defaults to 512):
+ Maximum number of lyric tokens to keep.
+ unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ model_input_names = ["input_ids", "attention_mask"]
+
+ def __init__(
+ self,
+ artists_file,
+ genres_file,
+ lyrics_file,
+ version=["v3", "v2", "v2"],
+ max_n_lyric_tokens=512,
+ n_genres=5,
+ unk_token="<|endoftext|>",
+ **kwargs,
+ ):
+ unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
+ self.version = version
+ self.max_n_lyric_tokens = max_n_lyric_tokens
+ self.n_genres = n_genres
+ self._added_tokens_decoder = {0: unk_token}
+
+ with open(artists_file, encoding="utf-8") as vocab_handle:
+ self.artists_encoder = json.load(vocab_handle)
+
+ with open(genres_file, encoding="utf-8") as vocab_handle:
+ self.genres_encoder = json.load(vocab_handle)
+
+ with open(lyrics_file, encoding="utf-8") as vocab_handle:
+ self.lyrics_encoder = json.load(vocab_handle)
+
+ oov = r"[^A-Za-z0-9.,:;!?\-'\"()\[\] \t\n]+"
+ # In v2, we had a n_vocab=80 and in v3 we missed + and so n_vocab=79 of characters.
+ if len(self.lyrics_encoder) == 79:
+ oov = oov.replace(r"\-'", r"\-+'")
+
+ self.out_of_vocab = regex.compile(oov)
+ self.artists_decoder = {v: k for k, v in self.artists_encoder.items()}
+ self.genres_decoder = {v: k for k, v in self.genres_encoder.items()}
+ self.lyrics_decoder = {v: k for k, v in self.lyrics_encoder.items()}
+ super().__init__(
+ unk_token=unk_token,
+ n_genres=n_genres,
+ version=version,
+ max_n_lyric_tokens=max_n_lyric_tokens,
+ **kwargs,
+ )
+
+ @property
+ def vocab_size(self):
+ return len(self.artists_encoder) + len(self.genres_encoder) + len(self.lyrics_encoder)
+
+ def get_vocab(self):
+ return {
+ "artists_encoder": self.artists_encoder,
+ "genres_encoder": self.genres_encoder,
+ "lyrics_encoder": self.lyrics_encoder,
+ }
+
+ def _convert_token_to_id(self, list_artists, list_genres, list_lyrics):
+ """Converts the artist, genre and lyrics tokens to their index using the vocabulary.
+ The total_length, offset and duration have to be provided in order to select relevant lyrics and add padding to
+ the lyrics token sequence.
+ """
+ artists_id = [self.artists_encoder.get(artist, 0) for artist in list_artists]
+ for genres in range(len(list_genres)):
+ list_genres[genres] = [self.genres_encoder.get(genre, 0) for genre in list_genres[genres]]
+ list_genres[genres] = list_genres[genres] + [-1] * (self.n_genres - len(list_genres[genres]))
+
+ lyric_ids = [[self.lyrics_encoder.get(character, 0) for character in list_lyrics[0]], [], []]
+ return artists_id, list_genres, lyric_ids
+
+ def _tokenize(self, lyrics):
+ """
+ Converts a string into a sequence of tokens (string), using the tokenizer. Split in words for word-based
+ vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces).
+
+ Do NOT take care of added tokens. Only the lyrics are split into character for the character-based vocabulary.
+ """
+ # only lyrics are not tokenized, but character based is easily handled
+ return list(lyrics)
+
+ def tokenize(self, artist, genre, lyrics, **kwargs):
+ """
+ Converts three strings in a 3 sequence of tokens using the tokenizer
+ """
+ artist, genre, lyrics = self.prepare_for_tokenization(artist, genre, lyrics)
+ lyrics = self._tokenize(lyrics)
+ return artist, genre, lyrics
+
+ def prepare_for_tokenization(
+ self, artists: str, genres: str, lyrics: str, is_split_into_words: bool = False
+ ) -> Tuple[str, str, str, Dict[str, Any]]:
+ """
+ Performs any necessary transformations before tokenization.
+
+ Args:
+ artist (`str`):
+ The artist name to prepare. This will mostly lower the string
+ genres (`str`):
+ The genre name to prepare. This will mostly lower the string.
+ lyrics (`str`):
+ The lyrics to prepare.
+ is_split_into_words (`bool`, *optional*, defaults to `False`):
+ Whether or not the input is already pre-tokenized (e.g., split into words). If set to `True`, the
+ tokenizer assumes the input is already split into words (for instance, by splitting it on whitespace)
+ which it will tokenize. This is useful for NER or token classification.
+ """
+ for idx in range(len(self.version)):
+ if self.version[idx] == "v3":
+ artists[idx] = artists[idx].lower()
+ genres[idx] = [genres[idx].lower()]
+ else:
+ artists[idx] = self._normalize(artists[idx]) + ".v2"
+ genres[idx] = [
+ self._normalize(genre) + ".v2" for genre in genres[idx].split("_")
+ ] # split is for the full dictionary with combined genres
+
+ if self.version[0] == "v2":
+ self.out_of_vocab = regex.compile(r"[^A-Za-z0-9.,:;!?\-'\"()\[\] \t\n]+")
+ vocab = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789.,:;!?-+'\"()[] \t\n"
+ self.vocab = {vocab[index]: index + 1 for index in range(len(vocab))}
+ self.vocab[""] = 0
+ self.n_vocab = len(vocab) + 1
+ self.lyrics_encoder = self.vocab
+ self.lyrics_decoder = {v: k for k, v in self.vocab.items()}
+ self.lyrics_decoder[0] = ""
+ else:
+ self.out_of_vocab = regex.compile(r"[^A-Za-z0-9.,:;!?\-+'\"()\[\] \t\n]+")
+
+ lyrics = self._run_strip_accents(lyrics)
+ lyrics = lyrics.replace("\\", "\n")
+ lyrics = self.out_of_vocab.sub("", lyrics), [], []
+ return artists, genres, lyrics
+
+ def _run_strip_accents(self, text):
+ """Strips accents from a piece of text."""
+ text = unicodedata.normalize("NFD", text)
+ output = []
+ for char in text:
+ cat = unicodedata.category(char)
+ if cat == "Mn":
+ continue
+ output.append(char)
+ return "".join(output)
+
+ def _normalize(self, text: str) -> str:
+ """
+ Normalizes the input text. This process is for the genres and the artist
+
+ Args:
+ text (`str`):
+ Artist or Genre string to normalize
+ """
+
+ accepted = (
+ [chr(i) for i in range(ord("a"), ord("z") + 1)]
+ + [chr(i) for i in range(ord("A"), ord("Z") + 1)]
+ + [chr(i) for i in range(ord("0"), ord("9") + 1)]
+ + ["."]
+ )
+ accepted = frozenset(accepted)
+ pattern = re.compile(r"_+")
+ text = "".join([c if c in accepted else "_" for c in text.lower()])
+ text = pattern.sub("_", text).strip("_")
+ return text
+
+ def convert_lyric_tokens_to_string(self, lyrics: List[str]) -> str:
+ return " ".join(lyrics)
+
+ def convert_to_tensors(
+ self, inputs, tensor_type: Optional[Union[str, TensorType]] = None, prepend_batch_axis: bool = False
+ ):
+ """
+ Convert the inner content to tensors.
+
+ Args:
+ tensor_type (`str` or [`~utils.TensorType`], *optional*):
+ The type of tensors to use. If `str`, should be one of the values of the enum [`~utils.TensorType`]. If
+ unset, no modification is done.
+ prepend_batch_axis (`int`, *optional*, defaults to `False`):
+ Whether or not to add the batch dimension during the conversion.
+ """
+ # Convert to TensorType
+ if not isinstance(tensor_type, TensorType):
+ tensor_type = TensorType(tensor_type)
+
+ # Get a function reference for the correct framework
+ if tensor_type == TensorType.TENSORFLOW:
+ if not is_tf_available():
+ raise ImportError(
+ "Unable to convert output to TensorFlow tensors format, TensorFlow is not installed."
+ )
+ import tensorflow as tf
+
+ as_tensor = tf.constant
+ is_tensor = tf.is_tensor
+ elif tensor_type == TensorType.PYTORCH:
+ if not is_torch_available():
+ raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.")
+ import torch
+
+ as_tensor = torch.tensor
+ is_tensor = torch.is_tensor
+ elif tensor_type == TensorType.JAX:
+ if not is_flax_available():
+ raise ImportError("Unable to convert output to JAX tensors format, JAX is not installed.")
+ import jax.numpy as jnp # noqa: F811
+
+ as_tensor = jnp.array
+ is_tensor = _is_jax
+ else:
+ as_tensor = np.asarray
+ is_tensor = _is_numpy
+
+ # Do the tensor conversion in batch
+
+ try:
+ if prepend_batch_axis:
+ inputs = [inputs]
+
+ if not is_tensor(inputs):
+ inputs = as_tensor(inputs)
+ except: # noqa E722
+ raise ValueError(
+ "Unable to create tensor, you should probably activate truncation and/or padding "
+ "with 'padding=True' 'truncation=True' to have batched tensors with the same length."
+ )
+
+ return inputs
+
+ def __call__(self, artist, genres, lyrics="", return_tensors="pt") -> BatchEncoding:
+ """Convert the raw string to a list of token ids
+
+ Args:
+ artist (`str`):
+ Name of the artist.
+ genres (`str`):
+ List of genres that will be mixed to condition the audio
+ lyrics (`str`, *optional*, defaults to `""`):
+ Lyrics used to condition the generation
+ """
+ input_ids = [0, 0, 0]
+ artist = [artist] * len(self.version)
+ genres = [genres] * len(self.version)
+
+ artists_tokens, genres_tokens, lyrics_tokens = self.tokenize(artist, genres, lyrics)
+ artists_id, genres_ids, full_tokens = self._convert_token_to_id(artists_tokens, genres_tokens, lyrics_tokens)
+
+ attention_masks = [-INFINITY] * len(full_tokens[-1])
+ input_ids = [
+ self.convert_to_tensors(
+ [input_ids + [artists_id[i]] + genres_ids[i] + full_tokens[i]], tensor_type=return_tensors
+ )
+ for i in range(len(self.version))
+ ]
+ return BatchEncoding({"input_ids": input_ids, "attention_masks": attention_masks})
+
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
+ """
+ Saves the tokenizer's vocabulary dictionary to the provided save_directory.
+
+ Args:
+ save_directory (`str`):
+ A path to the directory where to saved. It will be created if it doesn't exist.
+
+ filename_prefix (`Optional[str]`, *optional*):
+ A prefix to add to the names of the files saved by the tokenizer.
+
+ """
+ if not os.path.isdir(save_directory):
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
+ return
+
+ artists_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["artists_file"]
+ )
+ with open(artists_file, "w", encoding="utf-8") as f:
+ f.write(json.dumps(self.artists_encoder, ensure_ascii=False))
+
+ genres_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["genres_file"]
+ )
+ with open(genres_file, "w", encoding="utf-8") as f:
+ f.write(json.dumps(self.genres_encoder, ensure_ascii=False))
+
+ lyrics_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["lyrics_file"]
+ )
+ with open(lyrics_file, "w", encoding="utf-8") as f:
+ f.write(json.dumps(self.lyrics_encoder, ensure_ascii=False))
+
+ return (artists_file, genres_file, lyrics_file)
+
+ def _convert_id_to_token(self, artists_index, genres_index, lyric_index):
+ """
+ Converts an index (integer) in a token (str) using the vocab.
+
+ Args:
+ artists_index (`int`):
+ Index of the artist in its corresponding dictionary.
+ genres_index (`Union[List[int], int]`):
+ Index of the genre in its corresponding dictionary.
+ lyric_index (`List[int]`):
+ List of character indices, which each correspond to a character.
+ """
+ artist = self.artists_decoder.get(artists_index)
+ genres = [self.genres_decoder.get(genre) for genre in genres_index]
+ lyrics = [self.lyrics_decoder.get(character) for character in lyric_index]
+ return artist, genres, lyrics
+
+
+__all__ = ["JukeboxTokenizer"]
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/mctct/__init__.py b/docs/transformers/build/lib/transformers/models/deprecated/mctct/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..53ec5ed37c13614266d04cde838ff7360946a451
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/mctct/__init__.py
@@ -0,0 +1,29 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ....utils import _LazyModule
+from ....utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_mctct import *
+ from .feature_extraction_mctct import *
+ from .modeling_mctct import *
+ from .processing_mctct import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/mctct/configuration_mctct.py b/docs/transformers/build/lib/transformers/models/deprecated/mctct/configuration_mctct.py
new file mode 100644
index 0000000000000000000000000000000000000000..9cba190a0f460e3fed5a3ebbc773e9ab31283c1a
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/mctct/configuration_mctct.py
@@ -0,0 +1,184 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""M-CTC-T model configuration"""
+
+from ....configuration_utils import PretrainedConfig
+from ....utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class MCTCTConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`MCTCTModel`]. It is used to instantiate an
+ M-CTC-T model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the M-CTC-T
+ [speechbrain/m-ctc-t-large](https://huggingface.co/speechbrain/m-ctc-t-large) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 8065):
+ Vocabulary size of the M-CTC-T model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`MCTCTModel`].
+ hidden_size (`int`, *optional*, defaults to 1536):
+ Dimension of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 36):
+ Number of hidden layers in the Transformer encoder.
+ intermediate_size (`int`, *optional*, defaults to 6144):
+ Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 4):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ attention_head_dim (`int`, *optional*, defaults to 384):
+ Dimensions of each attention head for each attention layer in the Transformer encoder.
+ max_position_embeddings (`int`, *optional*, defaults to 920):
+ The maximum sequence length that this model might ever be used with (after log-mel spectrogram extraction).
+ layer_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the layer normalization layers.
+ layerdrop (`float`, *optional*, defaults to 0.3):
+ The probability of dropping an encoder layer during training. The default 0.3 value is used in the original
+ implementation.
+ hidden_act (`str` or `function`, *optional*, defaults to `"relu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.3):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.3):
+ The dropout ratio for the attention probabilities.
+ pad_token_id (`int`, *optional*, defaults to 1):
+ The tokenizer index of the pad token.
+ bos_token_id (`int`, *optional*, defaults to 0):
+ The tokenizer index of the bos token.
+ eos_token_id (`int`, *optional*, defaults to 2):
+ The tokenizer index of the eos token.
+ conv_glu_dim (`int`, *optional*, defaults to 1):
+ The dimension of the output of the `Conv1dSubsampler` layer in which GLU is applied on. Though the original
+ Flashlight code uses the value of 2, here it's adapted to 1 due to transposition differences.
+ conv_dropout (`int`, *optional*, defaults to 0.3):
+ The probability of randomly dropping the `Conv1dSubsampler` layer during training.
+ num_conv_layers (`int`, *optional*, defaults to 1):
+ Number of convolution layers before applying transformer encoder layers.
+ conv_kernel (`Sequence[int]`, *optional*, defaults to `(7,)`):
+ The kernel size of the 1D convolution applied before transformer layers. `len(conv_kernel)` must be equal
+ to `num_conv_layers`.
+ conv_stride (`Sequence[int]`, *optional*, defaults to `(3,)`):
+ The stride length of the 1D convolution applied before transformer layers. `len(conv_stride)` must be equal
+ to `num_conv_layers`.
+ input_feat_per_channel (`int`, *optional*, defaults to 80):
+ Feature dimensions of the channels of the input to the Conv1D layer.
+ input_channels (`int`, *optional*, defaults to 1):
+ Number of input channels of the input to the Conv1D layer.
+ conv_channels (`List[int]`, *optional*):
+ Channel sizes of intermediate Conv1D layers.
+ ctc_loss_reduction (`str`, *optional*, defaults to `"sum"`):
+ Specifies the reduction to apply to the output of `torch.nn.CTCLoss`. Only relevant when training an
+ instance of [`MCTCTForCTC`].
+ ctc_zero_infinity (`bool`, *optional*, defaults to `False`):
+ Whether to zero infinite losses and the associated gradients of `torch.nn.CTCLoss`. Infinite losses mainly
+ occur when the inputs are too short to be aligned to the targets. Only relevant when training an instance
+ of [`MCTCTForCTC`].
+
+ Example:
+
+ ```python
+ >>> from transformers import MCTCTConfig, MCTCTModel
+
+ >>> # Initializing a M-CTC-T mctct-large style configuration
+ >>> configuration = MCTCTConfig()
+
+ >>> # Initializing a model (with random weights) from the mctct-large style configuration
+ >>> model = MCTCTModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "mctct"
+
+ def __init__(
+ self,
+ vocab_size=8065,
+ hidden_size=1536,
+ num_hidden_layers=36,
+ intermediate_size=6144,
+ num_attention_heads=4,
+ attention_head_dim=384,
+ max_position_embeddings=920,
+ layer_norm_eps=1e-5,
+ layerdrop=0.3,
+ hidden_act="relu",
+ initializer_range=0.02,
+ hidden_dropout_prob=0.3,
+ attention_probs_dropout_prob=0.3,
+ pad_token_id=1,
+ bos_token_id=0,
+ eos_token_id=2,
+ conv_glu_dim=1,
+ conv_dropout=0.3,
+ num_conv_layers=1,
+ conv_kernel=(7,),
+ conv_stride=(3,),
+ input_feat_per_channel=80,
+ input_channels=1,
+ conv_channels=None,
+ ctc_loss_reduction="sum",
+ ctc_zero_infinity=False,
+ **kwargs,
+ ):
+ super().__init__(**kwargs, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id)
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.intermediate_size = intermediate_size
+ self.num_attention_heads = num_attention_heads
+ self.attention_head_dim = attention_head_dim
+ self.max_position_embeddings = max_position_embeddings
+ self.layer_norm_eps = layer_norm_eps
+ self.layerdrop = layerdrop
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.pad_token_id = pad_token_id
+ self.bos_token_id = bos_token_id
+ self.eos_token_id = eos_token_id
+ self.conv_glu_dim = conv_glu_dim
+ self.conv_dropout = conv_dropout
+ self.num_conv_layers = num_conv_layers
+ self.input_feat_per_channel = input_feat_per_channel
+ self.input_channels = input_channels
+ self.conv_channels = conv_channels
+ self.ctc_loss_reduction = ctc_loss_reduction
+ self.ctc_zero_infinity = ctc_zero_infinity
+
+ # prevents config testing fail with exporting to json
+ self.conv_kernel = list(conv_kernel)
+ self.conv_stride = list(conv_stride)
+
+ if len(self.conv_kernel) != self.num_conv_layers:
+ raise ValueError(
+ "Configuration for convolutional module is incorrect. "
+ "It is required that `len(config.conv_kernel)` == `config.num_conv_layers` "
+ f"but is `len(config.conv_kernel) = {len(self.conv_kernel)}`, "
+ f"`config.num_conv_layers = {self.num_conv_layers}`."
+ )
+
+
+__all__ = ["MCTCTConfig"]
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/mctct/feature_extraction_mctct.py b/docs/transformers/build/lib/transformers/models/deprecated/mctct/feature_extraction_mctct.py
new file mode 100644
index 0000000000000000000000000000000000000000..f210031f3097e084c288fa58773d4271301e7c29
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/mctct/feature_extraction_mctct.py
@@ -0,0 +1,291 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Feature extractor class for M-CTC-T
+"""
+
+from typing import List, Optional, Union
+
+import numpy as np
+
+from ....audio_utils import mel_filter_bank, optimal_fft_length, spectrogram, window_function
+from ....feature_extraction_sequence_utils import SequenceFeatureExtractor
+from ....feature_extraction_utils import BatchFeature
+from ....file_utils import PaddingStrategy, TensorType
+from ....utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class MCTCTFeatureExtractor(SequenceFeatureExtractor):
+ r"""
+ Constructs a M-CTC-T feature extractor.
+
+ This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains
+ most of the main methods. Users should refer to this superclass for more information regarding those methods. This
+ code has been adapted from Flashlight's C++ code. For more information about the implementation, one can refer to
+ this [notebook](https://colab.research.google.com/drive/1GLtINkkhzms-IsdcGy_-tVCkv0qNF-Gt#scrollTo=pMCRGMmUC_an)
+ that takes the user step-by-step in the implementation.
+
+ Args:
+ feature_size (`int`, defaults to 80):
+ The feature dimension of the extracted features. This is the number of mel_frequency
+ sampling_rate (`int`, defaults to 16000):
+ The sampling rate at which the audio files should be digitalized expressed in hertz (Hz).
+ padding_value (`float`, defaults to 0.0):
+ The value that is used to fill the padding values.
+ hop_length (`int`, defaults to 10):
+ Number of audio samples between windows. Otherwise referred to as "shift" in many papers.
+ win_length (`int`, defaults to 25):
+ Number of ms per window
+ win_function (`str`, defaults to `"hamming_window"`):
+ Name for the window function used for windowing, must be accessible via `torch.{win_function}`
+ frame_signal_scale (`float`, defaults to 32768.0):
+ Constant multiplied in creating the frames before applying DFT.
+ preemphasis_coeff (`float`, defaults to 0.97):
+ Constant multiplied in applying Pre-emphasis before DFT.
+ mel_floor (`float` defaults to 1.0):
+ Minimum value of mel frequency banks.
+ normalize_means (`bool`, *optional*, defaults to `True`):
+ Whether or not to zero-mean normalize the extracted features.
+ normalize_vars (`bool`, *optional*, defaults to `True`):
+ Whether or not to unit-variance normalize the extracted features.
+ """
+
+ model_input_names = ["input_features", "attention_mask"]
+
+ def __init__(
+ self,
+ feature_size=80,
+ sampling_rate=16000,
+ padding_value=0.0,
+ hop_length=10,
+ win_length=25,
+ win_function="hamming_window",
+ frame_signal_scale=32768.0,
+ preemphasis_coeff=0.97,
+ mel_floor=1.0,
+ normalize_means=True,
+ normalize_vars=True,
+ return_attention_mask=False,
+ **kwargs,
+ ):
+ super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
+
+ self.feature_size = feature_size
+ self.sampling_rate = sampling_rate
+ self.padding_value = padding_value
+ self.hop_length = hop_length
+ self.win_length = win_length
+ self.frame_signal_scale = frame_signal_scale
+ self.preemphasis_coeff = preemphasis_coeff
+ self.mel_floor = mel_floor
+ self.normalize_means = normalize_means
+ self.normalize_vars = normalize_vars
+ self.win_function = win_function
+ self.return_attention_mask = return_attention_mask
+
+ self.sample_size = win_length * sampling_rate // 1000
+ self.sample_stride = hop_length * sampling_rate // 1000
+
+ self.n_fft = optimal_fft_length(self.sample_size)
+ self.n_freqs = (self.n_fft // 2) + 1
+
+ def _extract_mfsc_features(self, one_waveform: np.array) -> np.ndarray:
+ """
+ Extracts MFSC Features for one waveform vector (unbatched). Adapted from Flashlight's C++ MFSC code.
+ """
+ if self.win_function == "hamming_window":
+ window = window_function(window_length=self.sample_size, name=self.win_function, periodic=False)
+ else:
+ window = window_function(window_length=self.sample_size, name=self.win_function)
+
+ fbanks = mel_filter_bank(
+ num_frequency_bins=self.n_freqs,
+ num_mel_filters=self.feature_size,
+ min_frequency=0.0,
+ max_frequency=self.sampling_rate / 2.0,
+ sampling_rate=self.sampling_rate,
+ )
+
+ msfc_features = spectrogram(
+ one_waveform * self.frame_signal_scale,
+ window=window,
+ frame_length=self.sample_size,
+ hop_length=self.sample_stride,
+ fft_length=self.n_fft,
+ center=False,
+ preemphasis=self.preemphasis_coeff,
+ mel_filters=fbanks,
+ mel_floor=self.mel_floor,
+ log_mel="log",
+ )
+ return msfc_features.T
+
+ def _normalize_one(self, x, input_length, padding_value):
+ # make sure we normalize float32 arrays
+ if self.normalize_means:
+ mean = x[:input_length].mean(axis=0)
+ x = np.subtract(x, mean)
+ if self.normalize_vars:
+ std = x[:input_length].std(axis=0)
+ x = np.divide(x, std)
+
+ if input_length < x.shape[0]:
+ x[input_length:] = padding_value
+
+ # make sure array is in float32
+ x = x.astype(np.float32)
+
+ return x
+
+ def normalize(
+ self, input_features: List[np.ndarray], attention_mask: Optional[np.ndarray] = None
+ ) -> List[np.ndarray]:
+ lengths = attention_mask.sum(-1) if attention_mask is not None else [x.shape[0] for x in input_features]
+ return [self._normalize_one(x, n, self.padding_value) for x, n in zip(input_features, lengths)]
+
+ def __call__(
+ self,
+ raw_speech: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]],
+ padding: Union[bool, str, PaddingStrategy] = False,
+ max_length: Optional[int] = None,
+ truncation: bool = False,
+ pad_to_multiple_of: Optional[int] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ sampling_rate: Optional[int] = None,
+ **kwargs,
+ ) -> BatchFeature:
+ """
+ Main method to featurize and prepare for the model one or several sequence(s). sequences. It returns the
+ log-mel spectrogram of the input audio, as implemented in the original Flashlight MFSC feature extraction code.
+
+ Args:
+ raw_speech (`torch.Tensor`, `np.ndarray`, `List[float]`, `List[torch.Tensor]`, `List[np.ndarray]`, `List[List[float]]`):
+ The sequence or batch of sequences to be padded. Each sequence can be a tensor, a numpy array, a list
+ of float values, a list of tensors, a list of numpy arrays or a list of list of float values. Must be
+ mono channel audio, not stereo, i.e. single float per timestep.
+ padding (`bool`, `str` or [`~file_utils.PaddingStrategy`], *optional*, defaults to `False`):
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding
+ index) among:
+
+ - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
+ sequence if provided).
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
+ acceptable input length for the model if that argument is not provided.
+ - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
+ lengths).
+ max_length (`int`, *optional*):
+ Maximum length of the returned list and optionally padding length (see above).
+ truncation (`bool`):
+ Activates truncation to cut input sequences longer than *max_length* to *max_length*.
+ pad_to_multiple_of (`int`, *optional*):
+ If set will pad the sequence to a multiple of the provided value.
+
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
+ `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128.
+ return_attention_mask (`bool`, *optional*):
+ Whether to return the attention mask. If left to the default, will return the attention mask according
+ to the specific feature_extractor's default.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+ return_tensors (`str` or [`~file_utils.TensorType`], *optional*):
+ If set, will return tensors instead of list of python integers. Acceptable values are:
+
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return Numpy `np.ndarray` objects.
+ sampling_rate (`int`, *optional*):
+ The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass
+ `sampling_rate` at the forward call to prevent silent errors.
+ padding_value (`float`, defaults to 0.0):
+ """
+
+ if sampling_rate is not None:
+ if sampling_rate != self.sampling_rate:
+ raise ValueError(
+ f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of"
+ f" {self.sampling_rate}. Please make sure that the provided `raw_speech` input was sampled with"
+ f" {self.sampling_rate} and not {sampling_rate}."
+ )
+ else:
+ logger.warning(
+ "It is strongly recommended to pass the ``sampling_rate`` argument to this function. "
+ "Failing to do so can result in silent errors that might be hard to debug."
+ )
+
+ is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1
+ if is_batched_numpy and len(raw_speech.shape) > 2:
+ raise ValueError(f"Only mono-channel audio is supported for input to {self}")
+ is_batched = is_batched_numpy or (
+ isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list)))
+ )
+
+ if is_batched:
+ raw_speech = [np.asarray(speech, dtype=np.float32) for speech in raw_speech]
+ elif not is_batched and not isinstance(raw_speech, np.ndarray):
+ raw_speech = np.asarray(raw_speech, dtype=np.float32)
+ elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(np.float64):
+ raw_speech = raw_speech.astype(np.float32)
+
+ # always return batch
+ if not is_batched:
+ raw_speech = [raw_speech]
+
+ # extract fbank features
+ features = [self._extract_mfsc_features(one_waveform) for one_waveform in raw_speech]
+
+ # convert into correct format for padding
+ encoded_inputs = BatchFeature({"input_features": features})
+
+ padded_inputs = self.pad(
+ encoded_inputs,
+ padding=padding,
+ max_length=max_length,
+ truncation=truncation,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_attention_mask=True,
+ **kwargs,
+ )
+ # make sure list is in array format
+ input_features = padded_inputs.get("input_features")
+ if isinstance(input_features[0], list):
+ padded_inputs["input_features"] = [np.asarray(feature, dtype=np.float32) for feature in input_features]
+
+ attention_mask = padded_inputs.get("attention_mask")
+ if attention_mask is not None:
+ padded_inputs["attention_mask"] = [np.asarray(array, dtype=np.int32) for array in attention_mask]
+
+ if self.normalize_means or self.normalize_vars:
+ attention_mask = (
+ np.array(attention_mask, dtype=np.int32)
+ if self._get_padding_strategies(padding, max_length=max_length) is not PaddingStrategy.DO_NOT_PAD
+ and padding
+ else None
+ )
+ padded_inputs["input_features"] = self.normalize(
+ padded_inputs["input_features"], attention_mask=attention_mask
+ )
+
+ if return_tensors is not None:
+ padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
+
+ return padded_inputs
+
+
+__all__ = ["MCTCTFeatureExtractor"]
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/mctct/modeling_mctct.py b/docs/transformers/build/lib/transformers/models/deprecated/mctct/modeling_mctct.py
new file mode 100644
index 0000000000000000000000000000000000000000..2dd074b28c53d6d501ec680013afc517118bc227
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/mctct/modeling_mctct.py
@@ -0,0 +1,791 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch M-CTC-T model."""
+
+import math
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+
+from ....activations import ACT2FN
+from ....file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
+from ....integrations.deepspeed import is_deepspeed_zero3_enabled
+from ....integrations.fsdp import is_fsdp_managed_module
+from ....modeling_attn_mask_utils import _prepare_4d_attention_mask
+from ....modeling_outputs import BaseModelOutput, CausalLMOutput
+from ....modeling_utils import (
+ PreTrainedModel,
+ apply_chunking_to_forward,
+ find_pruneable_heads_and_indices,
+ prune_linear_layer,
+)
+from ....utils import logging
+from .configuration_mctct import MCTCTConfig
+
+
+logger = logging.get_logger(__name__)
+
+_HIDDEN_STATES_START_POSITION = 1
+
+_CONFIG_FOR_DOC = "MCTCTConfig"
+
+# Base docstring
+_CHECKPOINT_FOR_DOC = "speechbrain/m-ctc-t-large"
+_EXPECTED_OUTPUT_SHAPE = [1, 195, 1536]
+
+# CTC docstring
+_CTC_EXPECTED_OUTPUT = '"Mr. Quilter is the apostle of the middle classes, and we\'re glad to welcome his gospel."'
+_CTC_EXPECTED_LOSS = 1885.65
+
+
+class MCTCTConv1dSubsampler(nn.Module):
+ """
+ Convolutional subsampler: a stack of 1D convolution (along temporal dimension) followed by non-linear activation
+ via gated linear units (https://arxiv.org/abs/1911.08460)
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.glu_dim = config.conv_glu_dim
+
+ self.dropout = nn.Dropout(config.conv_dropout)
+
+ self.num_layers = config.num_conv_layers
+ self.in_channels = config.input_feat_per_channel * config.input_channels
+
+ if self.num_layers > 1:
+ if config.conv_channels is None:
+ raise ValueError(
+ "Need to specify `conv_channels` configuration in `MCTCTConfig` to use multiple convolution"
+ " layers."
+ )
+
+ self.mid_channels = config.conv_channels
+ else:
+ self.mid_channels = None
+
+ self.out_channels = config.hidden_size * 2 # considering GLU halving
+ self.kernel_size = config.conv_kernel
+ self.stride = config.conv_stride
+
+ # NOTE: MCTCT by construction only uses one convolution kernel. I've made this flexible to allow for
+ # multiple layers of convolutions, but not sure if this model definition should just restrict it
+ # to one layer. This becomes especially relevant when considering the padding like line 1 of forward().
+ self.conv_layers = nn.ModuleList(
+ nn.Conv1d(
+ self.in_channels if i == 0 else self.mid_channels[i],
+ self.mid_channels[i] if i < self.num_layers - 1 else self.out_channels,
+ kernel_size=k,
+ stride=self.stride[i],
+ padding="valid",
+ )
+ for i, k in enumerate(self.kernel_size)
+ )
+
+ def forward(self, input_features):
+ # NOTE: in reference to the NOTE in __init__, right now it just calculates padding as if
+ # there will be just one conv layer.
+ padding = sum([size // 2 for size in self.kernel_size]) # (7, 7) -> (3, 3)
+
+ input_features = torch.nn.functional.pad(input_features, (0, 0, padding, padding), "constant", 0)
+ hidden_states = input_features.transpose(1, 2).contiguous() # -> Batch x Frame x Time
+ for conv in self.conv_layers:
+ hidden_states = conv(hidden_states)
+ hidden_states = nn.functional.glu(hidden_states, dim=self.glu_dim)
+ hidden_states = self.dropout(hidden_states)
+
+ hidden_states = hidden_states.transpose(1, 2).contiguous() # -> Batch x Time x Frame
+ return hidden_states
+
+
+class MCTCTEmbeddings(nn.Module):
+ """Construct the embeddings from word, position and token_type embeddings."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
+
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
+ # any TensorFlow checkpoint file
+ # self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.LayerNorm = MCTCTLayerNorm()
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+ self.register_buffer(
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
+ )
+ self.register_buffer(
+ "token_type_ids",
+ torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
+ persistent=False,
+ )
+
+ def forward(
+ self, input_features=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
+ ):
+ input_shape = input_features.size() if input_features is not None else inputs_embeds.size()[:-1]
+
+ seq_length = input_shape[1]
+
+ if position_ids is None:
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
+
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
+ # issue #5664
+ if token_type_ids is None:
+ if hasattr(self, "token_type_ids"):
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
+ token_type_ids = buffered_token_type_ids_expanded
+ else:
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.word_embeddings(input_features)
+
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
+
+ embeddings = inputs_embeds + token_type_embeddings
+
+ embeddings = self.LayerNorm(embeddings)
+ embeddings = self.dropout(embeddings)
+ return embeddings
+
+
+class MCTCTSelfAttention(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+ raise ValueError(
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
+ f"heads ({config.num_attention_heads})"
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = config.attention_head_dim
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=False)
+ self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=False)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=False)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+
+ self.max_position_embeddings = config.max_position_embeddings
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
+
+ self.is_decoder = config.is_decoder
+
+ def transpose_for_scores(self, x):
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
+ x = x.view(*new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def reshape_fortran(self, x, shape):
+ if len(x.shape) > 0:
+ x = x.permute(*reversed(range(len(x.shape))))
+ return x.reshape(*reversed(shape)).permute(*reversed(range(len(shape))))
+
+ def relative_position_embedding_rotate(self, scores):
+ # NOTE: should re-evaluate whether this re-implementation was truly necessary
+ # or the reason why my complete re-haul worked was due to some other part
+ # of the code. Adding this and the reshape fortrain code seems very undesirable.
+ scores = scores.permute(0, 2, 3, 1) # e.g. [10, 1839, 14, 4]
+
+ batch, hidden_state, seq_len, heads = scores.shape
+
+ # e.g. [10, 1853, 14, 4]
+ scores = torch.cat((scores, torch.zeros((batch, seq_len, seq_len, heads), device=scores.device)), dim=1)
+
+ # e.g. [10, 25942, 1, 4]
+ scores = self.reshape_fortran(scores, [batch, (hidden_state + seq_len) * seq_len, 1, heads])
+
+ # e.g. [10, 25928, 1, 4]
+ scores = scores[:, : (seq_len + hidden_state - 1) * seq_len]
+
+ # e.g. [10, 1852, 14, 4]
+ scores = self.reshape_fortran(scores, [batch, hidden_state + seq_len - 1, seq_len, heads])
+
+ halfpoint = hidden_state // 2
+ scores = scores[:, halfpoint : halfpoint + seq_len].transpose(1, 2) # e.g. [10, 14, 14, 4]
+
+ return scores.permute(0, 3, 1, 2)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ output_attentions=False,
+ ):
+ mixed_query_layer = self.query(hidden_states)
+ mixed_query_layer = mixed_query_layer / math.sqrt(self.attention_head_size)
+
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+
+ # relative key position embeddings
+ positional_embedding = self.distance_embedding.weight
+ relative_position_scores = torch.einsum("lh, bche -> bcle", positional_embedding, query_layer.transpose(2, 3))
+
+ relative_position_scores = self.relative_position_embedding_rotate(relative_position_scores)
+ attention_scores = attention_scores + relative_position_scores
+
+ if attention_mask is not None:
+ # Apply the attention mask is (precomputed for all layers in MCTCTModel forward() function)
+ attention_scores = attention_scores + attention_mask
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs = attention_probs * head_mask
+
+ context_layer = torch.matmul(attention_probs, value_layer)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).flatten(start_dim=-2)
+
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+ return outputs
+
+
+class MCTCTLayerNorm(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.singleton_weight = nn.Parameter(torch.ones(1))
+ self.singleton_bias = nn.Parameter(torch.zeros(1))
+
+ def forward(self, hidden_states):
+ return (hidden_states * self.singleton_weight) + self.singleton_bias
+
+
+class MCTCTSelfOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states, input_tensor):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class MCTCTAttention(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.self = MCTCTSelfAttention(config)
+ self.output = MCTCTSelfOutput(config)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
+ )
+
+ # Prune linear layers
+ self.self.query = prune_linear_layer(self.self.query, index)
+ self.self.key = prune_linear_layer(self.self.key, index)
+ self.self.value = prune_linear_layer(self.self.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ output_attentions=False,
+ ):
+ self_outputs = self.self(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ output_attentions,
+ )
+ attention_output = self.output(self_outputs[0], hidden_states)
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+
+ return outputs
+
+
+class MCTCTIntermediate(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ return hidden_states
+
+
+class MCTCTOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states, input_tensor):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class MCTCTLayer(nn.Module):
+ def __init__(self, config: MCTCTConfig):
+ super().__init__()
+
+ self.seq_len_dim = 1
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+
+ self.intermediate = MCTCTIntermediate(config)
+ self.attention = MCTCTAttention(config)
+ self.is_decoder = config.is_decoder
+ self.output = MCTCTOutput(config)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ output_attentions=False,
+ ):
+ self_attention_outputs = self.attention(
+ hidden_states, attention_mask, head_mask, output_attentions=output_attentions
+ )
+ attention_output = self_attention_outputs[0]
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
+
+ layer_output = apply_chunking_to_forward(
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
+ )
+
+ outputs = (layer_output,) + outputs
+
+ return outputs
+
+ def feed_forward_chunk(self, attention_output):
+ intermediate_output = self.intermediate(attention_output)
+ layer_output = self.output(intermediate_output, attention_output)
+ return layer_output
+
+
+class MCTCTPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = MCTCTConfig
+ base_model_prefix = "mctct"
+ main_input_name = "input_features"
+ supports_gradient_checkpointing = True
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, MCTCTLayerNorm):
+ module.singleton_weight.data.fill_(1.0)
+ module.singleton_bias.data.zero_()
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+
+ def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
+ """
+ Computes the output length of the convolutional layers
+ """
+ dilation = 1
+ for _, kernel_sz, stride in zip(
+ range(self.config.num_conv_layers), self.config.conv_kernel, self.config.conv_stride
+ ):
+ padding = kernel_sz // 2
+ input_lengths = input_lengths + 2 * padding - dilation * (kernel_sz - 1) - 1
+ input_lengths = torch.div(input_lengths, stride, rounding_mode="trunc") + 1
+
+ return input_lengths
+
+ def _get_feature_vector_attention_mask(self, feature_vector_length, attention_mask):
+ # generate creates 3D attention mask, because of the shape of input_features
+ # convert it to 2D if thats the case
+ if len(attention_mask.shape) > 2:
+ attention_mask = attention_mask[:, :, -1]
+
+ # subsampled_lengths = attention_mask.sum(-1)
+ subsampled_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1))
+ bsz = attention_mask.size()[0]
+ attention_mask = torch.zeros(
+ (bsz, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
+ )
+
+ # these two operations makes sure that all values
+ # before the output lengths indices are attended to
+ attention_mask[(torch.arange(bsz, device=attention_mask.device), subsampled_lengths - 1)] = 1
+ attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).long()
+ return attention_mask
+
+
+MCTCT_START_DOCSTRING = r"""
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
+ it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+ behavior.
+
+ Parameters:
+ config ([`MCTCTConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+MCTCT_INPUTS_DOCSTRING = r"""
+ Args:
+ input_features (`torch.LongTensor` of shape `({0})`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`Wav2Vec2CTCTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+class MCTCTEncoder(MCTCTPreTrainedModel):
+ def __init__(self, config: MCTCTConfig):
+ super().__init__(config)
+ self.hidden_dropout_prob = config.hidden_dropout_prob
+
+ self.layer_norm = MCTCTLayerNorm()
+ self.conv = MCTCTConv1dSubsampler(config)
+ self.layers = nn.ModuleList([MCTCTLayer(config) for _ in range(config.num_hidden_layers)])
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ input_features: torch.Tensor,
+ attention_mask: torch.Tensor,
+ head_mask: torch.Tensor,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ) -> Union[Tuple, BaseModelOutput]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ input_features = self.layer_norm(input_features)
+
+ inputs_embeds = self.conv(input_features)
+
+ # subsample attention mask if necessary
+ if attention_mask is not None:
+ attention_mask = self._get_feature_vector_attention_mask(inputs_embeds.shape[1], attention_mask)
+
+ hidden_states = nn.functional.dropout(inputs_embeds, p=self.hidden_dropout_prob, training=self.training)
+
+ # expand attention_mask
+ if attention_mask is not None:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
+
+ encoder_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ # check if head_mask has a correct number of layers specified if desired
+ if head_mask is not None:
+ if head_mask.size()[0] != len(self.layers):
+ raise ValueError(
+ f"The head_mask should be specified for {len(self.layers)} layers, "
+ f"but it is for {head_mask.size()[0]}."
+ )
+
+ synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
+ for idx, encoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+ dropout_probability = torch.rand([])
+
+ skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
+ if not skip_the_layer or synced_gpus:
+ # under fsdp or deepspeed zero3 all gpus must run in sync
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ encoder_layer.__call__,
+ hidden_states,
+ attention_mask,
+ (head_mask[idx] if head_mask is not None else None),
+ output_attentions,
+ )
+ else:
+ layer_outputs = encoder_layer(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if skip_the_layer:
+ layer_outputs = (None, None)
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
+ )
+
+
+@add_start_docstrings(
+ "The bare M-CTC-T Model transformer outputting raw hidden-states without any specific head on top.",
+ MCTCT_START_DOCSTRING,
+)
+class MCTCTModel(MCTCTPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.config = config
+
+ self.encoder = MCTCTEncoder(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(MCTCT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=BaseModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ modality="audio",
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
+ )
+ def forward(
+ self,
+ input_features: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutput]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_features is None:
+ raise ValueError("You have to specify input_features.")
+
+ encoder_outputs = self.encoder(
+ input_features,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = encoder_outputs[0]
+
+ if not return_dict:
+ return (sequence_output,) + encoder_outputs[1:]
+
+ return BaseModelOutput(
+ last_hidden_state=sequence_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """MCTCT Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
+ MCTCT_START_DOCSTRING,
+)
+class MCTCTForCTC(MCTCTPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.mctct = MCTCTModel(config)
+
+ if config.vocab_size is None:
+ raise ValueError(
+ f"You are trying to instantiate {self.__class__} with a configuration that "
+ "does not define the vocabulary size of the language model head. Please "
+ "instantiate the model as follows: `MCTCTForCTC.from_pretrained(..., vocab_size=vocab_size)`. "
+ "or define `vocab_size` of your model's configuration."
+ )
+ output_hidden_size = config.hidden_size
+
+ self.ctc_head = nn.Linear(output_hidden_size, config.vocab_size)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(MCTCT_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=CausalLMOutput,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output=_CTC_EXPECTED_OUTPUT,
+ expected_loss=_CTC_EXPECTED_LOSS,
+ )
+ def forward(
+ self,
+ input_features: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ labels: Optional[torch.LongTensor] = None,
+ ) -> Union[Tuple, CausalLMOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
+ Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
+ the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
+ All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
+ config.vocab_size - 1]`.
+ """
+ if labels is not None and labels.max() >= self.config.vocab_size:
+ raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ outputs = self.mctct(
+ input_features,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[0]
+
+ logits = self.ctc_head(hidden_states)
+
+ loss = None
+ if labels is not None:
+ # retrieve loss input_lengths from attention_mask
+ attention_mask = (
+ attention_mask
+ if attention_mask is not None
+ else torch.ones(input_features.shape[:-1], dtype=torch.long)
+ )
+ input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
+ # assuming that padded tokens are filled with -100
+ # when not being attended to
+ labels_mask = labels >= 0
+ target_lengths = labels_mask.sum(-1)
+ flattened_targets = labels.masked_select(labels_mask)
+
+ # ctc_loss doesn't support fp16
+ log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
+
+ with torch.backends.cudnn.flags(enabled=False):
+ loss = nn.functional.ctc_loss(
+ log_probs,
+ flattened_targets,
+ input_lengths,
+ target_lengths,
+ blank=self.config.pad_token_id,
+ reduction=self.config.ctc_loss_reduction,
+ zero_infinity=self.config.ctc_zero_infinity,
+ )
+
+ if not return_dict:
+ output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
+ return ((loss,) + output) if loss is not None else output
+
+ return CausalLMOutput(
+ loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
+ )
+
+
+__all__ = ["MCTCTForCTC", "MCTCTModel", "MCTCTPreTrainedModel"]
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/mctct/processing_mctct.py b/docs/transformers/build/lib/transformers/models/deprecated/mctct/processing_mctct.py
new file mode 100644
index 0000000000000000000000000000000000000000..f953c5895a0d2e1cf3a343fbdefc99861da2dadd
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/mctct/processing_mctct.py
@@ -0,0 +1,146 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Speech processor class for M-CTC-T
+"""
+
+import warnings
+from contextlib import contextmanager
+
+from ....processing_utils import ProcessorMixin
+
+
+class MCTCTProcessor(ProcessorMixin):
+ r"""
+ Constructs a MCTCT processor which wraps a MCTCT feature extractor and a MCTCT tokenizer into a single processor.
+
+ [`MCTCTProcessor`] offers all the functionalities of [`MCTCTFeatureExtractor`] and [`AutoTokenizer`]. See the
+ [`~MCTCTProcessor.__call__`] and [`~MCTCTProcessor.decode`] for more information.
+
+ Args:
+ feature_extractor (`MCTCTFeatureExtractor`):
+ An instance of [`MCTCTFeatureExtractor`]. The feature extractor is a required input.
+ tokenizer (`AutoTokenizer`):
+ An instance of [`AutoTokenizer`]. The tokenizer is a required input.
+ """
+
+ feature_extractor_class = "MCTCTFeatureExtractor"
+ tokenizer_class = "AutoTokenizer"
+
+ def __init__(self, feature_extractor, tokenizer):
+ super().__init__(feature_extractor, tokenizer)
+ self.current_processor = self.feature_extractor
+ self._in_target_context_manager = False
+
+ def __call__(self, *args, **kwargs):
+ """
+ When used in normal mode, this method forwards all its arguments to MCTCTFeatureExtractor's
+ [`~MCTCTFeatureExtractor.__call__`] and returns its output. If used in the context
+ [`~MCTCTProcessor.as_target_processor`] this method forwards all its arguments to AutoTokenizer's
+ [`~AutoTokenizer.__call__`]. Please refer to the docstring of the above two methods for more information.
+ """
+ # For backward compatibility
+ if self._in_target_context_manager:
+ return self.current_processor(*args, **kwargs)
+
+ if "raw_speech" in kwargs:
+ warnings.warn("Using `raw_speech` as a keyword argument is deprecated. Use `audio` instead.")
+ audio = kwargs.pop("raw_speech")
+ else:
+ audio = kwargs.pop("audio", None)
+ sampling_rate = kwargs.pop("sampling_rate", None)
+ text = kwargs.pop("text", None)
+ if len(args) > 0:
+ audio = args[0]
+ args = args[1:]
+
+ if audio is None and text is None:
+ raise ValueError("You need to specify either an `audio` or `text` input to process.")
+
+ if audio is not None:
+ inputs = self.feature_extractor(audio, *args, sampling_rate=sampling_rate, **kwargs)
+ if text is not None:
+ encodings = self.tokenizer(text, **kwargs)
+
+ if text is None:
+ return inputs
+ elif audio is None:
+ return encodings
+ else:
+ inputs["labels"] = encodings["input_ids"]
+ return inputs
+
+ def batch_decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to AutoTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please refer
+ to the docstring of this method for more information.
+ """
+ return self.tokenizer.batch_decode(*args, **kwargs)
+
+ def pad(self, *args, **kwargs):
+ """
+ When used in normal mode, this method forwards all its arguments to MCTCTFeatureExtractor's
+ [`~MCTCTFeatureExtractor.pad`] and returns its output. If used in the context
+ [`~MCTCTProcessor.as_target_processor`] this method forwards all its arguments to PreTrainedTokenizer's
+ [`~PreTrainedTokenizer.pad`]. Please refer to the docstring of the above two methods for more information.
+ """
+ # For backward compatibility
+ if self._in_target_context_manager:
+ return self.current_processor.pad(*args, **kwargs)
+
+ input_features = kwargs.pop("input_features", None)
+ labels = kwargs.pop("labels", None)
+ if len(args) > 0:
+ input_features = args[0]
+ args = args[1:]
+
+ if input_features is not None:
+ input_features = self.feature_extractor.pad(input_features, *args, **kwargs)
+ if labels is not None:
+ labels = self.tokenizer.pad(labels, **kwargs)
+
+ if labels is None:
+ return input_features
+ elif input_features is None:
+ return labels
+ else:
+ input_features["labels"] = labels["input_ids"]
+ return input_features
+
+ def decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to AutoTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to the
+ docstring of this method for more information.
+ """
+ return self.tokenizer.decode(*args, **kwargs)
+
+ @contextmanager
+ def as_target_processor(self):
+ """
+ Temporarily sets the tokenizer for processing the input. Useful for encoding the labels when fine-tuning MCTCT.
+ """
+ warnings.warn(
+ "`as_target_processor` is deprecated and will be removed in v5 of Transformers. You can process your "
+ "labels by using the argument `text` of the regular `__call__` method (either in the same call as "
+ "your audio inputs, or in a separate call."
+ )
+ self._in_target_context_manager = True
+ self.current_processor = self.tokenizer
+ yield
+ self.current_processor = self.feature_extractor
+ self._in_target_context_manager = False
+
+
+__all__ = ["MCTCTProcessor"]
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/mega/__init__.py b/docs/transformers/build/lib/transformers/models/deprecated/mega/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..cff2c19505f9306c28edb5bdabb66f668363f5fa
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/mega/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ....utils import _LazyModule
+from ....utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_mega import *
+ from .modeling_mega import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/mega/configuration_mega.py b/docs/transformers/build/lib/transformers/models/deprecated/mega/configuration_mega.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b9d53d52079f0dde5237b50ce36730685a3767e
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/mega/configuration_mega.py
@@ -0,0 +1,243 @@
+# coding=utf-8
+# Copyright 2023 The Mega Authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""MEGA configuration"""
+
+from collections import OrderedDict
+from typing import Mapping
+
+from ....configuration_utils import PretrainedConfig
+from ....onnx import OnnxConfig
+from ....utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class MegaConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`MegaModel`]. It is used to instantiate a Mega
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the Mega
+ [mnaylor/mega-base-wikitext](https://huggingface.co/mnaylor/mega-base-wikitext) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 30522):
+ Vocabulary size of the Mega model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`MegaModel`].
+ hidden_size (`int`, *optional*, defaults to 128):
+ Dimensionality of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 4):
+ Number of hidden layers in the Mega encoder.
+ intermediate_size (`int`, *optional*, defaults to 256):
+ Dimensionality of the hidden size (self-attention value projection) within the Mega encoder
+ ema_projection_size (`int`, *optional*, defaults to 16):
+ Dimensionality of the MegaMultiDimensionDampedEma
+ bidirectional (`bool`, *optional*, defaults to `True`):
+ Whether the MegaMultiDimensionDampedEma used in Mega's self-attention should work bidirectionally (`True`)
+ or unidirectionally (`False`). Bidirectional EMA is incompatible with causal decoding, so this should be
+ False if you intend to use the model as a decoder.
+ shared_representation_size (`int`, *optional*, defaults to 64):
+ Dimensionality of the linear projection for shared representation of self-attention queries and keys
+ use_chunking (`bool`, *optional*, defaults to `False`):
+ Whether to chunk inputs for linear self-attention complexity (described as Mega-chunk in the paper)
+ chunk_size (`int`, *optional*, defaults to -1):
+ If `use_chunking` is set to `True`, determines the size of the chunks to apply to the input sequence. If
+ chunking is used, input sequences must be padded to a multiple of `chunk_size`
+ truncation (`int`, *optional*):
+ If specified, the sequence length for which to truncate MegaMultiDimensionDampedEma
+ normalize_before_mega (`bool`, *optional*, defaults to `True`):
+ Whether to normalize before (`True`) or after (`False`) passing through Mega encoder blocks
+ normalization_type (`str`, *optional*, defaults to `"scalenorm"`):
+ Type of normalization to use in Mega encoder blocks. Choose one of `"scalenorm"`, `"layernorm"`,
+ `"rmsnorm"`, `"batchnorm"`, or `"syncbatchnorm"` (GPU required for syncbatchnorm)
+ norm_affine (`bool`, *optional*, defaults to `True`):
+ If `True`, applies a parameterized affine transformation to inputs during normalization
+ activation (`str`, *optional*, defaults to `"silu"`):
+ Activation function to apply within Mega encoder blocks. Choose one of `"silu"`, `"relu"`, `"linear"`,
+ `"gelu"`, or `"gelu_accurate"`
+ attention_activation (`str`, *optional*, defaults to `"softmax"`):
+ Activation function to apply for single-headed self-attention (a la Transformer). Choose one of
+ `"softmax"`, `"laplace"`, or `"relu2"`
+ dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout probability for EMA self-attention
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for the attention probabilities.
+ use_feature_dropout (`bool`, *optional*, defaults to `False`):
+ Whether to use feature-based (`True`) or standard dropout (`False`)
+ use_normalized_ffn (`bool`, *optional*, defaults to `True`):
+ Whether to use the normalized feed-forward sub-layer in Mega blocks (`True`) or pass Mega encoder output
+ as-is (`False`)
+ nffn_hidden_size (`int`, *optional*, defaults to 256):
+ If using the normalized feed-forward network (NFFN) layer within Mega (`use_normalized_ffn = True`), this
+ is the hidden size of the NFFN
+ normalize_before_ffn (`bool`, *optional*, defaults to `True`):
+ Whether to normalize before (`True`) or after (`False`) the feed-forward portion of NFFN
+ nffn_activation_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for the NFFN component.
+ max_positions (`int`, *optional*, defaults to 2048):
+ The maximum sequence length to use for positional representations. For `"simple"` relative positional bias,
+ this is a hard limit on input length; `"rotary"` relative positional bias will extrapolate to longer
+ sequences
+ add_token_type_embeddings (`bool`, *optional*, defaults to `True`):
+ Whether to account for token types in embeddings. Left as optional to maintain compatibility with original
+ implementation while adding support for token types.
+ type_vocab_size (`int`, *optional*, defaults to 2):
+ The vocabulary size of the `token_type_ids` passed when calling [`MegaModel`]. Only used if
+ `add_token_type_embeddings = True`
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ ema_delta_alpha_range (`float`, *optional*, defaults to 0.2):
+ The standard deviation for initializing the delta (damping factor) and alpha (decay factor) parameters in
+ MegaMultiDimensionDampedEma.
+ ema_beta_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation for initializing the beta parameter (expansion matrix) in
+ MegaMultiDimensionDampedEma.
+ ema_gamma_omega_range (`float`, *optional*, defaults to 1.0):
+ The standard deviation for initializing the gamma (projection matrix) and omega (residual weight)
+ parameters in MultiDimensionEMA.
+ relative_positional_bias (`str`, *optional*, defaults to `"rotary"`):
+ Type of relative positional encoding. Choose one of `"rotary"` or `"simple"`. If `"simple"` is selected,
+ `max_positions` is used as a limit on input size, while `"rotary"` extrapolates beyond `max_positions`.
+ is_decoder (`bool`, *optional*, defaults to `False`):
+ Whether the model is used as a decoder or not. If `False`, the model is used as an encoder.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ classifier_dropout (`float`, *optional*):
+ The dropout ratio for the classification head.
+ add_lm_hidden_dense_layer (`bool`, *optional*, defaults to `True`):
+ Whether to include a hidden layer for projection between encoder outputs and LM heads (`True`) or pass
+ hidden states directly to LM head (`False`). Remains optional for compatibility with original
+ implementation
+
+ Examples:
+
+ ```python
+ >>> from transformers import MegaConfig, MegaModel
+
+ >>> # Initializing a Mega configuration
+ >>> configuration = MegaConfig()
+
+ >>> # Initializing a model (with random weights) from the configuration
+ >>> model = MegaModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "mega"
+
+ def __init__(
+ self,
+ vocab_size=30522,
+ hidden_size=128,
+ num_hidden_layers=4,
+ intermediate_size=256,
+ ema_projection_size=16,
+ bidirectional=True,
+ shared_representation_size=64,
+ use_chunking=False,
+ chunk_size=-1,
+ truncation=None,
+ normalize_before_mega=True,
+ normalization_type="scalenorm",
+ norm_affine=True,
+ activation="silu",
+ attention_activation="softmax",
+ dropout_prob=0.1,
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ use_feature_dropout=False,
+ use_normalized_ffn=True,
+ nffn_hidden_size=256,
+ normalize_before_ffn=True,
+ nffn_activation_dropout_prob=0.1,
+ max_positions=2048,
+ add_token_type_embeddings=False,
+ type_vocab_size=2,
+ initializer_range=0.02,
+ ema_delta_alpha_range=0.2,
+ ema_beta_range=0.02,
+ ema_gamma_omega_range=1.0,
+ pad_token_id=1,
+ bos_token_id=0,
+ eos_token_id=2,
+ relative_positional_bias="rotary",
+ classifier_dropout=None,
+ use_cache=True,
+ add_lm_hidden_dense_layer=True,
+ **kwargs,
+ ):
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
+
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.activation = activation
+ self.attention_activation = attention_activation
+ self.intermediate_size = intermediate_size
+ self.ema_projection_size = ema_projection_size
+ self.bidirectional = bidirectional
+ self.shared_representation_size = shared_representation_size
+ self.use_chunking = use_chunking
+ self.chunk_size = chunk_size
+ self.truncation = truncation
+ self.normalize_before_mega = normalize_before_mega
+ self.normalization_type = normalization_type
+ self.norm_affine = norm_affine
+ self.dropout_prob = dropout_prob
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.use_feature_dropout = use_feature_dropout
+ self.use_normalized_ffn = use_normalized_ffn
+ self.nffn_hidden_size = nffn_hidden_size
+ self.normalize_before_ffn = normalize_before_ffn
+ self.nffn_activation_dropout_prob = nffn_activation_dropout_prob
+ self.max_positions = max_positions
+ self.add_token_type_embeddings = add_token_type_embeddings
+ self.type_vocab_size = type_vocab_size
+ self.initializer_range = initializer_range
+ self.ema_delta_alpha_range = ema_delta_alpha_range
+ self.ema_beta_range = ema_beta_range
+ self.ema_gamma_omega_range = ema_gamma_omega_range
+ self.relative_positional_bias = relative_positional_bias
+ self.use_cache = use_cache
+ self.classifier_dropout = classifier_dropout
+ self.add_lm_hidden_dense_layer = add_lm_hidden_dense_layer
+ self.num_attention_heads = 1 # not used but required by Hugging Face
+
+
+class MegaOnnxConfig(OnnxConfig):
+ @property
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
+ if self.task == "multiple-choice":
+ dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
+ else:
+ dynamic_axis = {0: "batch", 1: "sequence"}
+ return OrderedDict(
+ [
+ ("input_ids", dynamic_axis),
+ ("attention_mask", dynamic_axis),
+ ]
+ )
+
+
+__all__ = ["MegaConfig", "MegaOnnxConfig"]
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/mega/convert_mega_original_pytorch_checkpoint_to_pytorch.py b/docs/transformers/build/lib/transformers/models/deprecated/mega/convert_mega_original_pytorch_checkpoint_to_pytorch.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6dbb12890e55fd26c39794aa28018265e98e164
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/mega/convert_mega_original_pytorch_checkpoint_to_pytorch.py
@@ -0,0 +1,298 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Convert Mega pretrained checkpoint. Built to convert the Masked LM checkpoint located at
+https://huggingface.co/mnaylor/mega-wikitext-103
+
+Requirements:
+ - clone the Mega repo and install fairseq from there
+ 1. git clone https://github.com/facebookresearch/mega.git
+ 2. cd mega && pip install -e
+ - clone the pretrained weights for the original implementation from the hugging face repo
+ * use this location as the path for pretrained weights
+"""
+
+import argparse
+
+# utilities to import the model weights and config file
+import os
+import pickle as pkl
+
+# PyTorch + new model classes
+import torch
+from torch import nn
+
+from transformers import AutoTokenizer, MegaConfig, MegaForMaskedLM
+
+
+# import the EncoderLayer class used to pretrain
+# !! NOTE !! this requires the version of fairseq that is built when you install the Mega source
+try:
+ from fairseq.modules.mega_layer import MegaEncoderLayer
+except ImportError:
+ raise ImportError("You need to install the version of fairseq from the Mega repo!")
+
+
+# define the wrapper classes used to train the MLM (see colab notebook below)
+# https://colab.research.google.com/drive/1qfUO6o5HRdxBblWlw058HVyvaEPhPpH8?usp=sharing
+# MegaLM outputs hidden states
+class MegaLM(nn.Module):
+ "The base class for our Mega encoder - given input IDs, embed text and return encoder output"
+
+ def __init__(self, mega_args, depth, vocab_size):
+ super().__init__()
+ self.mega_args = mega_args
+ self.embedding_layer = nn.Embedding(vocab_size, self.mega_args.encoder_embed_dim)
+ self.encoders = nn.ModuleList([MegaEncoderLayer(self.mega_args) for _ in range(depth)])
+ self.depth = depth
+
+ def forward(self, input_ids, attention_mask, batch_first=True, ignore_mask_value=0):
+ """
+ Code for a forward pass - expects input_ids and attention_mask to come from a Hugging Face tokenizer as PyTorch
+ tensors, and returns a tensor of size (batch, n_classes) containing classification logits
+
+ Other options:
+ - batch_first: boolean indicating whether the batch dimension is first in input_ids (default: True, which
+ aligns with the HF tokenizer behavior)
+ - ignore_mask_value: the value in attention_mask that identifies tokens that should be ignored (default: 0,
+ which aligns with HF tokenizer)
+ """
+
+ # Mega expects embeddings to be (time, batch, embedding size), but
+ # Hugging Face returns tokens as (batch, time)
+ if batch_first:
+ input_ids = input_ids.T
+
+ # to make things more confusing, Mega expects the attention mask to
+ # be (batch, time), but with values of 0 (normal token) and 1 (ignore token)
+ # which is the opposite of what HF returns
+ if ignore_mask_value == 0:
+ attention_mask = 1 - attention_mask
+
+ # get token embeddings from IDs
+ embeds = self.embedding_layer(input_ids)
+
+ # pass through the Mega layers
+ # input is (time, batch, encoder dim) and output is the same
+ for encoder in self.encoders:
+ embeds = encoder(embeds, attention_mask)
+
+ # return according to the shape specified
+ if batch_first:
+ # (T, B, H) --> (B, T, H)
+ return torch.transpose(embeds, 0, 1)
+ else:
+ return embeds
+
+
+# renamed from MegaForMaskedLM to avoid confusion with new module
+class OriginalMegaForMaskedLM(nn.Module):
+ "A wrapper class for doing masked language modeling with Mega"
+
+ def __init__(self, mega_args, depth, vocab_size):
+ super().__init__()
+ self.mega = MegaLM(mega_args, depth, vocab_size)
+ self.mlm_head = nn.Linear(mega_args.encoder_embed_dim, vocab_size)
+ self.dropout = nn.Dropout(p=0.1)
+
+ def forward(self, input_ids, attention_mask, batch_first=True, ignore_mask_value=0):
+ """
+ Perform a forward pass through the Mega encoder and the masked LM head. Returns logits for each vocabulary
+ entry.
+
+ If `batch_first` (default to align with Hugging Face tokenizer behavior), output will have the shape (Batch
+ size, Sequence length, Vocab size); otherwise (S, B, V)
+ """
+ encoder_output = self.mega(input_ids, attention_mask, batch_first, ignore_mask_value)
+ return self.mlm_head(self.dropout(encoder_output))
+
+
+# code to convert the checkpoint located in the user-specified location
+def convert_checkpoint_to_huggingface(pretrained_checkpoint_path, output_path, includes_tokenizer):
+ with open(os.path.join(pretrained_checkpoint_path, "model_args.pkl"), "rb") as f:
+ mega_original_args = pkl.load(f)
+
+ # load the original encoder
+ original_mlm = OriginalMegaForMaskedLM(**mega_original_args).eval()
+
+ # load its weights
+ print(
+ "Original Mega encoder:",
+ original_mlm.mega.load_state_dict(
+ torch.load(
+ os.path.join(pretrained_checkpoint_path, "encoder_weights.pt"), map_location="cpu", weights_only=True
+ )
+ ),
+ )
+ print(
+ "Original Mega MLM layer:",
+ original_mlm.mlm_head.load_state_dict(
+ torch.load(
+ os.path.join(pretrained_checkpoint_path, "mlm_head_weights.pt"), map_location="cpu", weights_only=True
+ )
+ ),
+ )
+
+ # create a new config from the old one
+ hf_config = MegaConfig(
+ num_hidden_layers=mega_original_args["depth"],
+ vocab_size=mega_original_args["vocab_size"],
+ hidden_size=mega_original_args["mega_args"].encoder_embed_dim,
+ shared_representation_size=mega_original_args["mega_args"].encoder_z_dim,
+ intermediate_size=mega_original_args["mega_args"].encoder_hidden_dim,
+ ema_projection_size=mega_original_args["mega_args"].encoder_n_dim,
+ dropout_prob=mega_original_args["mega_args"].dropout,
+ attention_probs_dropout_prob=mega_original_args["mega_args"].attention_dropout,
+ hidden_dropout_prob=mega_original_args["mega_args"].hidden_dropout,
+ activation=mega_original_args["mega_args"].activation_fn,
+ attention_activation=mega_original_args["mega_args"].attention_activation_fn,
+ bidirectional=mega_original_args["mega_args"].bidirectional,
+ use_chunking=mega_original_args["mega_args"].encoder_chunk_size > 0,
+ chunk_size=mega_original_args["mega_args"].encoder_chunk_size,
+ truncation=mega_original_args["mega_args"].truncation_length,
+ normalization_type=mega_original_args["mega_args"].normalization_type,
+ normalize_before_mega=True,
+ norm_affine=True,
+ use_feature_dropout=mega_original_args["mega_args"].feature_dropout,
+ relative_positional_bias=mega_original_args["mega_args"].rel_pos_bias,
+ max_positions=mega_original_args["mega_args"].max_source_positions,
+ nffn_hidden_size=mega_original_args["mega_args"].encoder_ffn_embed_dim,
+ normalize_before_ffn=mega_original_args["mega_args"].normalize_before,
+ # new arguments added for HF implementation
+ nffn_activation_dropout_prob=0.0,
+ add_token_type_embeddings=False,
+ add_lm_hidden_dense_layer=False,
+ )
+
+ hf_mlm = MegaForMaskedLM(hf_config).eval()
+
+ # the originl checkpoint just uses nn.Embedding for the word embeddings
+ # we use a wrapper module for embeddings to add support for positional embeddings
+ hf_mlm.mega.embedding_layer.word_embeddings.weight = original_mlm.mega.embedding_layer.weight
+
+ # modify the state dictionary of the original checkpoint to account for naming issues in the Hugging Face
+ # ecosystem -- any names containing "beta" or "gamma" aren't safe to use and are renamed upon _load_pretrained,
+ # also renaming previously confusing parameter names
+ original_state_dict = original_mlm.mega.encoders.state_dict()
+ updated_keys = {}
+ for module_name in original_state_dict.keys():
+ new_module_name = None
+ # have to handle gamma, beta, and alpha differently due to their use
+ # in multiple modules within the original repository;
+ # beta is used in EMA, MovingAverageGatedAttention, and RotaryRelativePositionalBias, and must be renamed due to flax/tf weights
+ # the EMA sublayer was renamed from "move" to "ema_gate" for readability, so that is also done here
+ if "beta" in module_name:
+ # EMA sub-layers were always called "move" in the original repo
+ if "move.beta" in module_name:
+ new_module_name = module_name.replace("move.beta", "ema_gate.ema_expansion_matrix")
+ elif "mega_layer.beta" in module_name:
+ new_module_name = module_name.replace("beta", "qk_bias")
+ else:
+ new_module_name = module_name.replace("beta", "b_param")
+ # beta is used in EMA and MovingAverageGatedAttention, and must be renamed due to flax/tf weights
+ elif "gamma" in module_name:
+ if "move.gamma" in module_name:
+ new_module_name = module_name.replace("move.gamma", "ema_gate.kernel_projection_matrix")
+ elif "mega_layer.gamma" in module_name:
+ new_module_name = module_name.replace("gamma", "qk_weight")
+ else:
+ new_module_name = module_name.replace("gamma", "g_param")
+ # alpha is used in EMA and positional bias; renaming to improve readability
+ elif "move.alpha" in module_name:
+ new_module_name = module_name.replace("move.alpha", "ema_gate.decay_factor")
+ # delta is only used in EMA; renaming to improve readability
+ elif "move.delta" in module_name:
+ new_module_name = module_name.replace("move.delta", "ema_gate.damping_factor")
+ # omega is only used in EMA; renaming to improve readability
+ elif "omega" in module_name:
+ new_module_name = module_name.replace("move.omega", "ema_gate.residual_weight")
+
+ if new_module_name:
+ updated_keys[module_name] = new_module_name
+
+ if len(updated_keys) != 0:
+ print(f"Renaming these keys: {updated_keys.keys()}")
+ else:
+ print("No need to rename state dict entries")
+ for old, new in updated_keys.items():
+ original_state_dict[new] = original_state_dict.pop(old)
+
+ # now attempt to load the state dictionary with updated names
+ # note that we now call it `mega.layers` instead of `mega.encoders` due to hugging face style
+ print("HF Mega encoder:", hf_mlm.mega.layers.load_state_dict(original_state_dict))
+
+ # load the MLM head weights directly
+ print(
+ "HF Mega MLM layer:",
+ hf_mlm.mlm_head.load_state_dict(
+ torch.load(
+ os.path.join(pretrained_checkpoint_path, "mlm_head_weights.pt"), map_location="cpu", weights_only=True
+ )
+ ),
+ )
+
+ # test on a randomly generated input sequence
+ input_ids = torch.randint(0, hf_config.vocab_size, size=(4, 256))
+ input_mask = torch.ones_like(input_ids)
+ # mask a few tokens to make sure masking is applied appropriately :)
+ input_mask[:, -10:] = 0
+
+ # run forward passes
+ original_output = original_mlm(input_ids, input_mask, batch_first=True, ignore_mask_value=0)
+ hf_output = hf_mlm(input_ids, input_mask)[0]
+
+ # print shapes and diff
+ print(f"original output {original_output.shape}")
+ print(f"hf output {hf_output.shape}")
+ print(f"max diff: {(original_output - hf_output).max()}") # 0.0
+ success = torch.allclose(original_output, hf_output, atol=1e-3)
+
+ if success:
+ print("Yay!")
+ hf_mlm.save_pretrained(output_path)
+ else:
+ raise RuntimeError(f"Something's broken :(\nOriginal:\n{original_output}\n\nHF\n{hf_output}\n{hf_mlm}")
+
+ if includes_tokenizer:
+ print("Transferring tokenizer")
+ tokenizer = AutoTokenizer.from_pretrained(pretrained_checkpoint_path)
+ tokenizer.save_pretrained(output_path)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--pretrained_checkpoint_path",
+ default=None,
+ type=str,
+ required=True,
+ help="Point to the directory containing your model weights using the official Mega repo",
+ )
+
+ parser.add_argument(
+ "--output_path", default=None, type=str, required=True, help="Location to save the Hugging Face version"
+ )
+
+ parser.add_argument(
+ "--includes_tokenizer",
+ action="store_true",
+ help="Use this flag if there is a Hugging Face tokenizer in the original checkpoint repo",
+ )
+
+ args = parser.parse_args()
+
+ convert_checkpoint_to_huggingface(args.pretrained_checkpoint_path, args.output_path, args.includes_tokenizer)
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/mega/modeling_mega.py b/docs/transformers/build/lib/transformers/models/deprecated/mega/modeling_mega.py
new file mode 100644
index 0000000000000000000000000000000000000000..d5a490b01d3cc893758490c669f3928b93edab65
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/mega/modeling_mega.py
@@ -0,0 +1,2285 @@
+# coding=utf-8
+# Copyright 2023 The Mega Authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch MEGA model."""
+
+import math
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from ....activations import ACT2FN
+from ....modeling_outputs import (
+ BaseModelOutputWithPoolingAndCrossAttentions,
+ CausalLMOutputWithCrossAttentions,
+ MaskedLMOutput,
+ MultipleChoiceModelOutput,
+ QuestionAnsweringModelOutput,
+ SequenceClassifierOutput,
+ TokenClassifierOutput,
+)
+from ....modeling_utils import PreTrainedModel
+from ....pytorch_utils import ALL_LAYERNORM_LAYERS
+from ....utils import (
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from .configuration_mega import MegaConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "mnaylor/mega-base-wikitext"
+_CONFIG_FOR_DOC = "MegaConfig"
+
+
+class MegaEmbeddings(nn.Module):
+ """
+ Mega's basic implementation does not incorporate token type embeddings, so this is a stripped-down version of
+ RoBERTa's embeddings which optionally includes token types
+ """
+
+ def __init__(self, config: MegaConfig):
+ super().__init__()
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
+ self.use_token_types = config.add_token_type_embeddings
+ if self.use_token_types:
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
+ # registering a buffer here allows model tracing when not passing optional token type IDs
+ # more info at transformers issue #5664
+ self.register_buffer(
+ "token_type_ids", torch.zeros(config.max_positions, dtype=torch.long).expand((1, -1)), persistent=False
+ )
+
+ self.padding_idx = config.pad_token_id
+
+ def forward(self, input_ids=None, token_type_ids=None, inputs_embeds=None):
+ if (input_ids is None) and (inputs_embeds is None):
+ raise ValueError("Must provide one of input_ids or inputs_embeds")
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ device = input_ids.device
+
+ # get the word embeddings if only IDs are provided
+ inputs_embeds = self.word_embeddings(input_ids)
+ else:
+ input_shape = inputs_embeds.size()[:-1]
+ device = inputs_embeds.device
+
+ # the original Mega implementation did not include token type embeddings, so we add
+ # an option to use them if desired; if embeddings are present and token type IDs are
+ # not provided, we will use a registered buffer (which helps with tracing)
+ if self.use_token_types:
+ if token_type_ids is None:
+ if hasattr(self, "token_type_ids"):
+ buffered_token_type_ids = self.token_type_ids[:, : input_shape[1]]
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], input_shape[1])
+ token_type_ids = buffered_token_type_ids_expanded
+ else:
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
+
+ # access token type embeddings
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
+ # add the token type embeddings to the word embeddings
+ embeddings = inputs_embeds + token_type_embeddings
+ else:
+ embeddings = inputs_embeds
+ return embeddings
+
+
+class MegaSimpleRelativePositionalBias(nn.Module):
+ """
+ Simple relative positional embeddings copied from the Mega repo; renamed variables for better readability
+ """
+
+ def __init__(self, config: MegaConfig):
+ super().__init__()
+ self.config = config
+ self.max_positions = self.config.max_positions if self.config.chunk_size < 0 else self.config.chunk_size
+ self.rel_pos_bias = nn.Parameter(torch.Tensor(2 * config.max_positions - 1))
+
+ def forward(self, seq_len):
+ if seq_len > self.max_positions:
+ raise ValueError("Sequence length {} going beyond max length {}".format(seq_len, self.max_positions))
+
+ # seq_len * 2 - 1
+ bias = self.rel_pos_bias[(self.max_positions - seq_len) : (self.max_positions + seq_len - 1)]
+ # seq_len * 3 - 1
+ tile = F.pad(bias, (0, seq_len))
+ # (seq_len * 3 - 1) * seq_len
+ tile = torch.tile(tile, (seq_len,))
+ tile = tile[:-seq_len]
+ # seq_len x (3 * seq_len - 2)
+ tile = tile.view(seq_len, 3 * seq_len - 2)
+ start = (2 * seq_len - 1) // 2
+ end = tile.size(1) - start
+ tile = tile[:, start:end]
+ return tile
+
+
+class MegaRotaryRelativePositionalBias(nn.Module):
+ """
+ Rotary relative bias for positional information; similar in concept to RoPE (i.e. RoFormer) but taken from the Mega
+ repo due to differences in implementation.
+
+ When initialized, produces a positional bias which ranges from position 0 to config.max_positions, but can
+ extrapolate to longer sequences. Can be indexed according to input position IDs
+ """
+
+ def __init__(self, config: MegaConfig):
+ super().__init__()
+ if config.hidden_size % 2 != 0:
+ raise RuntimeError("Rotary positional bias requires `hidden_size` to be a multiple of 2")
+ self.config = config
+ self.embed_dim = config.shared_representation_size
+ self.max_positions = self.config.max_positions if self.config.chunk_size < 0 else self.config.chunk_size
+ self.sine, self.cosine = MegaRotaryRelativePositionalBias.get_sinusoid_embeddings(
+ config.max_positions, self.embed_dim
+ )
+ # alpha and beta parameters for the rotary bias; beta renamed to b_param to avoid clashes with tf/flax weight handling
+ # in loading pretrained weights
+ self.alpha = nn.Parameter(torch.Tensor(1, self.embed_dim))
+ self.b_param = nn.Parameter(torch.Tensor(1, self.embed_dim))
+ self.register_buffer("_float_tensor", torch.FloatTensor([0.0]))
+
+ @staticmethod
+ def get_sinusoid_embeddings(max_positions: int, embedding_dim: int):
+ half_dim = embedding_dim // 2
+ emb = math.log(10000) / half_dim
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb)
+ emb = torch.arange(max_positions, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
+ return torch.sin(emb), torch.cos(emb)
+
+ def rotary(self, input):
+ seq_len, embed_dim = input.size()
+ chunk_1, chunk_2 = torch.chunk(input, 2, dim=-1)
+ if self.sine is None or seq_len > self.sine.size(0):
+ self.sine, self.cosine = MegaRotaryRelativePositionalBias.get_sinusoid_embeddings(seq_len, embed_dim)
+ self.max_positions = seq_len
+ self.sine = self.sine.to(self._float_tensor)
+ self.cosine = self.cosine.to(self._float_tensor)
+
+ sin = self.sine[:seq_len]
+ cos = self.cosine[:seq_len]
+ return torch.cat([chunk_1 * cos - chunk_2 * sin, chunk_2 * cos + chunk_1 * sin], dim=1)
+
+ def forward(self, seq_len):
+ rotary_alpha = self.rotary(self.alpha.expand(seq_len, self.embed_dim))
+ rotary_beta = self.rotary(self.b_param.expand(seq_len, self.embed_dim))
+ bias = torch.einsum("mk,nk->mn", rotary_alpha, rotary_beta)
+ return bias
+
+
+class MegaDropout(nn.Module):
+ """
+ A unified class for standard dropout functionality and featurewise dropout.
+
+ The original fairseq Mega repo used 2 classes for these, which included some unnecessary handling of training logic
+ and an unused `inplace` option. The original implementation used torch.nn.functional instead of submodules, which
+ is retained here as well.
+ """
+
+ def __init__(self, dropout_probability, is_featurewise=False):
+ super().__init__()
+ self.dropout_probability = dropout_probability
+ self.is_featurewise = is_featurewise
+
+ def forward(self, input, batch_first: bool = False):
+ if self.is_featurewise:
+ if batch_first:
+ # (batch_size X sequence_length X feature_dimension)
+ # -> (batch_size X feature_dimension X sequence_length)
+ # -> (batch_size X sequence_length X feature_dimension)
+ return F.dropout2d(
+ input.transpose(-1, -2), p=self.dropout_probability, training=self.training
+ ).transpose(-1, -2)
+ else:
+ if input.dim() != 3:
+ raise ValueError(
+ "Feature dropout inputs must be exactly 3-dimensional if inputs are ordered [sequence length, batch size, hidden dimension]"
+ )
+ # (sequence_length X batch_size X feature_dimension)
+ # -> (batch_size X feature_dimension X sequence_length)
+ # -> (sequence_length X batch_size X feature_dimension)
+ return F.dropout2d(input.permute(1, 2, 0), p=self.dropout_probability, training=self.training).permute(
+ 2, 0, 1
+ )
+ else:
+ return F.dropout(input, p=self.dropout_probability, training=self.training)
+
+
+class MegaRMSNorm(nn.Module):
+ """
+ RMSNorm used in Mega implementation. Differs from T5's RMSNorm by applying the weight prior to taking the square
+ root (as opposed to after in T5)
+ """
+
+ def __init__(self, number_features, eps=1e-6, affine=True):
+ super().__init__()
+ self.num_features = number_features
+ self.eps = eps
+ self.affine = affine
+ if affine:
+ self.weight = nn.Parameter(torch.Tensor(self.num_features))
+ else:
+ self.register_parameter("weight", None)
+
+ def forward(self, input):
+ mean_square = torch.mean(torch.square(input), dim=-1, keepdim=True)
+ if self.weight is not None:
+ input = input * self.weight
+
+ input * torch.rsqrt(mean_square + self.eps)
+ return input
+
+ def extra_repr(self):
+ return f"{self.num_features}, eps={self.eps}, affine={self.affine}"
+
+
+class MegaScaleNorm(nn.Module):
+ """
+ Scale normalization introduced in MEGA which is similar to RMSNorm, but uses a single parameter for scalar
+ multiplication instead of a vector, and applies over a specified dimension
+ """
+
+ def __init__(self, dim, eps=1e-6, affine=True):
+ super().__init__()
+ self.dim = dim
+ self.eps = eps
+ self.affine = affine
+ if affine:
+ self.scalar = nn.Parameter(torch.Tensor(1))
+ else:
+ self.register_parameter("scalar", None)
+
+ def forward(self, input):
+ mean_square = torch.mean(torch.square(input), dim=self.dim, keepdim=True)
+ if self.scalar is not None:
+ input = self.scalar * input
+
+ output = input * torch.rsqrt(mean_square + self.eps)
+ return output
+
+
+class MegaSequenceNorm(nn.Module):
+ """
+ A wrapper class for various layer normalization options used in Mega. Used to handle differences in expectations on
+ input axis locations for different normalization methods.
+ """
+
+ def __init__(self, norm_type, embedding_dim, eps=1e-5, affine=True, export=False):
+ super().__init__()
+ if norm_type == "layernorm":
+ self.norm = nn.LayerNorm(embedding_dim, eps, elementwise_affine=affine)
+ elif norm_type == "scalenorm":
+ self.norm = MegaScaleNorm(dim=-1, eps=eps, affine=affine)
+ elif norm_type == "rmsnorm":
+ self.norm = MegaRMSNorm(embedding_dim, eps=eps, affine=affine)
+ elif norm_type == "batchnorm":
+ self.norm = nn.BatchNorm1d(embedding_dim, eps=eps, affine=affine)
+ elif norm_type == "syncbatchnorm":
+ self.norm = nn.SyncBatchNorm(embedding_dim, eps=eps, affine=affine)
+ else:
+ raise ValueError("Unknown norm type: {}".format(norm_type))
+
+ def forward(self, input):
+ if isinstance(self.norm, nn.modules.batchnorm._BatchNorm):
+ if input.dim() != 3:
+ raise ValueError("BatchNorm inputs must be exactly 3-dimensional")
+ input = input.permute(1, 2, 0)
+ input = self.norm(input)
+ return input.permute(2, 0, 1)
+ else:
+ return self.norm(input)
+
+
+# add this layernorm class to ALL_LAYERNORM_LAYERS
+ALL_LAYERNORM_LAYERS.append(MegaSequenceNorm)
+
+
+class MegaMultiDimensionDampedEma(nn.Module):
+ """
+ Mega's Exponential Moving Average layer, largely left unmodified from the original repo with the exception of
+ variable names and moving away from the stateful representation of incremental decoding state. See
+ "https://arxiv.org/abs/2209.10655" for more details.
+ """
+
+ def __init__(self, config: MegaConfig):
+ super().__init__()
+
+ self.config = config
+
+ self.embed_dim = config.hidden_size
+ self.ndim = config.ema_projection_size
+ self.bidirectional = config.bidirectional
+ self.truncation = config.truncation
+ self.scale = math.sqrt(1.0 / self.ndim)
+
+ kernel_dim = 2 * config.hidden_size if self.bidirectional else config.hidden_size
+ # renamed delta (damping_factor) and alpha (decay_factor) to be more descriptive of what the parameters are doing
+ self.damping_factor = nn.Parameter(torch.Tensor(kernel_dim, self.ndim, 1))
+ self.decay_factor = nn.Parameter(torch.Tensor(kernel_dim, self.ndim, 1))
+ # renamed gamma (kernel_projection_matrix) and beta (ema_expansion_matrix) respectively to avoid HF renaming
+ # things and align with the paper's description of these params' behavior
+ self.ema_expansion_matrix = nn.Parameter(torch.Tensor(kernel_dim, self.ndim, 1))
+ self.kernel_projection_matrix = nn.Parameter(torch.Tensor(kernel_dim, self.ndim))
+ # renamed omega to residual_weight to describe what it's doing
+ self.residual_weight = nn.Parameter(torch.Tensor(config.hidden_size))
+ self._kernel = None
+ self._coeffs = None
+
+ def _compute_ema_coefficients(self):
+ self._coeffs = None
+ # convert the alpha and delta parameters (kernel_dim x EMA projection size x 1) to [0, 1] with sigmoid
+ damping_factor = torch.sigmoid(self.damping_factor)
+ decay_factor = torch.sigmoid(self.decay_factor)
+ previous_timestep_weight = 1.0 - damping_factor * decay_factor
+ return damping_factor, previous_timestep_weight
+
+ def _compute_efficient_ema_kernel(self, length: int):
+ # computes the kernel used for efficient damped EMA applied via FFT convolution
+ self._kernel = None
+ # p and q have shape (kernel_dim x ema_projection_size x 1)
+ damping_factor, previous_timestep_weight = self._compute_ema_coefficients()
+ # extend the kernel to (kernel_dim X ema_projection_size X sequence_length) and
+ # multiply q by sequential ints up to the sequence length
+ vander = torch.arange(length).to(damping_factor).view(1, 1, length) * torch.log(previous_timestep_weight)
+ kernel = (damping_factor * self.ema_expansion_matrix) * torch.exp(vander)
+ # (kernel_dim X ema_projection_size X sequence_length) -> (kernel_dim, sequence_length)
+ return torch.einsum("dnl,dn->dl", kernel, self.kernel_projection_matrix * self.scale)
+
+ def get_ema_coefficients(self):
+ if self.training:
+ return self._compute_ema_coefficients()
+ else:
+ if self._coeffs is None:
+ self._coeffs = self._compute_ema_coefficients()
+ return self._coeffs
+
+ def get_ema_kernel(self, length: int):
+ kernel_size = length if self.truncation is None else min(self.truncation, length)
+ if self.training:
+ return self._compute_efficient_ema_kernel(kernel_size)
+ else:
+ if self._kernel is None or self._kernel.size(-1) < kernel_size:
+ self._kernel = self._compute_efficient_ema_kernel(kernel_size)
+ return self._kernel[..., :kernel_size]
+
+ def fft_convolution(self, inputs, kernel, length):
+ # this is a wrapper for repeated use of EMA calculation via FFT (fast Fourier transform) convolution
+ inputs_fft = torch.fft.rfft(inputs.float(), n=2 * length)
+ kernel_fft = torch.fft.rfft(kernel.float(), n=2 * length)
+ convolved_sequence = torch.fft.irfft(inputs_fft * kernel_fft, n=2 * length)
+ return convolved_sequence
+
+ def ema_step(self, inputs, length, past_state=None):
+ if length == 1:
+ return self.one_ema_step(inputs, past_state=past_state)
+
+ # (kernel_dim X ema_projection_size X 1)
+ damping_factor, previous_timestep_weight = self.get_ema_coefficients()
+ # (kernel_dim X ema_projection_size X 1+sequence_length)
+ vander = torch.arange(length + 1).to(damping_factor).view(1, 1, length + 1) * torch.log(
+ previous_timestep_weight
+ )
+ vander = torch.exp(vander)
+ if past_state is not None:
+ # (kernel_dim X ema_projection_size X sequence_length) * (kernel_dim X ema_projection_size X 1)
+ # -> (kernel_dim X ema_projection_size X sequence_length)
+ past_ema_proj = vander[:, :, 1:] * (self.kernel_projection_matrix * self.scale).unsqueeze(-1)
+ # past_state will be (batch_size, kernel_dim, ema_projection_size)
+ past_ema_state = torch.einsum("bdn,dnl->bdl", past_state, past_ema_proj)
+ # (kernel_dim X ema_projection_size) * (batch_size X kernel_dim X ema_projection_size)
+ # -> (batch_size X kernel_dim X ema_projection_size)
+ past_vandermonde = vander[:, :, -1] * past_state
+ else:
+ past_ema_state = None
+ past_vandermonde = None
+
+ # (kernel_dim X ema_projection_size X sequence_length)
+ vander = vander[:, :, :-1]
+ kernel = (damping_factor * self.ema_expansion_matrix) * vander
+ kernel_proj = torch.einsum("dnl,dn->dl", kernel, self.kernel_projection_matrix * self.scale)
+
+ ema_output = self.fft_convolution(inputs, kernel_proj, length=length)[..., 0:length]
+ ema_output = ema_output.type_as(inputs)
+ if past_ema_state is not None:
+ ema_output = ema_output + past_ema_state
+
+ updated_hidden_state = torch.einsum("bdl,dnl->bdn", inputs, torch.flip(kernel, dims=[2]))
+ if past_vandermonde is not None:
+ updated_hidden_state = updated_hidden_state + past_vandermonde
+ # return a tuple:
+ # (sequence_length, batch_size, kernel_dim)
+ # (batch_size, kernel_dim, ema_projection_size)
+ return ema_output.permute(2, 0, 1), updated_hidden_state
+
+ def one_ema_step(self, inputs, past_state=None):
+ damping_factor, previous_timestep_weight = self.get_ema_coefficients()
+ # (kernel_dim X ema_projection_size) x (batch_size X kernel_dim X 1)
+ # -> (batch_size X kernel_dim X ema_projection_size)
+ updated_state = (damping_factor * self.ema_expansion_matrix).squeeze(-1) * inputs
+ if past_state is not None:
+ updated_state = updated_state + previous_timestep_weight.squeeze(-1) * past_state
+ # (batch_size X kernel_dim)
+ out = torch.einsum("bdn,dn->bd", updated_state, self.kernel_projection_matrix * self.scale)
+ # (1 X batch_size X kernel_dim), (batch_size X kernel_dim X ema_projection_size)
+ return out.unsqueeze(0), updated_state
+
+ def forward(
+ self,
+ inputs,
+ attention_mask: Optional[torch.Tensor] = None,
+ prev_state: Optional[torch.Tensor] = None,
+ use_cache: bool = False,
+ ) -> torch.Tensor:
+ """
+ Mega's exponential moving average (EMA) sub-layer applied prior to single-headed (traditional) self-attention
+
+ Args:
+ inputs (`torch.Tensor` of shape `(sequence_length, batch_size, hidden_size)`):
+ Hidden state / embedding input to update via EMA based on FFT convolution
+ attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indicates which inputs are to be ignored (mostly due to padding), where elements are either 1 for *not
+ masked* or 0 for *masked*
+ prev_state (`torch.Tensor` of shape `(batch_size, config.ndim)`, *optional*):
+ The hidden state returned from the previous timestep during incremental decoding.
+ use_cache (`bool`, default `False`):
+ Whether to perfom incremental decoding; uses `prev_state` as the prior timestep, and returns the
+ updated EMA hidden state for use in the next step
+
+ Returns:
+ `tuple(torch.FloatTensor)` containing various elements depending on configuration ([`MegaConfig`]) and
+ inputs:
+ - **hidden_states** (`torch.FloatTensor` of shape `(sequence_length, batch_size, hidden_size)`) -- Hidden
+ states updated by EMA, with same shapes as inputs
+ - **updated_state** (*optional*, returned when `use_cache=True`) `torch.FloatTensor of shape `(batch_size,
+ config.ndim)` -- The incremental EMA state for use in the next step of incremental decoding
+ """
+
+ seq_len, bsz, embed_dim = inputs.size()
+ if embed_dim != self.embed_dim:
+ raise ValueError(
+ f"Unexpected embedding dimension received: input is {embed_dim}, model expects {self.embed_dim}"
+ )
+
+ # sequence_length X batch_size X hidden_size
+ residual = inputs * self.residual_weight
+
+ # (sequence_length x batch_size x hidden_size) -> (batch_size x hidden_size x sequence_length)
+ inputs = inputs.permute(1, 2, 0)
+ # mask the input: output is a tensor with 0 in the masked positions
+ if attention_mask is not None:
+ inputs = inputs * (attention_mask.unsqueeze(1).type_as(inputs))
+
+ if self.bidirectional and use_cache:
+ raise RuntimeError("Bidirectional EMA does not support incremental state")
+
+ if use_cache:
+ out, updated_state = self.ema_step(inputs, seq_len, past_state=prev_state)
+
+ # (batch_size X hidden_size) -> (1 x batch_size x hidden_size)
+ out = F.silu(out + residual)
+
+ # if incremental decoding, return the new state along with the output
+ return out, updated_state
+ else:
+ # (hidden_size x sequence_length)
+ kernel = self.get_ema_kernel(seq_len)
+ fft_len = seq_len
+ s_index = 0
+ kernel_size = kernel.size(1)
+ if self.bidirectional:
+ # split the kernel for each direction of EMA
+ k1, k2 = torch.split(kernel, [self.embed_dim, self.embed_dim], dim=0)
+ # (hidden_size X 2*sequence_length - 1)
+ kernel = F.pad(k1, (kernel_size - 1, 0)) + F.pad(k2.flip(-1), (0, kernel_size - 1))
+ inputs = F.pad(inputs, (kernel_size - 1, 0))
+ fft_len = fft_len + kernel_size - 1
+ s_index = 2 * kernel_size - 2
+
+ ema_output = self.fft_convolution(inputs, kernel, length=fft_len)[..., s_index : s_index + seq_len]
+ ema_output = ema_output.type_as(inputs)
+ # (batch_size X hidden_size X sequence_length) -> (sequence_length X batch_size X hidden_size)
+ gated_ema_output = F.silu(ema_output.permute(2, 0, 1) + residual)
+
+ return gated_ema_output, None
+
+
+class MegaGatedCrossAttention(nn.Module):
+ """
+ Gated Structured State Attention for use in encoder-decoder model. See Mega paper for more details. Only
+ modifications from original implementation are variable names, removing the unnecessary `before_attn_fn` and
+ `static_kv` arguments, and the stateful representation of incremental decoder state.
+ """
+
+ def __init__(self, config: MegaConfig):
+ super().__init__()
+
+ self.config = config
+ self.activation = ACT2FN[self.config.activation]
+ self.attention_activation = self.config.attention_activation
+ self.scaling = self.config.shared_representation_size**-0.5 if self.attention_activation == "softmax" else None
+
+ self.dropout = MegaDropout(self.config.dropout_prob, is_featurewise=self.config.use_feature_dropout)
+ self.hidden_dropout = MegaDropout(
+ self.config.hidden_dropout_prob, is_featurewise=self.config.use_feature_dropout
+ )
+ # Attention dropout is standard dropout
+ self.attention_dropout = MegaDropout(self.config.attention_probs_dropout_prob, is_featurewise=False)
+
+ self.prenorm = self.config.normalize_before_mega
+ self.norm = MegaSequenceNorm(
+ self.config.normalization_type, self.config.hidden_size, affine=self.config.norm_affine
+ )
+
+ self.k_proj = nn.Linear(self.config.hidden_size, self.config.shared_representation_size)
+ self.v_proj = nn.Linear(self.config.hidden_size, self.config.hidden_size)
+ self.q_proj = nn.Linear(
+ self.config.hidden_size, 2 * self.config.hidden_size + self.config.shared_representation_size
+ )
+ self.h_proj = nn.Linear(self.config.hidden_size, self.config.hidden_size)
+
+ if self.config.relative_positional_bias == "simple":
+ self.rel_pos_bias = MegaSimpleRelativePositionalBias(config)
+ elif self.config.relative_positional_bias == "rotary":
+ self.rel_pos_bias = MegaRotaryRelativePositionalBias(config)
+ else:
+ raise ValueError("unknown relative position bias: {}".format(self.config.relative_positional_bias))
+
+ self.softmax = nn.Softmax(dim=-1)
+
+ def element_attention(self, query, key, key_padding_mask, pidx):
+ bsz, src_len, _ = key.size()
+ tgt_len = query.size(1) if pidx is None else pidx + 1
+ if key_padding_mask is not None:
+ # (batch_size X source_sequence_length) --> (batch_size X 1 X 1)
+ lengths = key_padding_mask.sum(dim=-1).view(bsz, 1, 1)
+ else:
+ lengths = src_len
+
+ # (target_sequence_length X source_sequence_length)
+ bias = self.rel_pos_bias(max(tgt_len, src_len))[:, :src_len]
+ if pidx is not None:
+ if query.size(1) != 1:
+ raise ValueError("Position offset provided with queries longer than 1 token")
+ # source_sequence_length
+ bias = bias[pidx]
+ else:
+ # (target_sequence_length X source_sequence_length)
+ bias = bias[:tgt_len]
+
+ # (batch_size X target_sequence_length X source_sequence_length)
+ qk = torch.bmm(query, key.transpose(1, 2)) / lengths + bias
+
+ attn_weights = ACT2FN[self.attention_activation](qk).type_as(qk)
+
+ if key_padding_mask is not None:
+ attn_weights = attn_weights * key_padding_mask.unsqueeze(1)
+
+ return attn_weights
+
+ def softmax_attention(self, query, key, key_padding_mask, pidx):
+ bsz, src_len, _ = key.size()
+ tgt_len = query.size(1) if pidx is None else pidx + 1
+
+ # (target_sequence_length X source_sequence_length)
+ bias = self.rel_pos_bias(max(tgt_len, src_len))[:, :src_len]
+ if pidx is not None:
+ if query.size(1) != 1:
+ raise ValueError("Position offset provided with queries longer than 1 token")
+ # source_sequence_length
+ bias = bias[pidx]
+ else:
+ # (target_sequence_length X source_sequence_length)
+ bias = bias[:tgt_len]
+
+ # scaled attention
+ query = query * self.scaling
+ # (batch_size X target_sequence_length X source_sequence_length)
+ qk = torch.bmm(query, key.transpose(1, 2)) + bias
+
+ if key_padding_mask is not None:
+ qk = qk.masked_fill((1 - key_padding_mask).unsqueeze(1).to(torch.bool), float("-inf"))
+
+ attn_weights = self.softmax(qk).type_as(qk)
+ return attn_weights
+
+ def forward(
+ self,
+ query,
+ key: Optional[torch.Tensor],
+ value: Optional[torch.Tensor],
+ key_padding_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ """
+ Gated cross-attention used in Mega
+
+ Args:
+ query (`torch.Tensor` of shape `(target_sequence_length, batch_size, hidden_size)`):
+ The self (or target) sequence input used as query inputs for cross-attention
+ key (`torch.Tensor` of shape `(source_sequence_length, batch_size, hidden_size)`):
+ The cross (or source) sequence input with shape used as keys in cross-attention
+ value (`torch.Tensor` of shape `(source_sequence_length, batch_size, hidden_size)`):
+ The cross (or source) sequence input with shape used as values in cross-attention
+ key_padding_mask (`torch.LongTensor` of shape `(batch_size, source_sequence_length)`, *optional*):
+ Padding mask corresponding to the source sequence, where entries are 1 for *not masked* and 0 for
+ *masked* tokens
+ past_key_values (`tuple(torch.FloatTensor)`, *optional*):
+ If provided, the hidden state returned from the previous timestep during incremental decoding; expects
+ that prior cross-attention keys and values will be the last two items in the tuple
+ output_attentions (`bool`, defaults to `False`):
+ Whether or not to return the cross-attention weights.
+ use_cache (`bool`, defaults to `False`):
+ Whether to perfom incremental decoding; uses `prev_state` as the prior timestep, and returns the
+ updated EMA hidden state for use in the next step
+
+ Returns:
+ `tuple(torch.FloatTensor)` containing various elements depending on configuration ([`MegaConfig`]) and
+ inputs:
+ - **hidden_states** (`torch.FloatTensor` of shape `(target_sequence_length, batch_size, hidden_size)`) --
+ Hidden states from target sequence updated by gated cross-attention
+ - **attn_weights** (*optional*, returned when `output_attentions=True`) `torch.FloatTensor` of shape
+ `(batch_size, source_sequence_length, target_sequence_length)` -- The pairwise cross-attention weights
+ corresponding to each token in the source and target sequences
+ - **cross_key** (*optional*, returned when `use_cache=True`) `torch.FloatTensor` of shape `(batch_size,
+ source_sequence_length, config.shared_representation_size)` -- The cross-attention key state for use in
+ the next step of incremental decoding
+ - **cross_value** (*optional*, returned when `use_cache=True`) `torch.FloatTensor` of shape `(batch_size,
+ source_sequence_length, config.hidden_size)` -- The cross-attention value state for use in the next step
+ of incremental decoding
+ """
+
+ seq_len, bsz, embed_dim = query.size()
+ if embed_dim != self.config.hidden_size:
+ raise ValueError(
+ f"Unexpected embedding dimension received: input is {embed_dim} but expected {self.config.hidden_size}"
+ )
+
+ if past_key_values is not None:
+ # make sure the inputs only have a sequence length of 1 if we're doing incremental decoding
+ if seq_len != 1:
+ raise ValueError(f"Incremental decoding requested with self-sequence length > 1: {seq_len}")
+ # expect past_key_values to have (self_key, self_value, self_ema, cross_key, cross_value)
+ prev_cross_key, prev_cross_value = past_key_values[-2:]
+ key = value = None
+
+ # use the self-attention cache to get the position id of the current step
+ prev_self_key = past_key_values[0]
+ num_incremental_steps = prev_self_key.size(1) + 1
+ else:
+ prev_cross_key = prev_cross_value = None
+ # we still need the position id if we're doing incremental decoding (past_key_values will be None for the first step)
+ num_incremental_steps = 0 if use_cache and (seq_len == 1) else None
+
+ full_query = query
+ if self.prenorm:
+ full_query = self.norm(full_query)
+
+ # (target_sequence_length X batch_size X 2*hidden_size + shared_representation_size)
+ query_projected = self.q_proj(full_query)
+ # split the query projections into separate components
+ # - residual_weight is passed through sigmoid and sent through elementwise multiplication to the gated/weighted targets prior to being added to the query directly
+ # - target_gate is a silu-gated tensor that is multiplied by the attention-weighted target below prior to residual connection
+ # - attention_query is the part that is passed to the attention function
+ residual_weight, target_gate, attention_query = torch.split(
+ query_projected,
+ [self.config.hidden_size, self.config.hidden_size, self.config.shared_representation_size],
+ dim=-1,
+ )
+
+ # (target_sequence_length X batch_size X hidden_size)
+ residual_weight = torch.sigmoid(residual_weight)
+ target_gate = F.silu(target_gate)
+
+ if key is None:
+ if value is not None:
+ raise ValueError("Key and value must be `None` simultaneously")
+ projected_key = projected_value = None
+ else:
+ # (source_sequence_length X batch_size X shared_representation_size)
+ projected_key = self.k_proj(key)
+ # (source_sequence_length X batch_size X hidden_size)
+ projected_value = self.activation(self.v_proj(key))
+
+ # (target_sequence_length X batch_size X shared_representation_size)
+ # -> (batch_size X target_sequence_length X shared_representation_size)
+ attention_query = attention_query.transpose(0, 1)
+ if projected_key is not None:
+ projected_key = projected_key.transpose(0, 1)
+ if projected_value is not None:
+ projected_value = projected_value.transpose(0, 1)
+
+ # if we're doing incremental decoding, k and v are None and need to be overwritten with past values
+ if past_key_values is not None:
+ projected_key = prev_cross_key
+ projected_value = prev_cross_value
+
+ # if we're returning the cache for later use, store these now for later return (can be done without having past_key_values provided)
+ if use_cache:
+ updated_cross_key = projected_key
+ updated_cross_value = projected_value
+
+ ctx_len = projected_key.size(1)
+ # This is part of a workaround to get around fork/join parallelism
+ # not supporting Optional types.
+ if key_padding_mask is not None and key_padding_mask.dim() == 0:
+ key_padding_mask = None
+
+ if key_padding_mask is not None:
+ if key_padding_mask.size(0) != bsz:
+ raise ValueError("Key padding mask does not align on the batch dimension")
+ if key_padding_mask.size(1) != ctx_len:
+ raise ValueError("Key padding mask does not align on the sequence length dimension")
+
+ if self.attention_activation == "softmax":
+ attn_weights = self.softmax_attention(
+ attention_query, projected_key, key_padding_mask, num_incremental_steps
+ )
+ else:
+ attn_weights = self.element_attention(
+ attention_query, projected_key, key_padding_mask, num_incremental_steps
+ )
+
+ projected_value = self.hidden_dropout(projected_value, batch_first=True)
+ kernel = self.attention_dropout(attn_weights)
+ # (batch_size X target_sequence_length X hidden_size)
+ # -> (target_sequence_length X batch_size X hidden_size)
+ weighted_targets = torch.bmm(kernel, projected_value).transpose(0, 1)
+ # (target_sequence_length X batch_size X hidden_size)
+ weighted_targets = self.activation(self.h_proj(weighted_targets * target_gate))
+ weighted_targets = self.dropout(weighted_targets)
+ out = torch.addcmul(query, residual_weight, weighted_targets - query)
+
+ if not self.prenorm:
+ out = self.norm(out)
+
+ outputs = (out, attn_weights) if output_attentions else (out,)
+ if use_cache:
+ outputs = outputs + (updated_cross_key, updated_cross_value)
+
+ return outputs
+
+
+class MegaMovingAverageGatedAttention(nn.Module):
+ """
+ Pure PyTorch implementation of Mega block; see https://arxiv.org/abs/2209.10655 and original fairseq implementation
+ at https://github.com/facebookresearch/mega (copyright Meta Research, licensed under MIT License)
+
+ Differences from original implementation include hidden state refactor and fixed inconsistency with additive /
+ multiplicative attention masks
+ """
+
+ def __init__(self, config: MegaConfig):
+ super().__init__()
+ self.config = config
+ self.activation = ACT2FN[self.config.activation]
+ self.scaling = (
+ self.config.shared_representation_size**-0.5 if self.config.attention_activation == "softmax" else None
+ )
+ self.dropout = MegaDropout(self.config.dropout_prob, is_featurewise=self.config.use_feature_dropout)
+ self.hidden_dropout = MegaDropout(
+ self.config.hidden_dropout_prob, is_featurewise=self.config.use_feature_dropout
+ )
+ # attention dropout is standard dropout
+ self.attention_dropout = MegaDropout(self.config.attention_probs_dropout_prob, is_featurewise=False)
+
+ self.norm = MegaSequenceNorm(
+ self.config.normalization_type, self.config.hidden_size, affine=self.config.norm_affine
+ )
+ self.ema_gate = MegaMultiDimensionDampedEma(config)
+
+ self.v_proj = nn.Linear(self.config.hidden_size, self.config.intermediate_size)
+ self.mx_proj = nn.Linear(
+ self.config.hidden_size,
+ self.config.shared_representation_size + self.config.intermediate_size + 2 * self.config.hidden_size,
+ )
+ self.h_proj = nn.Linear(self.config.intermediate_size, self.config.hidden_size)
+
+ self.qk_weight = nn.Parameter(torch.Tensor(2, self.config.shared_representation_size))
+ self.qk_bias = nn.Parameter(torch.Tensor(2, self.config.shared_representation_size))
+
+ if self.config.relative_positional_bias == "simple":
+ self.rel_pos_bias = MegaSimpleRelativePositionalBias(config)
+ elif self.config.relative_positional_bias == "rotary":
+ self.rel_pos_bias = MegaRotaryRelativePositionalBias(config)
+ else:
+ raise ValueError(f"Unknown relative positional bias: {self.config.relative_positional_bias}")
+
+ self.softmax = nn.Softmax(dim=-1)
+ self.attention_function = (
+ self.softmax_attention if self.config.attention_activation == "softmax" else self.element_attention
+ )
+
+ def element_attention(self, query, key, padding_mask, causal_mask):
+ """
+ Apply element-wise attention via relu^2 or laplace. Same as original implementation but with standardized
+ causal attention mask. Expects the Hugging Face standard attention mask paradigm: 1 for not masked, and 0 for
+ masked.
+ """
+ seq_len = key.size(2)
+ if padding_mask is not None:
+ # (batch_size X number of chunks X 1)
+ lengths = padding_mask.sum(-1, keepdim=True)
+ # (batch_size X number of chunks X 1 X 1)
+ lengths = lengths.clamp(min=1.0).unsqueeze(-1)
+ else:
+ lengths = seq_len
+
+ if causal_mask is not None:
+ lengths = causal_mask.sum(dim=-1, keepdim=True)
+
+ # (sequence_length X sequence_length)
+ bias = self.rel_pos_bias(seq_len)
+ if seq_len != query.size(2):
+ if query.size(2) != 1:
+ raise ValueError("Size mismatch between Q and K in element attention")
+ # (1 X sequence_length)
+ bias = bias[-1:]
+
+ # (batch_size X number of chunks X sequence_length X sequence_length)
+ qk = torch.matmul(query, key.transpose(2, 3)) / lengths + bias
+
+ attn_weights = ACT2FN[self.config.attention_activation](qk).type_as(qk)
+
+ if padding_mask is not None:
+ attn_weights = attn_weights * padding_mask.unsqueeze(2)
+
+ if causal_mask is not None:
+ attn_weights = attn_weights * causal_mask
+
+ return attn_weights
+
+ def softmax_attention(self, query, key, padding_mask, causal_mask):
+ "Standard softmax self-attention, as in the original Transformer paper"
+ seq_len = key.size(2)
+ # (sequence_length X sequence_length)
+ bias = self.rel_pos_bias(seq_len)
+ if seq_len != query.size(2):
+ if query.size(2) != 1:
+ raise ValueError("Size mismatch between Q and K in softmax attention")
+ # (1 X sequence_length)
+ bias = bias[-1:]
+
+ # scaled attention
+ query = query * self.scaling
+
+ # (batch_size x number of chunks x chunk_size x chunk_size) if chunking
+ # (batch_size x 1 x sequence_length x sequence_length) otherwise
+ qk = torch.matmul(query, key.transpose(2, 3)) + bias
+
+ # apply causal mask (presumed to be 1/0 for not masked / masked)
+ # additive, but convert to 0/-inf (which is not explicitly in the Mega source code)
+ if causal_mask is not None:
+ additive_causal_mask = torch.zeros_like(causal_mask, dtype=qk.dtype)
+ additive_causal_mask = additive_causal_mask.masked_fill((1 - causal_mask).bool(), float("-inf"))
+ qk = qk + additive_causal_mask
+
+ if padding_mask is not None:
+ # 1 for tokens which are *not masked*
+ # 0 for tokens which are *masked*
+ # replace masked tokens with -inf to make softmax ignore them
+ # need to invert the padding mask to match what mega original did
+ padding_mask = 1 - padding_mask
+ padding_mask_all = padding_mask.all(dim=-1, keepdim=True)
+ padding_mask = torch.logical_and(padding_mask, ~padding_mask_all)
+ qk = qk.masked_fill(padding_mask.unsqueeze(2).to(torch.bool), float("-inf"))
+
+ attn_weights = self.softmax(qk).type_as(qk)
+ return attn_weights
+
+ def forward(
+ self,
+ input,
+ padding_mask: Optional[torch.Tensor] = None,
+ causal_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions=False,
+ use_cache=False,
+ ):
+ """
+ Mega's self-attention block, which combines multi-headed EMA with traditional self-attention
+
+ Args:
+ input (`torch.Tensor` of shape `(sequence_length, batch_size, hidden_size)`):
+ Hidden states to be updated by Mega's self-attention
+ padding_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indicates which inputs are to be ignored due to padding, where elements are either 1 for *not masked*
+ or 0 for *masked*
+ causal_mask (`torch.LongTensor` of shape `(sequence_length, sequence_length)`, *optional*):
+ Indicates which inputs are to be ignored due to causal attention, where elements are either 1 for *not
+ masked* or 0 for *masked*
+ past_key_values (`tuple(torch.Tensor)`, *optional*):
+ The hidden states returned from the previous timestep during incremental decoding; expects that
+ self-attention key, value, and EMA states are the first 3 entries in the tuple
+ output_attentions (`bool`, default `False`):
+ Whether to return self-attention weights
+ use_cache (`bool`, default `False`):
+ Whether to perfom incremental decoding; uses `past_key_values` as prior state, and returns the updated
+ states for use in the next step
+
+ Returns:
+ `tuple(torch.FloatTensor)` containing various elements depending on configuration ([`MegaConfig`]) and
+ inputs:
+ - **hidden_states** (`torch.FloatTensor` of shape `(sequence_length, batch_size, hidden_size)`) -- Hidden
+ states from target sequence updated by Mega's self-attention
+ - **attn_weights** (*optional*, returned when `output_attentions=True`) `torch.FloatTensor` of shape
+ `(batch_size, 1, sequence_length, sequence_length)` -- The self-attention weights corresponding to how
+ each token in the input sequence attends to every other token
+ - **self_key** (*optional*, returned when `use_cache=True`) `torch.FloatTensor` of shape `(batch_size,
+ sequence_length, config.shared_representation_size)` -- The self-attention key state for use in the next
+ step of incremental decoding
+ - **self_value** (*optional*, returned when `use_cache=True`) `torch.FloatTensor` of shape `(batch_size,
+ sequence_length, config.hidden_size)` -- The self-attention value state for use in the next step of
+ incremental decoding
+ - **self_ema_state** (*optional*, returned when `use_cache=True`) `torch.FloatTensor` of shape
+ `(batch_size, config.ndim)` The incremental EMA state for use in the next step of incremental decoding.
+ """
+
+ seq_len, bsz, embed_dim = input.size()
+ if embed_dim != self.config.hidden_size:
+ raise ValueError(f"Input embedding dimension should be {self.config.hidden_size}; received {embed_dim}")
+
+ # store inputs for residual connection and handle pre-norm if requested
+ residual = input
+ if self.config.normalize_before_mega:
+ input = self.norm(input)
+
+ # (sequence_length X batch_size X hidden_size) -> (sequence_length X batch_size X intermediate_size)
+ value = self.activation(self.v_proj(input))
+
+ # unpack the incremental state if provided
+ # assumed to be (self K, self V, self EMA state, cross K, cross V)
+ # also assumes that incremental decoding is working one token at a time, so input sequence length must be 1
+ if self.config.is_decoder and (past_key_values is not None):
+ if seq_len > 1:
+ raise ValueError(f"Incremental decoding only supports self sequence length of 1; received {seq_len}")
+ # the first 3 items in the saved states will be these regardless of whether cross-attention is present
+ prev_self_key, prev_self_value, prev_ema_state = past_key_values[0:3]
+ else:
+ prev_self_key = prev_self_value = prev_ema_state = None
+
+ # ema output is (sequence_length x batch_size x hidden_size)
+ # updated_ema_state will be None if use_cache=False; otherwise (batch_size, config.ndim)
+ ema_out, updated_ema_state = self.ema_gate(
+ input, attention_mask=padding_mask, prev_state=prev_ema_state, use_cache=use_cache
+ )
+ ema_out = self.dropout(ema_out)
+
+ # (sequence_length X batch_size X hidden_size)
+ # -> (sequence_length X batch_size X 2*hidden_size + config.shared_representation_size + config.intermediate_size)
+ # - residual_weight -> sigmoid -> applied to residual connection in torch.addcmul
+ # - query_key_gates -> split into two components: query_key becomes query and key for attention input, gates becomes gating for self-attention output
+ # - intermediate_state -> added to weighted attention output, sent through activation, and has inputs subtracted during
+ # torch.addcmul to create the final layer output
+ base = self.mx_proj(ema_out)
+ residual_weight, query_key_gates, intermediate_state = torch.split(
+ base,
+ [
+ self.config.hidden_size,
+ self.config.shared_representation_size + self.config.intermediate_size,
+ self.config.hidden_size,
+ ],
+ dim=-1,
+ )
+
+ # (sequence_length X batch_size X hidden_size)
+ residual_weight = torch.sigmoid(residual_weight)
+
+ # (sequence_length X batch_size X shared_representation_size + intermediate_size)
+ query_key_gates = F.silu(query_key_gates)
+
+ # split into two different tensors: one for Q/K usage and the other for gating self-attention
+ query_key, attention_gate = torch.split(
+ query_key_gates, [self.config.shared_representation_size, self.config.intermediate_size], dim=-1
+ )
+
+ # (sequence_length X batch_size X shared_representation_size)
+ # -> (sequence_length X batch_size X 1 X shared_representation_size)
+ # -> (sequence_length X batch_size X 2 X shared_representation_size)
+ query_key = query_key.unsqueeze(2) * self.qk_weight + self.qk_bias
+
+ # (sequence_length X batch_size X 2 X shared_representation_size)
+ # -> 2 tensors of (sequence_length X batch_size X shared_representation_size)
+ query, key = torch.unbind(query_key, dim=2)
+
+ # (sequence_length X batch_size X dimension)
+ # -> (batch_size X sequence_length X dimension)
+ # where `dimension` is either shared_representation_size (queries and keys) or intermediate_size (values)
+ query = query.transpose(0, 1)
+ key = key.transpose(0, 1)
+ value = value.transpose(0, 1)
+
+ if self.config.is_decoder:
+ # combine history and current to save updated state (if history is provided)
+ # when chunking is applied, the past states will be None at the end of the chunk, in
+ # which case, proceed as if no K/V history had been provided
+ # saved states are stored with shape (batch_size X sequence_length X dimension)
+ if prev_self_key is not None:
+ key = torch.cat([prev_self_key, key], dim=1)
+ if prev_self_value is not None:
+ value = torch.cat([prev_self_value, value], dim=1)
+
+ # if not chunking, store as-is
+ if not self.config.use_chunking:
+ updated_self_key = key
+ updated_self_value = value
+ else:
+ curr_len = key.size(1) % self.config.chunk_size
+ if curr_len == 0:
+ # if we're chunking and have reached the end of a chunk, wipe out the saved state
+ updated_self_key = None
+ updated_self_value = None
+ else:
+ updated_self_key = key
+ updated_self_value = value
+
+ ctx_len = key.size(1) # potentially differs from seq_len because of incremental decoding
+ if not self.config.use_chunking:
+ # if we're not chunking, treat the entire sequence as one long chunk
+ # (batch_size X sequence_length X dimension) -> (batch_size X 1 X sequence_length X dimension)
+ query = query.unsqueeze(1)
+ key = key.unsqueeze(1)
+ value = value.unsqueeze(1)
+ if padding_mask is not None:
+ # (batch_size X sequence_length) -> (batch_size X 1 X sequence_length)
+ padding_mask = padding_mask.unsqueeze(1)
+ else:
+ # otherwise, split the sequences in the batch into `n_chunks` chunks of size `chunk_size`
+ if seq_len < self.config.chunk_size:
+ query = query.unsqueeze(1)
+ else:
+ # (batch_size X sequence_length X dimension) -> (batch_size X n_chunks X chunk_size X dimension)
+ n_chunks = seq_len // self.config.chunk_size
+ query = query.reshape(bsz, n_chunks, self.config.chunk_size, self.config.shared_representation_size)
+
+ if ctx_len < self.config.chunk_size:
+ key = key.unsqueeze(1)
+ value = value.unsqueeze(1)
+ if padding_mask is not None:
+ padding_mask = padding_mask.unsqueeze(1)
+ else:
+ # (batch_size X sequence_length X dimension) -> (batch_size X n_chunks X chunk_size X dimension)
+ n_chunks = ctx_len // self.config.chunk_size
+ key = key.reshape(bsz, n_chunks, self.config.chunk_size, self.config.shared_representation_size)
+ value = value.reshape(bsz, n_chunks, self.config.chunk_size, self.config.intermediate_size)
+ if padding_mask is not None:
+ padding_mask = padding_mask.view(bsz, n_chunks, self.config.chunk_size)
+
+ # this is in the original Mega implementation to work around fork/join parallelism not supporting optional types
+ if padding_mask is not None and padding_mask.dim() == 0:
+ padding_mask = None
+
+ attn_weights = self.attention_function(query, key, padding_mask=padding_mask, causal_mask=causal_mask)
+
+ value = self.hidden_dropout(value, batch_first=True)
+ kernel = self.attention_dropout(attn_weights)
+
+ # (batch_size x n_chunks x chunk_size x intermediate_size) -> (sequence_length X batch_size X intermediate_size)
+ weighted_self_output = (
+ torch.matmul(kernel, value).view(bsz, seq_len, self.config.intermediate_size).transpose(0, 1)
+ )
+
+ # (sequence_length X batch_size X intermediate_size) -> (sequence_length X batch_size X hidden_size)
+ weighted_self_output = self.activation(intermediate_state + self.h_proj(weighted_self_output * attention_gate))
+ weighted_self_output = self.dropout(weighted_self_output)
+ # (sequence_length X batch_size X hidden_size)
+ out = torch.addcmul(residual, residual_weight, weighted_self_output - residual)
+
+ if not self.config.normalize_before_mega:
+ out = self.norm(out)
+
+ return_values = (out, attn_weights) if output_attentions else (out,)
+
+ if self.config.is_decoder:
+ return_values = return_values + (updated_self_key, updated_self_value, updated_ema_state)
+
+ return return_values
+
+
+class MegaNormalizedFeedForwardNetwork(nn.Module):
+ """
+ Normalized feed-forward network used in Mega blocks. Left as-is from original Mega repo aside from retrieving args
+ from Hugging Face config
+ """
+
+ def __init__(self, config: MegaConfig):
+ super().__init__()
+
+ self.config = config
+ self.hidden_dim = config.nffn_hidden_size
+ self.act_fn = config.activation
+ self.activation = ACT2FN[config.activation]
+
+ self.dropout = MegaDropout(self.config.dropout_prob, is_featurewise=self.config.use_feature_dropout)
+ self.hidden_dropout = MegaDropout(
+ self.config.nffn_activation_dropout_prob, is_featurewise=self.config.use_feature_dropout
+ )
+
+ self.prenorm = self.config.normalize_before_ffn
+ self.norm = MegaSequenceNorm(
+ self.config.normalization_type, self.config.hidden_size, affine=self.config.norm_affine
+ )
+
+ self.fc1 = nn.Linear(self.config.hidden_size, self.config.nffn_hidden_size)
+ self.fc2 = nn.Linear(self.config.nffn_hidden_size, self.config.hidden_size)
+
+ def forward(self, inputs):
+ residual = inputs
+
+ if self.prenorm:
+ inputs = self.norm(inputs)
+
+ hidden = self.activation(self.fc1(inputs))
+ hidden = self.hidden_dropout(hidden)
+ output = self.fc2(hidden)
+ output = self.dropout(output)
+ output = output + residual
+
+ if not self.prenorm:
+ output = self.norm(output)
+
+ return output
+
+
+class MegaBlock(nn.Module):
+ def __init__(self, config: MegaConfig):
+ super().__init__()
+ self.seq_len_dim = 1
+ self.mega_layer = MegaMovingAverageGatedAttention(config)
+ self.nffn = MegaNormalizedFeedForwardNetwork(config) if config.use_normalized_ffn else None
+ self.is_decoder = config.is_decoder
+ self.add_cross_attention = config.add_cross_attention
+ if self.add_cross_attention:
+ if not self.is_decoder:
+ raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
+ self.cross_attn = MegaGatedCrossAttention(config)
+ else:
+ self.cross_attn = None
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.LongTensor] = None,
+ causal_mask: Optional[torch.LongTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ past_key_value: Optional[Tuple[torch.FloatTensor]] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: bool = False,
+ ) -> Tuple[torch.Tensor]:
+ """
+ A single Mega layer: either encoder or decoder, with optional cross-attention and optional normalized
+ feed-forward layer
+
+ Args:
+ hidden_states (`torch.Tensor` of shape `(target_sequence_length, batch_size, hidden_size)`):
+ Hidden states to be updated by the Mega block
+ attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Indicates which entries in the self/target sequence are to be ignored (mostly due to padding), where
+ elements are either 1 for *not masked* or 0 for *masked*. Causal attention is enforced internally.
+ causal_mask (`torch.LongTensor` of shape `(sequence_length, sequence_length)`, *optional*):
+ Indicates which inputs are to be ignored due to causal attention, where elements are either 1 for *not
+ masked* or 0 for *masked*
+ encoder_hidden_states (`torch.Tensor`, of shape `(source_sequence_length, batch_size, hidden_size)`, *optional*):
+ Encoder hidden states to be used for cross-attention (and required for encoder-decoder model setup)
+ encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, source_sequence_length)`, *optional*):
+ Indicates which entries in the cross/source sequence are to be ignored (mostly due to padding), where
+ elements are either 1 for *not masked* or 0 for *masked*.
+ past_key_value (`tuple(torch.Tensor)`, *optional*):
+ The hidden states returned from the previous timestep during incremental decoding; expects that
+ self-attention key, value, and EMA states are the first 3 entries in the tuple, and (if doing
+ cross-attention) cross-attention key and value are the last 2 entries in the tuple
+ output_attentions (`bool`, default `False`):
+ Whether to return self-attention weights
+ use_cache (`bool`, default `False`):
+ Whether to perfom incremental decoding; uses `past_key_value` as prior state, and returns the updated
+ states for use in the next step
+
+ Returns:
+ `tuple(torch.FloatTensor)` containing various elements depending on configuration ([`MegaConfig`]) and
+ inputs:
+ - **hidden_states** (`torch.FloatTensor` of shape `(target_sequence_length, batch_size, hidden_size)`) --
+ Hidden states from target sequence updated by Mega
+ - **self_attn_weights** (*optional*, returned when `output_attentions=True`) `torch.FloatTensor` of shape
+ `(batch_size, 1, target_sequence_length, target_sequence_length)` -- The self-attention weights
+ corresponding to how each token in the input sequence attends to every other token
+ - **cross_attn_weights** (*optional*, returned when `output_attentions=True` and
+ `config.add_cross_attention=True`) `torch.FloatTensor` of shape `(batch_size, source_sequence_length,
+ target_sequence_length)` -- Pairwise cross-attention weights between every entry in the source sequence
+ and target sequence
+ - **self_key** (*optional*, returned when `use_cache=True`) `torch.FloatTensor` of shape `(batch_size,
+ sequence_length, config.shared_representation_size)` -- The self-attention key state for use in the next
+ step of incremental decoding
+ - **self_value** (*optional*, returned when `use_cache=True`) `torch.FloatTensor` of shape `(batch_size,
+ sequence_length, config.hidden_size)` -- The self-attention value state for use in the next step of
+ incremental decoding
+ - **self_ema_state** (*optional*, returned when `use_cache=True`) `torch.FloatTensor` of shape
+ `(batch_size, config.ndim)` The incremental EMA state for use in the next step of incremental decoding.
+ - **cross_key** (*optional*, returned when `use_cache=True` and `config.is_decoder=True`)
+ `torch.FloatTensor` of shape `(batch_size, source_sequence_length, config.shared_representation_size)` --
+ The cross-attention key state for use in the next step of incremental decoding
+ - **cross_value** (*optional*, returned when `use_cache=True` and `config.is_decoder=True`)
+ `torch.FloatTensor` of shape `(batch_size, source_sequence_length, config.hidden_size)` -- The
+ cross-attention value state for use in the next step of incremental decoding
+ """
+
+ # incremental decoding in the MegaMultiDimensionDampedEma module requires that the attention mask has the same
+ # sequence length as the input tensor; if we're caching incremental states, we assume the input
+ # sequence length is 1 (Mega will break otherwise), so we take the padding mask for the final
+ # token in the input (mask is received as [batch X sequence length])
+ if use_cache and (past_key_value is not None) and (attention_mask is not None):
+ mega_padding_mask = attention_mask[:, -1].unsqueeze(-1)
+ else:
+ mega_padding_mask = attention_mask
+
+ mega_outputs = self.mega_layer(
+ input=hidden_states,
+ padding_mask=mega_padding_mask,
+ causal_mask=causal_mask,
+ past_key_values=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+
+ new_hidden_states = mega_outputs[0]
+ self_key, self_value, self_ema_state = mega_outputs[-3:] if use_cache else (None, None, None)
+ self_attention_weights = mega_outputs[1] if output_attentions else None
+
+ # optional cross attention
+ if self.cross_attn is not None:
+ if encoder_hidden_states is None:
+ raise ValueError("Requested cross-attention without providing encoder hidden states")
+
+ cross_attn_outputs = self.cross_attn(
+ query=new_hidden_states,
+ key=encoder_hidden_states,
+ value=encoder_hidden_states,
+ key_padding_mask=encoder_attention_mask,
+ past_key_values=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+
+ # update the hidden state from cross attention
+ new_hidden_states = cross_attn_outputs[0]
+ # store cross-attention k/v if caching
+ cross_key, cross_value = cross_attn_outputs[-2:] if use_cache else (None, None)
+ cross_attention_weights = cross_attn_outputs[1] if output_attentions else None
+
+ # optional NFFN follows cross attention
+ if self.nffn is not None:
+ new_hidden_states = self.nffn(new_hidden_states)
+
+ outs = (new_hidden_states,)
+ if output_attentions:
+ outs = outs + (self_attention_weights,)
+ if self.cross_attn is not None:
+ outs = outs + (cross_attention_weights,)
+
+ if use_cache:
+ new_key_values = (
+ self_key,
+ self_value,
+ self_ema_state,
+ )
+ if self.cross_attn is not None:
+ new_key_values = new_key_values + (cross_key, cross_value)
+
+ outs = outs + (new_key_values,)
+
+ return outs
+
+
+# copied from transformers.models.roberta.modeling_roberta.RobertaPooler with Roberta->Mega
+class MegaPooler(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.activation = nn.Tanh()
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ # We "pool" the model by simply taking the hidden state corresponding
+ # to the first token.
+ first_token_tensor = hidden_states[:, 0]
+ pooled_output = self.dense(first_token_tensor)
+ pooled_output = self.activation(pooled_output)
+ return pooled_output
+
+
+class MegaPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = MegaConfig
+ base_model_prefix = "mega"
+ supports_gradient_checkpointing = False
+ _no_split_modules = ["MegaMovingAverageGatedAttention"]
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, MegaMultiDimensionDampedEma):
+ with torch.no_grad():
+ # delta & alpha
+ nn.init.normal_(module.damping_factor, mean=0.0, std=self.config.ema_delta_alpha_range)
+ nn.init.normal_(module.decay_factor, mean=0.0, std=self.config.ema_delta_alpha_range)
+ # beta [1, -1, 1, -1, ...] seems more stable.
+ val = torch.ones(self.config.ema_projection_size, 1)
+ if self.config.ema_projection_size > 1:
+ idx = torch.tensor(list(range(1, self.config.ema_projection_size, 2)))
+ val.index_fill_(0, idx, -1.0)
+ module.ema_expansion_matrix.normal_(mean=0.0, std=self.config.ema_beta_range).add_(val)
+ # gamma & omega
+ nn.init.normal_(module.kernel_projection_matrix, mean=0.0, std=self.config.ema_gamma_omega_range)
+ nn.init.normal_(module.residual_weight, mean=0.0, std=self.config.ema_gamma_omega_range)
+ elif isinstance(module, MegaSimpleRelativePositionalBias):
+ nn.init.normal_(module.rel_pos_bias, mean=0.0, std=self.config.initializer_range)
+ elif isinstance(module, MegaRotaryRelativePositionalBias):
+ nn.init.normal_(module.alpha, mean=0.0, std=self.config.initializer_range)
+ nn.init.normal_(module.b_param, mean=0.0, std=self.config.initializer_range)
+ elif isinstance(module, MegaScaleNorm):
+ if self.config.norm_affine:
+ nn.init.constant_(module.scalar, 1.0)
+ elif isinstance(module, MegaRMSNorm):
+ if self.config.norm_affine:
+ nn.init.constant_(module.weight, 1.0)
+ elif isinstance(module, MegaMovingAverageGatedAttention):
+ # linear layers covered separately by the generic nn.Linear init below
+ nn.init.normal_(module.qk_weight, mean=0.0, std=self.config.initializer_range)
+ nn.init.constant_(module.qk_bias, 0.0)
+ elif isinstance(module, nn.Linear):
+ # initializes all linear layers in the entire network
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+
+MEGA_START_DOCSTRING = r"""
+
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`MegaConfig`]): Model configuration class with all the parameters of the
+ model. Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+MEGA_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `({0})`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,1]`:
+
+ - 0 corresponds to a *sentence A* token,
+ - 1 corresponds to a *sentence B* token.
+ This parameter can only be used when the model is initialized with `add_token_type_embeddings` parameter
+ set to `True`. All the value in this tensor should be always < config.type_vocab_size.
+
+ [What are token type IDs?](../glossary#token-type-ids)
+ inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare MEGA Model transformer outputting raw hidden-states without any specific head on top.",
+ MEGA_START_DOCSTRING,
+)
+class MegaModel(MegaPreTrainedModel):
+ """
+
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
+ cross-attention is added after self-attention, following the architecture described in *Mega: Moving Average
+ Equipped Gated Attention*_ by Xuezhe Ma, Chunting Zhou, Xiang Kong, Junxian He, Liangke Gui, Graham Neubig,
+ Jonathan May, and Luke Zettlemoyer
+
+ To behave as a decoder the model needs to be initialized with the `is_decoder` argument of the configuration set to
+ `True` and `bidirectional` set to `False`. To be used in a Seq2Seq model, the model needs to initialized with both
+ `is_decoder=True` and `bidirectional=False` argument as well as `add_cross_attention` set to `True`; an
+ `encoder_hidden_states` is then expected as an input to the forward pass.
+
+ .. _*Mega: Moving Average Equipped Gated Attention*: https://arxiv.org/abs/2209.10655
+
+ """
+
+ def __init__(self, config: MegaConfig, add_pooling_layer=True):
+ super().__init__(config)
+ self.config = config
+
+ self.embedding_layer = MegaEmbeddings(config)
+ self.layers = nn.ModuleList([MegaBlock(config) for _ in range(config.num_hidden_layers)])
+
+ self.pooler = MegaPooler(config) if add_pooling_layer else None
+
+ # Initialize weights and apply final processing (retained from RoBERTa code)
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embedding_layer.word_embeddings
+
+ def set_input_embeddings(self, value):
+ self.embedding_layer.word_embeddings = value
+
+ @add_start_docstrings_to_model_forward(MEGA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=BaseModelOutputWithPoolingAndCrossAttentions,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
+ r"""
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+ the model is configured as a decoder.
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ """
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
+ input_shape = input_ids.size()
+ device = input_ids.device
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ device = inputs_embeds.device
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ if self.config.use_chunking:
+ input_shape = torch.tensor([input_shape[0], self.config.chunk_size])
+
+ batch_size, sequence_length = input_shape
+
+ if self.config.use_chunking and (sequence_length > self.config.chunk_size):
+ if sequence_length % self.config.chunk_size != 0:
+ raise ValueError(
+ f"config.use_chunking is activated; input sequence length must be shorter than or a multiple of config.chunk_size\nreceived sequence length of {sequence_length} with chunk size {self.config.chunk_size}"
+ )
+
+ if self.config.is_decoder:
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ # Mega expects the causal mask to be a 2D square matrix of (from) x (to) over the input sequence length
+ # the HF utility function generates a 3D causal mask which includes batch size, so we'll create a dummy
+ # mask with the correct device and all ones
+ temp_mask_for_extension = torch.ones((1, sequence_length), dtype=torch.long, device=device)
+ causal_mask = self.create_extended_attention_mask_for_decoder(input_shape, temp_mask_for_extension)
+
+ # get rid of batch dimension in the generated mask; result is (sequence_length X sequence_length)
+ causal_mask = causal_mask.squeeze(0)
+ else:
+ use_cache = False
+ causal_mask = None
+
+ # if using cache, make sure we have a tuple of tuples which matches the length of our hidden layers
+ if (past_key_values is not None) and (len(past_key_values) != self.config.num_hidden_layers):
+ raise ValueError(
+ f"Received past key/value cache with size mismatch; expected {self.config.num_hidden_layers}, received {len(past_key_values)}"
+ )
+
+ # get embeddings (batch X sequence length X embed dim)
+ embedding_output = self.embedding_layer(
+ input_ids=input_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
+ )
+
+ # transpose for Mega --> (seq len X batch X embed dim)
+ hidden_states = embedding_output.transpose(0, 1)
+
+ # we expect encoder hidden states to also have batch first in line
+ # with typical Hugging Face behavior (which is also how we return them)
+ # Mega expects sequence length first, so do the same transpose here
+ if encoder_hidden_states is not None:
+ encoder_hidden_states = encoder_hidden_states.transpose(0, 1)
+
+ # pass through mega layers
+ all_hidden_states = (embedding_output,) if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
+ next_decoder_cache = () if use_cache else None
+ for i, mega_layer in enumerate(self.layers):
+ current_decoder_cache = past_key_values[i] if past_key_values is not None else None
+ mega_outputs = mega_layer(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ causal_mask=causal_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ past_key_value=current_decoder_cache,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+
+ hidden_states = mega_outputs[0]
+ if output_hidden_states:
+ # store layer-wise hidden states in the way that the user expects
+ # (seq len X batch X embed dim) --> (batch X seq len X embed dim)
+ all_hidden_states += (hidden_states.transpose(0, 1),)
+ if output_attentions:
+ self_attn_weights = mega_outputs[1]
+ all_self_attentions += (self_attn_weights,)
+ if self.config.add_cross_attention:
+ cross_attn_weights = mega_outputs[2]
+ all_cross_attentions += (cross_attn_weights,)
+ if use_cache:
+ updated_cache = mega_outputs[-1]
+ next_decoder_cache += (updated_cache,)
+
+ # transpose final hidden states
+ hidden_states = hidden_states.transpose(0, 1)
+
+ # optional pooling layer
+ pooled_output = self.pooler(hidden_states) if self.pooler is not None else None
+
+ if not return_dict:
+ return (hidden_states, pooled_output) + (
+ all_hidden_states,
+ next_decoder_cache,
+ all_self_attentions,
+ all_cross_attentions,
+ )
+
+ return BaseModelOutputWithPoolingAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ pooler_output=pooled_output,
+ past_key_values=next_decoder_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+@add_start_docstrings(
+ """MEGA Model with a `language modeling` head on top for CLM fine-tuning.""", MEGA_START_DOCSTRING
+)
+class MegaForCausalLM(MegaPreTrainedModel):
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config: MegaConfig):
+ super().__init__(config)
+
+ if not config.is_decoder:
+ logger.warning("If you want to use `MegaForCausalLM` as a standalone, add `is_decoder=True.`")
+
+ self.mega = MegaModel(config, add_pooling_layer=False)
+
+ if config.add_lm_hidden_dense_layer:
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.hidden_activation = nn.Tanh()
+ else:
+ self.dense = None
+ self.hidden_activation = None
+
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ @add_start_docstrings_to_model_forward(MEGA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ past_key_values: Tuple[Tuple[torch.FloatTensor]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
+ r"""
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+ the model is configured as a decoder.
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
+ `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
+ ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, MegaForCausalLM, AutoConfig
+ >>> import torch
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("mnaylor/mega-base-wikitext")
+ >>> config = AutoConfig.from_pretrained("mnaylor/mega-base-wikitext")
+ >>> config.is_decoder = True
+ >>> config.bidirectional = False
+ >>> model = MegaForCausalLM.from_pretrained(
+ ... "mnaylor/mega-base-wikitext", config=config, ignore_mismatched_sizes=True
+ ... )
+
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
+ >>> outputs = model(**inputs)
+
+ >>> prediction_logits = outputs.logits
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ if labels is not None:
+ use_cache = False
+
+ outputs = self.mega(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ inputs_embeds=inputs_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ if self.dense is not None:
+ sequence_output = self.dense(sequence_output)
+ sequence_output = self.hidden_activation(sequence_output)
+
+ prediction_scores = self.lm_head(sequence_output)
+
+ lm_loss = None
+ if labels is not None:
+ # we are doing next-token prediction; shift prediction scores and input ids by one
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
+ labels = labels[:, 1:].contiguous()
+ loss_fct = CrossEntropyLoss()
+ lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
+
+ if not return_dict:
+ output = (prediction_scores,) + outputs[2:]
+ return ((lm_loss,) + output) if lm_loss is not None else output
+
+ return CausalLMOutputWithCrossAttentions(
+ loss=lm_loss,
+ logits=prediction_scores,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ cross_attentions=outputs.cross_attentions,
+ )
+
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
+ input_shape = input_ids.shape
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
+ if attention_mask is None:
+ attention_mask = input_ids.new_ones(input_shape)
+
+ # cut decoder_input_ids if past is used
+ if past_key_values is not None:
+ input_ids = input_ids[:, -1:]
+
+ return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
+
+ def _reorder_cache(self, past_key_values, beam_idx):
+ reordered_past = ()
+ for layer_past in past_key_values:
+ reordered_past += (
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
+ )
+ return reordered_past
+
+
+@add_start_docstrings("""MEGA Model with a `language modeling` head on top.""", MEGA_START_DOCSTRING)
+class MegaForMaskedLM(MegaPreTrainedModel):
+ _tied_weights_keys = ["mlm_head.weight"]
+
+ def __init__(self, config: MegaConfig):
+ super().__init__(config)
+
+ if config.is_decoder:
+ logger.warning(
+ "If you want to use `MegaForMaskedLM`, set `config.is_decoder=False` for "
+ "bi-directional self-attention."
+ )
+
+ self.mega = MegaModel(config, add_pooling_layer=False)
+ if config.add_lm_hidden_dense_layer:
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.hidden_activation = nn.Tanh()
+ else:
+ self.dense = None
+ self.hidden_activation = None
+ self.mlm_head = nn.Linear(config.hidden_size, config.vocab_size)
+ self.dropout = nn.Dropout(config.dropout_prob)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_output_embeddings(self):
+ return self.mlm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.mlm_head = new_embeddings
+
+ @add_start_docstrings_to_model_forward(MEGA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=MaskedLMOutput,
+ config_class=_CONFIG_FOR_DOC,
+ mask="",
+ expected_output="' Paris'",
+ expected_loss=0.1,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
+ kwargs (`Dict[str, any]`, optional, defaults to *{}*):
+ Used to hide legacy arguments that have been deprecated.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.mega(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ inputs_embeds=inputs_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = outputs[0]
+ if self.dense is not None:
+ sequence_output = self.dense(sequence_output)
+ sequence_output = self.hidden_activation(sequence_output)
+ prediction_scores = self.mlm_head(sequence_output)
+
+ masked_lm_loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
+
+ if not return_dict:
+ output = (prediction_scores,) + outputs[2:]
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
+
+ return MaskedLMOutput(
+ loss=masked_lm_loss,
+ logits=prediction_scores,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ MEGA Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
+ output) e.g. for GLUE tasks.
+ """,
+ MEGA_START_DOCSTRING,
+)
+class MegaForSequenceClassification(MegaPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.config = config
+
+ self.mega = MegaModel(config, add_pooling_layer=False)
+ self.classifier = MegaClassificationHead(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(MEGA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=SequenceClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.mega(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = outputs[0]
+ logits = self.classifier(sequence_output)
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ MEGA Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
+ softmax) e.g. for RocStories/SWAG tasks.
+ """,
+ MEGA_START_DOCSTRING,
+)
+class MegaForMultipleChoice(MegaPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.mega = MegaModel(config)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ self.classifier = nn.Linear(config.hidden_size, 1)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(MEGA_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=MultipleChoiceModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
+ num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
+ `input_ids` above)
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
+
+ flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
+ flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
+ flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
+ flat_inputs_embeds = (
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
+ if inputs_embeds is not None
+ else None
+ )
+
+ outputs = self.mega(
+ flat_input_ids,
+ token_type_ids=flat_token_type_ids,
+ attention_mask=flat_attention_mask,
+ inputs_embeds=flat_inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ pooled_output = outputs[1]
+
+ pooled_output = self.dropout(pooled_output)
+ logits = self.classifier(pooled_output)
+ reshaped_logits = logits.view(-1, num_choices)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(reshaped_logits, labels)
+
+ if not return_dict:
+ output = (reshaped_logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return MultipleChoiceModelOutput(
+ loss=loss,
+ logits=reshaped_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ MEGA Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
+ Named-Entity-Recognition (NER) tasks.
+ """,
+ MEGA_START_DOCSTRING,
+)
+class MegaForTokenClassification(MegaPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.mega = MegaModel(config, add_pooling_layer=False)
+ classifier_dropout = (
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
+ )
+ self.dropout = nn.Dropout(classifier_dropout)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(MEGA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TokenClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.mega(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ sequence_output = self.dropout(sequence_output)
+ logits = self.classifier(sequence_output)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+# copied from transformers.models.roberta.modeling_roberta.RobertaClassificationHead with Roberta->Mega
+class MegaClassificationHead(nn.Module):
+ """Head for sentence-level classification tasks."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ classifier_dropout = (
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
+ )
+ self.dropout = nn.Dropout(classifier_dropout)
+ self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
+
+ def forward(self, features, **kwargs):
+ x = features[:, 0, :] # take token (equiv. to [CLS])
+ x = self.dropout(x)
+ x = self.dense(x)
+ x = torch.tanh(x)
+ x = self.dropout(x)
+ x = self.out_proj(x)
+ return x
+
+
+@add_start_docstrings(
+ """
+ MEGA Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
+ layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
+ """,
+ MEGA_START_DOCSTRING,
+)
+class MegaForQuestionAnswering(MegaPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.mega = MegaModel(config, add_pooling_layer=False)
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(MEGA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=QuestionAnsweringModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ start_positions: Optional[torch.LongTensor] = None,
+ end_positions: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
+ r"""
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
+ are not taken into account for computing the loss.
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
+ are not taken into account for computing the loss.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.mega(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ logits = self.qa_outputs(sequence_output)
+ start_logits, end_logits = logits.split(1, dim=-1)
+ start_logits = start_logits.squeeze(-1).contiguous()
+ end_logits = end_logits.squeeze(-1).contiguous()
+
+ total_loss = None
+ if start_positions is not None and end_positions is not None:
+ # If we are on multi-GPU, split add a dimension
+ if len(start_positions.size()) > 1:
+ start_positions = start_positions.squeeze(-1)
+ if len(end_positions.size()) > 1:
+ end_positions = end_positions.squeeze(-1)
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
+ ignored_index = start_logits.size(1)
+ start_positions = start_positions.clamp(0, ignored_index)
+ end_positions = end_positions.clamp(0, ignored_index)
+
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
+ start_loss = loss_fct(start_logits, start_positions)
+ end_loss = loss_fct(end_logits, end_positions)
+ total_loss = (start_loss + end_loss) / 2
+
+ if not return_dict:
+ output = (start_logits, end_logits) + outputs[2:]
+ return ((total_loss,) + output) if total_loss is not None else output
+
+ return QuestionAnsweringModelOutput(
+ loss=total_loss,
+ start_logits=start_logits,
+ end_logits=end_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+__all__ = [
+ "MegaForCausalLM",
+ "MegaForMaskedLM",
+ "MegaForMultipleChoice",
+ "MegaForQuestionAnswering",
+ "MegaForSequenceClassification",
+ "MegaForTokenClassification",
+ "MegaModel",
+ "MegaPreTrainedModel",
+]
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/mmbt/__init__.py b/docs/transformers/build/lib/transformers/models/deprecated/mmbt/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..03b556e2eddf6d5f81f6f4a15346596dd5a32f85
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/mmbt/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2020 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ....utils import _LazyModule
+from ....utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_mmbt import *
+ from .modeling_mmbt import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/mmbt/configuration_mmbt.py b/docs/transformers/build/lib/transformers/models/deprecated/mmbt/configuration_mmbt.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c58e4e6cdd0d0a8c87b9b94ca896c87970b5654
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/mmbt/configuration_mmbt.py
@@ -0,0 +1,45 @@
+# coding=utf-8
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Copyright (c) HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""MMBT configuration"""
+
+from ....utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class MMBTConfig:
+ """
+ This is the configuration class to store the configuration of a [`MMBTModel`]. It is used to instantiate a MMBT
+ model according to the specified arguments, defining the model architecture.
+
+ Args:
+ config ([`PreTrainedConfig`]):
+ Config of the underlying Transformer models. Its values are copied over to use a single config.
+ num_labels (`int`, *optional*):
+ Size of final Linear layer for classification.
+ modal_hidden_size (`int`, *optional*, defaults to 2048):
+ Embedding dimension of the non-text modality encoder.
+ """
+
+ def __init__(self, config, num_labels=None, modal_hidden_size=2048):
+ self.__dict__ = config.__dict__
+ self.modal_hidden_size = modal_hidden_size
+ if num_labels:
+ self.num_labels = num_labels
+
+
+__all__ = ["MMBTConfig"]
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/mmbt/modeling_mmbt.py b/docs/transformers/build/lib/transformers/models/deprecated/mmbt/modeling_mmbt.py
new file mode 100644
index 0000000000000000000000000000000000000000..45ae577f7fced2d276f4f54c5bf859e27e08ebae
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/mmbt/modeling_mmbt.py
@@ -0,0 +1,410 @@
+# coding=utf-8
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Copyright (c) HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch MMBT model."""
+
+import torch
+from torch import nn
+from torch.nn import CrossEntropyLoss, MSELoss
+
+from ....modeling_outputs import BaseModelOutputWithPooling, SequenceClassifierOutput
+from ....modeling_utils import ModuleUtilsMixin
+from ....utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
+
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "MMBTConfig"
+
+
+class ModalEmbeddings(nn.Module):
+ """Generic Modal Embeddings which takes in an encoder, and a transformer embedding."""
+
+ def __init__(self, config, encoder, embeddings):
+ super().__init__()
+ self.config = config
+ self.encoder = encoder
+ self.proj_embeddings = nn.Linear(config.modal_hidden_size, config.hidden_size)
+ self.position_embeddings = embeddings.position_embeddings
+ self.token_type_embeddings = embeddings.token_type_embeddings
+ self.word_embeddings = embeddings.word_embeddings
+ self.LayerNorm = embeddings.LayerNorm
+ self.dropout = nn.Dropout(p=config.hidden_dropout_prob)
+
+ def forward(self, input_modal, start_token=None, end_token=None, position_ids=None, token_type_ids=None):
+ token_embeddings = self.proj_embeddings(self.encoder(input_modal))
+ seq_length = token_embeddings.size(1)
+
+ if start_token is not None:
+ start_token_embeds = self.word_embeddings(start_token)
+ seq_length += 1
+ token_embeddings = torch.cat([start_token_embeds.unsqueeze(1), token_embeddings], dim=1)
+
+ if end_token is not None:
+ end_token_embeds = self.word_embeddings(end_token)
+ seq_length += 1
+ token_embeddings = torch.cat([token_embeddings, end_token_embeds.unsqueeze(1)], dim=1)
+
+ if position_ids is None:
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=input_modal.device)
+ position_ids = position_ids.unsqueeze(0).expand(input_modal.size(0), seq_length)
+
+ if token_type_ids is None:
+ token_type_ids = torch.zeros(
+ (input_modal.size(0), seq_length), dtype=torch.long, device=input_modal.device
+ )
+
+ position_embeddings = self.position_embeddings(position_ids)
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
+ embeddings = token_embeddings + position_embeddings + token_type_embeddings
+ embeddings = self.LayerNorm(embeddings)
+ embeddings = self.dropout(embeddings)
+ return embeddings
+
+
+MMBT_START_DOCSTRING = r"""
+ MMBT model was proposed in [Supervised Multimodal Bitransformers for Classifying Images and
+ Text](https://github.com/facebookresearch/mmbt) by Douwe Kiela, Suvrat Bhooshan, Hamed Firooz, Davide Testuggine.
+ It's a supervised multimodal bitransformer model that fuses information from text and other image encoders, and
+ obtain state-of-the-art performance on various multimodal classification benchmark tasks.
+
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`MMBTConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration.
+ transformer (`nn.Module`): A text transformer that is used by MMBT.
+ It should have embeddings, encoder, and pooler attributes.
+ encoder (`nn.Module`): Encoder for the second modality.
+ It should take in a batch of modal inputs and return k, n dimension embeddings.
+"""
+
+MMBT_INPUTS_DOCSTRING = r"""
+ Args:
+ input_modal (`torch.FloatTensor` of shape `(batch_size, ***)`):
+ The other modality data. It will be the shape that the encoder for that type expects. e.g. With an Image
+ Encoder, the shape would be (batch_size, channels, height, width)
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. It does not expect [CLS] token to be added as it's
+ appended to the end of other modality embeddings. Indices can be obtained using [`AutoTokenizer`]. See
+ [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ modal_start_tokens (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Optional start token to be added to Other Modality Embedding. [CLS] Most commonly used for classification
+ tasks.
+ modal_end_tokens (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Optional end token to be added to Other Modality Embedding. [SEP] Most commonly used.
+ attention_mask (*optional*) `torch.FloatTensor` of shape `(batch_size, sequence_length)`:
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ token_type_ids (*optional*) `torch.LongTensor` of shape `(batch_size, sequence_length)`:
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
+ 1]`:
+
+ - 0 corresponds to a *sentence A* token,
+ - 1 corresponds to a *sentence B* token.
+
+ [What are token type IDs?](../glossary#token-type-ids)
+ modal_token_type_ids (*optional*) `torch.LongTensor` of shape `(batch_size, modal_sequence_length)`:
+ Segment token indices to indicate different portions of the non-text modality. The embeddings from these
+ tokens will be summed with the respective token embeddings for the non-text modality.
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ modal_position_ids (`torch.LongTensor` of shape `(batch_size, modal_sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings for the non-text modality.
+ Selected in the range `[0, config.max_position_embeddings - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, embedding_dim)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+ the model is configured as a decoder.
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare MMBT Model outputting raw hidden-states without any specific head on top.",
+ MMBT_START_DOCSTRING,
+)
+class MMBTModel(nn.Module, ModuleUtilsMixin):
+ def __init__(self, config, transformer, encoder):
+ super().__init__()
+ self.config = config
+ self.transformer = transformer
+ self.modal_encoder = ModalEmbeddings(config, encoder, transformer.embeddings)
+
+ @add_start_docstrings_to_model_forward(MMBT_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_modal,
+ input_ids=None,
+ modal_start_tokens=None,
+ modal_end_tokens=None,
+ attention_mask=None,
+ token_type_ids=None,
+ modal_token_type_ids=None,
+ position_ids=None,
+ modal_position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ r"""
+ Returns:
+
+ Examples:
+
+ ```python
+ # For example purposes. Not runnable.
+ transformer = BertModel.from_pretrained("google-bert/bert-base-uncased")
+ encoder = ImageEncoder(args)
+ mmbt = MMBTModel(config, transformer, encoder)
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_txt_shape = input_ids.size()
+ elif inputs_embeds is not None:
+ input_txt_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+
+ modal_embeddings = self.modal_encoder(
+ input_modal,
+ start_token=modal_start_tokens,
+ end_token=modal_end_tokens,
+ position_ids=modal_position_ids,
+ token_type_ids=modal_token_type_ids,
+ )
+
+ input_modal_shape = modal_embeddings.size()[:-1]
+
+ if token_type_ids is None:
+ token_type_ids = torch.ones(input_txt_shape, dtype=torch.long, device=device)
+
+ txt_embeddings = self.transformer.embeddings(
+ input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
+ )
+
+ embedding_output = torch.cat([modal_embeddings, txt_embeddings], 1)
+
+ input_shape = embedding_output.size()[:-1]
+
+ if attention_mask is None:
+ attention_mask = torch.ones(input_shape, device=device)
+ else:
+ attention_mask = torch.cat(
+ [torch.ones(input_modal_shape, device=device, dtype=torch.long), attention_mask], dim=1
+ )
+ if encoder_attention_mask is None:
+ encoder_attention_mask = torch.ones(input_shape, device=device)
+ else:
+ encoder_attention_mask = torch.cat(
+ [torch.ones(input_modal_shape, device=device), encoder_attention_mask], dim=1
+ )
+
+ extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ encoder_outputs = self.transformer.encoder(
+ embedding_output,
+ attention_mask=extended_attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_extended_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = encoder_outputs[0]
+ pooled_output = self.transformer.pooler(sequence_output)
+
+ if not return_dict:
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+ def get_input_embeddings(self):
+ return self.embeddings.word_embeddings
+
+ def set_input_embeddings(self, value):
+ self.embeddings.word_embeddings = value
+
+
+@add_start_docstrings(
+ """
+ MMBT Model with a sequence classification/regression head on top (a linear layer on top of the pooled output)
+ """,
+ MMBT_START_DOCSTRING,
+ MMBT_INPUTS_DOCSTRING,
+)
+class MMBTForClassification(nn.Module):
+ r"""
+ **labels**: (*optional*) `torch.LongTensor` of shape `(batch_size,)`:
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+
+ Returns: *Tuple* comprising various elements depending on the configuration (config) and inputs: **loss**:
+ (*optional*, returned when `labels` is provided) `torch.FloatTensor` of shape `(1,)`: Classification (or
+ regression if config.num_labels==1) loss. **logits**:
+ `torch.FloatTensor` of shape `(batch_size, config.num_labels)` Classification (or regression if
+ config.num_labels==1) scores (before SoftMax).
+ **hidden_states**: (*optional*, returned when `output_hidden_states=True`) list of `torch.FloatTensor` (one for
+ the output of each layer + the output of the embeddings) of shape `(batch_size, sequence_length, hidden_size)`:
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs. **attentions**:
+ (*optional*, returned when `output_attentions=True`) list of `torch.FloatTensor` (one for each layer) of shape
+ `(batch_size, num_heads, sequence_length, sequence_length)`: Attentions weights after the attention softmax, used
+ to compute the weighted average in the self-attention heads.
+
+ Examples:
+
+ ```python
+ # For example purposes. Not runnable.
+ transformer = BertModel.from_pretrained("google-bert/bert-base-uncased")
+ encoder = ImageEncoder(args)
+ model = MMBTForClassification(config, transformer, encoder)
+ outputs = model(input_modal, input_ids, labels=labels)
+ loss, logits = outputs[:2]
+ ```"""
+
+ def __init__(self, config, transformer, encoder):
+ super().__init__()
+ self.num_labels = config.num_labels
+
+ self.mmbt = MMBTModel(config, transformer, encoder)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+
+ def forward(
+ self,
+ input_modal,
+ input_ids=None,
+ modal_start_tokens=None,
+ modal_end_tokens=None,
+ attention_mask=None,
+ token_type_ids=None,
+ modal_token_type_ids=None,
+ position_ids=None,
+ modal_position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ labels=None,
+ return_dict=None,
+ ):
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.mmbt(
+ input_modal=input_modal,
+ input_ids=input_ids,
+ modal_start_tokens=modal_start_tokens,
+ modal_end_tokens=modal_end_tokens,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ modal_token_type_ids=modal_token_type_ids,
+ position_ids=position_ids,
+ modal_position_ids=modal_position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ return_dict=return_dict,
+ )
+
+ pooled_output = outputs[1]
+
+ pooled_output = self.dropout(pooled_output)
+ logits = self.classifier(pooled_output)
+
+ loss = None
+ if labels is not None:
+ if self.num_labels == 1:
+ # We are doing regression
+ loss_fct = MSELoss()
+ loss = loss_fct(logits.view(-1), labels.view(-1))
+ else:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+__all__ = ["MMBTForClassification", "MMBTModel", "ModalEmbeddings"]
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/nat/__init__.py b/docs/transformers/build/lib/transformers/models/deprecated/nat/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c5373969ce7831491b0d5fa5495078fb1d3f6e4e
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/nat/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ....utils import _LazyModule
+from ....utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_nat import *
+ from .modeling_nat import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/nat/configuration_nat.py b/docs/transformers/build/lib/transformers/models/deprecated/nat/configuration_nat.py
new file mode 100644
index 0000000000000000000000000000000000000000..85961aa2fe8d0bc547b4919855ff8a28815a0fcd
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/nat/configuration_nat.py
@@ -0,0 +1,148 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Neighborhood Attention Transformer model configuration"""
+
+from ....configuration_utils import PretrainedConfig
+from ....utils import logging
+from ....utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
+
+
+logger = logging.get_logger(__name__)
+
+
+class NatConfig(BackboneConfigMixin, PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`NatModel`]. It is used to instantiate a Nat model
+ according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the Nat
+ [shi-labs/nat-mini-in1k-224](https://huggingface.co/shi-labs/nat-mini-in1k-224) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ patch_size (`int`, *optional*, defaults to 4):
+ The size (resolution) of each patch. NOTE: Only patch size of 4 is supported at the moment.
+ num_channels (`int`, *optional*, defaults to 3):
+ The number of input channels.
+ embed_dim (`int`, *optional*, defaults to 64):
+ Dimensionality of patch embedding.
+ depths (`List[int]`, *optional*, defaults to `[3, 4, 6, 5]`):
+ Number of layers in each level of the encoder.
+ num_heads (`List[int]`, *optional*, defaults to `[2, 4, 8, 16]`):
+ Number of attention heads in each layer of the Transformer encoder.
+ kernel_size (`int`, *optional*, defaults to 7):
+ Neighborhood Attention kernel size.
+ mlp_ratio (`float`, *optional*, defaults to 3.0):
+ Ratio of MLP hidden dimensionality to embedding dimensionality.
+ qkv_bias (`bool`, *optional*, defaults to `True`):
+ Whether or not a learnable bias should be added to the queries, keys and values.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
+ The dropout probability for all fully connected layers in the embeddings and encoder.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ drop_path_rate (`float`, *optional*, defaults to 0.1):
+ Stochastic depth rate.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder. If string, `"gelu"`, `"relu"`,
+ `"selu"` and `"gelu_new"` are supported.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the layer normalization layers.
+ layer_scale_init_value (`float`, *optional*, defaults to 0.0):
+ The initial value for the layer scale. Disabled if <=0.
+ out_features (`List[str]`, *optional*):
+ If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
+ (depending on how many stages the model has). If unset and `out_indices` is set, will default to the
+ corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the
+ same order as defined in the `stage_names` attribute.
+ out_indices (`List[int]`, *optional*):
+ If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
+ many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
+ If unset and `out_features` is unset, will default to the last stage. Must be in the
+ same order as defined in the `stage_names` attribute.
+
+ Example:
+
+ ```python
+ >>> from transformers import NatConfig, NatModel
+
+ >>> # Initializing a Nat shi-labs/nat-mini-in1k-224 style configuration
+ >>> configuration = NatConfig()
+
+ >>> # Initializing a model (with random weights) from the shi-labs/nat-mini-in1k-224 style configuration
+ >>> model = NatModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "nat"
+
+ attribute_map = {
+ "num_attention_heads": "num_heads",
+ "num_hidden_layers": "num_layers",
+ }
+
+ def __init__(
+ self,
+ patch_size=4,
+ num_channels=3,
+ embed_dim=64,
+ depths=[3, 4, 6, 5],
+ num_heads=[2, 4, 8, 16],
+ kernel_size=7,
+ mlp_ratio=3.0,
+ qkv_bias=True,
+ hidden_dropout_prob=0.0,
+ attention_probs_dropout_prob=0.0,
+ drop_path_rate=0.1,
+ hidden_act="gelu",
+ initializer_range=0.02,
+ layer_norm_eps=1e-5,
+ layer_scale_init_value=0.0,
+ out_features=None,
+ out_indices=None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.embed_dim = embed_dim
+ self.depths = depths
+ self.num_layers = len(depths)
+ self.num_heads = num_heads
+ self.kernel_size = kernel_size
+ self.mlp_ratio = mlp_ratio
+ self.qkv_bias = qkv_bias
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.drop_path_rate = drop_path_rate
+ self.hidden_act = hidden_act
+ self.layer_norm_eps = layer_norm_eps
+ self.initializer_range = initializer_range
+ # we set the hidden_size attribute in order to make Nat work with VisionEncoderDecoderModel
+ # this indicates the channel dimension after the last stage of the model
+ self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
+ self.layer_scale_init_value = layer_scale_init_value
+ self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)]
+ self._out_features, self._out_indices = get_aligned_output_features_output_indices(
+ out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
+ )
+
+
+__all__ = ["NatConfig"]
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/nat/modeling_nat.py b/docs/transformers/build/lib/transformers/models/deprecated/nat/modeling_nat.py
new file mode 100644
index 0000000000000000000000000000000000000000..70ecffcf51ea9afa07baf52b32181d9a803b9871
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/nat/modeling_nat.py
@@ -0,0 +1,953 @@
+# coding=utf-8
+# Copyright 2022 SHI Labs and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Neighborhood Attention Transformer model."""
+
+import math
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from ....activations import ACT2FN
+from ....modeling_outputs import BackboneOutput
+from ....modeling_utils import PreTrainedModel
+from ....pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
+from ....utils import (
+ ModelOutput,
+ OptionalDependencyNotAvailable,
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ is_natten_available,
+ logging,
+ replace_return_docstrings,
+ requires_backends,
+)
+from ....utils.backbone_utils import BackboneMixin
+from .configuration_nat import NatConfig
+
+
+if is_natten_available():
+ from natten.functional import natten2dav, natten2dqkrpb
+else:
+
+ def natten2dqkrpb(*args, **kwargs):
+ raise OptionalDependencyNotAvailable()
+
+ def natten2dav(*args, **kwargs):
+ raise OptionalDependencyNotAvailable()
+
+
+logger = logging.get_logger(__name__)
+
+# General docstring
+_CONFIG_FOR_DOC = "NatConfig"
+
+# Base docstring
+_CHECKPOINT_FOR_DOC = "shi-labs/nat-mini-in1k-224"
+_EXPECTED_OUTPUT_SHAPE = [1, 7, 7, 512]
+
+# Image classification docstring
+_IMAGE_CLASS_CHECKPOINT = "shi-labs/nat-mini-in1k-224"
+_IMAGE_CLASS_EXPECTED_OUTPUT = "tiger cat"
+
+
+# drop_path and NatDropPath are from the timm library.
+
+
+@dataclass
+class NatEncoderOutput(ModelOutput):
+ """
+ Nat encoder's outputs, with potential hidden states and attentions.
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
+ shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
+ shape `(batch_size, hidden_size, height, width)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
+ include the spatial dimensions.
+ """
+
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+ reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class NatModelOutput(ModelOutput):
+ """
+ Nat model's outputs that also contains a pooling of the last hidden states.
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*, returned when `add_pooling_layer=True` is passed):
+ Average pooling of the last layer hidden-state.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
+ shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
+ shape `(batch_size, hidden_size, height, width)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
+ include the spatial dimensions.
+ """
+
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ pooler_output: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+ reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class NatImageClassifierOutput(ModelOutput):
+ """
+ Nat outputs for image classification.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Classification (or regression if config.num_labels==1) loss.
+ logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
+ Classification (or regression if config.num_labels==1) scores (before SoftMax).
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
+ shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
+ shape `(batch_size, hidden_size, height, width)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
+ include the spatial dimensions.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+ reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+
+
+class NatEmbeddings(nn.Module):
+ """
+ Construct the patch and position embeddings.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+
+ self.patch_embeddings = NatPatchEmbeddings(config)
+
+ self.norm = nn.LayerNorm(config.embed_dim)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor]:
+ embeddings = self.patch_embeddings(pixel_values)
+ embeddings = self.norm(embeddings)
+
+ embeddings = self.dropout(embeddings)
+
+ return embeddings
+
+
+class NatPatchEmbeddings(nn.Module):
+ """
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
+ `hidden_states` (patch embeddings) of shape `(batch_size, height, width, hidden_size)` to be consumed by a
+ Transformer.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ patch_size = config.patch_size
+ num_channels, hidden_size = config.num_channels, config.embed_dim
+ self.num_channels = num_channels
+
+ if patch_size == 4:
+ pass
+ else:
+ # TODO: Support arbitrary patch sizes.
+ raise ValueError("Dinat only supports patch size of 4 at the moment.")
+
+ self.projection = nn.Sequential(
+ nn.Conv2d(self.num_channels, hidden_size // 2, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
+ nn.Conv2d(hidden_size // 2, hidden_size, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
+ )
+
+ def forward(self, pixel_values: Optional[torch.FloatTensor]) -> torch.Tensor:
+ _, num_channels, height, width = pixel_values.shape
+ if num_channels != self.num_channels:
+ raise ValueError(
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
+ )
+ embeddings = self.projection(pixel_values)
+ embeddings = embeddings.permute(0, 2, 3, 1)
+
+ return embeddings
+
+
+class NatDownsampler(nn.Module):
+ """
+ Convolutional Downsampling Layer.
+
+ Args:
+ dim (`int`):
+ Number of input channels.
+ norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`):
+ Normalization layer class.
+ """
+
+ def __init__(self, dim: int, norm_layer: nn.Module = nn.LayerNorm) -> None:
+ super().__init__()
+ self.dim = dim
+ self.reduction = nn.Conv2d(dim, 2 * dim, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
+ self.norm = norm_layer(2 * dim)
+
+ def forward(self, input_feature: torch.Tensor) -> torch.Tensor:
+ input_feature = self.reduction(input_feature.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
+ input_feature = self.norm(input_feature)
+ return input_feature
+
+
+def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
+ """
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+ Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
+ however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
+ layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
+ argument.
+ """
+ if drop_prob == 0.0 or not training:
+ return input
+ keep_prob = 1 - drop_prob
+ shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
+ random_tensor.floor_() # binarize
+ output = input.div(keep_prob) * random_tensor
+ return output
+
+
+class NatDropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob: Optional[float] = None) -> None:
+ super().__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ return drop_path(hidden_states, self.drop_prob, self.training)
+
+ def extra_repr(self) -> str:
+ return "p={}".format(self.drop_prob)
+
+
+class NeighborhoodAttention(nn.Module):
+ def __init__(self, config, dim, num_heads, kernel_size):
+ super().__init__()
+ if dim % num_heads != 0:
+ raise ValueError(
+ f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})"
+ )
+
+ self.num_attention_heads = num_heads
+ self.attention_head_size = int(dim / num_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+ self.kernel_size = kernel_size
+
+ # rpb is learnable relative positional biases; same concept is used Swin.
+ self.rpb = nn.Parameter(torch.zeros(num_heads, (2 * self.kernel_size - 1), (2 * self.kernel_size - 1)))
+
+ self.query = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
+ self.key = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
+ self.value = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+
+ def transpose_for_scores(self, x):
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
+ x = x.view(new_x_shape)
+ return x.permute(0, 3, 1, 2, 4)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor]:
+ query_layer = self.transpose_for_scores(self.query(hidden_states))
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+
+ # Apply the scale factor before computing attention weights. It's usually more efficient because
+ # attention weights are typically a bigger tensor compared to query.
+ # It gives identical results because scalars are commutable in matrix multiplication.
+ query_layer = query_layer / math.sqrt(self.attention_head_size)
+
+ # Compute NA between "query" and "key" to get the raw attention scores, and add relative positional biases.
+ attention_scores = natten2dqkrpb(query_layer, key_layer, self.rpb, self.kernel_size, 1)
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs)
+
+ context_layer = natten2dav(attention_probs, value_layer, self.kernel_size, 1)
+ context_layer = context_layer.permute(0, 2, 3, 1, 4).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(new_context_layer_shape)
+
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+ return outputs
+
+
+class NeighborhoodAttentionOutput(nn.Module):
+ def __init__(self, config, dim):
+ super().__init__()
+ self.dense = nn.Linear(dim, dim)
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+
+ return hidden_states
+
+
+class NeighborhoodAttentionModule(nn.Module):
+ def __init__(self, config, dim, num_heads, kernel_size):
+ super().__init__()
+ self.self = NeighborhoodAttention(config, dim, num_heads, kernel_size)
+ self.output = NeighborhoodAttentionOutput(config, dim)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
+ )
+
+ # Prune linear layers
+ self.self.query = prune_linear_layer(self.self.query, index)
+ self.self.key = prune_linear_layer(self.self.key, index)
+ self.self.value = prune_linear_layer(self.self.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor]:
+ self_outputs = self.self(hidden_states, output_attentions)
+ attention_output = self.output(self_outputs[0], hidden_states)
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+ return outputs
+
+
+class NatIntermediate(nn.Module):
+ def __init__(self, config, dim):
+ super().__init__()
+ self.dense = nn.Linear(dim, int(config.mlp_ratio * dim))
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ return hidden_states
+
+
+class NatOutput(nn.Module):
+ def __init__(self, config, dim):
+ super().__init__()
+ self.dense = nn.Linear(int(config.mlp_ratio * dim), dim)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ return hidden_states
+
+
+class NatLayer(nn.Module):
+ def __init__(self, config, dim, num_heads, drop_path_rate=0.0):
+ super().__init__()
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.kernel_size = config.kernel_size
+ self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps)
+ self.attention = NeighborhoodAttentionModule(config, dim, num_heads, kernel_size=self.kernel_size)
+ self.drop_path = NatDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
+ self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps)
+ self.intermediate = NatIntermediate(config, dim)
+ self.output = NatOutput(config, dim)
+ self.layer_scale_parameters = (
+ nn.Parameter(config.layer_scale_init_value * torch.ones((2, dim)), requires_grad=True)
+ if config.layer_scale_init_value > 0
+ else None
+ )
+
+ def maybe_pad(self, hidden_states, height, width):
+ window_size = self.kernel_size
+ pad_values = (0, 0, 0, 0, 0, 0)
+ if height < window_size or width < window_size:
+ pad_l = pad_t = 0
+ pad_r = max(0, window_size - width)
+ pad_b = max(0, window_size - height)
+ pad_values = (0, 0, pad_l, pad_r, pad_t, pad_b)
+ hidden_states = nn.functional.pad(hidden_states, pad_values)
+ return hidden_states, pad_values
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ batch_size, height, width, channels = hidden_states.size()
+ shortcut = hidden_states
+
+ hidden_states = self.layernorm_before(hidden_states)
+ # pad hidden_states if they are smaller than kernel size
+ hidden_states, pad_values = self.maybe_pad(hidden_states, height, width)
+
+ _, height_pad, width_pad, _ = hidden_states.shape
+
+ attention_outputs = self.attention(hidden_states, output_attentions=output_attentions)
+
+ attention_output = attention_outputs[0]
+
+ was_padded = pad_values[3] > 0 or pad_values[5] > 0
+ if was_padded:
+ attention_output = attention_output[:, :height, :width, :].contiguous()
+
+ if self.layer_scale_parameters is not None:
+ attention_output = self.layer_scale_parameters[0] * attention_output
+
+ hidden_states = shortcut + self.drop_path(attention_output)
+
+ layer_output = self.layernorm_after(hidden_states)
+ layer_output = self.output(self.intermediate(layer_output))
+
+ if self.layer_scale_parameters is not None:
+ layer_output = self.layer_scale_parameters[1] * layer_output
+
+ layer_output = hidden_states + self.drop_path(layer_output)
+
+ layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,)
+ return layer_outputs
+
+
+class NatStage(nn.Module):
+ def __init__(self, config, dim, depth, num_heads, drop_path_rate, downsample):
+ super().__init__()
+ self.config = config
+ self.dim = dim
+ self.layers = nn.ModuleList(
+ [
+ NatLayer(
+ config=config,
+ dim=dim,
+ num_heads=num_heads,
+ drop_path_rate=drop_path_rate[i],
+ )
+ for i in range(depth)
+ ]
+ )
+
+ # patch merging layer
+ if downsample is not None:
+ self.downsample = downsample(dim=dim, norm_layer=nn.LayerNorm)
+ else:
+ self.downsample = None
+
+ self.pointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor]:
+ _, height, width, _ = hidden_states.size()
+ for i, layer_module in enumerate(self.layers):
+ layer_outputs = layer_module(hidden_states, output_attentions)
+ hidden_states = layer_outputs[0]
+
+ hidden_states_before_downsampling = hidden_states
+ if self.downsample is not None:
+ hidden_states = self.downsample(hidden_states_before_downsampling)
+
+ stage_outputs = (hidden_states, hidden_states_before_downsampling)
+
+ if output_attentions:
+ stage_outputs += layer_outputs[1:]
+ return stage_outputs
+
+
+class NatEncoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.num_levels = len(config.depths)
+ self.config = config
+ dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")]
+ self.levels = nn.ModuleList(
+ [
+ NatStage(
+ config=config,
+ dim=int(config.embed_dim * 2**i_layer),
+ depth=config.depths[i_layer],
+ num_heads=config.num_heads[i_layer],
+ drop_path_rate=dpr[sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])],
+ downsample=NatDownsampler if (i_layer < self.num_levels - 1) else None,
+ )
+ for i_layer in range(self.num_levels)
+ ]
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ output_attentions: Optional[bool] = False,
+ output_hidden_states: Optional[bool] = False,
+ output_hidden_states_before_downsampling: Optional[bool] = False,
+ return_dict: Optional[bool] = True,
+ ) -> Union[Tuple, NatEncoderOutput]:
+ all_hidden_states = () if output_hidden_states else None
+ all_reshaped_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+
+ if output_hidden_states:
+ # rearrange b h w c -> b c h w
+ reshaped_hidden_state = hidden_states.permute(0, 3, 1, 2)
+ all_hidden_states += (hidden_states,)
+ all_reshaped_hidden_states += (reshaped_hidden_state,)
+
+ for i, layer_module in enumerate(self.levels):
+ layer_outputs = layer_module(hidden_states, output_attentions)
+
+ hidden_states = layer_outputs[0]
+ hidden_states_before_downsampling = layer_outputs[1]
+
+ if output_hidden_states and output_hidden_states_before_downsampling:
+ # rearrange b h w c -> b c h w
+ reshaped_hidden_state = hidden_states_before_downsampling.permute(0, 3, 1, 2)
+ all_hidden_states += (hidden_states_before_downsampling,)
+ all_reshaped_hidden_states += (reshaped_hidden_state,)
+ elif output_hidden_states and not output_hidden_states_before_downsampling:
+ # rearrange b h w c -> b c h w
+ reshaped_hidden_state = hidden_states.permute(0, 3, 1, 2)
+ all_hidden_states += (hidden_states,)
+ all_reshaped_hidden_states += (reshaped_hidden_state,)
+
+ if output_attentions:
+ all_self_attentions += layer_outputs[2:]
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
+
+ return NatEncoderOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ reshaped_hidden_states=all_reshaped_hidden_states,
+ )
+
+
+class NatPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = NatConfig
+ base_model_prefix = "nat"
+ main_input_name = "pixel_values"
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+
+NAT_START_DOCSTRING = r"""
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
+ it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+ behavior.
+
+ Parameters:
+ config ([`NatConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+NAT_INPUTS_DOCSTRING = r"""
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`]
+ for details.
+
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare Nat Model transformer outputting raw hidden-states without any specific head on top.",
+ NAT_START_DOCSTRING,
+)
+class NatModel(NatPreTrainedModel):
+ def __init__(self, config, add_pooling_layer=True):
+ super().__init__(config)
+
+ requires_backends(self, ["natten"])
+
+ self.config = config
+ self.num_levels = len(config.depths)
+ self.num_features = int(config.embed_dim * 2 ** (self.num_levels - 1))
+
+ self.embeddings = NatEmbeddings(config)
+ self.encoder = NatEncoder(config)
+
+ self.layernorm = nn.LayerNorm(self.num_features, eps=config.layer_norm_eps)
+ self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embeddings.patch_embeddings
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ @add_start_docstrings_to_model_forward(NAT_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=NatModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ modality="vision",
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
+ )
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, NatModelOutput]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ embedding_output = self.embeddings(pixel_values)
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = encoder_outputs[0]
+ sequence_output = self.layernorm(sequence_output)
+
+ pooled_output = None
+ if self.pooler is not None:
+ pooled_output = self.pooler(sequence_output.flatten(1, 2).transpose(1, 2))
+ pooled_output = torch.flatten(pooled_output, 1)
+
+ if not return_dict:
+ output = (sequence_output, pooled_output) + encoder_outputs[1:]
+
+ return output
+
+ return NatModelOutput(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ reshaped_hidden_states=encoder_outputs.reshaped_hidden_states,
+ )
+
+
+@add_start_docstrings(
+ """
+ Nat Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
+ the [CLS] token) e.g. for ImageNet.
+ """,
+ NAT_START_DOCSTRING,
+)
+class NatForImageClassification(NatPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ requires_backends(self, ["natten"])
+
+ self.num_labels = config.num_labels
+ self.nat = NatModel(config)
+
+ # Classifier head
+ self.classifier = (
+ nn.Linear(self.nat.num_features, config.num_labels) if config.num_labels > 0 else nn.Identity()
+ )
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(NAT_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_IMAGE_CLASS_CHECKPOINT,
+ output_type=NatImageClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
+ )
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, NatImageClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.nat(
+ pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ pooled_output = outputs[1]
+
+ logits = self.classifier(pooled_output)
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return NatImageClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ reshaped_hidden_states=outputs.reshaped_hidden_states,
+ )
+
+
+@add_start_docstrings(
+ "NAT backbone, to be used with frameworks like DETR and MaskFormer.",
+ NAT_START_DOCSTRING,
+)
+class NatBackbone(NatPreTrainedModel, BackboneMixin):
+ def __init__(self, config):
+ super().__init__(config)
+ super()._init_backbone(config)
+
+ requires_backends(self, ["natten"])
+
+ self.embeddings = NatEmbeddings(config)
+ self.encoder = NatEncoder(config)
+ self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
+
+ # Add layer norms to hidden states of out_features
+ hidden_states_norms = {}
+ for stage, num_channels in zip(self.out_features, self.channels):
+ hidden_states_norms[stage] = nn.LayerNorm(num_channels)
+ self.hidden_states_norms = nn.ModuleDict(hidden_states_norms)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embeddings.patch_embeddings
+
+ @add_start_docstrings_to_model_forward(NAT_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ output_hidden_states: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> BackboneOutput:
+ """
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoImageProcessor, AutoBackbone
+ >>> import torch
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> processor = AutoImageProcessor.from_pretrained("shi-labs/nat-mini-in1k-224")
+ >>> model = AutoBackbone.from_pretrained(
+ ... "shi-labs/nat-mini-in1k-224", out_features=["stage1", "stage2", "stage3", "stage4"]
+ ... )
+
+ >>> inputs = processor(image, return_tensors="pt")
+
+ >>> outputs = model(**inputs)
+
+ >>> feature_maps = outputs.feature_maps
+ >>> list(feature_maps[-1].shape)
+ [1, 512, 7, 7]
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+
+ embedding_output = self.embeddings(pixel_values)
+
+ outputs = self.encoder(
+ embedding_output,
+ output_attentions=output_attentions,
+ output_hidden_states=True,
+ output_hidden_states_before_downsampling=True,
+ return_dict=True,
+ )
+
+ hidden_states = outputs.reshaped_hidden_states
+
+ feature_maps = ()
+ for stage, hidden_state in zip(self.stage_names, hidden_states):
+ if stage in self.out_features:
+ # TODO can we simplify this?
+ batch_size, num_channels, height, width = hidden_state.shape
+ hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous()
+ hidden_state = hidden_state.view(batch_size, height * width, num_channels)
+ hidden_state = self.hidden_states_norms[stage](hidden_state)
+ hidden_state = hidden_state.view(batch_size, height, width, num_channels)
+ hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
+ feature_maps += (hidden_state,)
+
+ if not return_dict:
+ output = (feature_maps,)
+ if output_hidden_states:
+ output += (outputs.hidden_states,)
+ return output
+
+ return BackboneOutput(
+ feature_maps=feature_maps,
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
+ attentions=outputs.attentions,
+ )
+
+
+__all__ = ["NatForImageClassification", "NatModel", "NatPreTrainedModel", "NatBackbone"]
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/nezha/__init__.py b/docs/transformers/build/lib/transformers/models/deprecated/nezha/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f0690129ae9edf84829278790b5e065bbd6608ee
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/nezha/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ....utils import _LazyModule
+from ....utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_nezha import *
+ from .modeling_nezha import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/nezha/configuration_nezha.py b/docs/transformers/build/lib/transformers/models/deprecated/nezha/configuration_nezha.py
new file mode 100644
index 0000000000000000000000000000000000000000..00d193cd1ae68655e1631b7caea2220987a29f0b
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/nezha/configuration_nezha.py
@@ -0,0 +1,105 @@
+from .... import PretrainedConfig
+
+
+class NezhaConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of an [`NezhaModel`]. It is used to instantiate an Nezha
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the Nezha
+ [sijunhe/nezha-cn-base](https://huggingface.co/sijunhe/nezha-cn-base) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, optional, defaults to 21128):
+ Vocabulary size of the NEZHA model. Defines the different tokens that can be represented by the
+ *inputs_ids* passed to the forward method of [`NezhaModel`].
+ hidden_size (`int`, optional, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, optional, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, optional, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ intermediate_size (`int`, optional, defaults to 3072):
+ The dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ hidden_act (`str` or `function`, optional, defaults to "gelu"):
+ The non-linear activation function (function or string) in the encoder and pooler.
+ hidden_dropout_prob (`float`, optional, defaults to 0.1):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob (`float`, optional, defaults to 0.1):
+ The dropout ratio for the attention probabilities.
+ max_position_embeddings (`int`, optional, defaults to 512):
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
+ (e.g., 512 or 1024 or 2048).
+ type_vocab_size (`int`, optional, defaults to 2):
+ The vocabulary size of the *token_type_ids* passed into [`NezhaModel`].
+ initializer_range (`float`, optional, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, optional, defaults to 1e-12):
+ The epsilon used by the layer normalization layers.
+ classifier_dropout (`float`, optional, defaults to 0.1):
+ The dropout ratio for attached classifiers.
+ is_decoder (`bool`, *optional*, defaults to `False`):
+ Whether the model is used as a decoder or not. If `False`, the model is used as an encoder.
+
+ Example:
+
+ ```python
+ >>> from transformers import NezhaConfig, NezhaModel
+
+ >>> # Initializing an Nezha configuration
+ >>> configuration = NezhaConfig()
+
+ >>> # Initializing a model (with random weights) from the Nezha-base style configuration model
+ >>> model = NezhaModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "nezha"
+
+ def __init__(
+ self,
+ vocab_size=21128,
+ hidden_size=768,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ intermediate_size=3072,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ max_position_embeddings=512,
+ max_relative_position=64,
+ type_vocab_size=2,
+ initializer_range=0.02,
+ layer_norm_eps=1e-12,
+ classifier_dropout=0.1,
+ pad_token_id=0,
+ bos_token_id=2,
+ eos_token_id=3,
+ use_cache=True,
+ **kwargs,
+ ):
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
+
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.hidden_act = hidden_act
+ self.intermediate_size = intermediate_size
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.max_position_embeddings = max_position_embeddings
+ self.max_relative_position = max_relative_position
+ self.type_vocab_size = type_vocab_size
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.classifier_dropout = classifier_dropout
+ self.use_cache = use_cache
+
+
+__all__ = ["NezhaConfig"]
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/nezha/modeling_nezha.py b/docs/transformers/build/lib/transformers/models/deprecated/nezha/modeling_nezha.py
new file mode 100644
index 0000000000000000000000000000000000000000..7be52bee5847cb9c82b2efb32e8729f198f9ae75
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/nezha/modeling_nezha.py
@@ -0,0 +1,1697 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Nezha model."""
+
+import math
+import os
+import warnings
+from dataclasses import dataclass
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from ....activations import ACT2FN
+from ....modeling_outputs import (
+ BaseModelOutputWithPastAndCrossAttentions,
+ BaseModelOutputWithPoolingAndCrossAttentions,
+ MaskedLMOutput,
+ MultipleChoiceModelOutput,
+ NextSentencePredictorOutput,
+ QuestionAnsweringModelOutput,
+ SequenceClassifierOutput,
+ TokenClassifierOutput,
+)
+from ....modeling_utils import PreTrainedModel
+from ....pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
+from ....utils import (
+ ModelOutput,
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from .configuration_nezha import NezhaConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "sijunhe/nezha-cn-base"
+_CONFIG_FOR_DOC = "NezhaConfig"
+
+
+def load_tf_weights_in_nezha(model, config, tf_checkpoint_path):
+ """Load tf checkpoints in a pytorch model."""
+ try:
+ import re
+
+ import numpy as np
+ import tensorflow as tf
+ except ImportError:
+ logger.error(
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
+ "https://www.tensorflow.org/install/ for installation instructions."
+ )
+ raise
+ tf_path = os.path.abspath(tf_checkpoint_path)
+ logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
+ # Load weights from TF model
+ init_vars = tf.train.list_variables(tf_path)
+ names = []
+ arrays = []
+ for name, shape in init_vars:
+ logger.info(f"Loading TF weight {name} with shape {shape}")
+ array = tf.train.load_variable(tf_path, name)
+ names.append(name)
+ arrays.append(array)
+
+ for name, array in zip(names, arrays):
+ name = name.split("/")
+ # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
+ # which are not required for using pretrained model
+ if any(
+ n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
+ for n in name
+ ):
+ logger.info(f"Skipping {'/'.join(name)}")
+ continue
+ pointer = model
+ for m_name in name:
+ if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
+ scope_names = re.split(r"_(\d+)", m_name)
+ else:
+ scope_names = [m_name]
+ if scope_names[0] == "kernel" or scope_names[0] == "gamma":
+ pointer = getattr(pointer, "weight")
+ elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
+ pointer = getattr(pointer, "bias")
+ elif scope_names[0] == "output_weights":
+ pointer = getattr(pointer, "weight")
+ elif scope_names[0] == "squad":
+ pointer = getattr(pointer, "classifier")
+ else:
+ try:
+ pointer = getattr(pointer, scope_names[0])
+ except AttributeError:
+ logger.info(f"Skipping {'/'.join(name)}")
+ continue
+ if len(scope_names) >= 2:
+ num = int(scope_names[1])
+ pointer = pointer[num]
+ if m_name[-11:] == "_embeddings":
+ pointer = getattr(pointer, "weight")
+ elif m_name == "kernel":
+ array = np.transpose(array)
+ try:
+ if pointer.shape != array.shape:
+ raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
+ except AssertionError as e:
+ e.args += (pointer.shape, array.shape)
+ raise
+ logger.info(f"Initialize PyTorch weight {name}")
+ pointer.data = torch.from_numpy(array)
+ return model
+
+
+class NezhaRelativePositionsEncoding(nn.Module):
+ """Implement the Functional Relative Position Encoding"""
+
+ def __init__(self, length, depth, max_relative_position=127):
+ super().__init__()
+ vocab_size = max_relative_position * 2 + 1
+ range_vec = torch.arange(length)
+ range_mat = range_vec.repeat(length).view(length, length)
+ distance_mat = range_mat - torch.t(range_mat)
+ distance_mat_clipped = torch.clamp(distance_mat, -max_relative_position, max_relative_position)
+ final_mat = distance_mat_clipped + max_relative_position
+
+ embeddings_table = torch.zeros(vocab_size, depth)
+ position = torch.arange(0, vocab_size, dtype=torch.int64).float().unsqueeze(1)
+ div_term = torch.exp(torch.arange(0, depth, 2).float() * (-math.log(10000.0) / depth))
+ embeddings_table[:, 0::2] = torch.sin(position * div_term)
+ embeddings_table[:, 1::2] = torch.cos(position * div_term)
+
+ flat_relative_positions_matrix = final_mat.view(-1)
+ one_hot_relative_positions_matrix = torch.nn.functional.one_hot(
+ flat_relative_positions_matrix, num_classes=vocab_size
+ ).float()
+ positions_encoding = torch.matmul(one_hot_relative_positions_matrix, embeddings_table)
+ my_shape = list(final_mat.size())
+ my_shape.append(depth)
+ positions_encoding = positions_encoding.view(my_shape)
+ self.register_buffer("positions_encoding", positions_encoding, persistent=False)
+
+ def forward(self, length):
+ return self.positions_encoding[:length, :length, :]
+
+
+class NezhaEmbeddings(nn.Module):
+ """Construct the embeddings from word and token_type embeddings."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
+
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
+ # any TensorFlow checkpoint file
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ self.register_buffer(
+ "token_type_ids", torch.zeros((1, config.max_position_embeddings), dtype=torch.long), persistent=False
+ )
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ ) -> torch.Tensor:
+ if input_ids is not None:
+ input_shape = input_ids.size()
+ else:
+ input_shape = inputs_embeds.size()[:-1]
+
+ seq_length = input_shape[1]
+
+ if inputs_embeds is None:
+ inputs_embeds = self.word_embeddings(input_ids)
+
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
+ # issue #5664
+ if token_type_ids is None:
+ if hasattr(self, "token_type_ids"):
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
+ token_type_ids = buffered_token_type_ids_expanded
+ else:
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=inputs_embeds.device)
+
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
+
+ embeddings = inputs_embeds + token_type_embeddings
+ embeddings = self.LayerNorm(embeddings)
+ embeddings = self.dropout(embeddings)
+ return embeddings
+
+
+class NezhaSelfAttention(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ if config.hidden_size % config.num_attention_heads != 0:
+ raise ValueError(
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
+ f"heads ({config.num_attention_heads})"
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+ self.relative_positions_encoding = NezhaRelativePositionsEncoding(
+ length=config.max_position_embeddings,
+ depth=self.attention_head_size,
+ max_relative_position=config.max_relative_position,
+ )
+ self.is_decoder = config.is_decoder
+
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
+ x = x.view(new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor]:
+ mixed_query_layer = self.query(hidden_states)
+
+ # If this is instantiated as a cross-attention module, the keys
+ # and values come from an encoder; the attention mask needs to be
+ # such that the encoder's padding tokens are not attended to.
+ is_cross_attention = encoder_hidden_states is not None
+
+ if is_cross_attention and past_key_value is not None:
+ # reuse k,v, cross_attentions
+ key_layer = past_key_value[0]
+ value_layer = past_key_value[1]
+ attention_mask = encoder_attention_mask
+ elif is_cross_attention:
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
+ attention_mask = encoder_attention_mask
+ elif past_key_value is not None:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
+ else:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+
+ if self.is_decoder:
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_layer, value_layer)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+
+ batch_size, num_attention_heads, from_seq_length, to_seq_length = attention_scores.size()
+ relations_keys = self.relative_positions_encoding(to_seq_length)
+ query_layer_t = query_layer.permute(2, 0, 1, 3)
+
+ query_layer_r = query_layer_t.contiguous().view(
+ from_seq_length, batch_size * num_attention_heads, self.attention_head_size
+ )
+ key_position_scores = torch.matmul(query_layer_r, relations_keys.permute(0, 2, 1))
+ key_position_scores_r = key_position_scores.view(
+ from_seq_length, batch_size, num_attention_heads, from_seq_length
+ )
+ key_position_scores_r_t = key_position_scores_r.permute(1, 2, 0, 3)
+ attention_scores = attention_scores + key_position_scores_r_t
+
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+
+ if attention_mask is not None:
+ # Apply the attention mask is (precomputed for all layers in NezhaModel forward() function)
+ attention_scores = attention_scores + attention_mask
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs = attention_probs * head_mask
+
+ context_layer = torch.matmul(attention_probs, value_layer)
+ relations_values = self.relative_positions_encoding(to_seq_length)
+ attention_probs_t = attention_probs.permute(2, 0, 1, 3)
+ attentions_probs_r = attention_probs_t.contiguous().view(
+ from_seq_length, batch_size * num_attention_heads, to_seq_length
+ )
+ value_position_scores = torch.matmul(attentions_probs_r, relations_values)
+ value_position_scores_r = value_position_scores.view(
+ from_seq_length, batch_size, num_attention_heads, self.attention_head_size
+ )
+ value_position_scores_r_t = value_position_scores_r.permute(1, 2, 0, 3)
+ context_layer = context_layer + value_position_scores_r_t
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(new_context_layer_shape)
+
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+ if self.is_decoder:
+ outputs = outputs + (past_key_value,)
+ return outputs
+
+
+class NezhaSelfOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class NezhaAttention(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.self = NezhaSelfAttention(config)
+ self.output = NezhaSelfOutput(config)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
+ )
+
+ # Prune linear layers
+ self.self.query = prune_linear_layer(self.self.query, index)
+ self.self.key = prune_linear_layer(self.self.key, index)
+ self.self.value = prune_linear_layer(self.self.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor]:
+ self_outputs = self.self(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ )
+ attention_output = self.output(self_outputs[0], hidden_states)
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+ return outputs
+
+
+class NezhaIntermediate(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ return hidden_states
+
+
+class NezhaOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class NezhaLayer(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.seq_len_dim = 1
+ self.attention = NezhaAttention(config)
+ self.is_decoder = config.is_decoder
+ self.add_cross_attention = config.add_cross_attention
+ if self.add_cross_attention:
+ if not self.is_decoder:
+ raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
+ self.crossattention = NezhaAttention(config)
+ self.intermediate = NezhaIntermediate(config)
+ self.output = NezhaOutput(config)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor]:
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
+ self_attention_outputs = self.attention(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ output_attentions=output_attentions,
+ past_key_value=self_attn_past_key_value,
+ )
+ attention_output = self_attention_outputs[0]
+
+ # if decoder, the last output is tuple of self-attn cache
+ if self.is_decoder:
+ outputs = self_attention_outputs[1:-1]
+ present_key_value = self_attention_outputs[-1]
+ else:
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
+
+ cross_attn_present_key_value = None
+ if self.is_decoder and encoder_hidden_states is not None:
+ if not hasattr(self, "crossattention"):
+ raise ValueError(
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
+ " by setting `config.add_cross_attention=True`"
+ )
+
+ # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
+ cross_attention_outputs = self.crossattention(
+ attention_output,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ cross_attn_past_key_value,
+ output_attentions,
+ )
+ attention_output = cross_attention_outputs[0]
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
+
+ # add cross-attn cache to positions 3,4 of present_key_value tuple
+ cross_attn_present_key_value = cross_attention_outputs[-1]
+ present_key_value = present_key_value + cross_attn_present_key_value
+
+ layer_output = apply_chunking_to_forward(
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
+ )
+ outputs = (layer_output,) + outputs
+
+ # if decoder, return the attn key/values as the last output
+ if self.is_decoder:
+ outputs = outputs + (present_key_value,)
+
+ return outputs
+
+ def feed_forward_chunk(self, attention_output):
+ intermediate_output = self.intermediate(attention_output)
+ layer_output = self.output(intermediate_output, attention_output)
+ return layer_output
+
+
+class NezhaEncoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.layer = nn.ModuleList([NezhaLayer(config) for _ in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = False,
+ output_hidden_states: Optional[bool] = False,
+ return_dict: Optional[bool] = True,
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ next_decoder_cache = () if use_cache else None
+ for i, layer_module in enumerate(self.layer):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+ past_key_value = past_key_values[i] if past_key_values is not None else None
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ layer_module.__call__,
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ )
+ else:
+ layer_outputs = layer_module(
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ )
+
+ hidden_states = layer_outputs[0]
+ if use_cache:
+ next_decoder_cache += (layer_outputs[-1],)
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+ if self.config.add_cross_attention:
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [
+ hidden_states,
+ next_decoder_cache,
+ all_hidden_states,
+ all_self_attentions,
+ all_cross_attentions,
+ ]
+ if v is not None
+ )
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=next_decoder_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+class NezhaPooler(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.activation = nn.Tanh()
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ # We "pool" the model by simply taking the hidden state corresponding
+ # to the first token.
+ first_token_tensor = hidden_states[:, 0]
+ pooled_output = self.dense(first_token_tensor)
+ pooled_output = self.activation(pooled_output)
+ return pooled_output
+
+
+class NezhaPredictionHeadTransform(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ if isinstance(config.hidden_act, str):
+ self.transform_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.transform_act_fn = config.hidden_act
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.transform_act_fn(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states)
+ return hidden_states
+
+
+class NezhaLMPredictionHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.transform = NezhaPredictionHeadTransform(config)
+
+ # The output weights are the same as the input embeddings, but there is
+ # an output-only bias for each token.
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
+
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
+ self.decoder.bias = self.bias
+
+ def _tie_weights(self):
+ self.decoder.bias = self.bias
+
+ def forward(self, hidden_states):
+ hidden_states = self.transform(hidden_states)
+ hidden_states = self.decoder(hidden_states)
+ return hidden_states
+
+
+class NezhaOnlyMLMHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.predictions = NezhaLMPredictionHead(config)
+
+ def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
+ prediction_scores = self.predictions(sequence_output)
+ return prediction_scores
+
+
+class NezhaOnlyNSPHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
+
+ def forward(self, pooled_output):
+ seq_relationship_score = self.seq_relationship(pooled_output)
+ return seq_relationship_score
+
+
+class NezhaPreTrainingHeads(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.predictions = NezhaLMPredictionHead(config)
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
+
+ def forward(self, sequence_output, pooled_output):
+ prediction_scores = self.predictions(sequence_output)
+ seq_relationship_score = self.seq_relationship(pooled_output)
+ return prediction_scores, seq_relationship_score
+
+
+class NezhaPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = NezhaConfig
+ load_tf_weights = load_tf_weights_in_nezha
+ base_model_prefix = "nezha"
+ supports_gradient_checkpointing = True
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, nn.Linear):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+
+@dataclass
+class NezhaForPreTrainingOutput(ModelOutput):
+ """
+ Output type of [`NezhaForPreTraining`].
+
+ Args:
+ loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
+ Total loss as the sum of the masked language modeling loss and the next sequence prediction
+ (classification) loss.
+ prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ seq_relationship_logits (`torch.FloatTensor` of shape `(batch_size, 2)`):
+ Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
+ before SoftMax).
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
+ shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ prediction_logits: Optional[torch.FloatTensor] = None
+ seq_relationship_logits: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+
+
+NEZHA_START_DOCSTRING = r"""
+
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`NezhaConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+NEZHA_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `({0})`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
+ 1]`:
+
+ - 0 corresponds to a *sentence A* token,
+ - 1 corresponds to a *sentence B* token.
+
+ [What are token type IDs?](../glossary#token-type-ids)
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare Nezha Model transformer outputting raw hidden-states without any specific head on top.",
+ NEZHA_START_DOCSTRING,
+)
+class NezhaModel(NezhaPreTrainedModel):
+ """
+
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
+ cross-attention is added between the self-attention layers, following the architecture described in [Attention is
+ all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
+
+ To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
+ to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
+ `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
+ """
+
+ def __init__(self, config, add_pooling_layer=True):
+ super().__init__(config)
+ self.config = config
+
+ self.embeddings = NezhaEmbeddings(config)
+ self.encoder = NezhaEncoder(config)
+
+ self.pooler = NezhaPooler(config) if add_pooling_layer else None
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embeddings.word_embeddings
+
+ def set_input_embeddings(self, value):
+ self.embeddings.word_embeddings = value
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ @add_start_docstrings_to_model_forward(NEZHA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=BaseModelOutputWithPoolingAndCrossAttentions,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
+ r"""
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+ the model is configured as a decoder.
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if self.config.is_decoder:
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ else:
+ use_cache = False
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
+ input_shape = input_ids.size()
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ batch_size, seq_length = input_shape
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+
+ # past_key_values_length
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
+
+ if attention_mask is None:
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
+
+ if token_type_ids is None:
+ if hasattr(self.embeddings, "token_type_ids"):
+ buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
+ token_type_ids = buffered_token_type_ids_expanded
+ else:
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
+
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
+
+ # If a 2D or 3D attention mask is provided for the cross-attention
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if self.config.is_decoder and encoder_hidden_states is not None:
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
+ if encoder_attention_mask is None:
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+ else:
+ encoder_extended_attention_mask = None
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ embedding_output = self.embeddings(
+ input_ids=input_ids,
+ token_type_ids=token_type_ids,
+ inputs_embeds=inputs_embeds,
+ )
+ encoder_outputs = self.encoder(
+ embedding_output,
+ attention_mask=extended_attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_extended_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = encoder_outputs[0]
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
+
+ if not return_dict:
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPoolingAndCrossAttentions(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ past_key_values=encoder_outputs.past_key_values,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ cross_attentions=encoder_outputs.cross_attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ Nezha Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next
+ sentence prediction (classification)` head.
+ """,
+ NEZHA_START_DOCSTRING,
+)
+class NezhaForPreTraining(NezhaPreTrainedModel):
+ _tied_weights_keys = ["cls.predictions.decoder"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.nezha = NezhaModel(config)
+ self.cls = NezhaPreTrainingHeads(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_output_embeddings(self):
+ return self.cls.predictions.decoder
+
+ def set_output_embeddings(self, new_embeddings):
+ self.cls.predictions.decoder = new_embeddings
+ self.cls.predictions.bias = new_embeddings.bias
+
+ @add_start_docstrings_to_model_forward(NEZHA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @replace_return_docstrings(output_type=NezhaForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ next_sentence_label: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], NezhaForPreTrainingOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked),
+ the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
+ next_sentence_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the next sequence prediction (classification) loss. Input should be a sequence
+ pair (see `input_ids` docstring) Indices should be in `[0, 1]`:
+
+ - 0 indicates sequence B is a continuation of sequence A,
+ - 1 indicates sequence B is a random sequence.
+ kwargs (`Dict[str, any]`, optional, defaults to *{}*):
+ Used to hide legacy arguments that have been deprecated.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, NezhaForPreTraining
+ >>> import torch
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("sijunhe/nezha-cn-base")
+ >>> model = NezhaForPreTraining.from_pretrained("sijunhe/nezha-cn-base")
+
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
+ >>> outputs = model(**inputs)
+
+ >>> prediction_logits = outputs.prediction_logits
+ >>> seq_relationship_logits = outputs.seq_relationship_logits
+ ```
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.nezha(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output, pooled_output = outputs[:2]
+ prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
+
+ total_loss = None
+ if labels is not None and next_sentence_label is not None:
+ loss_fct = CrossEntropyLoss()
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
+ next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
+ total_loss = masked_lm_loss + next_sentence_loss
+
+ if not return_dict:
+ output = (prediction_scores, seq_relationship_score) + outputs[2:]
+ return ((total_loss,) + output) if total_loss is not None else output
+
+ return NezhaForPreTrainingOutput(
+ loss=total_loss,
+ prediction_logits=prediction_scores,
+ seq_relationship_logits=seq_relationship_score,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings("""Nezha Model with a `language modeling` head on top.""", NEZHA_START_DOCSTRING)
+class NezhaForMaskedLM(NezhaPreTrainedModel):
+ _tied_weights_keys = ["cls.predictions.decoder"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ if config.is_decoder:
+ logger.warning(
+ "If you want to use `NezhaForMaskedLM` make sure `config.is_decoder=False` for "
+ "bi-directional self-attention."
+ )
+
+ self.nezha = NezhaModel(config, add_pooling_layer=False)
+ self.cls = NezhaOnlyMLMHead(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_output_embeddings(self):
+ return self.cls.predictions.decoder
+
+ def set_output_embeddings(self, new_embeddings):
+ self.cls.predictions.decoder = new_embeddings
+ self.cls.predictions.bias = new_embeddings.bias
+
+ @add_start_docstrings_to_model_forward(NEZHA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=MaskedLMOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
+ """
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.nezha(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+ prediction_scores = self.cls(sequence_output)
+
+ masked_lm_loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
+
+ if not return_dict:
+ output = (prediction_scores,) + outputs[2:]
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
+
+ return MaskedLMOutput(
+ loss=masked_lm_loss,
+ logits=prediction_scores,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
+ input_shape = input_ids.shape
+ effective_batch_size = input_shape[0]
+
+ # add a dummy token
+ if self.config.pad_token_id is None:
+ raise ValueError("The PAD token should be defined for generation")
+
+ attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
+ dummy_token = torch.full(
+ (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
+ )
+ input_ids = torch.cat([input_ids, dummy_token], dim=1)
+
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
+
+
+@add_start_docstrings(
+ """Nezha Model with a `next sentence prediction (classification)` head on top.""",
+ NEZHA_START_DOCSTRING,
+)
+class NezhaForNextSentencePrediction(NezhaPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.nezha = NezhaModel(config)
+ self.cls = NezhaOnlyNSPHead(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(NEZHA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ **kwargs,
+ ) -> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
+ (see `input_ids` docstring). Indices should be in `[0, 1]`:
+
+ - 0 indicates sequence B is a continuation of sequence A,
+ - 1 indicates sequence B is a random sequence.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, NezhaForNextSentencePrediction
+ >>> import torch
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("sijunhe/nezha-cn-base")
+ >>> model = NezhaForNextSentencePrediction.from_pretrained("sijunhe/nezha-cn-base")
+
+ >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
+ >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
+ >>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt")
+
+ >>> outputs = model(**encoding, labels=torch.LongTensor([1]))
+ >>> logits = outputs.logits
+ >>> assert logits[0, 0] < logits[0, 1] # next sentence was random
+ ```
+ """
+
+ if "next_sentence_label" in kwargs:
+ warnings.warn(
+ "The `next_sentence_label` argument is deprecated and will be removed in a future version, use"
+ " `labels` instead.",
+ FutureWarning,
+ )
+ labels = kwargs.pop("next_sentence_label")
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.nezha(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ pooled_output = outputs[1]
+
+ seq_relationship_scores = self.cls(pooled_output)
+
+ next_sentence_loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1))
+
+ if not return_dict:
+ output = (seq_relationship_scores,) + outputs[2:]
+ return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output
+
+ return NextSentencePredictorOutput(
+ loss=next_sentence_loss,
+ logits=seq_relationship_scores,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ Nezha Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
+ output) e.g. for GLUE tasks.
+ """,
+ NEZHA_START_DOCSTRING,
+)
+class NezhaForSequenceClassification(NezhaPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.config = config
+
+ self.nezha = NezhaModel(config)
+ classifier_dropout = (
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
+ )
+ self.dropout = nn.Dropout(classifier_dropout)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(NEZHA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=SequenceClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.nezha(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ pooled_output = outputs[1]
+
+ pooled_output = self.dropout(pooled_output)
+ logits = self.classifier(pooled_output)
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ Nezha Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
+ softmax) e.g. for RocStories/SWAG tasks.
+ """,
+ NEZHA_START_DOCSTRING,
+)
+class NezhaForMultipleChoice(NezhaPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.nezha = NezhaModel(config)
+ classifier_dropout = (
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
+ )
+ self.dropout = nn.Dropout(classifier_dropout)
+ self.classifier = nn.Linear(config.hidden_size, 1)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(NEZHA_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=MultipleChoiceModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
+ num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
+ `input_ids` above)
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
+ input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
+ token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
+ inputs_embeds = (
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
+ if inputs_embeds is not None
+ else None
+ )
+
+ outputs = self.nezha(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ pooled_output = outputs[1]
+ print(pooled_output.shape)
+ pooled_output = self.dropout(pooled_output)
+ logits = self.classifier(pooled_output)
+ print(logits.shape)
+ print(num_choices)
+ reshaped_logits = logits.view(-1, num_choices)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(reshaped_logits, labels)
+
+ if not return_dict:
+ output = (reshaped_logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return MultipleChoiceModelOutput(
+ loss=loss,
+ logits=reshaped_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ Nezha Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
+ Named-Entity-Recognition (NER) tasks.
+ """,
+ NEZHA_START_DOCSTRING,
+)
+class NezhaForTokenClassification(NezhaPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.nezha = NezhaModel(config, add_pooling_layer=False)
+ classifier_dropout = (
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
+ )
+ self.dropout = nn.Dropout(classifier_dropout)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(NEZHA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TokenClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.nezha(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ sequence_output = self.dropout(sequence_output)
+ logits = self.classifier(sequence_output)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ Nezha Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
+ layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
+ """,
+ NEZHA_START_DOCSTRING,
+)
+class NezhaForQuestionAnswering(NezhaPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.nezha = NezhaModel(config, add_pooling_layer=False)
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(NEZHA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=QuestionAnsweringModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ start_positions: Optional[torch.Tensor] = None,
+ end_positions: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
+ r"""
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
+ are not taken into account for computing the loss.
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
+ are not taken into account for computing the loss.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.nezha(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ logits = self.qa_outputs(sequence_output)
+ start_logits, end_logits = logits.split(1, dim=-1)
+ start_logits = start_logits.squeeze(-1).contiguous()
+ end_logits = end_logits.squeeze(-1).contiguous()
+
+ total_loss = None
+ if start_positions is not None and end_positions is not None:
+ # If we are on multi-GPU, split add a dimension
+ if len(start_positions.size()) > 1:
+ start_positions = start_positions.squeeze(-1)
+ if len(end_positions.size()) > 1:
+ end_positions = end_positions.squeeze(-1)
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
+ ignored_index = start_logits.size(1)
+ start_positions = start_positions.clamp(0, ignored_index)
+ end_positions = end_positions.clamp(0, ignored_index)
+
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
+ start_loss = loss_fct(start_logits, start_positions)
+ end_loss = loss_fct(end_logits, end_positions)
+ total_loss = (start_loss + end_loss) / 2
+
+ if not return_dict:
+ output = (start_logits, end_logits) + outputs[2:]
+ return ((total_loss,) + output) if total_loss is not None else output
+
+ return QuestionAnsweringModelOutput(
+ loss=total_loss,
+ start_logits=start_logits,
+ end_logits=end_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+__all__ = [
+ "NezhaForNextSentencePrediction",
+ "NezhaForMaskedLM",
+ "NezhaForPreTraining",
+ "NezhaForMultipleChoice",
+ "NezhaForQuestionAnswering",
+ "NezhaForSequenceClassification",
+ "NezhaForTokenClassification",
+ "NezhaModel",
+ "NezhaPreTrainedModel",
+]
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/open_llama/__init__.py b/docs/transformers/build/lib/transformers/models/deprecated/open_llama/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b3964d194bed041987c6236b5c60bfcf3b7caf4
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/open_llama/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2023 EleutherAI and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ....utils import _LazyModule
+from ....utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_open_llama import *
+ from .modeling_open_llama import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/open_llama/configuration_open_llama.py b/docs/transformers/build/lib/transformers/models/deprecated/open_llama/configuration_open_llama.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4bc9cc72a7661fe571cce6649b40f2307d9b3ce
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/open_llama/configuration_open_llama.py
@@ -0,0 +1,169 @@
+# coding=utf-8
+# Copyright 2023 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Open-Llama model configuration"""
+
+from ....configuration_utils import PretrainedConfig
+from ....utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class OpenLlamaConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`OpenLlamaModel`]. It is used to instantiate an
+ Open-Llama model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the
+ [s-JoL/Open-Llama-V1](https://huggingface.co/s-JoL/Open-Llama-V1).
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 32000):
+ Vocabulary size of the Open-Llama model. Defines the number of different tokens that can be represented by
+ the `inputs_ids` passed when calling [`OpenLlamaModel`]
+ hidden_size (`int`, *optional*, defaults to 4096):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 11008):
+ Dimension of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 32):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 32):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
+ just in case (e.g., 512 or 1024 or 2048).
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-12):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ tie_word_embeddings(`bool`, *optional*, defaults to `False`):
+ Whether to tie weight embeddings
+ rope_theta (`float`, *optional*, defaults to 10000.0):
+ The base period of the RoPE embeddings.
+ rope_scaling (`Dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
+ strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
+ `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
+ `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
+ these scaling strategies behave:
+ https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
+ experimental feature, subject to breaking API changes in future versions.
+
+ Example:
+
+ ```python
+ >>> from transformers import OpenLlamaModel, OpenLlamaConfig
+
+ >>> # Initializing a Open-Llama open_llama-7b style configuration
+ >>> configuration = OpenLlamaConfig()
+
+ >>> # Initializing a model from the open_llama-7b style configuration
+ >>> model = OpenLlamaModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "open-llama"
+
+ def __init__(
+ self,
+ vocab_size=100000,
+ hidden_size=4096,
+ intermediate_size=11008,
+ num_hidden_layers=32,
+ num_attention_heads=32,
+ hidden_act="silu",
+ max_position_embeddings=2048,
+ initializer_range=0.02,
+ rms_norm_eps=1e-6,
+ use_cache=True,
+ pad_token_id=0,
+ bos_token_id=1,
+ eos_token_id=2,
+ tie_word_embeddings=False,
+ use_memory_efficient_attention=True,
+ hidden_dropout_prob=0.1,
+ attention_dropout_prob=0.1,
+ use_stable_embedding=True,
+ shared_input_output_embedding=True,
+ rope_theta=10000.0,
+ rope_scaling=None,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.use_memory_efficient_attention = kwargs.pop(
+ "use_memorry_efficient_attention", use_memory_efficient_attention
+ )
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_dropout_prob = attention_dropout_prob
+ self.use_stable_embedding = use_stable_embedding
+ self.shared_input_output_embedding = shared_input_output_embedding
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ self._rope_scaling_validation()
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+
+ def _rope_scaling_validation(self):
+ """
+ Validate the `rope_scaling` configuration.
+ """
+ if self.rope_scaling is None:
+ return
+
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
+ raise ValueError(
+ f"`rope_scaling` must be a dictionary with two fields, `type` and `factor`, got {self.rope_scaling}"
+ )
+ rope_scaling_type = self.rope_scaling.get("type", None)
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
+ raise ValueError(
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
+ )
+ if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
+ raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
+
+
+__all__ = ["OpenLlamaConfig"]
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/open_llama/modeling_open_llama.py b/docs/transformers/build/lib/transformers/models/deprecated/open_llama/modeling_open_llama.py
new file mode 100644
index 0000000000000000000000000000000000000000..79d79ea546a950bd3e6deb81b93ff6ee7b5d60a4
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/open_llama/modeling_open_llama.py
@@ -0,0 +1,975 @@
+# coding=utf-8
+# Copyright 2023 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Open-Llama model."""
+
+import math
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from ....activations import ACT2FN
+from ....modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
+from ....modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
+from ....modeling_utils import PreTrainedModel
+from ....utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
+from .configuration_open_llama import OpenLlamaConfig
+
+
+logger = logging.get_logger(__name__)
+
+try:
+ from xformers import ops as xops
+except ImportError:
+ xops = None
+
+
+_CONFIG_FOR_DOC = "OpenLlamaConfig"
+
+
+class OpenLlamaRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ OpenLlamaRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+class OpenLlamaRotaryEmbedding(nn.Module):
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
+ super().__init__()
+
+ self.dim = dim
+ self.max_position_embeddings = max_position_embeddings
+ self.base = base
+ inv_freq = 1.0 / (
+ self.base
+ ** (torch.arange(0, self.dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / self.dim)
+ )
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+
+ # Build here to make `torch.jit.trace` work.
+ self._set_cos_sin_cache(
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
+ )
+
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
+ self.max_seq_len_cached = seq_len
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
+
+ freqs = torch.outer(t, self.inv_freq)
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
+ emb = torch.cat((freqs, freqs), dim=-1)
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
+
+ def forward(self, x, seq_len=None):
+ # x: [bs, num_attention_heads, seq_len, head_size]
+ if seq_len > self.max_seq_len_cached:
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
+
+ return (
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
+ )
+
+
+class OpenLlamaLinearScalingRotaryEmbedding(OpenLlamaRotaryEmbedding):
+ """OpenLlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
+
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
+ self.scaling_factor = scaling_factor
+ super().__init__(dim, max_position_embeddings, base, device)
+
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
+ self.max_seq_len_cached = seq_len
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
+ t = t / self.scaling_factor
+
+ freqs = torch.outer(t, self.inv_freq)
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
+ emb = torch.cat((freqs, freqs), dim=-1)
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
+
+
+class OpenLlamaDynamicNTKScalingRotaryEmbedding(OpenLlamaRotaryEmbedding):
+ """OpenLlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
+
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
+ self.scaling_factor = scaling_factor
+ super().__init__(dim, max_position_embeddings, base, device)
+
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
+ self.max_seq_len_cached = seq_len
+
+ if seq_len > self.max_position_embeddings:
+ base = self.base * (
+ (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
+ ) ** (self.dim / (self.dim - 2))
+ inv_freq = 1.0 / (
+ base
+ ** (torch.arange(0, self.dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / self.dim)
+ )
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
+
+ freqs = torch.outer(t, self.inv_freq)
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
+ emb = torch.cat((freqs, freqs), dim=-1)
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`):
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
+ used to pass offsetted position ids when working with a KV-cache.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+class OpenLlamaMLP(nn.Module):
+ def __init__(
+ self,
+ hidden_size: int,
+ intermediate_size: int,
+ hidden_act: str,
+ dropout_prob: float,
+ ):
+ super().__init__()
+ self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
+ self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
+ self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
+ self.act_fn = ACT2FN[hidden_act]
+ self.dropout = nn.Dropout(dropout_prob)
+
+ def forward(self, x):
+ out = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+ return self.dropout(out)
+
+
+class OpenLlamaAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: OpenLlamaConfig):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.hidden_size // self.num_heads
+ self.max_position_embeddings = config.max_position_embeddings
+ self.dropout_prob = config.attention_dropout_prob
+ self.rope_theta = config.rope_theta
+
+ if (self.head_dim * self.num_heads) != self.hidden_size:
+ raise ValueError(
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
+ f" and `num_heads`: {self.num_heads})."
+ )
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
+ self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
+ self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
+ self._init_rope()
+
+ def _init_rope(self):
+ if self.config.rope_scaling is None:
+ self.rotary_emb = OpenLlamaRotaryEmbedding(
+ self.head_dim,
+ max_position_embeddings=self.max_position_embeddings,
+ base=self.rope_theta,
+ )
+ else:
+ scaling_type = self.config.rope_scaling["type"]
+ scaling_factor = self.config.rope_scaling["factor"]
+ if scaling_type == "linear":
+ self.rotary_emb = OpenLlamaLinearScalingRotaryEmbedding(
+ self.head_dim,
+ max_position_embeddings=self.max_position_embeddings,
+ scaling_factor=scaling_factor,
+ base=self.rope_theta,
+ )
+ elif scaling_type == "dynamic":
+ self.rotary_emb = OpenLlamaDynamicNTKScalingRotaryEmbedding(
+ self.head_dim,
+ max_position_embeddings=self.max_position_embeddings,
+ scaling_factor=scaling_factor,
+ base=self.rope_theta,
+ )
+ else:
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
+
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value[0].shape[-2]
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+ # [bsz, nh, t, hd]
+
+ if past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+
+ past_key_value = (key_states, value_states) if use_cache else None
+
+ if self.config.use_memory_efficient_attention and xops is not None and self.training:
+ attn_weights = None
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+ attn_output = xops.memory_efficient_attention(
+ query_states, key_states, value_states, attn_bias=xops.LowerTriangularMask(), p=self.dropout_prob
+ )
+ else:
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
+ )
+ attn_weights = attn_weights + attention_mask
+ attn_weights = torch.max(
+ attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device)
+ )
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2)
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class OpenLlamaDecoderLayer(nn.Module):
+ def __init__(self, config: OpenLlamaConfig):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.self_attn = OpenLlamaAttention(config=config)
+ self.mlp = OpenLlamaMLP(
+ hidden_size=self.hidden_size,
+ intermediate_size=config.intermediate_size,
+ hidden_act=config.hidden_act,
+ dropout_prob=config.hidden_dropout_prob,
+ )
+ self.input_layernorm = OpenLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = OpenLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+ """
+
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+
+OPEN_LLAMA_START_DOCSTRING = r"""
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`OpenLlamaConfig`]):
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
+ load the weights associated with the model, only the configuration. Check out the
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+@add_start_docstrings(
+ "The bare Open-Llama Model outputting raw hidden-states without any specific head on top.",
+ OPEN_LLAMA_START_DOCSTRING,
+)
+class OpenLlamaPreTrainedModel(PreTrainedModel):
+ config_class = OpenLlamaConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["OpenLlamaDecoderLayer"]
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ if self.config.use_stable_embedding:
+ torch.nn.init.xavier_normal_(module.weight.data)
+ else:
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+
+OPEN_LLAMA_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
+ `past_key_values`).
+
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
+ information on the default strategy.
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.n_positions - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare Open-Llama Model outputting raw hidden-states without any specific head on top.",
+ OPEN_LLAMA_START_DOCSTRING,
+)
+class OpenLlamaModel(OpenLlamaPreTrainedModel):
+ """
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`OpenLlamaDecoderLayer`]
+
+ Args:
+ config: OpenLlamaConfig
+ """
+
+ def __init__(self, config: OpenLlamaConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ if config.use_stable_embedding:
+ self.embed_layer_norm = nn.LayerNorm(config.hidden_size)
+ else:
+ self.embed_layer_norm = None
+ self.layers = nn.ModuleList([OpenLlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
+ self.norm = OpenLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ @add_start_docstrings_to_model_forward(OPEN_LLAMA_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+ elif input_ids is not None:
+ batch_size, seq_length = input_ids.shape
+ elif inputs_embeds is not None:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+
+ seq_length_with_past = seq_length
+ past_key_values_length = 0
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ if past_key_values is not None:
+ past_key_values_length = past_key_values[0][0].shape[2]
+ seq_length_with_past = seq_length_with_past + past_key_values_length
+
+ if position_ids is None:
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+ position_ids = torch.arange(
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
+ )
+ position_ids = position_ids.unsqueeze(0)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+ if self.embed_layer_norm:
+ inputs_embeds = self.embed_layer_norm(inputs_embeds)
+ # embed positions
+ if self.config.use_memory_efficient_attention and self.training:
+ attention_mask = None
+ elif attention_mask is None:
+ attention_mask = torch.ones(
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
+ )
+
+ input_shape = (batch_size, seq_length)
+ attention_mask = _prepare_4d_causal_attention_mask(
+ attention_mask, input_shape, inputs_embeds, past_key_values_length
+ )
+
+ hidden_states = inputs_embeds
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = () if use_cache else None
+
+ for idx, decoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ decoder_layer.__call__,
+ hidden_states,
+ attention_mask,
+ position_ids,
+ None,
+ output_attentions,
+ None,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = next_decoder_cache if use_cache else None
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+
+class OpenLlamaForCausalLM(OpenLlamaPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = OpenLlamaModel(config)
+ if config.shared_input_output_embedding:
+ self.lm_head = None
+ else:
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model = decoder
+
+ def get_decoder(self):
+ return self.model
+
+ @add_start_docstrings_to_model_forward(OPEN_LLAMA_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, OpenLlamaForCausalLM
+
+ >>> model = OpenLlamaForCausalLM.from_pretrained("openlm-research/open_llama_7b")
+ >>> tokenizer = AutoTokenizer.from_pretrained("openlm-research/open_llama_7b")
+
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
+ ```"""
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[0]
+ if self.config.shared_input_output_embedding:
+ logits = torch.einsum(
+ "blh,vh->blv", hidden_states.to(self.model.embed_tokens.weight.device), self.model.embed_tokens.weight
+ )
+ else:
+ logits = self.lm_head(hidden_states)
+
+ loss = None
+ if labels is not None:
+ # move labels to correct device to enable model parallelism
+ labels = labels.to(logits.device)
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss()
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
+ shift_labels = shift_labels.view(-1)
+ # Enable model parallelism
+ shift_labels = shift_labels.to(shift_logits.device)
+ loss = loss_fct(shift_logits, shift_labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
+ ):
+ if past_key_values is not None:
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = input_ids.shape[1] - 1
+
+ input_ids = input_ids[:, remove_prefix_length:]
+
+ position_ids = kwargs.get("position_ids", None)
+ if attention_mask is not None and position_ids is None:
+ # create position_ids on the fly for batch generation
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ if past_key_values:
+ position_ids = position_ids[:, -input_ids.shape[1] :]
+
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
+ if inputs_embeds is not None and past_key_values is None:
+ model_inputs = {"inputs_embeds": inputs_embeds}
+ else:
+ model_inputs = {"input_ids": input_ids}
+
+ model_inputs.update(
+ {
+ "position_ids": position_ids,
+ "past_key_values": past_key_values,
+ "use_cache": kwargs.get("use_cache"),
+ "attention_mask": attention_mask,
+ }
+ )
+ return model_inputs
+
+ @staticmethod
+ def _reorder_cache(past_key_values, beam_idx):
+ reordered_past = ()
+ for layer_past in past_key_values:
+ reordered_past += (
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
+ )
+ return reordered_past
+
+
+@add_start_docstrings(
+ """
+ The LLaMa Model transformer with a sequence classification head on top (linear layer).
+
+ [`OpenLlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal
+ models (e.g. GPT-2) do.
+
+ Since it does classification on the last token, it requires to know the position of the last token. If a
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
+ each row of the batch).
+ """,
+ OPEN_LLAMA_START_DOCSTRING,
+)
+class OpenLlamaForSequenceClassification(OpenLlamaPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.model = OpenLlamaModel(config)
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ @add_start_docstrings_to_model_forward(OPEN_LLAMA_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ hidden_states = transformer_outputs[0]
+ logits = self.score(hidden_states)
+
+ if input_ids is not None:
+ batch_size = input_ids.shape[0]
+ else:
+ batch_size = inputs_embeds.shape[0]
+
+ if self.config.pad_token_id is None and batch_size != 1:
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
+ if self.config.pad_token_id is None:
+ sequence_lengths = -1
+ else:
+ if input_ids is not None:
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
+ sequence_lengths = sequence_lengths.to(logits.device)
+ else:
+ sequence_lengths = -1
+
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
+
+ loss = None
+ if labels is not None:
+ labels = labels.to(logits.device)
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(pooled_logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(pooled_logits, labels)
+ if not return_dict:
+ output = (pooled_logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutputWithPast(
+ loss=loss,
+ logits=pooled_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+
+__all__ = ["OpenLlamaPreTrainedModel", "OpenLlamaModel", "OpenLlamaForCausalLM", "OpenLlamaForSequenceClassification"]
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/qdqbert/__init__.py b/docs/transformers/build/lib/transformers/models/deprecated/qdqbert/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..864b321bc2ee3a521e3d6da5403cab7363dea56b
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/qdqbert/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2021 NVIDIA Corporation and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ....utils import _LazyModule
+from ....utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_qdqbert import *
+ from .modeling_qdqbert import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/qdqbert/configuration_qdqbert.py b/docs/transformers/build/lib/transformers/models/deprecated/qdqbert/configuration_qdqbert.py
new file mode 100644
index 0000000000000000000000000000000000000000..91ac82bc5a0292355ac659548c7c3d4336f2536c
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/qdqbert/configuration_qdqbert.py
@@ -0,0 +1,123 @@
+# coding=utf-8
+# Copyright 2021 NVIDIA Corporation and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""QDQBERT model configuration"""
+
+from ....configuration_utils import PretrainedConfig
+from ....utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class QDQBertConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`QDQBertModel`]. It is used to instantiate an
+ QDQBERT model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the BERT
+ [google-bert/bert-base-uncased](https://huggingface.co/google-bert/bert-base-uncased) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 30522):
+ Vocabulary size of the QDQBERT model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`QDQBertModel`].
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimension of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ intermediate_size (`int`, *optional*, defaults to 3072):
+ Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for the attention probabilities.
+ max_position_embeddings (`int`, *optional*, defaults to 512):
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
+ just in case (e.g., 512 or 1024 or 2048).
+ type_vocab_size (`int`, *optional*, defaults to 2):
+ The vocabulary size of the `token_type_ids` passed when calling [`QDQBertModel`].
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
+ The epsilon used by the layer normalization layers.
+ is_decoder (`bool`, *optional*, defaults to `False`):
+ Whether the model is used as a decoder or not. If `False`, the model is used as an encoder.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+
+ Examples:
+
+ ```python
+ >>> from transformers import QDQBertModel, QDQBertConfig
+
+ >>> # Initializing a QDQBERT google-bert/bert-base-uncased style configuration
+ >>> configuration = QDQBertConfig()
+
+ >>> # Initializing a model from the google-bert/bert-base-uncased style configuration
+ >>> model = QDQBertModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "qdqbert"
+
+ def __init__(
+ self,
+ vocab_size=30522,
+ hidden_size=768,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ intermediate_size=3072,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ max_position_embeddings=512,
+ type_vocab_size=2,
+ initializer_range=0.02,
+ layer_norm_eps=1e-12,
+ use_cache=True,
+ pad_token_id=1,
+ bos_token_id=0,
+ eos_token_id=2,
+ **kwargs,
+ ):
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
+
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.initializer_range = initializer_range
+ self.type_vocab_size = type_vocab_size
+ self.layer_norm_eps = layer_norm_eps
+ self.use_cache = use_cache
+
+
+__all__ = ["QDQBertConfig"]
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/qdqbert/modeling_qdqbert.py b/docs/transformers/build/lib/transformers/models/deprecated/qdqbert/modeling_qdqbert.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b68a4e426d445f5c6e3c472e3dc5ea54e661d57
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/qdqbert/modeling_qdqbert.py
@@ -0,0 +1,1749 @@
+# coding=utf-8
+# Copyright 2021 NVIDIA Corporation and The HuggingFace Team.
+# Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch QDQBERT model."""
+
+import math
+import os
+import warnings
+from typing import Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from ....activations import ACT2FN
+from ....modeling_outputs import (
+ BaseModelOutputWithPastAndCrossAttentions,
+ BaseModelOutputWithPoolingAndCrossAttentions,
+ CausalLMOutputWithCrossAttentions,
+ MaskedLMOutput,
+ MultipleChoiceModelOutput,
+ NextSentencePredictorOutput,
+ QuestionAnsweringModelOutput,
+ SequenceClassifierOutput,
+ TokenClassifierOutput,
+)
+from ....modeling_utils import PreTrainedModel
+from ....pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
+from ....utils import (
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ is_pytorch_quantization_available,
+ logging,
+ replace_return_docstrings,
+ requires_backends,
+)
+from .configuration_qdqbert import QDQBertConfig
+
+
+logger = logging.get_logger(__name__)
+
+# soft dependency
+if is_pytorch_quantization_available():
+ try:
+ from pytorch_quantization import nn as quant_nn
+ from pytorch_quantization.nn.modules.tensor_quantizer import TensorQuantizer
+ except OSError:
+ logger.error(
+ "QDQBERT model are not usable since `pytorch_quantization` can't be loaded. Please try to reinstall it"
+ " following the instructions here:"
+ " https://github.com/NVIDIA/TensorRT/tree/master/tools/pytorch-quantization."
+ )
+
+_CHECKPOINT_FOR_DOC = "google-bert/bert-base-uncased"
+_CONFIG_FOR_DOC = "QDQBertConfig"
+
+
+def load_tf_weights_in_qdqbert(model, tf_checkpoint_path):
+ """Load tf checkpoints in a pytorch model."""
+ try:
+ import re
+
+ import numpy as np
+ import tensorflow as tf
+ except ImportError:
+ logger.error(
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
+ "https://www.tensorflow.org/install/ for installation instructions."
+ )
+ raise
+ tf_path = os.path.abspath(tf_checkpoint_path)
+ logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
+ # Load weights from TF model
+ init_vars = tf.train.list_variables(tf_path)
+ names = []
+ arrays = []
+ for name, shape in init_vars:
+ logger.info(f"Loading TF weight {name} with shape {shape}")
+ array = tf.train.load_variable(tf_path, name)
+ names.append(name)
+ arrays.append(array)
+
+ for name, array in zip(names, arrays):
+ name = name.split("/")
+ # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
+ # which are not required for using pretrained model
+ if any(
+ n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
+ for n in name
+ ):
+ logger.info(f"Skipping {'/'.join(name)}")
+ continue
+ pointer = model
+ for m_name in name:
+ if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
+ scope_names = re.split(r"_(\d+)", m_name)
+ else:
+ scope_names = [m_name]
+ if scope_names[0] == "kernel" or scope_names[0] == "gamma":
+ pointer = getattr(pointer, "weight")
+ elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
+ pointer = getattr(pointer, "bias")
+ elif scope_names[0] == "output_weights":
+ pointer = getattr(pointer, "weight")
+ elif scope_names[0] == "squad":
+ pointer = getattr(pointer, "classifier")
+ else:
+ try:
+ pointer = getattr(pointer, scope_names[0])
+ except AttributeError:
+ logger.info(f"Skipping {'/'.join(name)}")
+ continue
+ if len(scope_names) >= 2:
+ num = int(scope_names[1])
+ pointer = pointer[num]
+ if m_name[-11:] == "_embeddings":
+ pointer = getattr(pointer, "weight")
+ elif m_name == "kernel":
+ array = np.transpose(array)
+ try:
+ if pointer.shape != array.shape:
+ raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
+ except AssertionError as e:
+ e.args += (pointer.shape, array.shape)
+ raise
+ logger.info(f"Initialize PyTorch weight {name}")
+ pointer.data = torch.from_numpy(array)
+ return model
+
+
+class QDQBertEmbeddings(nn.Module):
+ """Construct the embeddings from word, position and token_type embeddings."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
+
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
+ # any TensorFlow checkpoint file
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
+ self.register_buffer(
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
+ )
+ self.register_buffer(
+ "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
+ )
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ past_key_values_length: int = 0,
+ ) -> torch.Tensor:
+ if input_ids is not None:
+ input_shape = input_ids.size()
+ else:
+ input_shape = inputs_embeds.size()[:-1]
+
+ seq_length = input_shape[1]
+
+ if position_ids is None:
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
+
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
+ # issue #5664
+ if token_type_ids is None:
+ if hasattr(self, "token_type_ids"):
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
+ token_type_ids = buffered_token_type_ids_expanded
+ else:
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.word_embeddings(input_ids)
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
+
+ embeddings = inputs_embeds + token_type_embeddings
+ if self.position_embedding_type == "absolute":
+ position_embeddings = self.position_embeddings(position_ids)
+ embeddings += position_embeddings
+ embeddings = self.LayerNorm(embeddings)
+ embeddings = self.dropout(embeddings)
+ return embeddings
+
+
+class QDQBertSelfAttention(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+ raise ValueError(
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
+ f"heads ({config.num_attention_heads})"
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.query = quant_nn.QuantLinear(config.hidden_size, self.all_head_size)
+ self.key = quant_nn.QuantLinear(config.hidden_size, self.all_head_size)
+ self.value = quant_nn.QuantLinear(config.hidden_size, self.all_head_size)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+ self.max_position_embeddings = config.max_position_embeddings
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
+
+ self.is_decoder = config.is_decoder
+
+ self.matmul_q_input_quantizer = TensorQuantizer(quant_nn.QuantLinear.default_quant_desc_input)
+ self.matmul_k_input_quantizer = TensorQuantizer(quant_nn.QuantLinear.default_quant_desc_input)
+ self.matmul_v_input_quantizer = TensorQuantizer(quant_nn.QuantLinear.default_quant_desc_input)
+ self.matmul_a_input_quantizer = TensorQuantizer(quant_nn.QuantLinear.default_quant_desc_input)
+
+ def transpose_for_scores(self, x):
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
+ x = x.view(*new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=None,
+ output_attentions=False,
+ ):
+ mixed_query_layer = self.query(hidden_states)
+
+ # If this is instantiated as a cross-attention module, the keys
+ # and values come from an encoder; the attention mask needs to be
+ # such that the encoder's padding tokens are not attended to.
+ is_cross_attention = encoder_hidden_states is not None
+
+ if is_cross_attention and past_key_value is not None:
+ # reuse k,v, cross_attentions
+ key_layer = past_key_value[0]
+ value_layer = past_key_value[1]
+ attention_mask = encoder_attention_mask
+ elif is_cross_attention:
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
+ attention_mask = encoder_attention_mask
+ elif past_key_value is not None:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
+ else:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+
+ if self.is_decoder:
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_layer, value_layer)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(
+ self.matmul_q_input_quantizer(query_layer), self.matmul_k_input_quantizer(key_layer.transpose(-1, -2))
+ )
+
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+ seq_length = hidden_states.size()[1]
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
+ distance = position_ids_l - position_ids_r
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
+
+ if self.position_embedding_type == "relative_key":
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+ attention_scores = attention_scores + relative_position_scores
+ elif self.position_embedding_type == "relative_key_query":
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
+
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+ if attention_mask is not None:
+ # Apply the attention mask is (precomputed for all layers in QDQBertModel forward() function)
+ attention_scores = attention_scores + attention_mask
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs = attention_probs * head_mask
+
+ context_layer = torch.matmul(
+ self.matmul_a_input_quantizer(attention_probs), self.matmul_v_input_quantizer(value_layer)
+ )
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(*new_context_layer_shape)
+
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+ if self.is_decoder:
+ outputs = outputs + (past_key_value,)
+ return outputs
+
+
+class QDQBertSelfOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ # Quantize Linear layer
+ self.dense = quant_nn.QuantLinear(config.hidden_size, config.hidden_size)
+
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ # Quantize the inputs to the residual add
+ self.add_local_input_quantizer = TensorQuantizer(quant_nn.QuantLinear.default_quant_desc_input)
+ self.add_residual_input_quantizer = TensorQuantizer(quant_nn.QuantLinear.default_quant_desc_input)
+
+ def forward(self, hidden_states, input_tensor):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ # Quantize the inputs to the residual add
+ add_local = self.add_local_input_quantizer(hidden_states)
+ add_residual = self.add_residual_input_quantizer(input_tensor)
+ hidden_states = self.LayerNorm(add_local + add_residual)
+ return hidden_states
+
+
+# Based on transformers.models.bert.modeling_bert.BertAttention with Bert -> QDQBert
+class QDQBertAttention(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.self = QDQBertSelfAttention(config)
+ self.output = QDQBertSelfOutput(config)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
+ )
+
+ # Prune linear layers
+ self.self.query = prune_linear_layer(self.self.query, index)
+ self.self.key = prune_linear_layer(self.self.key, index)
+ self.self.value = prune_linear_layer(self.self.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=None,
+ output_attentions=False,
+ ):
+ self_outputs = self.self(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ )
+ attention_output = self.output(self_outputs[0], hidden_states)
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+ return outputs
+
+
+class QDQBertIntermediate(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ # Quantize Linear layer
+ self.dense = quant_nn.QuantLinear(config.hidden_size, config.intermediate_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ return hidden_states
+
+
+class QDQBertOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ # Quantize Linear layer
+ self.dense = quant_nn.QuantLinear(config.intermediate_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ # Quantize the inputs to the residual add
+ self.add_local_input_quantizer = TensorQuantizer(quant_nn.QuantLinear.default_quant_desc_input)
+ self.add_residual_input_quantizer = TensorQuantizer(quant_nn.QuantLinear.default_quant_desc_input)
+
+ def forward(self, hidden_states, input_tensor):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ # Quantize the inputs to the residual add
+ add_local = self.add_local_input_quantizer(hidden_states)
+ add_residual = self.add_residual_input_quantizer(input_tensor)
+ hidden_states = self.LayerNorm(add_local + add_residual)
+ return hidden_states
+
+
+# Based on transformers.models.bert.modeling_bert.BertLayer with Bert -> QDQBert
+class QDQBertLayer(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.seq_len_dim = 1
+ self.attention = QDQBertAttention(config)
+ self.is_decoder = config.is_decoder
+ self.add_cross_attention = config.add_cross_attention
+ if self.add_cross_attention:
+ if not self.is_decoder:
+ raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
+ self.crossattention = QDQBertAttention(config)
+ self.intermediate = QDQBertIntermediate(config)
+ self.output = QDQBertOutput(config)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=None,
+ output_attentions=False,
+ ):
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
+ self_attention_outputs = self.attention(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ output_attentions=output_attentions,
+ past_key_value=self_attn_past_key_value,
+ )
+ attention_output = self_attention_outputs[0]
+
+ # if decoder, the last output is tuple of self-attn cache
+ if self.is_decoder:
+ outputs = self_attention_outputs[1:-1]
+ present_key_value = self_attention_outputs[-1]
+ else:
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
+
+ cross_attn_present_key_value = None
+ if self.is_decoder and encoder_hidden_states is not None:
+ if not hasattr(self, "crossattention"):
+ raise ValueError(
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
+ " by setting `config.add_cross_attention=True`"
+ )
+
+ # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
+ cross_attention_outputs = self.crossattention(
+ attention_output,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ cross_attn_past_key_value,
+ output_attentions,
+ )
+ attention_output = cross_attention_outputs[0]
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
+
+ # add cross-attn cache to positions 3,4 of present_key_value tuple
+ cross_attn_present_key_value = cross_attention_outputs[-1]
+ present_key_value = present_key_value + cross_attn_present_key_value
+
+ layer_output = self.feed_forward_chunk(attention_output)
+ outputs = (layer_output,) + outputs
+
+ # if decoder, return the attn key/values as the last output
+ if self.is_decoder:
+ outputs = outputs + (present_key_value,)
+
+ return outputs
+
+ def feed_forward_chunk(self, attention_output):
+ intermediate_output = self.intermediate(attention_output)
+ layer_output = self.output(intermediate_output, attention_output)
+ return layer_output
+
+
+# Based on transformers.models.bert.modeling_bert.BertEncoder with Bert -> QDQBert
+class QDQBertEncoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.layer = nn.ModuleList([QDQBertLayer(config) for _ in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_values=None,
+ use_cache=None,
+ output_attentions=False,
+ output_hidden_states=False,
+ return_dict=True,
+ ):
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
+
+ next_decoder_cache = () if use_cache else None
+ for i, layer_module in enumerate(self.layer):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+ past_key_value = past_key_values[i] if past_key_values is not None else None
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+ layer_outputs = self._gradient_checkpointing_func(
+ layer_module.__call__,
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ )
+ else:
+ layer_outputs = layer_module(
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ )
+
+ hidden_states = layer_outputs[0]
+ if use_cache:
+ next_decoder_cache += (layer_outputs[-1],)
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+ if self.config.add_cross_attention:
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [
+ hidden_states,
+ next_decoder_cache,
+ all_hidden_states,
+ all_self_attentions,
+ all_cross_attentions,
+ ]
+ if v is not None
+ )
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=next_decoder_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+class QDQBertPooler(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.activation = nn.Tanh()
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ # We "pool" the model by simply taking the hidden state corresponding
+ # to the first token.
+ first_token_tensor = hidden_states[:, 0]
+ pooled_output = self.dense(first_token_tensor)
+ pooled_output = self.activation(pooled_output)
+ return pooled_output
+
+
+class QDQBertPredictionHeadTransform(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ if isinstance(config.hidden_act, str):
+ self.transform_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.transform_act_fn = config.hidden_act
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.transform_act_fn(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states)
+ return hidden_states
+
+
+# Based on transformers.models.bert.modeling_bert.BertLMPredictionHead with Bert -> QDQBert
+class QDQBertLMPredictionHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.transform = QDQBertPredictionHeadTransform(config)
+
+ # The output weights are the same as the input embeddings, but there is
+ # an output-only bias for each token.
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
+
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
+ self.decoder.bias = self.bias
+
+ def _tie_weights(self):
+ self.decoder.bias = self.bias
+
+ def forward(self, hidden_states):
+ hidden_states = self.transform(hidden_states)
+ hidden_states = self.decoder(hidden_states)
+ return hidden_states
+
+
+# Based on transformers.models.bert.modeling_bert.BertOnlyMLMHead with Bert -> QDQBert
+class QDQBertOnlyMLMHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.predictions = QDQBertLMPredictionHead(config)
+
+ def forward(self, sequence_output):
+ prediction_scores = self.predictions(sequence_output)
+ return prediction_scores
+
+
+class QDQBertOnlyNSPHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
+
+ def forward(self, pooled_output):
+ seq_relationship_score = self.seq_relationship(pooled_output)
+ return seq_relationship_score
+
+
+# Based on transformers.models.bert.modeling_bert.BertPreTrainingHeads with Bert -> QDQBert
+class QDQBertPreTrainingHeads(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.predictions = QDQBertLMPredictionHead(config)
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
+
+ def forward(self, sequence_output, pooled_output):
+ prediction_scores = self.predictions(sequence_output)
+ seq_relationship_score = self.seq_relationship(pooled_output)
+ return prediction_scores, seq_relationship_score
+
+
+# Based on transformers.models.bert.modeling_bert.BertPreTrainedModel with Bert -> QDQBert
+class QDQBertPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = QDQBertConfig
+ load_tf_weights = load_tf_weights_in_qdqbert
+ base_model_prefix = "bert"
+ supports_gradient_checkpointing = True
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, nn.Linear):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+
+QDQBERT_START_DOCSTRING = r"""
+
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`QDQBertConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+QDQBERT_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `({0})`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
+ 1]`:
+
+ - 0 corresponds to a *sentence A* token,
+ - 1 corresponds to a *sentence B* token.
+
+ [What are token type IDs?](../glossary#token-type-ids)
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare QDQBERT Model transformer outputting raw hidden-states without any specific head on top.",
+ QDQBERT_START_DOCSTRING,
+)
+class QDQBertModel(QDQBertPreTrainedModel):
+ """
+
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
+ cross-attention is added between the self-attention layers, following the architecture described in [Attention is
+ all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
+
+ To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
+ to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
+ `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
+ """
+
+ def __init__(self, config, add_pooling_layer: bool = True):
+ requires_backends(self, "pytorch_quantization")
+ super().__init__(config)
+ self.config = config
+
+ self.embeddings = QDQBertEmbeddings(config)
+ self.encoder = QDQBertEncoder(config)
+
+ self.pooler = QDQBertPooler(config) if add_pooling_layer else None
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embeddings.word_embeddings
+
+ def set_input_embeddings(self, value):
+ self.embeddings.word_embeddings = value
+
+ def _prune_heads(self, heads_to_prune: Dict[int, List[int]]):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ @add_start_docstrings_to_model_forward(QDQBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=BaseModelOutputWithPoolingAndCrossAttentions,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPoolingAndCrossAttentions]:
+ r"""
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+ the model is configured as a decoder.
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if self.config.is_decoder:
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ else:
+ use_cache = False
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
+ input_shape = input_ids.size()
+ batch_size, seq_length = input_shape
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ batch_size, seq_length = input_shape
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+
+ # past_key_values_length
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
+
+ if attention_mask is None:
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
+
+ if token_type_ids is None:
+ if hasattr(self.embeddings, "token_type_ids"):
+ buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
+ token_type_ids = buffered_token_type_ids_expanded
+ else:
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
+
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
+
+ # If a 2D or 3D attention mask is provided for the cross-attention
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if self.config.is_decoder and encoder_hidden_states is not None:
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
+ if encoder_attention_mask is None:
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+ else:
+ encoder_extended_attention_mask = None
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ embedding_output = self.embeddings(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ token_type_ids=token_type_ids,
+ inputs_embeds=inputs_embeds,
+ past_key_values_length=past_key_values_length,
+ )
+ encoder_outputs = self.encoder(
+ embedding_output,
+ attention_mask=extended_attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_extended_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = encoder_outputs[0]
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
+
+ if not return_dict:
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPoolingAndCrossAttentions(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ past_key_values=encoder_outputs.past_key_values,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ cross_attentions=encoder_outputs.cross_attentions,
+ )
+
+
+@add_start_docstrings(
+ """QDQBERT Model with a `language modeling` head on top for CLM fine-tuning.""", QDQBERT_START_DOCSTRING
+)
+class QDQBertLMHeadModel(QDQBertPreTrainedModel):
+ _tied_weights_keys = ["predictions.decoder.weight", "predictions.decoder.bias"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ if not config.is_decoder:
+ logger.warning("If you want to use `QDQBertLMHeadModel` as a standalone, add `is_decoder=True.`")
+
+ self.bert = QDQBertModel(config, add_pooling_layer=False)
+ self.cls = QDQBertOnlyMLMHead(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_output_embeddings(self):
+ return self.cls.predictions.decoder
+
+ def set_output_embeddings(self, new_embeddings):
+ self.cls.predictions.decoder = new_embeddings
+ self.cls.predictions.bias = new_embeddings.bias
+
+ @add_start_docstrings_to_model_forward(QDQBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.LongTensor]]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
+ r"""
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+ the model is configured as a decoder.
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
+ `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
+ ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, QDQBertLMHeadModel, QDQBertConfig
+ >>> import torch
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-cased")
+ >>> config = QDQBertConfig.from_pretrained("google-bert/bert-base-cased")
+ >>> config.is_decoder = True
+ >>> model = QDQBertLMHeadModel.from_pretrained("google-bert/bert-base-cased", config=config)
+
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
+ >>> outputs = model(**inputs)
+
+ >>> prediction_logits = outputs.logits
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ if labels is not None:
+ use_cache = False
+
+ outputs = self.bert(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+ prediction_scores = self.cls(sequence_output)
+
+ lm_loss = None
+ if labels is not None:
+ # we are doing next-token prediction; shift prediction scores and input ids by one
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
+ labels = labels[:, 1:].contiguous()
+ loss_fct = CrossEntropyLoss()
+ lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
+
+ if not return_dict:
+ output = (prediction_scores,) + outputs[2:]
+ return ((lm_loss,) + output) if lm_loss is not None else output
+
+ return CausalLMOutputWithCrossAttentions(
+ loss=lm_loss,
+ logits=prediction_scores,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ cross_attentions=outputs.cross_attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids: Optional[torch.LongTensor],
+ past_key_values=None,
+ attention_mask: Optional[torch.Tensor] = None,
+ **model_kwargs,
+ ):
+ input_shape = input_ids.shape
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
+ if attention_mask is None:
+ attention_mask = input_ids.new_ones(input_shape)
+
+ # cut decoder_input_ids if past_key_values is used
+ if past_key_values is not None:
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = input_ids.shape[1] - 1
+
+ input_ids = input_ids[:, remove_prefix_length:]
+
+ return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
+
+ def _reorder_cache(self, past_key_values, beam_idx):
+ reordered_past = ()
+ for layer_past in past_key_values:
+ reordered_past += (
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
+ )
+ return reordered_past
+
+
+@add_start_docstrings("""QDQBERT Model with a `language modeling` head on top.""", QDQBERT_START_DOCSTRING)
+class QDQBertForMaskedLM(QDQBertPreTrainedModel):
+ _tied_weights_keys = ["predictions.decoder.weight", "predictions.decoder.bias"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ if config.is_decoder:
+ logger.warning(
+ "If you want to use `QDQBertForMaskedLM` make sure `config.is_decoder=False` for "
+ "bi-directional self-attention."
+ )
+
+ self.bert = QDQBertModel(config, add_pooling_layer=False)
+ self.cls = QDQBertOnlyMLMHead(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_output_embeddings(self):
+ return self.cls.predictions.decoder
+
+ def set_output_embeddings(self, new_embeddings):
+ self.cls.predictions.decoder = new_embeddings
+ self.cls.predictions.bias = new_embeddings.bias
+
+ @add_start_docstrings_to_model_forward(QDQBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=MaskedLMOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, MaskedLMOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
+ """
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.bert(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+ prediction_scores = self.cls(sequence_output)
+
+ masked_lm_loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
+
+ if not return_dict:
+ output = (prediction_scores,) + outputs[2:]
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
+
+ return MaskedLMOutput(
+ loss=masked_lm_loss,
+ logits=prediction_scores,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self, input_ids: torch.LongTensor, attention_mask: Optional[torch.FloatTensor] = None, **model_kwargs
+ ):
+ input_shape = input_ids.shape
+ effective_batch_size = input_shape[0]
+
+ # add a dummy token
+ if self.config.pad_token_id is None:
+ raise ValueError("The PAD token should be defined for generation")
+
+ attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
+ dummy_token = torch.full(
+ (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
+ )
+ input_ids = torch.cat([input_ids, dummy_token], dim=1)
+
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
+
+
+@add_start_docstrings(
+ """Bert Model with a `next sentence prediction (classification)` head on top.""",
+ QDQBERT_START_DOCSTRING,
+)
+class QDQBertForNextSentencePrediction(QDQBertPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.bert = QDQBertModel(config)
+ self.cls = QDQBertOnlyNSPHead(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(QDQBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ **kwargs,
+ ) -> Union[Tuple, NextSentencePredictorOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
+ (see `input_ids` docstring). Indices should be in `[0, 1]`:
+
+ - 0 indicates sequence B is a continuation of sequence A,
+ - 1 indicates sequence B is a random sequence.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, QDQBertForNextSentencePrediction
+ >>> import torch
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
+ >>> model = QDQBertForNextSentencePrediction.from_pretrained("google-bert/bert-base-uncased")
+
+ >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
+ >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
+ >>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt")
+
+ >>> outputs = model(**encoding, labels=torch.LongTensor([1]))
+ >>> logits = outputs.logits
+ >>> assert logits[0, 0] < logits[0, 1] # next sentence was random
+ ```"""
+
+ if "next_sentence_label" in kwargs:
+ warnings.warn(
+ "The `next_sentence_label` argument is deprecated and will be removed in a future version, use"
+ " `labels` instead.",
+ FutureWarning,
+ )
+ labels = kwargs.pop("next_sentence_label")
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.bert(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ pooled_output = outputs[1]
+
+ seq_relationship_scores = self.cls(pooled_output)
+
+ next_sentence_loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1))
+
+ if not return_dict:
+ output = (seq_relationship_scores,) + outputs[2:]
+ return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output
+
+ return NextSentencePredictorOutput(
+ loss=next_sentence_loss,
+ logits=seq_relationship_scores,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
+ output) e.g. for GLUE tasks.
+ """,
+ QDQBERT_START_DOCSTRING,
+)
+class QDQBertForSequenceClassification(QDQBertPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.config = config
+
+ self.bert = QDQBertModel(config)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(QDQBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=SequenceClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, SequenceClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.bert(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ pooled_output = outputs[1]
+
+ pooled_output = self.dropout(pooled_output)
+ logits = self.classifier(pooled_output)
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
+ softmax) e.g. for RocStories/SWAG tasks.
+ """,
+ QDQBERT_START_DOCSTRING,
+)
+class QDQBertForMultipleChoice(QDQBertPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.bert = QDQBertModel(config)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ self.classifier = nn.Linear(config.hidden_size, 1)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(QDQBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=MultipleChoiceModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, MultipleChoiceModelOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
+ num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
+ `input_ids` above)
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
+
+ input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
+ token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
+ position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
+ inputs_embeds = (
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
+ if inputs_embeds is not None
+ else None
+ )
+
+ outputs = self.bert(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ pooled_output = outputs[1]
+
+ pooled_output = self.dropout(pooled_output)
+ logits = self.classifier(pooled_output)
+ reshaped_logits = logits.view(-1, num_choices)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(reshaped_logits, labels)
+
+ if not return_dict:
+ output = (reshaped_logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return MultipleChoiceModelOutput(
+ loss=loss,
+ logits=reshaped_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ QDQBERT Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
+ Named-Entity-Recognition (NER) tasks.
+ """,
+ QDQBERT_START_DOCSTRING,
+)
+class QDQBertForTokenClassification(QDQBertPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.bert = QDQBertModel(config, add_pooling_layer=False)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(QDQBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TokenClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, TokenClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.bert(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ sequence_output = self.dropout(sequence_output)
+ logits = self.classifier(sequence_output)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ QDQBERT Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
+ layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
+ """,
+ QDQBERT_START_DOCSTRING,
+)
+class QDQBertForQuestionAnswering(QDQBertPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.bert = QDQBertModel(config, add_pooling_layer=False)
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(QDQBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=QuestionAnsweringModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ start_positions: Optional[torch.LongTensor] = None,
+ end_positions: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, QuestionAnsweringModelOutput]:
+ r"""
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
+ are not taken into account for computing the loss.
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
+ are not taken into account for computing the loss.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.bert(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ logits = self.qa_outputs(sequence_output)
+ start_logits, end_logits = logits.split(1, dim=-1)
+ start_logits = start_logits.squeeze(-1).contiguous()
+ end_logits = end_logits.squeeze(-1).contiguous()
+
+ total_loss = None
+ if start_positions is not None and end_positions is not None:
+ # If we are on multi-GPU, split add a dimension
+ if len(start_positions.size()) > 1:
+ start_positions = start_positions.squeeze(-1)
+ if len(end_positions.size()) > 1:
+ end_positions = end_positions.squeeze(-1)
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
+ ignored_index = start_logits.size(1)
+ start_positions = start_positions.clamp(0, ignored_index)
+ end_positions = end_positions.clamp(0, ignored_index)
+
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
+ start_loss = loss_fct(start_logits, start_positions)
+ end_loss = loss_fct(end_logits, end_positions)
+ total_loss = (start_loss + end_loss) / 2
+
+ if not return_dict:
+ output = (start_logits, end_logits) + outputs[2:]
+ return ((total_loss,) + output) if total_loss is not None else output
+
+ return QuestionAnsweringModelOutput(
+ loss=total_loss,
+ start_logits=start_logits,
+ end_logits=end_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+__all__ = [
+ "QDQBertForMaskedLM",
+ "QDQBertForMultipleChoice",
+ "QDQBertForNextSentencePrediction",
+ "QDQBertForQuestionAnswering",
+ "QDQBertForSequenceClassification",
+ "QDQBertForTokenClassification",
+ "QDQBertLayer",
+ "QDQBertLMHeadModel",
+ "QDQBertModel",
+ "QDQBertPreTrainedModel",
+ "load_tf_weights_in_qdqbert",
+]
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/realm/__init__.py b/docs/transformers/build/lib/transformers/models/deprecated/realm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..cdfdeb5d179c8e954c9c6822d3f154d632394964
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/realm/__init__.py
@@ -0,0 +1,30 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ....utils import _LazyModule
+from ....utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_realm import *
+ from .modeling_realm import *
+ from .retrieval_realm import *
+ from .tokenization_realm import *
+ from .tokenization_realm_fast import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/realm/configuration_realm.py b/docs/transformers/build/lib/transformers/models/deprecated/realm/configuration_realm.py
new file mode 100644
index 0000000000000000000000000000000000000000..fbf32378a604beca939ff6e0241f69f4c4382120
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/realm/configuration_realm.py
@@ -0,0 +1,169 @@
+# coding=utf-8
+# Copyright 2022 The REALM authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""REALM model configuration."""
+
+from ....configuration_utils import PretrainedConfig
+from ....utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class RealmConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of
+
+ 1. [`RealmEmbedder`]
+ 2. [`RealmScorer`]
+ 3. [`RealmKnowledgeAugEncoder`]
+ 4. [`RealmRetriever`]
+ 5. [`RealmReader`]
+ 6. [`RealmForOpenQA`]
+
+ It is used to instantiate an REALM model according to the specified arguments, defining the model architecture.
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the REALM
+ [google/realm-cc-news-pretrained-embedder](https://huggingface.co/google/realm-cc-news-pretrained-embedder)
+ architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 30522):
+ Vocabulary size of the REALM model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`RealmEmbedder`], [`RealmScorer`], [`RealmKnowledgeAugEncoder`], or
+ [`RealmReader`].
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimension of the encoder layers and the pooler layer.
+ retriever_proj_size (`int`, *optional*, defaults to 128):
+ Dimension of the retriever(embedder) projection.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ num_candidates (`int`, *optional*, defaults to 8):
+ Number of candidates inputted to the RealmScorer or RealmKnowledgeAugEncoder.
+ intermediate_size (`int`, *optional*, defaults to 3072):
+ Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu_new"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for the attention probabilities.
+ max_position_embeddings (`int`, *optional*, defaults to 512):
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
+ just in case (e.g., 512 or 1024 or 2048).
+ type_vocab_size (`int`, *optional*, defaults to 2):
+ The vocabulary size of the `token_type_ids` passed when calling [`RealmEmbedder`], [`RealmScorer`],
+ [`RealmKnowledgeAugEncoder`], or [`RealmReader`].
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
+ The epsilon used by the layer normalization layers.
+ span_hidden_size (`int`, *optional*, defaults to 256):
+ Dimension of the reader's spans.
+ max_span_width (`int`, *optional*, defaults to 10):
+ Max span width of the reader.
+ reader_layer_norm_eps (`float`, *optional*, defaults to 1e-3):
+ The epsilon used by the reader's layer normalization layers.
+ reader_beam_size (`int`, *optional*, defaults to 5):
+ Beam size of the reader.
+ reader_seq_len (`int`, *optional*, defaults to 288+32):
+ Maximum sequence length of the reader.
+ num_block_records (`int`, *optional*, defaults to 13353718):
+ Number of block records.
+ searcher_beam_size (`int`, *optional*, defaults to 5000):
+ Beam size of the searcher. Note that when eval mode is enabled, *searcher_beam_size* will be the same as
+ *reader_beam_size*.
+
+ Example:
+
+ ```python
+ >>> from transformers import RealmConfig, RealmEmbedder
+
+ >>> # Initializing a REALM realm-cc-news-pretrained-* style configuration
+ >>> configuration = RealmConfig()
+
+ >>> # Initializing a model (with random weights) from the google/realm-cc-news-pretrained-embedder style configuration
+ >>> model = RealmEmbedder(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "realm"
+
+ def __init__(
+ self,
+ vocab_size=30522,
+ hidden_size=768,
+ retriever_proj_size=128,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ num_candidates=8,
+ intermediate_size=3072,
+ hidden_act="gelu_new",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ max_position_embeddings=512,
+ type_vocab_size=2,
+ initializer_range=0.02,
+ layer_norm_eps=1e-12,
+ span_hidden_size=256,
+ max_span_width=10,
+ reader_layer_norm_eps=1e-3,
+ reader_beam_size=5,
+ reader_seq_len=320, # 288 + 32
+ num_block_records=13353718,
+ searcher_beam_size=5000,
+ pad_token_id=1,
+ bos_token_id=0,
+ eos_token_id=2,
+ **kwargs,
+ ):
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
+
+ # Common config
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.retriever_proj_size = retriever_proj_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.num_candidates = num_candidates
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.initializer_range = initializer_range
+ self.type_vocab_size = type_vocab_size
+ self.layer_norm_eps = layer_norm_eps
+
+ # Reader config
+ self.span_hidden_size = span_hidden_size
+ self.max_span_width = max_span_width
+ self.reader_layer_norm_eps = reader_layer_norm_eps
+ self.reader_beam_size = reader_beam_size
+ self.reader_seq_len = reader_seq_len
+
+ # Retrieval config
+ self.num_block_records = num_block_records
+ self.searcher_beam_size = searcher_beam_size
+
+
+__all__ = ["RealmConfig"]
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/realm/modeling_realm.py b/docs/transformers/build/lib/transformers/models/deprecated/realm/modeling_realm.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac25a177333ef4eb5eb18f628adf1235d2084fee
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/realm/modeling_realm.py
@@ -0,0 +1,1862 @@
+# coding=utf-8
+# Copyright 2022 The REALM authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch REALM model."""
+
+import math
+import os
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import torch
+from torch import nn
+from torch.nn import CrossEntropyLoss
+
+from ....activations import ACT2FN
+from ....modeling_outputs import (
+ BaseModelOutputWithPastAndCrossAttentions,
+ BaseModelOutputWithPoolingAndCrossAttentions,
+ MaskedLMOutput,
+ ModelOutput,
+)
+from ....modeling_utils import PreTrainedModel
+from ....pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
+from ....utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
+from .configuration_realm import RealmConfig
+
+
+logger = logging.get_logger(__name__)
+_EMBEDDER_CHECKPOINT_FOR_DOC = "google/realm-cc-news-pretrained-embedder"
+_ENCODER_CHECKPOINT_FOR_DOC = "google/realm-cc-news-pretrained-encoder"
+_SCORER_CHECKPOINT_FOR_DOC = "google/realm-cc-news-pretrained-scorer"
+_CONFIG_FOR_DOC = "RealmConfig"
+
+
+def load_tf_weights_in_realm(model, config, tf_checkpoint_path):
+ """Load tf checkpoints in a pytorch model."""
+ try:
+ import re
+
+ import numpy as np
+ import tensorflow as tf
+ except ImportError:
+ logger.error(
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
+ "https://www.tensorflow.org/install/ for installation instructions."
+ )
+ raise
+ tf_path = os.path.abspath(tf_checkpoint_path)
+ logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
+ # Load weights from TF model
+ init_vars = tf.train.list_variables(tf_path)
+ names = []
+ arrays = []
+
+ for name, shape in init_vars:
+ logger.info(f"Loading TF weight {name} with shape {shape}")
+ array = tf.train.load_variable(tf_path, name)
+ names.append(name)
+ arrays.append(array)
+
+ for name, array in zip(names, arrays):
+ if isinstance(model, RealmReader) and "reader" not in name:
+ logger.info(f"Skipping {name} as it is not {model.__class__.__name__}'s parameter")
+ continue
+
+ # For pretrained openqa reader
+ if (name.startswith("bert") or name.startswith("cls")) and isinstance(model, RealmForOpenQA):
+ name = name.replace("bert/", "reader/realm/")
+ name = name.replace("cls/", "reader/cls/")
+
+ # For pretrained encoder
+ if (name.startswith("bert") or name.startswith("cls")) and isinstance(model, RealmKnowledgeAugEncoder):
+ name = name.replace("bert/", "realm/")
+
+ # For finetuned reader
+ if name.startswith("reader"):
+ reader_prefix = "" if isinstance(model, RealmReader) else "reader/"
+ name = name.replace("reader/module/bert/", f"{reader_prefix}realm/")
+ name = name.replace("reader/module/cls/", f"{reader_prefix}cls/")
+ name = name.replace("reader/dense/", f"{reader_prefix}qa_outputs/dense_intermediate/")
+ name = name.replace("reader/dense_1/", f"{reader_prefix}qa_outputs/dense_output/")
+ name = name.replace("reader/layer_normalization", f"{reader_prefix}qa_outputs/layer_normalization")
+
+ # For embedder and scorer
+ if name.startswith("module/module/module/"): # finetuned
+ embedder_prefix = "" if isinstance(model, RealmEmbedder) else "embedder/"
+ name = name.replace("module/module/module/module/bert/", f"{embedder_prefix}realm/")
+ name = name.replace("module/module/module/LayerNorm/", f"{embedder_prefix}cls/LayerNorm/")
+ name = name.replace("module/module/module/dense/", f"{embedder_prefix}cls/dense/")
+ name = name.replace("module/module/module/module/cls/predictions/", f"{embedder_prefix}cls/predictions/")
+ name = name.replace("module/module/module/bert/", f"{embedder_prefix}realm/")
+ name = name.replace("module/module/module/cls/predictions/", f"{embedder_prefix}cls/predictions/")
+ elif name.startswith("module/module/"): # pretrained
+ embedder_prefix = "" if isinstance(model, RealmEmbedder) else "embedder/"
+ name = name.replace("module/module/LayerNorm/", f"{embedder_prefix}cls/LayerNorm/")
+ name = name.replace("module/module/dense/", f"{embedder_prefix}cls/dense/")
+
+ name = name.split("/")
+ # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
+ # which are not required for using pretrained model
+ if any(
+ n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
+ for n in name
+ ):
+ logger.info(f"Skipping {'/'.join(name)}")
+ continue
+ pointer = model
+ for m_name in name:
+ if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
+ scope_names = re.split(r"_(\d+)", m_name)
+ else:
+ scope_names = [m_name]
+ if scope_names[0] == "kernel" or scope_names[0] == "gamma":
+ pointer = getattr(pointer, "weight")
+ elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
+ pointer = getattr(pointer, "bias")
+ else:
+ try:
+ pointer = getattr(pointer, scope_names[0])
+ except AttributeError:
+ logger.info(f"Skipping {'/'.join(name)}")
+ continue
+ if len(scope_names) >= 2:
+ num = int(scope_names[1])
+ pointer = pointer[num]
+ if m_name[-11:] == "_embeddings":
+ pointer = getattr(pointer, "weight")
+ elif m_name == "kernel":
+ array = np.transpose(array)
+ try:
+ assert pointer.shape == array.shape, (
+ f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
+ )
+ except AssertionError as e:
+ e.args += (pointer.shape, array.shape)
+ raise
+ logger.info(f"Initialize PyTorch weight {name}")
+ pointer.data = torch.from_numpy(array)
+ return model
+
+
+class RealmEmbeddings(nn.Module):
+ """Construct the embeddings from word, position and token_type embeddings."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
+
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
+ # any TensorFlow checkpoint file
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
+ self.register_buffer(
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
+ )
+ self.register_buffer(
+ "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
+ )
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ past_key_values_length: int = 0,
+ ) -> torch.Tensor:
+ if input_ids is not None:
+ input_shape = input_ids.size()
+ else:
+ input_shape = inputs_embeds.size()[:-1]
+
+ seq_length = input_shape[1]
+
+ if position_ids is None:
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
+
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
+ # issue #5664
+ if token_type_ids is None:
+ if hasattr(self, "token_type_ids"):
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
+ token_type_ids = buffered_token_type_ids_expanded
+ else:
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.word_embeddings(input_ids)
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
+
+ embeddings = inputs_embeds + token_type_embeddings
+ if self.position_embedding_type == "absolute":
+ position_embeddings = self.position_embeddings(position_ids)
+ embeddings += position_embeddings
+ embeddings = self.LayerNorm(embeddings)
+ embeddings = self.dropout(embeddings)
+ return embeddings
+
+
+class RealmSelfAttention(nn.Module):
+ def __init__(self, config, position_embedding_type=None):
+ super().__init__()
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+ raise ValueError(
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
+ f"heads ({config.num_attention_heads})"
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+ self.position_embedding_type = position_embedding_type or getattr(
+ config, "position_embedding_type", "absolute"
+ )
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+ self.max_position_embeddings = config.max_position_embeddings
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
+
+ self.is_decoder = config.is_decoder
+
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
+ x = x.view(new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor]:
+ mixed_query_layer = self.query(hidden_states)
+
+ # If this is instantiated as a cross-attention module, the keys
+ # and values come from an encoder; the attention mask needs to be
+ # such that the encoder's padding tokens are not attended to.
+ is_cross_attention = encoder_hidden_states is not None
+
+ if is_cross_attention and past_key_value is not None:
+ # reuse k,v, cross_attentions
+ key_layer = past_key_value[0]
+ value_layer = past_key_value[1]
+ attention_mask = encoder_attention_mask
+ elif is_cross_attention:
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
+ attention_mask = encoder_attention_mask
+ elif past_key_value is not None:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
+ else:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+
+ use_cache = past_key_value is not None
+ if self.is_decoder:
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_layer, value_layer)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+ query_length, key_length = query_layer.shape[2], key_layer.shape[2]
+ if use_cache:
+ position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
+ -1, 1
+ )
+ else:
+ position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
+ position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
+ distance = position_ids_l - position_ids_r
+
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
+
+ if self.position_embedding_type == "relative_key":
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+ attention_scores = attention_scores + relative_position_scores
+ elif self.position_embedding_type == "relative_key_query":
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
+
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+ if attention_mask is not None:
+ # Apply the attention mask is (precomputed for all layers in RealmModel forward() function)
+ attention_scores = attention_scores + attention_mask
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs = attention_probs * head_mask
+
+ context_layer = torch.matmul(attention_probs, value_layer)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(new_context_layer_shape)
+
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+ if self.is_decoder:
+ outputs = outputs + (past_key_value,)
+ return outputs
+
+
+class RealmSelfOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+REALM_SELF_ATTENTION_CLASSES = {
+ "eager": RealmSelfAttention,
+}
+
+
+class RealmAttention(nn.Module):
+ def __init__(self, config, position_embedding_type=None):
+ super().__init__()
+ self.self = REALM_SELF_ATTENTION_CLASSES[config._attn_implementation](
+ config, position_embedding_type=position_embedding_type
+ )
+ self.output = RealmSelfOutput(config)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
+ )
+
+ # Prune linear layers
+ self.self.query = prune_linear_layer(self.self.query, index)
+ self.self.key = prune_linear_layer(self.self.key, index)
+ self.self.value = prune_linear_layer(self.self.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor]:
+ self_outputs = self.self(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ )
+ attention_output = self.output(self_outputs[0], hidden_states)
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+ return outputs
+
+
+class RealmIntermediate(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ return hidden_states
+
+
+class RealmOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class RealmLayer(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.seq_len_dim = 1
+ self.attention = RealmAttention(config)
+ self.is_decoder = config.is_decoder
+ self.add_cross_attention = config.add_cross_attention
+ if self.add_cross_attention:
+ if not self.is_decoder:
+ raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
+ self.crossattention = RealmAttention(config, position_embedding_type="absolute")
+ self.intermediate = RealmIntermediate(config)
+ self.output = RealmOutput(config)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor]:
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
+ self_attention_outputs = self.attention(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ output_attentions=output_attentions,
+ past_key_value=self_attn_past_key_value,
+ )
+ attention_output = self_attention_outputs[0]
+
+ # if decoder, the last output is tuple of self-attn cache
+ if self.is_decoder:
+ outputs = self_attention_outputs[1:-1]
+ present_key_value = self_attention_outputs[-1]
+ else:
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
+
+ cross_attn_present_key_value = None
+ if self.is_decoder and encoder_hidden_states is not None:
+ if not hasattr(self, "crossattention"):
+ raise ValueError(
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
+ " by setting `config.add_cross_attention=True`"
+ )
+
+ # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
+ cross_attention_outputs = self.crossattention(
+ attention_output,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ cross_attn_past_key_value,
+ output_attentions,
+ )
+ attention_output = cross_attention_outputs[0]
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
+
+ # add cross-attn cache to positions 3,4 of present_key_value tuple
+ cross_attn_present_key_value = cross_attention_outputs[-1]
+ present_key_value = present_key_value + cross_attn_present_key_value
+
+ layer_output = apply_chunking_to_forward(
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
+ )
+ outputs = (layer_output,) + outputs
+
+ # if decoder, return the attn key/values as the last output
+ if self.is_decoder:
+ outputs = outputs + (present_key_value,)
+
+ return outputs
+
+ def feed_forward_chunk(self, attention_output):
+ intermediate_output = self.intermediate(attention_output)
+ layer_output = self.output(intermediate_output, attention_output)
+ return layer_output
+
+
+class RealmEncoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.layer = nn.ModuleList([RealmLayer(config) for _ in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = False,
+ output_hidden_states: Optional[bool] = False,
+ return_dict: Optional[bool] = True,
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ next_decoder_cache = () if use_cache else None
+ for i, layer_module in enumerate(self.layer):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+ past_key_value = past_key_values[i] if past_key_values is not None else None
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ layer_module.__call__,
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ )
+ else:
+ layer_outputs = layer_module(
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ )
+
+ hidden_states = layer_outputs[0]
+ if use_cache:
+ next_decoder_cache += (layer_outputs[-1],)
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+ if self.config.add_cross_attention:
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [
+ hidden_states,
+ next_decoder_cache,
+ all_hidden_states,
+ all_self_attentions,
+ all_cross_attentions,
+ ]
+ if v is not None
+ )
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=next_decoder_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+class RealmPooler(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.activation = nn.Tanh()
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ # We "pool" the model by simply taking the hidden state corresponding
+ # to the first token.
+ first_token_tensor = hidden_states[:, 0]
+ pooled_output = self.dense(first_token_tensor)
+ pooled_output = self.activation(pooled_output)
+ return pooled_output
+
+
+@dataclass
+class RealmEmbedderOutput(ModelOutput):
+ """
+ Outputs of [`RealmEmbedder`] models.
+
+ Args:
+ projected_score (`torch.FloatTensor` of shape `(batch_size, config.retriever_proj_size)`):
+
+ Projected score.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
+ shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ projected_score: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+
+
+@dataclass
+class RealmScorerOutput(ModelOutput):
+ """
+ Outputs of [`RealmScorer`] models.
+
+ Args:
+ relevance_score (`torch.FloatTensor` of shape `(batch_size, config.num_candidates)`):
+ The relevance score of document candidates (before softmax).
+ query_score (`torch.FloatTensor` of shape `(batch_size, config.retriever_proj_size)`):
+ Query score derived from the query embedder.
+ candidate_score (`torch.FloatTensor` of shape `(batch_size, config.num_candidates, config.retriever_proj_size)`):
+ Candidate score derived from the embedder.
+ """
+
+ relevance_score: Optional[torch.FloatTensor] = None
+ query_score: Optional[torch.FloatTensor] = None
+ candidate_score: Optional[torch.FloatTensor] = None
+
+
+@dataclass
+class RealmReaderOutput(ModelOutput):
+ """
+ Outputs of [`RealmReader`] models.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `start_positions`, `end_positions`, `has_answers` are provided):
+ Total loss.
+ retriever_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `start_positions`, `end_positions`, `has_answers` are provided):
+ Retriever loss.
+ reader_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `start_positions`, `end_positions`, `has_answers` are provided):
+ Reader loss.
+ retriever_correct (`torch.BoolTensor` of shape `(config.searcher_beam_size,)`, *optional*):
+ Whether or not an evidence block contains answer.
+ reader_correct (`torch.BoolTensor` of shape `(config.reader_beam_size, num_candidates)`, *optional*):
+ Whether or not a span candidate contains answer.
+ block_idx (`torch.LongTensor` of shape `()`):
+ The index of the retrieved evidence block in which the predicted answer is most likely.
+ candidate (`torch.LongTensor` of shape `()`):
+ The index of the retrieved span candidates in which the predicted answer is most likely.
+ start_pos (`torch.IntTensor` of shape `()`):
+ Predicted answer starting position in *RealmReader*'s inputs.
+ end_pos (`torch.IntTensor` of shape `()`):
+ Predicted answer ending position in *RealmReader*'s inputs.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
+ shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ retriever_loss: Optional[torch.FloatTensor] = None
+ reader_loss: Optional[torch.FloatTensor] = None
+ retriever_correct: torch.BoolTensor = None
+ reader_correct: torch.BoolTensor = None
+ block_idx: Optional[torch.LongTensor] = None
+ candidate: Optional[torch.LongTensor] = None
+ start_pos: torch.int32 = None
+ end_pos: torch.int32 = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+
+
+@dataclass
+class RealmForOpenQAOutput(ModelOutput):
+ """
+
+ Outputs of [`RealmForOpenQA`] models.
+
+ Args:
+ reader_output (`dict`):
+ Reader output.
+ predicted_answer_ids (`torch.LongTensor` of shape `(answer_sequence_length)`):
+ Predicted answer ids.
+ """
+
+ reader_output: dict = None
+ predicted_answer_ids: Optional[torch.LongTensor] = None
+
+
+class RealmPredictionHeadTransform(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ if isinstance(config.hidden_act, str):
+ self.transform_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.transform_act_fn = config.hidden_act
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ def forward(self, hidden_states):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.transform_act_fn(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states)
+ return hidden_states
+
+
+class RealmLMPredictionHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.transform = RealmPredictionHeadTransform(config)
+
+ # The output weights are the same as the input embeddings, but there is
+ # an output-only bias for each token.
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
+
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
+ self.decoder.bias = self.bias
+
+ def _tie_weights(self):
+ self.decoder.bias = self.bias
+
+ def forward(self, hidden_states):
+ hidden_states = self.transform(hidden_states)
+ hidden_states = self.decoder(hidden_states)
+ return hidden_states
+
+
+class RealmOnlyMLMHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.predictions = RealmLMPredictionHead(config)
+
+ def forward(self, sequence_output):
+ prediction_scores = self.predictions(sequence_output)
+ return prediction_scores
+
+
+class RealmScorerProjection(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.predictions = RealmLMPredictionHead(config)
+ self.dense = nn.Linear(config.hidden_size, config.retriever_proj_size)
+ self.LayerNorm = nn.LayerNorm(config.retriever_proj_size, eps=config.layer_norm_eps)
+
+ def forward(self, hidden_states):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states)
+ return hidden_states
+
+
+class RealmReaderProjection(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.dense_intermediate = nn.Linear(config.hidden_size, config.span_hidden_size * 2)
+ self.dense_output = nn.Linear(config.span_hidden_size, 1)
+ self.layer_normalization = nn.LayerNorm(config.span_hidden_size, eps=config.reader_layer_norm_eps)
+ self.relu = nn.ReLU()
+
+ def forward(self, hidden_states, block_mask):
+ def span_candidates(masks):
+ """
+ Generate span candidates.
+
+ Args:
+ masks: [num_retrievals, max_sequence_len]
+
+ Returns:
+ starts: [num_spans] ends: [num_spans] span_masks: [num_retrievals, num_spans]
+ whether spans locate in evidence block.
+ """
+ _, max_sequence_len = masks.shape
+
+ def _spans_given_width(width):
+ current_starts = torch.arange(max_sequence_len - width + 1, device=masks.device)
+ current_ends = torch.arange(width - 1, max_sequence_len, device=masks.device)
+ return current_starts, current_ends
+
+ starts, ends = zip(*(_spans_given_width(w + 1) for w in range(self.config.max_span_width)))
+
+ # [num_spans]
+ starts = torch.cat(starts, 0)
+ ends = torch.cat(ends, 0)
+
+ # [num_retrievals, num_spans]
+ start_masks = torch.index_select(masks, dim=-1, index=starts)
+ end_masks = torch.index_select(masks, dim=-1, index=ends)
+ span_masks = start_masks * end_masks
+
+ return starts, ends, span_masks
+
+ def mask_to_score(mask, dtype=torch.float32):
+ return (1.0 - mask.type(dtype)) * torch.finfo(dtype).min
+
+ # [reader_beam_size, max_sequence_len, span_hidden_size * 2]
+ hidden_states = self.dense_intermediate(hidden_states)
+ # [reader_beam_size, max_sequence_len, span_hidden_size]
+ start_projection, end_projection = hidden_states.chunk(2, dim=-1)
+
+ candidate_starts, candidate_ends, candidate_mask = span_candidates(block_mask)
+
+ candidate_start_projections = torch.index_select(start_projection, dim=1, index=candidate_starts)
+ candidate_end_projections = torch.index_select(end_projection, dim=1, index=candidate_ends)
+ candidate_hidden = candidate_start_projections + candidate_end_projections
+
+ # [reader_beam_size, num_candidates, span_hidden_size]
+ candidate_hidden = self.relu(candidate_hidden)
+ # [reader_beam_size, num_candidates, span_hidden_size]
+ candidate_hidden = self.layer_normalization(candidate_hidden)
+ # [reader_beam_size, num_candidates]
+ reader_logits = self.dense_output(candidate_hidden).squeeze(-1)
+ # [reader_beam_size, num_candidates]
+ reader_logits += mask_to_score(candidate_mask, dtype=reader_logits.dtype)
+
+ return reader_logits, candidate_starts, candidate_ends
+
+
+REALM_START_DOCSTRING = r"""
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
+ it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+ behavior.
+
+ Parameters:
+ config ([`RealmConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+REALM_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `({0})`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
+ 1]`:
+
+ - 0 corresponds to a *sentence A* token,
+ - 1 corresponds to a *sentence B* token.
+
+ [What are token type IDs?](../glossary#token-type-ids)
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+class RealmPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = RealmConfig
+ load_tf_weights = load_tf_weights_in_realm
+ base_model_prefix = "realm"
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, nn.Linear):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+ def _flatten_inputs(self, *inputs):
+ """Flatten inputs' shape to (-1, input_shape[-1])"""
+ flattened_inputs = []
+ for tensor in inputs:
+ if tensor is None:
+ flattened_inputs.append(None)
+ else:
+ input_shape = tensor.shape
+ if len(input_shape) > 2:
+ tensor = tensor.view((-1, input_shape[-1]))
+ flattened_inputs.append(tensor)
+ return flattened_inputs
+
+
+class RealmBertModel(RealmPreTrainedModel):
+ """
+ Same as the original BertModel but remove docstrings.
+ """
+
+ def __init__(self, config, add_pooling_layer=True):
+ super().__init__(config)
+ self.config = config
+
+ self.embeddings = RealmEmbeddings(config)
+ self.encoder = RealmEncoder(config)
+
+ self.pooler = RealmPooler(config) if add_pooling_layer else None
+
+ # Weights initialization is mostly managed by other Realm models,
+ # but we also have them initialized here to keep a consistency.
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embeddings.word_embeddings
+
+ def set_input_embeddings(self, value):
+ self.embeddings.word_embeddings = value
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_values=None,
+ use_cache=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if self.config.is_decoder:
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ else:
+ use_cache = False
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
+ input_shape = input_ids.size()
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ batch_size, seq_length = input_shape
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+
+ # past_key_values_length
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
+
+ if attention_mask is None:
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
+
+ if token_type_ids is None:
+ if hasattr(self.embeddings, "token_type_ids"):
+ buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
+ token_type_ids = buffered_token_type_ids_expanded
+ else:
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
+
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
+
+ # If a 2D or 3D attention mask is provided for the cross-attention
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if self.config.is_decoder and encoder_hidden_states is not None:
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
+ if encoder_attention_mask is None:
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+ else:
+ encoder_extended_attention_mask = None
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ embedding_output = self.embeddings(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ token_type_ids=token_type_ids,
+ inputs_embeds=inputs_embeds,
+ past_key_values_length=past_key_values_length,
+ )
+ encoder_outputs = self.encoder(
+ embedding_output,
+ attention_mask=extended_attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_extended_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = encoder_outputs[0]
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
+
+ if not return_dict:
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPoolingAndCrossAttentions(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ past_key_values=encoder_outputs.past_key_values,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ cross_attentions=encoder_outputs.cross_attentions,
+ )
+
+
+@add_start_docstrings(
+ "The embedder of REALM outputting projected score that will be used to calculate relevance score.",
+ REALM_START_DOCSTRING,
+)
+class RealmEmbedder(RealmPreTrainedModel):
+ _tied_weights_keys = ["cls.predictions.decoder.bias"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.realm = RealmBertModel(self.config)
+ self.cls = RealmScorerProjection(self.config)
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.realm.embeddings.word_embeddings
+
+ def set_input_embeddings(self, value):
+ self.realm.embeddings.word_embeddings = value
+
+ @add_start_docstrings_to_model_forward(REALM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @replace_return_docstrings(output_type=RealmEmbedderOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, RealmEmbedderOutput]:
+ r"""
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, RealmEmbedder
+ >>> import torch
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/realm-cc-news-pretrained-embedder")
+ >>> model = RealmEmbedder.from_pretrained("google/realm-cc-news-pretrained-embedder")
+
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
+ >>> outputs = model(**inputs)
+
+ >>> projected_score = outputs.projected_score
+ ```
+ """
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ realm_outputs = self.realm(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ # [batch_size, hidden_size]
+ pooler_output = realm_outputs[1]
+ # [batch_size, retriever_proj_size]
+ projected_score = self.cls(pooler_output)
+
+ if not return_dict:
+ return (projected_score,) + realm_outputs[2:4]
+ else:
+ return RealmEmbedderOutput(
+ projected_score=projected_score,
+ hidden_states=realm_outputs.hidden_states,
+ attentions=realm_outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ "The scorer of REALM outputting relevance scores representing the score of document candidates (before softmax).",
+ REALM_START_DOCSTRING,
+)
+class RealmScorer(RealmPreTrainedModel):
+ r"""
+ Args:
+ query_embedder ([`RealmEmbedder`]):
+ Embedder for input sequences. If not specified, it will use the same embedder as candidate sequences.
+ """
+
+ def __init__(self, config, query_embedder=None):
+ super().__init__(config)
+
+ self.embedder = RealmEmbedder(self.config)
+
+ self.query_embedder = query_embedder if query_embedder is not None else self.embedder
+
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(REALM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @replace_return_docstrings(output_type=RealmScorerOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ candidate_input_ids: Optional[torch.LongTensor] = None,
+ candidate_attention_mask: Optional[torch.FloatTensor] = None,
+ candidate_token_type_ids: Optional[torch.LongTensor] = None,
+ candidate_inputs_embeds: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, RealmScorerOutput]:
+ r"""
+ candidate_input_ids (`torch.LongTensor` of shape `(batch_size, num_candidates, sequence_length)`):
+ Indices of candidate input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ candidate_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_candidates, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ candidate_token_type_ids (`torch.LongTensor` of shape `(batch_size, num_candidates, sequence_length)`, *optional*):
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
+ 1]`:
+
+ - 0 corresponds to a *sentence A* token,
+ - 1 corresponds to a *sentence B* token.
+
+ [What are token type IDs?](../glossary#token-type-ids)
+ candidate_inputs_embeds (`torch.FloatTensor` of shape `(batch_size * num_candidates, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `candidate_input_ids` you can choose to directly pass an embedded
+ representation. This is useful if you want more control over how to convert *candidate_input_ids* indices
+ into associated vectors than the model's internal embedding lookup matrix.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> import torch
+ >>> from transformers import AutoTokenizer, RealmScorer
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/realm-cc-news-pretrained-scorer")
+ >>> model = RealmScorer.from_pretrained("google/realm-cc-news-pretrained-scorer", num_candidates=2)
+
+ >>> # batch_size = 2, num_candidates = 2
+ >>> input_texts = ["How are you?", "What is the item in the picture?"]
+ >>> candidates_texts = [["Hello world!", "Nice to meet you!"], ["A cute cat.", "An adorable dog."]]
+
+ >>> inputs = tokenizer(input_texts, return_tensors="pt")
+ >>> candidates_inputs = tokenizer.batch_encode_candidates(candidates_texts, max_length=10, return_tensors="pt")
+
+ >>> outputs = model(
+ ... **inputs,
+ ... candidate_input_ids=candidates_inputs.input_ids,
+ ... candidate_attention_mask=candidates_inputs.attention_mask,
+ ... candidate_token_type_ids=candidates_inputs.token_type_ids,
+ ... )
+ >>> relevance_score = outputs.relevance_score
+ ```"""
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is None and inputs_embeds is None:
+ raise ValueError("You have to specify either input_ids or input_embeds.")
+
+ if candidate_input_ids is None and candidate_inputs_embeds is None:
+ raise ValueError("You have to specify either candidate_input_ids or candidate_inputs_embeds.")
+
+ query_outputs = self.query_embedder(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ # [batch_size * num_candidates, candidate_seq_len]
+ (flattened_input_ids, flattened_attention_mask, flattened_token_type_ids) = self._flatten_inputs(
+ candidate_input_ids, candidate_attention_mask, candidate_token_type_ids
+ )
+
+ candidate_outputs = self.embedder(
+ flattened_input_ids,
+ attention_mask=flattened_attention_mask,
+ token_type_ids=flattened_token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=candidate_inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ # [batch_size, retriever_proj_size]
+ query_score = query_outputs[0]
+ # [batch_size * num_candidates, retriever_proj_size]
+ candidate_score = candidate_outputs[0]
+ # [batch_size, num_candidates, retriever_proj_size]
+ candidate_score = candidate_score.view(-1, self.config.num_candidates, self.config.retriever_proj_size)
+ # [batch_size, num_candidates]
+ relevance_score = torch.einsum("bd,bnd->bn", query_score, candidate_score)
+
+ if not return_dict:
+ return relevance_score, query_score, candidate_score
+
+ return RealmScorerOutput(
+ relevance_score=relevance_score, query_score=query_score, candidate_score=candidate_score
+ )
+
+
+@add_start_docstrings(
+ "The knowledge-augmented encoder of REALM outputting masked language model logits and marginal log-likelihood"
+ " loss.",
+ REALM_START_DOCSTRING,
+)
+class RealmKnowledgeAugEncoder(RealmPreTrainedModel):
+ _tied_weights_keys = ["cls.predictions.decoder"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.realm = RealmBertModel(self.config)
+ self.cls = RealmOnlyMLMHead(self.config)
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.realm.embeddings.word_embeddings
+
+ def set_input_embeddings(self, value):
+ self.realm.embeddings.word_embeddings = value
+
+ def get_output_embeddings(self):
+ return self.cls.predictions.decoder
+
+ def set_output_embeddings(self, new_embeddings):
+ self.cls.predictions.decoder = new_embeddings
+ self.cls.predictions.bias = new_embeddings.bias
+
+ @add_start_docstrings_to_model_forward(
+ REALM_INPUTS_DOCSTRING.format("batch_size, num_candidates, sequence_length")
+ )
+ @replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ relevance_score: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ mlm_mask: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, MaskedLMOutput]:
+ r"""
+ relevance_score (`torch.FloatTensor` of shape `(batch_size, num_candidates)`, *optional*):
+ Relevance score derived from RealmScorer, must be specified if you want to compute the masked language
+ modeling loss.
+
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
+
+ mlm_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid calculating joint loss on certain positions. If not specified, the loss will not be masked.
+ Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> import torch
+ >>> from transformers import AutoTokenizer, RealmKnowledgeAugEncoder
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/realm-cc-news-pretrained-encoder")
+ >>> model = RealmKnowledgeAugEncoder.from_pretrained(
+ ... "google/realm-cc-news-pretrained-encoder", num_candidates=2
+ ... )
+
+ >>> # batch_size = 2, num_candidates = 2
+ >>> text = [["Hello world!", "Nice to meet you!"], ["The cute cat.", "The adorable dog."]]
+
+ >>> inputs = tokenizer.batch_encode_candidates(text, max_length=10, return_tensors="pt")
+ >>> outputs = model(**inputs)
+ >>> logits = outputs.logits
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if labels is not None and relevance_score is None:
+ raise ValueError(
+ "You have to specify `relevance_score` when `labels` is specified in order to compute loss."
+ )
+
+ (flattened_input_ids, flattened_attention_mask, flattened_token_type_ids) = self._flatten_inputs(
+ input_ids, attention_mask, token_type_ids
+ )
+
+ joint_outputs = self.realm(
+ flattened_input_ids,
+ attention_mask=flattened_attention_mask,
+ token_type_ids=flattened_token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ # [batch_size * num_candidates, joint_seq_len, hidden_size]
+ joint_output = joint_outputs[0]
+ # [batch_size * num_candidates, joint_seq_len, vocab_size]
+ prediction_scores = self.cls(joint_output)
+ # [batch_size, num_candidates]
+ candidate_score = relevance_score
+
+ masked_lm_loss = None
+ if labels is not None:
+ batch_size, seq_length = labels.size()
+
+ if mlm_mask is None:
+ mlm_mask = torch.ones_like(labels, dtype=torch.float32)
+ else:
+ mlm_mask = mlm_mask.type(torch.float32)
+
+ # Compute marginal log-likelihood
+ loss_fct = CrossEntropyLoss(reduction="none") # -100 index = padding token
+
+ # [batch_size * num_candidates * joint_seq_len, vocab_size]
+ mlm_logits = prediction_scores.view(-1, self.config.vocab_size)
+ # [batch_size * num_candidates * joint_seq_len]
+ mlm_targets = labels.tile(1, self.config.num_candidates).view(-1)
+ # [batch_size, num_candidates, joint_seq_len]
+ masked_lm_log_prob = -loss_fct(mlm_logits, mlm_targets).view(
+ batch_size, self.config.num_candidates, seq_length
+ )
+ # [batch_size, num_candidates, 1]
+ candidate_log_prob = candidate_score.log_softmax(-1).unsqueeze(-1)
+ # [batch_size, num_candidates, joint_seq_len]
+ joint_gold_log_prob = candidate_log_prob + masked_lm_log_prob
+ # [batch_size, joint_seq_len]
+ marginal_gold_log_probs = joint_gold_log_prob.logsumexp(1)
+ # []
+ masked_lm_loss = -torch.nansum(torch.sum(marginal_gold_log_probs * mlm_mask) / torch.sum(mlm_mask))
+
+ if not return_dict:
+ output = (prediction_scores,) + joint_outputs[2:4]
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
+
+ return MaskedLMOutput(
+ loss=masked_lm_loss,
+ logits=prediction_scores,
+ hidden_states=joint_outputs.hidden_states,
+ attentions=joint_outputs.attentions,
+ )
+
+
+@add_start_docstrings("The reader of REALM.", REALM_START_DOCSTRING)
+class RealmReader(RealmPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.realm = RealmBertModel(config)
+ self.cls = RealmOnlyMLMHead(config)
+ self.qa_outputs = RealmReaderProjection(config)
+
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(REALM_INPUTS_DOCSTRING.format("reader_beam_size, sequence_length"))
+ @replace_return_docstrings(output_type=RealmReaderOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ relevance_score: Optional[torch.FloatTensor] = None,
+ block_mask: Optional[torch.BoolTensor] = None,
+ start_positions: Optional[torch.LongTensor] = None,
+ end_positions: Optional[torch.LongTensor] = None,
+ has_answers: Optional[torch.BoolTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, RealmReaderOutput]:
+ r"""
+ relevance_score (`torch.FloatTensor` of shape `(searcher_beam_size,)`, *optional*):
+ Relevance score, which must be specified if you want to compute the logits and marginal log loss.
+ block_mask (`torch.BoolTensor` of shape `(searcher_beam_size, sequence_length)`, *optional*):
+ The mask of the evidence block, which must be specified if you want to compute the logits and marginal log
+ loss.
+ start_positions (`torch.LongTensor` of shape `(searcher_beam_size,)`, *optional*):
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
+ are not taken into account for computing the loss.
+ end_positions (`torch.LongTensor` of shape `(searcher_beam_size,)`, *optional*):
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
+ are not taken into account for computing the loss.
+ has_answers (`torch.BoolTensor` of shape `(searcher_beam_size,)`, *optional*):
+ Whether or not the evidence block has answer(s).
+
+ Returns:
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if relevance_score is None:
+ raise ValueError("You have to specify `relevance_score` to calculate logits and loss.")
+ if block_mask is None:
+ raise ValueError("You have to specify `block_mask` to separate question block and evidence block.")
+ if token_type_ids.size(1) < self.config.max_span_width:
+ raise ValueError("The input sequence length must be greater than or equal to config.max_span_width.")
+ outputs = self.realm(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ # [reader_beam_size, joint_seq_len, hidden_size]
+ sequence_output = outputs[0]
+
+ # [reader_beam_size, num_candidates], [num_candidates], [num_candidates]
+ reader_logits, candidate_starts, candidate_ends = self.qa_outputs(
+ sequence_output, block_mask[0 : self.config.reader_beam_size]
+ )
+ # [searcher_beam_size, 1]
+ retriever_logits = torch.unsqueeze(relevance_score[0 : self.config.reader_beam_size], -1)
+ # [reader_beam_size, num_candidates]
+ reader_logits += retriever_logits
+ # []
+ predicted_block_index = torch.argmax(torch.max(reader_logits, dim=1).values)
+ # []
+ predicted_candidate = torch.argmax(torch.max(reader_logits, dim=0).values)
+ # [1]
+ predicted_start = torch.index_select(candidate_starts, dim=0, index=predicted_candidate)
+ # [1]
+ predicted_end = torch.index_select(candidate_ends, dim=0, index=predicted_candidate)
+
+ total_loss = None
+ retriever_loss = None
+ reader_loss = None
+ retriever_correct = None
+ reader_correct = None
+ if start_positions is not None and end_positions is not None and has_answers is not None:
+
+ def compute_correct_candidates(candidate_starts, candidate_ends, gold_starts, gold_ends):
+ """Compute correct span."""
+ # [reader_beam_size, num_answers, num_candidates]
+ is_gold_start = torch.eq(
+ torch.unsqueeze(torch.unsqueeze(candidate_starts, 0), 0), torch.unsqueeze(gold_starts, -1)
+ )
+ is_gold_end = torch.eq(
+ torch.unsqueeze(torch.unsqueeze(candidate_ends, 0), 0), torch.unsqueeze(gold_ends, -1)
+ )
+
+ # [reader_beam_size, num_candidates]
+ return torch.any(torch.logical_and(is_gold_start, is_gold_end), 1)
+
+ def marginal_log_loss(logits, is_correct):
+ """Loss based on the negative marginal log-likelihood."""
+
+ def mask_to_score(mask, dtype=torch.float32):
+ return (1.0 - mask.type(dtype)) * torch.finfo(dtype).min
+
+ # []
+ log_numerator = torch.logsumexp(logits + mask_to_score(is_correct, dtype=logits.dtype), dim=-1)
+ log_denominator = torch.logsumexp(logits, dim=-1)
+ return log_denominator - log_numerator
+
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
+ # `-1` is reserved for no answer.
+ ignored_index = sequence_output.size(1)
+ start_positions = start_positions.clamp(-1, ignored_index)
+ end_positions = end_positions.clamp(-1, ignored_index)
+
+ retriever_correct = has_answers
+ any_retriever_correct = torch.any(retriever_correct)
+
+ reader_correct = compute_correct_candidates(
+ candidate_starts=candidate_starts,
+ candidate_ends=candidate_ends,
+ gold_starts=start_positions[0 : self.config.reader_beam_size],
+ gold_ends=end_positions[0 : self.config.reader_beam_size],
+ )
+ any_reader_correct = torch.any(reader_correct)
+
+ retriever_loss = marginal_log_loss(relevance_score, retriever_correct)
+ reader_loss = marginal_log_loss(reader_logits.view(-1), reader_correct.view(-1))
+ retriever_loss *= any_retriever_correct.type(torch.float32)
+ reader_loss *= any_reader_correct.type(torch.float32)
+
+ total_loss = (retriever_loss + reader_loss).mean()
+
+ if not return_dict:
+ output = (predicted_block_index, predicted_candidate, predicted_start, predicted_end) + outputs[2:]
+ return (
+ ((total_loss, retriever_loss, reader_loss, retriever_correct, reader_correct) + output)
+ if total_loss is not None
+ else output
+ )
+
+ return RealmReaderOutput(
+ loss=total_loss,
+ retriever_loss=retriever_loss,
+ reader_loss=reader_loss,
+ retriever_correct=retriever_correct,
+ reader_correct=reader_correct,
+ block_idx=predicted_block_index,
+ candidate=predicted_candidate,
+ start_pos=predicted_start,
+ end_pos=predicted_end,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+REALM_FOR_OPEN_QA_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `({0})`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
+ 1]`:
+
+ - 0 corresponds to a *sentence A* token,
+ - 1 corresponds to a *sentence B* token (should not be used in this model by design).
+
+ [What are token type IDs?](../glossary#token-type-ids)
+ answer_ids (`list` of shape `(num_answers, answer_length)`, *optional*):
+ Answer ids for computing the marginal log-likelihood loss. Indices should be in `[-1, 0, ...,
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-1` are ignored (masked), the
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "`RealmForOpenQA` for end-to-end open domain question answering.",
+ REALM_START_DOCSTRING,
+)
+class RealmForOpenQA(RealmPreTrainedModel):
+ def __init__(self, config, retriever=None):
+ super().__init__(config)
+ self.embedder = RealmEmbedder(config)
+ self.reader = RealmReader(config)
+ self.register_buffer(
+ "block_emb",
+ torch.zeros(()).new_empty(
+ size=(config.num_block_records, config.retriever_proj_size),
+ dtype=torch.float32,
+ device=torch.device("cpu"),
+ ),
+ )
+ self.retriever = retriever
+
+ self.post_init()
+
+ @property
+ def searcher_beam_size(self):
+ if self.training:
+ return self.config.searcher_beam_size
+ return self.config.reader_beam_size
+
+ def block_embedding_to(self, device):
+ """Send `self.block_emb` to a specific device.
+
+ Args:
+ device (`str` or `torch.device`):
+ The device to which `self.block_emb` will be sent.
+ """
+
+ self.block_emb = self.block_emb.to(device)
+
+ @add_start_docstrings_to_model_forward(REALM_FOR_OPEN_QA_DOCSTRING.format("1, sequence_length"))
+ @replace_return_docstrings(output_type=RealmForOpenQAOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor],
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ answer_ids: Optional[torch.LongTensor] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, RealmForOpenQAOutput]:
+ r"""
+ Returns:
+
+ Example:
+
+ ```python
+ >>> import torch
+ >>> from transformers import RealmForOpenQA, RealmRetriever, AutoTokenizer
+
+ >>> retriever = RealmRetriever.from_pretrained("google/realm-orqa-nq-openqa")
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/realm-orqa-nq-openqa")
+ >>> model = RealmForOpenQA.from_pretrained("google/realm-orqa-nq-openqa", retriever=retriever)
+
+ >>> question = "Who is the pioneer in modern computer science?"
+ >>> question_ids = tokenizer([question], return_tensors="pt")
+ >>> answer_ids = tokenizer(
+ ... ["alan mathison turing"],
+ ... add_special_tokens=False,
+ ... return_token_type_ids=False,
+ ... return_attention_mask=False,
+ ... ).input_ids
+
+ >>> reader_output, predicted_answer_ids = model(**question_ids, answer_ids=answer_ids, return_dict=False)
+ >>> predicted_answer = tokenizer.decode(predicted_answer_ids)
+ >>> loss = reader_output.loss
+ ```"""
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is not None and input_ids.shape[0] != 1:
+ raise ValueError("The batch_size of the inputs must be 1.")
+
+ question_outputs = self.embedder(
+ input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, return_dict=True
+ )
+ # [1, projection_size]
+ question_projection = question_outputs[0]
+
+ # CPU computation starts.
+ # [1, block_emb_size]
+ batch_scores = torch.einsum("BD,QD->QB", self.block_emb, question_projection.to(self.block_emb.device))
+ # [1, searcher_beam_size]
+ _, retrieved_block_ids = torch.topk(batch_scores, k=self.searcher_beam_size, dim=-1)
+ # [searcher_beam_size]
+ retrieved_block_ids = retrieved_block_ids.squeeze()
+ # [searcher_beam_size, projection_size]
+ retrieved_block_emb = torch.index_select(self.block_emb, dim=0, index=retrieved_block_ids)
+ # CPU computation ends.
+
+ # Retrieve possible answers
+ has_answers, start_pos, end_pos, concat_inputs = self.retriever(
+ retrieved_block_ids.cpu(), input_ids, answer_ids, max_length=self.config.reader_seq_len
+ )
+
+ concat_inputs = concat_inputs.to(self.reader.device)
+ block_mask = concat_inputs.special_tokens_mask.type(torch.bool).to(device=self.reader.device)
+ block_mask.logical_not_().logical_and_(concat_inputs.token_type_ids.type(torch.bool))
+
+ if has_answers is not None:
+ has_answers = torch.tensor(has_answers, dtype=torch.bool, device=self.reader.device)
+ start_pos = torch.tensor(start_pos, dtype=torch.long, device=self.reader.device)
+ end_pos = torch.tensor(end_pos, dtype=torch.long, device=self.reader.device)
+
+ # [searcher_beam_size]
+ retrieved_logits = torch.einsum(
+ "D,BD->B", question_projection.squeeze(), retrieved_block_emb.to(self.reader.device)
+ )
+
+ reader_output = self.reader(
+ input_ids=concat_inputs.input_ids[0 : self.config.reader_beam_size],
+ attention_mask=concat_inputs.attention_mask[0 : self.config.reader_beam_size],
+ token_type_ids=concat_inputs.token_type_ids[0 : self.config.reader_beam_size],
+ relevance_score=retrieved_logits,
+ block_mask=block_mask,
+ has_answers=has_answers,
+ start_positions=start_pos,
+ end_positions=end_pos,
+ return_dict=True,
+ )
+
+ predicted_block = concat_inputs.input_ids[reader_output.block_idx]
+ predicted_answer_ids = predicted_block[reader_output.start_pos : reader_output.end_pos + 1]
+
+ if not return_dict:
+ return reader_output, predicted_answer_ids
+
+ return RealmForOpenQAOutput(
+ reader_output=reader_output,
+ predicted_answer_ids=predicted_answer_ids,
+ )
+
+
+__all__ = [
+ "RealmEmbedder",
+ "RealmForOpenQA",
+ "RealmKnowledgeAugEncoder",
+ "RealmPreTrainedModel",
+ "RealmReader",
+ "RealmScorer",
+ "load_tf_weights_in_realm",
+]
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/realm/retrieval_realm.py b/docs/transformers/build/lib/transformers/models/deprecated/realm/retrieval_realm.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3c084f1d2090975f8d6ed91665caf21e3e7829c
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/realm/retrieval_realm.py
@@ -0,0 +1,176 @@
+# coding=utf-8
+# Copyright 2022 The REALM authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""REALM Retriever model implementation."""
+
+import os
+from typing import Optional, Union
+
+import numpy as np
+from huggingface_hub import hf_hub_download
+
+from transformers import AutoTokenizer
+
+from ....utils import logging, strtobool
+
+
+_REALM_BLOCK_RECORDS_FILENAME = "block_records.npy"
+
+
+logger = logging.get_logger(__name__)
+
+
+def convert_tfrecord_to_np(block_records_path: str, num_block_records: int) -> np.ndarray:
+ import tensorflow.compat.v1 as tf
+
+ blocks_dataset = tf.data.TFRecordDataset(block_records_path, buffer_size=512 * 1024 * 1024)
+ blocks_dataset = blocks_dataset.batch(num_block_records, drop_remainder=True)
+ np_record = next(blocks_dataset.take(1).as_numpy_iterator())
+
+ return np_record
+
+
+class ScaNNSearcher:
+ """Note that ScaNNSearcher cannot currently be used within the model. In future versions, it might however be included."""
+
+ def __init__(
+ self,
+ db,
+ num_neighbors,
+ dimensions_per_block=2,
+ num_leaves=1000,
+ num_leaves_to_search=100,
+ training_sample_size=100000,
+ ):
+ """Build scann searcher."""
+
+ from scann.scann_ops.py.scann_ops_pybind import builder as Builder
+
+ builder = Builder(db=db, num_neighbors=num_neighbors, distance_measure="dot_product")
+ builder = builder.tree(
+ num_leaves=num_leaves, num_leaves_to_search=num_leaves_to_search, training_sample_size=training_sample_size
+ )
+ builder = builder.score_ah(dimensions_per_block=dimensions_per_block)
+
+ self.searcher = builder.build()
+
+ def search_batched(self, question_projection):
+ retrieved_block_ids, _ = self.searcher.search_batched(question_projection.detach().cpu())
+ return retrieved_block_ids.astype("int64")
+
+
+class RealmRetriever:
+ """The retriever of REALM outputting the retrieved evidence block and whether the block has answers as well as answer
+ positions."
+
+ Parameters:
+ block_records (`np.ndarray`):
+ A numpy array which cantains evidence texts.
+ tokenizer ([`RealmTokenizer`]):
+ The tokenizer to encode retrieved texts.
+ """
+
+ def __init__(self, block_records, tokenizer):
+ super().__init__()
+ self.block_records = block_records
+ self.tokenizer = tokenizer
+
+ def __call__(self, retrieved_block_ids, question_input_ids, answer_ids, max_length=None, return_tensors="pt"):
+ retrieved_blocks = np.take(self.block_records, indices=retrieved_block_ids, axis=0)
+
+ question = self.tokenizer.decode(question_input_ids[0], skip_special_tokens=True)
+
+ text = []
+ text_pair = []
+ for retrieved_block in retrieved_blocks:
+ text.append(question)
+ text_pair.append(retrieved_block.decode())
+
+ concat_inputs = self.tokenizer(
+ text, text_pair, padding=True, truncation=True, return_special_tokens_mask=True, max_length=max_length
+ )
+ concat_inputs_tensors = concat_inputs.convert_to_tensors(return_tensors)
+
+ if answer_ids is not None:
+ return self.block_has_answer(concat_inputs, answer_ids) + (concat_inputs_tensors,)
+ else:
+ return (None, None, None, concat_inputs_tensors)
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *init_inputs, **kwargs):
+ if os.path.isdir(pretrained_model_name_or_path):
+ block_records_path = os.path.join(pretrained_model_name_or_path, _REALM_BLOCK_RECORDS_FILENAME)
+ else:
+ block_records_path = hf_hub_download(
+ repo_id=pretrained_model_name_or_path, filename=_REALM_BLOCK_RECORDS_FILENAME, **kwargs
+ )
+ if not strtobool(os.environ.get("TRUST_REMOTE_CODE", "False")):
+ raise ValueError(
+ "This part uses `pickle.load` which is insecure and will execute arbitrary code that is "
+ "potentially malicious. It's recommended to never unpickle data that could have come from an "
+ "untrusted source, or that could have been tampered with. If you already verified the pickle "
+ "data and decided to use it, you can set the environment variable "
+ "`TRUST_REMOTE_CODE` to `True` to allow it."
+ )
+ block_records = np.load(block_records_path, allow_pickle=True)
+
+ tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, *init_inputs, **kwargs)
+
+ return cls(block_records, tokenizer)
+
+ def save_pretrained(self, save_directory):
+ # save block records
+ np.save(os.path.join(save_directory, _REALM_BLOCK_RECORDS_FILENAME), self.block_records)
+ # save tokenizer
+ self.tokenizer.save_pretrained(save_directory)
+
+ def block_has_answer(self, concat_inputs, answer_ids):
+ """check if retrieved_blocks has answers."""
+ has_answers = []
+ start_pos = []
+ end_pos = []
+ max_answers = 0
+
+ for input_id in concat_inputs.input_ids:
+ input_id_list = input_id.tolist()
+ # Check answers between two [SEP] tokens
+ first_sep_idx = input_id_list.index(self.tokenizer.sep_token_id)
+ second_sep_idx = first_sep_idx + 1 + input_id_list[first_sep_idx + 1 :].index(self.tokenizer.sep_token_id)
+
+ start_pos.append([])
+ end_pos.append([])
+ for answer in answer_ids:
+ for idx in range(first_sep_idx + 1, second_sep_idx):
+ if answer[0] == input_id_list[idx]:
+ if input_id_list[idx : idx + len(answer)] == answer:
+ start_pos[-1].append(idx)
+ end_pos[-1].append(idx + len(answer) - 1)
+
+ if len(start_pos[-1]) == 0:
+ has_answers.append(False)
+ else:
+ has_answers.append(True)
+ if len(start_pos[-1]) > max_answers:
+ max_answers = len(start_pos[-1])
+
+ # Pad -1 to max_answers
+ for start_pos_, end_pos_ in zip(start_pos, end_pos):
+ if len(start_pos_) < max_answers:
+ padded = [-1] * (max_answers - len(start_pos_))
+ start_pos_ += padded
+ end_pos_ += padded
+ return has_answers, start_pos, end_pos
+
+
+__all__ = ["RealmRetriever"]
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/realm/tokenization_realm.py b/docs/transformers/build/lib/transformers/models/deprecated/realm/tokenization_realm.py
new file mode 100644
index 0000000000000000000000000000000000000000..70e69bc4bc2bce4024322950add44e5f58b04f1c
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/realm/tokenization_realm.py
@@ -0,0 +1,563 @@
+# coding=utf-8
+# Copyright 2022 The REALM authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization classes for REALM."""
+
+import collections
+import os
+import unicodedata
+from typing import List, Optional, Tuple
+
+from ....tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace
+from ....tokenization_utils_base import BatchEncoding
+from ....utils import PaddingStrategy, logging
+
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
+
+
+def load_vocab(vocab_file):
+ """Loads a vocabulary file into a dictionary."""
+ vocab = collections.OrderedDict()
+ with open(vocab_file, "r", encoding="utf-8") as reader:
+ tokens = reader.readlines()
+ for index, token in enumerate(tokens):
+ token = token.rstrip("\n")
+ vocab[token] = index
+ return vocab
+
+
+def whitespace_tokenize(text):
+ """Runs basic whitespace cleaning and splitting on a piece of text."""
+ text = text.strip()
+ if not text:
+ return []
+ tokens = text.split()
+ return tokens
+
+
+class RealmTokenizer(PreTrainedTokenizer):
+ r"""
+ Construct a REALM tokenizer.
+
+ [`RealmTokenizer`] is identical to [`BertTokenizer`] and runs end-to-end tokenization: punctuation splitting and
+ wordpiece.
+
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
+ this superclass for more information regarding those methods.
+
+ Args:
+ vocab_file (`str`):
+ File containing the vocabulary.
+ do_lower_case (`bool`, *optional*, defaults to `True`):
+ Whether or not to lowercase the input when tokenizing.
+ do_basic_tokenize (`bool`, *optional*, defaults to `True`):
+ Whether or not to do basic tokenization before WordPiece.
+ never_split (`Iterable`, *optional*):
+ Collection of tokens which will never be split during tokenization. Only has an effect when
+ `do_basic_tokenize=True`
+ unk_token (`str`, *optional*, defaults to `"[UNK]"`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ sep_token (`str`, *optional*, defaults to `"[SEP]"`):
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
+ sequence classification or for a text and a question for question answering. It is also used as the last
+ token of a sequence built with special tokens.
+ pad_token (`str`, *optional*, defaults to `"[PAD]"`):
+ The token used for padding, for example when batching sequences of different lengths.
+ cls_token (`str`, *optional*, defaults to `"[CLS]"`):
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
+ mask_token (`str`, *optional*, defaults to `"[MASK]"`):
+ The token used for masking values. This is the token used when training this model with masked language
+ modeling. This is the token which the model will try to predict.
+ tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
+ Whether or not to tokenize Chinese characters.
+
+ This should likely be deactivated for Japanese (see this
+ [issue](https://github.com/huggingface/transformers/issues/328)).
+ strip_accents (`bool`, *optional*):
+ Whether or not to strip all accents. If this option is not specified, then it will be determined by the
+ value for `lowercase` (as in the original BERT).
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+
+ def __init__(
+ self,
+ vocab_file,
+ do_lower_case=True,
+ do_basic_tokenize=True,
+ never_split=None,
+ unk_token="[UNK]",
+ sep_token="[SEP]",
+ pad_token="[PAD]",
+ cls_token="[CLS]",
+ mask_token="[MASK]",
+ tokenize_chinese_chars=True,
+ strip_accents=None,
+ **kwargs,
+ ):
+ if not os.path.isfile(vocab_file):
+ raise ValueError(
+ f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained"
+ " model use `tokenizer = RealmTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
+ )
+ self.vocab = load_vocab(vocab_file)
+ self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
+ self.do_basic_tokenize = do_basic_tokenize
+ if do_basic_tokenize:
+ self.basic_tokenizer = BasicTokenizer(
+ do_lower_case=do_lower_case,
+ never_split=never_split,
+ tokenize_chinese_chars=tokenize_chinese_chars,
+ strip_accents=strip_accents,
+ )
+ self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=str(unk_token))
+ super().__init__(
+ do_lower_case=do_lower_case,
+ do_basic_tokenize=do_basic_tokenize,
+ never_split=never_split,
+ unk_token=unk_token,
+ sep_token=sep_token,
+ pad_token=pad_token,
+ cls_token=cls_token,
+ mask_token=mask_token,
+ tokenize_chinese_chars=tokenize_chinese_chars,
+ strip_accents=strip_accents,
+ **kwargs,
+ )
+
+ @property
+ def do_lower_case(self):
+ return self.basic_tokenizer.do_lower_case
+
+ @property
+ def vocab_size(self):
+ return len(self.vocab)
+
+ def get_vocab(self):
+ return dict(self.vocab, **self.added_tokens_encoder)
+
+ def _tokenize(self, text):
+ split_tokens = []
+ if self.do_basic_tokenize:
+ for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
+ # If the token is part of the never_split set
+ if token in self.basic_tokenizer.never_split:
+ split_tokens.append(token)
+ else:
+ split_tokens += self.wordpiece_tokenizer.tokenize(token)
+ else:
+ split_tokens = self.wordpiece_tokenizer.tokenize(text)
+ return split_tokens
+
+ def _convert_token_to_id(self, token):
+ """Converts a token (str) in an id using the vocab."""
+ return self.vocab.get(token, self.vocab.get(self.unk_token))
+
+ def _convert_id_to_token(self, index):
+ """Converts an index (integer) in a token (str) using the vocab."""
+ return self.ids_to_tokens.get(index, self.unk_token)
+
+ def convert_tokens_to_string(self, tokens):
+ """Converts a sequence of tokens (string) in a single string."""
+ out_string = " ".join(tokens).replace(" ##", "").strip()
+ return out_string
+
+ def batch_encode_candidates(self, text, **kwargs):
+ r"""
+ Encode a batch of text or text pair. This method is similar to regular __call__ method but has the following
+ differences:
+
+ 1. Handle additional num_candidate axis. (batch_size, num_candidates, text)
+ 2. Always pad the sequences to *max_length*.
+ 3. Must specify *max_length* in order to stack packs of candidates into a batch.
+
+ - single sequence: `[CLS] X [SEP]`
+ - pair of sequences: `[CLS] A [SEP] B [SEP]`
+
+ Args:
+ text (`List[List[str]]`):
+ The batch of sequences to be encoded. Each sequence must be in this format: (batch_size,
+ num_candidates, text).
+ text_pair (`List[List[str]]`, *optional*):
+ The batch of sequences to be encoded. Each sequence must be in this format: (batch_size,
+ num_candidates, text).
+ **kwargs:
+ Keyword arguments of the __call__ method.
+
+ Returns:
+ [`BatchEncoding`]: Encoded text or text pair.
+
+ Example:
+
+ ```python
+ >>> from transformers import RealmTokenizer
+
+ >>> # batch_size = 2, num_candidates = 2
+ >>> text = [["Hello world!", "Nice to meet you!"], ["The cute cat.", "The adorable dog."]]
+
+ >>> tokenizer = RealmTokenizer.from_pretrained("google/realm-cc-news-pretrained-encoder")
+ >>> tokenized_text = tokenizer.batch_encode_candidates(text, max_length=10, return_tensors="pt")
+ ```"""
+
+ # Always using a fixed sequence length to encode in order to stack candidates into a batch.
+ kwargs["padding"] = PaddingStrategy.MAX_LENGTH
+
+ batch_text = text
+ batch_text_pair = kwargs.pop("text_pair", None)
+ return_tensors = kwargs.pop("return_tensors", None)
+
+ output_data = {
+ "input_ids": [],
+ "attention_mask": [],
+ "token_type_ids": [],
+ }
+
+ for idx, candidate_text in enumerate(batch_text):
+ if batch_text_pair is not None:
+ candidate_text_pair = batch_text_pair[idx]
+ else:
+ candidate_text_pair = None
+
+ encoded_candidates = super().__call__(candidate_text, candidate_text_pair, return_tensors=None, **kwargs)
+
+ encoded_input_ids = encoded_candidates.get("input_ids")
+ encoded_attention_mask = encoded_candidates.get("attention_mask")
+ encoded_token_type_ids = encoded_candidates.get("token_type_ids")
+
+ if encoded_input_ids is not None:
+ output_data["input_ids"].append(encoded_input_ids)
+ if encoded_attention_mask is not None:
+ output_data["attention_mask"].append(encoded_attention_mask)
+ if encoded_token_type_ids is not None:
+ output_data["token_type_ids"].append(encoded_token_type_ids)
+
+ output_data = {key: item for key, item in output_data.items() if len(item) != 0}
+
+ return BatchEncoding(output_data, tensor_type=return_tensors)
+
+ def build_inputs_with_special_tokens(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
+ adding special tokens. A REALM sequence has the following format:
+
+ - single sequence: `[CLS] X [SEP]`
+ - pair of sequences: `[CLS] A [SEP] B [SEP]`
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs to which the special tokens will be added.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
+ """
+ if token_ids_1 is None:
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
+ cls = [self.cls_token_id]
+ sep = [self.sep_token_id]
+ return cls + token_ids_0 + sep + token_ids_1 + sep
+
+ def get_special_tokens_mask(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
+ ) -> List[int]:
+ """
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
+ special tokens using the tokenizer `prepare_for_model` method.
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
+ Whether or not the token list is already formatted with special tokens for the model.
+
+ Returns:
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
+ """
+
+ if already_has_special_tokens:
+ return super().get_special_tokens_mask(
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
+ )
+
+ if token_ids_1 is not None:
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
+ return [1] + ([0] * len(token_ids_0)) + [1]
+
+ def create_token_type_ids_from_sequences(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. A REALM sequence
+ pair mask has the following format:
+
+ ```
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
+ | first sequence | second sequence |
+ ```
+
+ If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s).
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
+ """
+ sep = [self.sep_token_id]
+ cls = [self.cls_token_id]
+ if token_ids_1 is None:
+ return len(cls + token_ids_0 + sep) * [0]
+ return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
+
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
+ index = 0
+ if os.path.isdir(save_directory):
+ vocab_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
+ )
+ else:
+ vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory
+ with open(vocab_file, "w", encoding="utf-8") as writer:
+ for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
+ if index != token_index:
+ logger.warning(
+ f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
+ " Please check that the vocabulary is not corrupted!"
+ )
+ index = token_index
+ writer.write(token + "\n")
+ index += 1
+ return (vocab_file,)
+
+
+class BasicTokenizer:
+ """
+ Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.).
+
+ Args:
+ do_lower_case (`bool`, *optional*, defaults to `True`):
+ Whether or not to lowercase the input when tokenizing.
+ never_split (`Iterable`, *optional*):
+ Collection of tokens which will never be split during tokenization. Only has an effect when
+ `do_basic_tokenize=True`
+ tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
+ Whether or not to tokenize Chinese characters.
+
+ This should likely be deactivated for Japanese (see this
+ [issue](https://github.com/huggingface/transformers/issues/328)).
+ strip_accents (`bool`, *optional*):
+ Whether or not to strip all accents. If this option is not specified, then it will be determined by the
+ value for `lowercase` (as in the original BERT).
+ """
+
+ def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True, strip_accents=None):
+ if never_split is None:
+ never_split = []
+ self.do_lower_case = do_lower_case
+ self.never_split = set(never_split)
+ self.tokenize_chinese_chars = tokenize_chinese_chars
+ self.strip_accents = strip_accents
+
+ def tokenize(self, text, never_split=None):
+ """
+ Basic Tokenization of a piece of text. Split on "white spaces" only, for sub-word tokenization, see
+ WordPieceTokenizer.
+
+ Args:
+ never_split (`List[str]`, *optional*)
+ Kept for backward compatibility purposes. Now implemented directly at the base class level (see
+ [`PreTrainedTokenizer.tokenize`]) List of token not to split.
+ """
+ # union() returns a new set by concatenating the two sets.
+ never_split = self.never_split.union(set(never_split)) if never_split else self.never_split
+ text = self._clean_text(text)
+
+ # This was added on November 1st, 2018 for the multilingual and Chinese
+ # models. This is also applied to the English models now, but it doesn't
+ # matter since the English models were not trained on any Chinese data
+ # and generally don't have any Chinese data in them (there are Chinese
+ # characters in the vocabulary because Wikipedia does have some Chinese
+ # words in the English Wikipedia.).
+ if self.tokenize_chinese_chars:
+ text = self._tokenize_chinese_chars(text)
+ orig_tokens = whitespace_tokenize(text)
+ split_tokens = []
+ for token in orig_tokens:
+ if token not in never_split:
+ if self.do_lower_case:
+ token = token.lower()
+ if self.strip_accents is not False:
+ token = self._run_strip_accents(token)
+ elif self.strip_accents:
+ token = self._run_strip_accents(token)
+ split_tokens.extend(self._run_split_on_punc(token, never_split))
+
+ output_tokens = whitespace_tokenize(" ".join(split_tokens))
+ return output_tokens
+
+ def _run_strip_accents(self, text):
+ """Strips accents from a piece of text."""
+ text = unicodedata.normalize("NFD", text)
+ output = []
+ for char in text:
+ cat = unicodedata.category(char)
+ if cat == "Mn":
+ continue
+ output.append(char)
+ return "".join(output)
+
+ def _run_split_on_punc(self, text, never_split=None):
+ """Splits punctuation on a piece of text."""
+ if never_split is not None and text in never_split:
+ return [text]
+ chars = list(text)
+ i = 0
+ start_new_word = True
+ output = []
+ while i < len(chars):
+ char = chars[i]
+ if _is_punctuation(char):
+ output.append([char])
+ start_new_word = True
+ else:
+ if start_new_word:
+ output.append([])
+ start_new_word = False
+ output[-1].append(char)
+ i += 1
+
+ return ["".join(x) for x in output]
+
+ def _tokenize_chinese_chars(self, text):
+ """Adds whitespace around any CJK character."""
+ output = []
+ for char in text:
+ cp = ord(char)
+ if self._is_chinese_char(cp):
+ output.append(" ")
+ output.append(char)
+ output.append(" ")
+ else:
+ output.append(char)
+ return "".join(output)
+
+ def _is_chinese_char(self, cp):
+ """Checks whether CP is the codepoint of a CJK character."""
+ # This defines a "chinese character" as anything in the CJK Unicode block:
+ # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
+ #
+ # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
+ # despite its name. The modern Korean Hangul alphabet is a different block,
+ # as is Japanese Hiragana and Katakana. Those alphabets are used to write
+ # space-separated words, so they are not treated specially and handled
+ # like the all of the other languages.
+ if (
+ (cp >= 0x4E00 and cp <= 0x9FFF)
+ or (cp >= 0x3400 and cp <= 0x4DBF) #
+ or (cp >= 0x20000 and cp <= 0x2A6DF) #
+ or (cp >= 0x2A700 and cp <= 0x2B73F) #
+ or (cp >= 0x2B740 and cp <= 0x2B81F) #
+ or (cp >= 0x2B820 and cp <= 0x2CEAF) #
+ or (cp >= 0xF900 and cp <= 0xFAFF)
+ or (cp >= 0x2F800 and cp <= 0x2FA1F) #
+ ): #
+ return True
+
+ return False
+
+ def _clean_text(self, text):
+ """Performs invalid character removal and whitespace cleanup on text."""
+ output = []
+ for char in text:
+ cp = ord(char)
+ if cp == 0 or cp == 0xFFFD or _is_control(char):
+ continue
+ if _is_whitespace(char):
+ output.append(" ")
+ else:
+ output.append(char)
+ return "".join(output)
+
+
+class WordpieceTokenizer:
+ """Runs WordPiece tokenization."""
+
+ def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
+ self.vocab = vocab
+ self.unk_token = unk_token
+ self.max_input_chars_per_word = max_input_chars_per_word
+
+ def tokenize(self, text):
+ """
+ Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform
+ tokenization using the given vocabulary.
+
+ For example, `input = "unaffable"` wil return as output `["un", "##aff", "##able"]`.
+
+ Args:
+ text: A single token or whitespace separated tokens. This should have
+ already been passed through *BasicTokenizer*.
+
+ Returns:
+ A list of wordpiece tokens.
+ """
+
+ output_tokens = []
+ for token in whitespace_tokenize(text):
+ chars = list(token)
+ if len(chars) > self.max_input_chars_per_word:
+ output_tokens.append(self.unk_token)
+ continue
+
+ is_bad = False
+ start = 0
+ sub_tokens = []
+ while start < len(chars):
+ end = len(chars)
+ cur_substr = None
+ while start < end:
+ substr = "".join(chars[start:end])
+ if start > 0:
+ substr = "##" + substr
+ if substr in self.vocab:
+ cur_substr = substr
+ break
+ end -= 1
+ if cur_substr is None:
+ is_bad = True
+ break
+ sub_tokens.append(cur_substr)
+ start = end
+
+ if is_bad:
+ output_tokens.append(self.unk_token)
+ else:
+ output_tokens.extend(sub_tokens)
+ return output_tokens
+
+
+__all__ = ["RealmTokenizer"]
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/realm/tokenization_realm_fast.py b/docs/transformers/build/lib/transformers/models/deprecated/realm/tokenization_realm_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c173227befd31e7f6e1fed26877dceb13ad4041
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/realm/tokenization_realm_fast.py
@@ -0,0 +1,252 @@
+# coding=utf-8
+# Copyright 2022 The REALM authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Fast Tokenization classes for REALM."""
+
+import json
+from typing import List, Optional, Tuple
+
+from tokenizers import normalizers
+
+from ....tokenization_utils_base import BatchEncoding
+from ....tokenization_utils_fast import PreTrainedTokenizerFast
+from ....utils import PaddingStrategy, logging
+from .tokenization_realm import RealmTokenizer
+
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.json"}
+
+
+class RealmTokenizerFast(PreTrainedTokenizerFast):
+ r"""
+ Construct a "fast" REALM tokenizer (backed by HuggingFace's *tokenizers* library). Based on WordPiece.
+
+ [`RealmTokenizerFast`] is identical to [`BertTokenizerFast`] and runs end-to-end tokenization: punctuation
+ splitting and wordpiece.
+
+ This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
+ refer to this superclass for more information regarding those methods.
+
+ Args:
+ vocab_file (`str`):
+ File containing the vocabulary.
+ do_lower_case (`bool`, *optional*, defaults to `True`):
+ Whether or not to lowercase the input when tokenizing.
+ unk_token (`str`, *optional*, defaults to `"[UNK]"`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ sep_token (`str`, *optional*, defaults to `"[SEP]"`):
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
+ sequence classification or for a text and a question for question answering. It is also used as the last
+ token of a sequence built with special tokens.
+ pad_token (`str`, *optional*, defaults to `"[PAD]"`):
+ The token used for padding, for example when batching sequences of different lengths.
+ cls_token (`str`, *optional*, defaults to `"[CLS]"`):
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
+ mask_token (`str`, *optional*, defaults to `"[MASK]"`):
+ The token used for masking values. This is the token used when training this model with masked language
+ modeling. This is the token which the model will try to predict.
+ clean_text (`bool`, *optional*, defaults to `True`):
+ Whether or not to clean the text before tokenization by removing any control characters and replacing all
+ whitespaces by the classic one.
+ tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
+ Whether or not to tokenize Chinese characters. This should likely be deactivated for Japanese (see [this
+ issue](https://github.com/huggingface/transformers/issues/328)).
+ strip_accents (`bool`, *optional*):
+ Whether or not to strip all accents. If this option is not specified, then it will be determined by the
+ value for `lowercase` (as in the original BERT).
+ wordpieces_prefix (`str`, *optional*, defaults to `"##"`):
+ The prefix for subwords.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ slow_tokenizer_class = RealmTokenizer
+
+ def __init__(
+ self,
+ vocab_file=None,
+ tokenizer_file=None,
+ do_lower_case=True,
+ unk_token="[UNK]",
+ sep_token="[SEP]",
+ pad_token="[PAD]",
+ cls_token="[CLS]",
+ mask_token="[MASK]",
+ tokenize_chinese_chars=True,
+ strip_accents=None,
+ **kwargs,
+ ):
+ super().__init__(
+ vocab_file,
+ tokenizer_file=tokenizer_file,
+ do_lower_case=do_lower_case,
+ unk_token=unk_token,
+ sep_token=sep_token,
+ pad_token=pad_token,
+ cls_token=cls_token,
+ mask_token=mask_token,
+ tokenize_chinese_chars=tokenize_chinese_chars,
+ strip_accents=strip_accents,
+ **kwargs,
+ )
+
+ normalizer_state = json.loads(self.backend_tokenizer.normalizer.__getstate__())
+ if (
+ normalizer_state.get("lowercase", do_lower_case) != do_lower_case
+ or normalizer_state.get("strip_accents", strip_accents) != strip_accents
+ or normalizer_state.get("handle_chinese_chars", tokenize_chinese_chars) != tokenize_chinese_chars
+ ):
+ normalizer_class = getattr(normalizers, normalizer_state.pop("type"))
+ normalizer_state["lowercase"] = do_lower_case
+ normalizer_state["strip_accents"] = strip_accents
+ normalizer_state["handle_chinese_chars"] = tokenize_chinese_chars
+ self.backend_tokenizer.normalizer = normalizer_class(**normalizer_state)
+
+ self.do_lower_case = do_lower_case
+
+ def batch_encode_candidates(self, text, **kwargs):
+ r"""
+ Encode a batch of text or text pair. This method is similar to regular __call__ method but has the following
+ differences:
+
+ 1. Handle additional num_candidate axis. (batch_size, num_candidates, text)
+ 2. Always pad the sequences to *max_length*.
+ 3. Must specify *max_length* in order to stack packs of candidates into a batch.
+
+ - single sequence: `[CLS] X [SEP]`
+ - pair of sequences: `[CLS] A [SEP] B [SEP]`
+
+ Args:
+ text (`List[List[str]]`):
+ The batch of sequences to be encoded. Each sequence must be in this format: (batch_size,
+ num_candidates, text).
+ text_pair (`List[List[str]]`, *optional*):
+ The batch of sequences to be encoded. Each sequence must be in this format: (batch_size,
+ num_candidates, text).
+ **kwargs:
+ Keyword arguments of the __call__ method.
+
+ Returns:
+ [`BatchEncoding`]: Encoded text or text pair.
+
+ Example:
+
+ ```python
+ >>> from transformers import RealmTokenizerFast
+
+ >>> # batch_size = 2, num_candidates = 2
+ >>> text = [["Hello world!", "Nice to meet you!"], ["The cute cat.", "The adorable dog."]]
+
+ >>> tokenizer = RealmTokenizerFast.from_pretrained("google/realm-cc-news-pretrained-encoder")
+ >>> tokenized_text = tokenizer.batch_encode_candidates(text, max_length=10, return_tensors="pt")
+ ```"""
+
+ # Always using a fixed sequence length to encode in order to stack candidates into a batch.
+ kwargs["padding"] = PaddingStrategy.MAX_LENGTH
+
+ batch_text = text
+ batch_text_pair = kwargs.pop("text_pair", None)
+ return_tensors = kwargs.pop("return_tensors", None)
+
+ output_data = {
+ "input_ids": [],
+ "attention_mask": [],
+ "token_type_ids": [],
+ }
+
+ for idx, candidate_text in enumerate(batch_text):
+ if batch_text_pair is not None:
+ candidate_text_pair = batch_text_pair[idx]
+ else:
+ candidate_text_pair = None
+
+ encoded_candidates = super().__call__(candidate_text, candidate_text_pair, return_tensors=None, **kwargs)
+
+ encoded_input_ids = encoded_candidates.get("input_ids")
+ encoded_attention_mask = encoded_candidates.get("attention_mask")
+ encoded_token_type_ids = encoded_candidates.get("token_type_ids")
+
+ if encoded_input_ids is not None:
+ output_data["input_ids"].append(encoded_input_ids)
+ if encoded_attention_mask is not None:
+ output_data["attention_mask"].append(encoded_attention_mask)
+ if encoded_token_type_ids is not None:
+ output_data["token_type_ids"].append(encoded_token_type_ids)
+
+ output_data = {key: item for key, item in output_data.items() if len(item) != 0}
+
+ return BatchEncoding(output_data, tensor_type=return_tensors)
+
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
+ """
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
+ adding special tokens. A REALM sequence has the following format:
+
+ - single sequence: `[CLS] X [SEP]`
+ - pair of sequences: `[CLS] A [SEP] B [SEP]`
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs to which the special tokens will be added.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
+ """
+ output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
+
+ if token_ids_1 is not None:
+ output += token_ids_1 + [self.sep_token_id]
+
+ return output
+
+ def create_token_type_ids_from_sequences(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. A REALM sequence
+ pair mask has the following format:
+
+ ```
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
+ | first sequence | second sequence |
+ ```
+
+ If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s).
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
+ """
+ sep = [self.sep_token_id]
+ cls = [self.cls_token_id]
+ if token_ids_1 is None:
+ return len(cls + token_ids_0 + sep) * [0]
+ return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
+
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
+ files = self._tokenizer.model.save(save_directory, name=filename_prefix)
+ return tuple(files)
+
+
+__all__ = ["RealmTokenizerFast"]
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/retribert/__init__.py b/docs/transformers/build/lib/transformers/models/deprecated/retribert/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a875576607430c041abab01eccaf468a6cc9272e
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/retribert/__init__.py
@@ -0,0 +1,29 @@
+# Copyright 2020 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ....utils import _LazyModule
+from ....utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_retribert import *
+ from .modeling_retribert import *
+ from .tokenization_retribert import *
+ from .tokenization_retribert_fast import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/retribert/configuration_retribert.py b/docs/transformers/build/lib/transformers/models/deprecated/retribert/configuration_retribert.py
new file mode 100644
index 0000000000000000000000000000000000000000..80d755a16961450cae783d834385e5e6873dc24e
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/retribert/configuration_retribert.py
@@ -0,0 +1,108 @@
+# coding=utf-8
+# Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""RetriBERT model configuration"""
+
+from ....configuration_utils import PretrainedConfig
+from ....utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class RetriBertConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`RetriBertModel`]. It is used to instantiate a
+ RetriBertModel model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the RetriBERT
+ [yjernite/retribert-base-uncased](https://huggingface.co/yjernite/retribert-base-uncased) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 30522):
+ Vocabulary size of the RetriBERT model. Defines the number of different tokens that can be represented by
+ the `inputs_ids` passed when calling [`RetriBertModel`]
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ intermediate_size (`int`, *optional*, defaults to 3072):
+ Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for the attention probabilities.
+ max_position_embeddings (`int`, *optional*, defaults to 512):
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
+ just in case (e.g., 512 or 1024 or 2048).
+ type_vocab_size (`int`, *optional*, defaults to 2):
+ The vocabulary size of the *token_type_ids* passed into [`BertModel`].
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
+ The epsilon used by the layer normalization layers.
+ share_encoders (`bool`, *optional*, defaults to `True`):
+ Whether or not to use the same Bert-type encoder for the queries and document
+ projection_dim (`int`, *optional*, defaults to 128):
+ Final dimension of the query and document representation after projection
+ """
+
+ model_type = "retribert"
+
+ def __init__(
+ self,
+ vocab_size=30522,
+ hidden_size=768,
+ num_hidden_layers=8,
+ num_attention_heads=12,
+ intermediate_size=3072,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ max_position_embeddings=512,
+ type_vocab_size=2,
+ initializer_range=0.02,
+ layer_norm_eps=1e-12,
+ share_encoders=True,
+ projection_dim=128,
+ pad_token_id=0,
+ **kwargs,
+ ):
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
+
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.hidden_act = hidden_act
+ self.intermediate_size = intermediate_size
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.max_position_embeddings = max_position_embeddings
+ self.type_vocab_size = type_vocab_size
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.share_encoders = share_encoders
+ self.projection_dim = projection_dim
+
+
+__all__ = ["RetriBertConfig"]
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/retribert/modeling_retribert.py b/docs/transformers/build/lib/transformers/models/deprecated/retribert/modeling_retribert.py
new file mode 100644
index 0000000000000000000000000000000000000000..bcae1c0239e6f9b214e6227bf279ad9fc7b8381a
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/retribert/modeling_retribert.py
@@ -0,0 +1,217 @@
+# coding=utf-8
+# Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+RetriBERT model
+"""
+
+import math
+from typing import Optional
+
+import torch
+import torch.utils.checkpoint as checkpoint
+from torch import nn
+
+from ....modeling_utils import PreTrainedModel
+from ....utils import add_start_docstrings, logging
+from ...bert.modeling_bert import BertModel
+from .configuration_retribert import RetriBertConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+# INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL #
+class RetriBertPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = RetriBertConfig
+ load_tf_weights = None
+ base_model_prefix = "retribert"
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+
+RETRIBERT_START_DOCSTRING = r"""
+
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`RetriBertConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+@add_start_docstrings(
+ """Bert Based model to embed queries or document for document retrieval.""",
+ RETRIBERT_START_DOCSTRING,
+)
+class RetriBertModel(RetriBertPreTrainedModel):
+ def __init__(self, config: RetriBertConfig) -> None:
+ super().__init__(config)
+ self.projection_dim = config.projection_dim
+
+ self.bert_query = BertModel(config)
+ self.bert_doc = None if config.share_encoders else BertModel(config)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ self.project_query = nn.Linear(config.hidden_size, config.projection_dim, bias=False)
+ self.project_doc = nn.Linear(config.hidden_size, config.projection_dim, bias=False)
+
+ self.ce_loss = nn.CrossEntropyLoss(reduction="mean")
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def embed_sentences_checkpointed(
+ self,
+ input_ids,
+ attention_mask,
+ sent_encoder,
+ checkpoint_batch_size=-1,
+ ):
+ # reproduces BERT forward pass with checkpointing
+ if checkpoint_batch_size < 0 or input_ids.shape[0] < checkpoint_batch_size:
+ return sent_encoder(input_ids, attention_mask=attention_mask)[1]
+ else:
+ # prepare implicit variables
+ device = input_ids.device
+ input_shape = input_ids.size()
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
+ head_mask = [None] * sent_encoder.config.num_hidden_layers
+ extended_attention_mask: torch.Tensor = sent_encoder.get_extended_attention_mask(
+ attention_mask, input_shape
+ )
+
+ # define function for checkpointing
+ def partial_encode(*inputs):
+ encoder_outputs = sent_encoder.encoder(
+ inputs[0],
+ attention_mask=inputs[1],
+ head_mask=head_mask,
+ )
+ sequence_output = encoder_outputs[0]
+ pooled_output = sent_encoder.pooler(sequence_output)
+ return pooled_output
+
+ # run embedding layer on everything at once
+ embedding_output = sent_encoder.embeddings(
+ input_ids=input_ids, position_ids=None, token_type_ids=token_type_ids, inputs_embeds=None
+ )
+ # run encoding and pooling on one mini-batch at a time
+ pooled_output_list = []
+ for b in range(math.ceil(input_ids.shape[0] / checkpoint_batch_size)):
+ b_embedding_output = embedding_output[b * checkpoint_batch_size : (b + 1) * checkpoint_batch_size]
+ b_attention_mask = extended_attention_mask[b * checkpoint_batch_size : (b + 1) * checkpoint_batch_size]
+ pooled_output = checkpoint.checkpoint(partial_encode, b_embedding_output, b_attention_mask)
+ pooled_output_list.append(pooled_output)
+ return torch.cat(pooled_output_list, dim=0)
+
+ def embed_questions(
+ self,
+ input_ids,
+ attention_mask=None,
+ checkpoint_batch_size=-1,
+ ):
+ q_reps = self.embed_sentences_checkpointed(
+ input_ids,
+ attention_mask,
+ self.bert_query,
+ checkpoint_batch_size,
+ )
+ return self.project_query(q_reps)
+
+ def embed_answers(
+ self,
+ input_ids,
+ attention_mask=None,
+ checkpoint_batch_size=-1,
+ ):
+ a_reps = self.embed_sentences_checkpointed(
+ input_ids,
+ attention_mask,
+ self.bert_query if self.bert_doc is None else self.bert_doc,
+ checkpoint_batch_size,
+ )
+ return self.project_doc(a_reps)
+
+ def forward(
+ self,
+ input_ids_query: torch.LongTensor,
+ attention_mask_query: Optional[torch.FloatTensor],
+ input_ids_doc: torch.LongTensor,
+ attention_mask_doc: Optional[torch.FloatTensor],
+ checkpoint_batch_size: int = -1,
+ ) -> torch.FloatTensor:
+ r"""
+ Args:
+ input_ids_query (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary for the queries in a batch.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask_query (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ input_ids_doc (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary for the documents in a batch.
+ attention_mask_doc (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on documents padding token indices.
+ checkpoint_batch_size (`int`, *optional*, defaults to `-1`):
+ If greater than 0, uses gradient checkpointing to only compute sequence representation on
+ `checkpoint_batch_size` examples at a time on the GPU. All query representations are still compared to
+ all document representations in the batch.
+
+ Return:
+ `torch.FloatTensor``: The bidirectional cross-entropy loss obtained while trying to match each query to its
+ corresponding document and each document to its corresponding query in the batch
+ """
+ device = input_ids_query.device
+ q_reps = self.embed_questions(input_ids_query, attention_mask_query, checkpoint_batch_size)
+ a_reps = self.embed_answers(input_ids_doc, attention_mask_doc, checkpoint_batch_size)
+ compare_scores = torch.mm(q_reps, a_reps.t())
+ loss_qa = self.ce_loss(compare_scores, torch.arange(compare_scores.shape[1]).to(device))
+ loss_aq = self.ce_loss(compare_scores.t(), torch.arange(compare_scores.shape[0]).to(device))
+ loss = (loss_qa + loss_aq) / 2
+ return loss
+
+
+__all__ = ["RetriBertModel", "RetriBertPreTrainedModel"]
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/retribert/tokenization_retribert.py b/docs/transformers/build/lib/transformers/models/deprecated/retribert/tokenization_retribert.py
new file mode 100644
index 0000000000000000000000000000000000000000..35a1874aa0e3c8c39e3d96c20a31d3b0595235f1
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/retribert/tokenization_retribert.py
@@ -0,0 +1,504 @@
+# coding=utf-8
+# Copyright 2018 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization classes for RetriBERT."""
+
+import collections
+import os
+import unicodedata
+from typing import List, Optional, Tuple
+
+from ....tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace
+from ....utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
+
+
+def load_vocab(vocab_file):
+ """Loads a vocabulary file into a dictionary."""
+ vocab = collections.OrderedDict()
+ with open(vocab_file, "r", encoding="utf-8") as reader:
+ tokens = reader.readlines()
+ for index, token in enumerate(tokens):
+ token = token.rstrip("\n")
+ vocab[token] = index
+ return vocab
+
+
+def whitespace_tokenize(text):
+ """Runs basic whitespace cleaning and splitting on a piece of text."""
+ text = text.strip()
+ if not text:
+ return []
+ tokens = text.split()
+ return tokens
+
+
+class RetriBertTokenizer(PreTrainedTokenizer):
+ r"""
+ Constructs a RetriBERT tokenizer.
+
+ [`RetriBertTokenizer`] is identical to [`BertTokenizer`] and runs end-to-end tokenization: punctuation splitting
+ and wordpiece.
+
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer
+ to: this superclass for more information regarding those methods.
+
+ Args:
+ vocab_file (`str`):
+ File containing the vocabulary.
+ do_lower_case (`bool`, *optional*, defaults to `True`):
+ Whether or not to lowercase the input when tokenizing.
+ do_basic_tokenize (`bool`, *optional*, defaults to `True`):
+ Whether or not to do basic tokenization before WordPiece.
+ never_split (`Iterable`, *optional*):
+ Collection of tokens which will never be split during tokenization. Only has an effect when
+ `do_basic_tokenize=True`
+ unk_token (`str`, *optional*, defaults to `"[UNK]"`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ sep_token (`str`, *optional*, defaults to `"[SEP]"`):
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
+ sequence classification or for a text and a question for question answering. It is also used as the last
+ token of a sequence built with special tokens.
+ pad_token (`str`, *optional*, defaults to `"[PAD]"`):
+ The token used for padding, for example when batching sequences of different lengths.
+ cls_token (`str`, *optional*, defaults to `"[CLS]"`):
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
+ mask_token (`str`, *optional*, defaults to `"[MASK]"`):
+ The token used for masking values. This is the token used when training this model with masked language
+ modeling. This is the token which the model will try to predict.
+ tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
+ Whether or not to tokenize Chinese characters. This should likely be deactivated for Japanese (see this
+ [issue](https://github.com/huggingface/transformers/issues/328)).
+ strip_accents (`bool`, *optional*):
+ Whether or not to strip all accents. If this option is not specified, then it will be determined by the
+ value for `lowercase` (as in the original BERT).
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ model_input_names = ["input_ids", "attention_mask"]
+
+ def __init__(
+ self,
+ vocab_file,
+ do_lower_case=True,
+ do_basic_tokenize=True,
+ never_split=None,
+ unk_token="[UNK]",
+ sep_token="[SEP]",
+ pad_token="[PAD]",
+ cls_token="[CLS]",
+ mask_token="[MASK]",
+ tokenize_chinese_chars=True,
+ strip_accents=None,
+ **kwargs,
+ ):
+ if not os.path.isfile(vocab_file):
+ raise ValueError(
+ f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained"
+ " model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
+ )
+ self.vocab = load_vocab(vocab_file)
+ self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
+ self.do_basic_tokenize = do_basic_tokenize
+ if do_basic_tokenize:
+ self.basic_tokenizer = BasicTokenizer(
+ do_lower_case=do_lower_case,
+ never_split=never_split,
+ tokenize_chinese_chars=tokenize_chinese_chars,
+ strip_accents=strip_accents,
+ )
+
+ self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=str(unk_token))
+
+ super().__init__(
+ do_lower_case=do_lower_case,
+ do_basic_tokenize=do_basic_tokenize,
+ never_split=never_split,
+ unk_token=unk_token,
+ sep_token=sep_token,
+ pad_token=pad_token,
+ cls_token=cls_token,
+ mask_token=mask_token,
+ tokenize_chinese_chars=tokenize_chinese_chars,
+ strip_accents=strip_accents,
+ **kwargs,
+ )
+
+ @property
+ def do_lower_case(self):
+ return self.basic_tokenizer.do_lower_case
+
+ @property
+ def vocab_size(self):
+ return len(self.vocab)
+
+ def get_vocab(self):
+ return dict(self.vocab, **self.added_tokens_encoder)
+
+ def _tokenize(self, text, split_special_tokens=False):
+ split_tokens = []
+ if self.do_basic_tokenize:
+ for token in self.basic_tokenizer.tokenize(
+ text, never_split=self.all_special_tokens if not split_special_tokens else None
+ ):
+ # If the token is part of the never_split set
+ if token in self.basic_tokenizer.never_split:
+ split_tokens.append(token)
+ else:
+ split_tokens += self.wordpiece_tokenizer.tokenize(token)
+ else:
+ split_tokens = self.wordpiece_tokenizer.tokenize(text)
+ return split_tokens
+
+ def _convert_token_to_id(self, token):
+ """Converts a token (str) in an id using the vocab."""
+ return self.vocab.get(token, self.vocab.get(self.unk_token))
+
+ def _convert_id_to_token(self, index):
+ """Converts an index (integer) in a token (str) using the vocab."""
+ return self.ids_to_tokens.get(index, self.unk_token)
+
+ def convert_tokens_to_string(self, tokens):
+ """Converts a sequence of tokens (string) in a single string."""
+ out_string = " ".join(tokens).replace(" ##", "").strip()
+ return out_string
+
+ def build_inputs_with_special_tokens(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
+ adding special tokens. A BERT sequence has the following format:
+
+ - single sequence: `[CLS] X [SEP]`
+ - pair of sequences: `[CLS] A [SEP] B [SEP]`
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs to which the special tokens will be added.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
+ """
+ if token_ids_1 is None:
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
+ cls = [self.cls_token_id]
+ sep = [self.sep_token_id]
+ return cls + token_ids_0 + sep + token_ids_1 + sep
+
+ def get_special_tokens_mask(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
+ ) -> List[int]:
+ """
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
+ special tokens using the tokenizer `prepare_for_model` method.
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
+ Whether or not the token list is already formatted with special tokens for the model.
+
+ Returns:
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
+ """
+
+ if already_has_special_tokens:
+ return super().get_special_tokens_mask(
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
+ )
+
+ if token_ids_1 is not None:
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
+ return [1] + ([0] * len(token_ids_0)) + [1]
+
+ def create_token_type_ids_from_sequences(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence
+ pair mask has the following format:
+
+ ```
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
+ | first sequence | second sequence |
+ ```
+
+ If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s).
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
+ """
+ sep = [self.sep_token_id]
+ cls = [self.cls_token_id]
+ if token_ids_1 is None:
+ return len(cls + token_ids_0 + sep) * [0]
+ return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
+
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
+ index = 0
+ if os.path.isdir(save_directory):
+ vocab_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
+ )
+ else:
+ vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory
+ with open(vocab_file, "w", encoding="utf-8") as writer:
+ for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
+ if index != token_index:
+ logger.warning(
+ f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
+ " Please check that the vocabulary is not corrupted!"
+ )
+ index = token_index
+ writer.write(token + "\n")
+ index += 1
+ return (vocab_file,)
+
+
+class BasicTokenizer:
+ """
+ Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.).
+
+ Args:
+ do_lower_case (`bool`, *optional*, defaults to `True`):
+ Whether or not to lowercase the input when tokenizing.
+ never_split (`Iterable`, *optional*):
+ Collection of tokens which will never be split during tokenization. Only has an effect when
+ `do_basic_tokenize=True`
+ tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
+ Whether or not to tokenize Chinese characters.
+
+ This should likely be deactivated for Japanese (see this
+ [issue](https://github.com/huggingface/transformers/issues/328)).
+ strip_accents (`bool`, *optional*):
+ Whether or not to strip all accents. If this option is not specified, then it will be determined by the
+ value for `lowercase` (as in the original BERT).
+ do_split_on_punc (`bool`, *optional*, defaults to `True`):
+ In some instances we want to skip the basic punctuation splitting so that later tokenization can capture
+ the full context of the words, such as contractions.
+ """
+
+ def __init__(
+ self,
+ do_lower_case=True,
+ never_split=None,
+ tokenize_chinese_chars=True,
+ strip_accents=None,
+ do_split_on_punc=True,
+ ):
+ if never_split is None:
+ never_split = []
+ self.do_lower_case = do_lower_case
+ self.never_split = set(never_split)
+ self.tokenize_chinese_chars = tokenize_chinese_chars
+ self.strip_accents = strip_accents
+ self.do_split_on_punc = do_split_on_punc
+
+ def tokenize(self, text, never_split=None):
+ """
+ Basic Tokenization of a piece of text. For sub-word tokenization, see WordPieceTokenizer.
+
+ Args:
+ never_split (`List[str]`, *optional*)
+ Kept for backward compatibility purposes. Now implemented directly at the base class level (see
+ [`PreTrainedTokenizer.tokenize`]) List of token not to split.
+ """
+ # union() returns a new set by concatenating the two sets.
+ never_split = self.never_split.union(set(never_split)) if never_split else self.never_split
+ text = self._clean_text(text)
+
+ # This was added on November 1st, 2018 for the multilingual and Chinese
+ # models. This is also applied to the English models now, but it doesn't
+ # matter since the English models were not trained on any Chinese data
+ # and generally don't have any Chinese data in them (there are Chinese
+ # characters in the vocabulary because Wikipedia does have some Chinese
+ # words in the English Wikipedia.).
+ if self.tokenize_chinese_chars:
+ text = self._tokenize_chinese_chars(text)
+ # prevents treating the same character with different unicode codepoints as different characters
+ unicode_normalized_text = unicodedata.normalize("NFC", text)
+ orig_tokens = whitespace_tokenize(unicode_normalized_text)
+ split_tokens = []
+ for token in orig_tokens:
+ if token not in never_split:
+ if self.do_lower_case:
+ token = token.lower()
+ if self.strip_accents is not False:
+ token = self._run_strip_accents(token)
+ elif self.strip_accents:
+ token = self._run_strip_accents(token)
+ split_tokens.extend(self._run_split_on_punc(token, never_split))
+
+ output_tokens = whitespace_tokenize(" ".join(split_tokens))
+ return output_tokens
+
+ def _run_strip_accents(self, text):
+ """Strips accents from a piece of text."""
+ text = unicodedata.normalize("NFD", text)
+ output = []
+ for char in text:
+ cat = unicodedata.category(char)
+ if cat == "Mn":
+ continue
+ output.append(char)
+ return "".join(output)
+
+ def _run_split_on_punc(self, text, never_split=None):
+ """Splits punctuation on a piece of text."""
+ if not self.do_split_on_punc or (never_split is not None and text in never_split):
+ return [text]
+ chars = list(text)
+ i = 0
+ start_new_word = True
+ output = []
+ while i < len(chars):
+ char = chars[i]
+ if _is_punctuation(char):
+ output.append([char])
+ start_new_word = True
+ else:
+ if start_new_word:
+ output.append([])
+ start_new_word = False
+ output[-1].append(char)
+ i += 1
+
+ return ["".join(x) for x in output]
+
+ def _tokenize_chinese_chars(self, text):
+ """Adds whitespace around any CJK character."""
+ output = []
+ for char in text:
+ cp = ord(char)
+ if self._is_chinese_char(cp):
+ output.append(" ")
+ output.append(char)
+ output.append(" ")
+ else:
+ output.append(char)
+ return "".join(output)
+
+ def _is_chinese_char(self, cp):
+ """Checks whether CP is the codepoint of a CJK character."""
+ # This defines a "chinese character" as anything in the CJK Unicode block:
+ # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
+ #
+ # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
+ # despite its name. The modern Korean Hangul alphabet is a different block,
+ # as is Japanese Hiragana and Katakana. Those alphabets are used to write
+ # space-separated words, so they are not treated specially and handled
+ # like the all of the other languages.
+ if (
+ (cp >= 0x4E00 and cp <= 0x9FFF)
+ or (cp >= 0x3400 and cp <= 0x4DBF) #
+ or (cp >= 0x20000 and cp <= 0x2A6DF) #
+ or (cp >= 0x2A700 and cp <= 0x2B73F) #
+ or (cp >= 0x2B740 and cp <= 0x2B81F) #
+ or (cp >= 0x2B820 and cp <= 0x2CEAF) #
+ or (cp >= 0xF900 and cp <= 0xFAFF)
+ or (cp >= 0x2F800 and cp <= 0x2FA1F) #
+ ): #
+ return True
+
+ return False
+
+ def _clean_text(self, text):
+ """Performs invalid character removal and whitespace cleanup on text."""
+ output = []
+ for char in text:
+ cp = ord(char)
+ if cp == 0 or cp == 0xFFFD or _is_control(char):
+ continue
+ if _is_whitespace(char):
+ output.append(" ")
+ else:
+ output.append(char)
+ return "".join(output)
+
+
+class WordpieceTokenizer:
+ """Runs WordPiece tokenization."""
+
+ def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
+ self.vocab = vocab
+ self.unk_token = unk_token
+ self.max_input_chars_per_word = max_input_chars_per_word
+
+ def tokenize(self, text):
+ """
+ Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform
+ tokenization using the given vocabulary.
+
+ For example, `input = "unaffable"` wil return as output `["un", "##aff", "##able"]`.
+
+ Args:
+ text: A single token or whitespace separated tokens. This should have
+ already been passed through *BasicTokenizer*.
+
+ Returns:
+ A list of wordpiece tokens.
+ """
+
+ output_tokens = []
+ for token in whitespace_tokenize(text):
+ chars = list(token)
+ if len(chars) > self.max_input_chars_per_word:
+ output_tokens.append(self.unk_token)
+ continue
+
+ is_bad = False
+ start = 0
+ sub_tokens = []
+ while start < len(chars):
+ end = len(chars)
+ cur_substr = None
+ while start < end:
+ substr = "".join(chars[start:end])
+ if start > 0:
+ substr = "##" + substr
+ if substr in self.vocab:
+ cur_substr = substr
+ break
+ end -= 1
+ if cur_substr is None:
+ is_bad = True
+ break
+ sub_tokens.append(cur_substr)
+ start = end
+
+ if is_bad:
+ output_tokens.append(self.unk_token)
+ else:
+ output_tokens.extend(sub_tokens)
+ return output_tokens
+
+
+__all__ = ["RetriBertTokenizer"]
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/retribert/tokenization_retribert_fast.py b/docs/transformers/build/lib/transformers/models/deprecated/retribert/tokenization_retribert_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc51d0e2a1de9280e6b8635804eaf46f7859e5b9
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/retribert/tokenization_retribert_fast.py
@@ -0,0 +1,179 @@
+# coding=utf-8
+# Copyright 2018 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization classes for RetriBERT."""
+
+import json
+from typing import List, Optional, Tuple
+
+from tokenizers import normalizers
+
+from ....tokenization_utils_fast import PreTrainedTokenizerFast
+from ....utils import logging
+from .tokenization_retribert import RetriBertTokenizer
+
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.json"}
+
+
+class RetriBertTokenizerFast(PreTrainedTokenizerFast):
+ r"""
+ Construct a "fast" RetriBERT tokenizer (backed by HuggingFace's *tokenizers* library).
+
+ [`RetriBertTokenizerFast`] is identical to [`BertTokenizerFast`] and runs end-to-end tokenization: punctuation
+ splitting and wordpiece.
+
+ This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
+ refer to this superclass for more information regarding those methods.
+
+ Args:
+ vocab_file (`str`):
+ File containing the vocabulary.
+ do_lower_case (`bool`, *optional*, defaults to `True`):
+ Whether or not to lowercase the input when tokenizing.
+ unk_token (`str`, *optional*, defaults to `"[UNK]"`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ sep_token (`str`, *optional*, defaults to `"[SEP]"`):
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
+ sequence classification or for a text and a question for question answering. It is also used as the last
+ token of a sequence built with special tokens.
+ pad_token (`str`, *optional*, defaults to `"[PAD]"`):
+ The token used for padding, for example when batching sequences of different lengths.
+ cls_token (`str`, *optional*, defaults to `"[CLS]"`):
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
+ mask_token (`str`, *optional*, defaults to `"[MASK]"`):
+ The token used for masking values. This is the token used when training this model with masked language
+ modeling. This is the token which the model will try to predict.
+ clean_text (`bool`, *optional*, defaults to `True`):
+ Whether or not to clean the text before tokenization by removing any control characters and replacing all
+ whitespaces by the classic one.
+ tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
+ Whether or not to tokenize Chinese characters. This should likely be deactivated for Japanese (see [this
+ issue](https://github.com/huggingface/transformers/issues/328)).
+ strip_accents (`bool`, *optional*):
+ Whether or not to strip all accents. If this option is not specified, then it will be determined by the
+ value for `lowercase` (as in the original BERT).
+ wordpieces_prefix (`str`, *optional*, defaults to `"##"`):
+ The prefix for subwords.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ slow_tokenizer_class = RetriBertTokenizer
+ model_input_names = ["input_ids", "attention_mask"]
+
+ def __init__(
+ self,
+ vocab_file=None,
+ tokenizer_file=None,
+ do_lower_case=True,
+ unk_token="[UNK]",
+ sep_token="[SEP]",
+ pad_token="[PAD]",
+ cls_token="[CLS]",
+ mask_token="[MASK]",
+ tokenize_chinese_chars=True,
+ strip_accents=None,
+ **kwargs,
+ ):
+ super().__init__(
+ vocab_file,
+ tokenizer_file=tokenizer_file,
+ do_lower_case=do_lower_case,
+ unk_token=unk_token,
+ sep_token=sep_token,
+ pad_token=pad_token,
+ cls_token=cls_token,
+ mask_token=mask_token,
+ tokenize_chinese_chars=tokenize_chinese_chars,
+ strip_accents=strip_accents,
+ **kwargs,
+ )
+
+ normalizer_state = json.loads(self.backend_tokenizer.normalizer.__getstate__())
+ if (
+ normalizer_state.get("lowercase", do_lower_case) != do_lower_case
+ or normalizer_state.get("strip_accents", strip_accents) != strip_accents
+ or normalizer_state.get("handle_chinese_chars", tokenize_chinese_chars) != tokenize_chinese_chars
+ ):
+ normalizer_class = getattr(normalizers, normalizer_state.pop("type"))
+ normalizer_state["lowercase"] = do_lower_case
+ normalizer_state["strip_accents"] = strip_accents
+ normalizer_state["handle_chinese_chars"] = tokenize_chinese_chars
+ self.backend_tokenizer.normalizer = normalizer_class(**normalizer_state)
+
+ self.do_lower_case = do_lower_case
+
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
+ """
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
+ adding special tokens. A BERT sequence has the following format:
+
+ - single sequence: `[CLS] X [SEP]`
+ - pair of sequences: `[CLS] A [SEP] B [SEP]`
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs to which the special tokens will be added.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
+ """
+ output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
+
+ if token_ids_1 is not None:
+ output += token_ids_1 + [self.sep_token_id]
+
+ return output
+
+ def create_token_type_ids_from_sequences(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence
+ pair mask has the following format:
+
+ ```
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
+ | first sequence | second sequence |
+ ```
+
+ If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s).
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
+ """
+ sep = [self.sep_token_id]
+ cls = [self.cls_token_id]
+ if token_ids_1 is None:
+ return len(cls + token_ids_0 + sep) * [0]
+ return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
+
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
+ files = self._tokenizer.model.save(save_directory, name=filename_prefix)
+ return tuple(files)
+
+
+__all__ = ["RetriBertTokenizerFast"]
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/speech_to_text_2/__init__.py b/docs/transformers/build/lib/transformers/models/deprecated/speech_to_text_2/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..78c549b6e294e1ca97a718ef601472d4d400e12a
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/speech_to_text_2/__init__.py
@@ -0,0 +1,29 @@
+# Copyright 2021 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ....utils import _LazyModule
+from ....utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_speech_to_text_2 import *
+ from .modeling_speech_to_text_2 import *
+ from .processing_speech_to_text_2 import *
+ from .tokenization_speech_to_text_2 import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/speech_to_text_2/configuration_speech_to_text_2.py b/docs/transformers/build/lib/transformers/models/deprecated/speech_to_text_2/configuration_speech_to_text_2.py
new file mode 100644
index 0000000000000000000000000000000000000000..2afd79feb28dd99c6ca94d2483b8f88b8cdc0a0e
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/speech_to_text_2/configuration_speech_to_text_2.py
@@ -0,0 +1,134 @@
+# coding=utf-8
+# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Speech2Text model configuration"""
+
+from ....configuration_utils import PretrainedConfig
+from ....utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class Speech2Text2Config(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Speech2Text2ForCausalLM`]. It is used to
+ instantiate an Speech2Text2 model according to the specified arguments, defining the model architecture.
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the Speech2Text2
+ [facebook/s2t-wav2vec2-large-en-de](https://huggingface.co/facebook/s2t-wav2vec2-large-en-de) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 50265):
+ Vocabulary size of the Speech2Text model. Defines the number of different tokens that can be represented by
+ the `inputs_ids` passed when calling [`Speech2TextModel`]
+ d_model (`int`, *optional*, defaults to 1024):
+ Dimensionality of the layers and the pooler layer.
+ decoder_layers (`int`, *optional*, defaults to 12):
+ Number of decoder layers.
+ decoder_attention_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ decoder_ffn_dim (`int`, *optional*, defaults to 4096):
+ Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
+ activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the pooler. If string, `"gelu"`, `"relu"`,
+ `"silu"` and `"gelu_new"` are supported.
+ dropout (`float`, *optional*, defaults to 0.1):
+ The dropout probability for all fully connected layers in the embeddings, and pooler.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ activation_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for activations inside the fully connected layer.
+ init_std (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ https://arxiv.org/abs/1909.11556>`__ for more details.
+ decoder_layerdrop (`float`, *optional*, defaults to 0.0):
+ The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
+ for more details.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models).
+ max_target_positions (`int`, *optional*, defaults to 1024):
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
+ just in case (e.g., 512 or 1024 or 2048).
+
+ Example:
+
+ ```python
+ >>> from transformers import Speech2Text2Config, Speech2Text2ForCausalLM
+
+ >>> # Initializing a Speech2Text2 s2t_transformer_s style configuration
+ >>> configuration = Speech2Text2Config()
+
+ >>> # Initializing a model (with random weights) from the s2t_transformer_s style configuration
+ >>> model = Speech2Text2ForCausalLM(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "speech_to_text_2"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ attribute_map = {"num_attention_heads": "decoder_attention_heads", "hidden_size": "d_model"}
+
+ def __init__(
+ self,
+ vocab_size=10000,
+ decoder_layers=6,
+ decoder_ffn_dim=2048,
+ decoder_attention_heads=4,
+ decoder_layerdrop=0.0,
+ use_cache=True,
+ activation_function="relu",
+ d_model=256,
+ dropout=0.1,
+ attention_dropout=0.0,
+ activation_dropout=0.0,
+ init_std=0.02,
+ decoder_start_token_id=2,
+ scale_embedding=True,
+ pad_token_id=1,
+ bos_token_id=0,
+ eos_token_id=2,
+ max_target_positions=1024,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.d_model = d_model
+ self.decoder_ffn_dim = decoder_ffn_dim
+ self.decoder_layers = decoder_layers
+ self.decoder_attention_heads = decoder_attention_heads
+ self.dropout = dropout
+ self.attention_dropout = attention_dropout
+ self.activation_dropout = activation_dropout
+ self.activation_function = activation_function
+ self.init_std = init_std
+ self.decoder_layerdrop = decoder_layerdrop
+ self.use_cache = use_cache
+ self.num_hidden_layers = decoder_layers
+ self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
+ self.max_target_positions = max_target_positions
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ decoder_start_token_id=decoder_start_token_id,
+ **kwargs,
+ )
+
+
+__all__ = ["Speech2Text2Config"]
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/speech_to_text_2/modeling_speech_to_text_2.py b/docs/transformers/build/lib/transformers/models/deprecated/speech_to_text_2/modeling_speech_to_text_2.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f1dd18d97ff61015759b99587b846f4825c427c
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/speech_to_text_2/modeling_speech_to_text_2.py
@@ -0,0 +1,930 @@
+# coding=utf-8
+# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Speech2Text2 model."""
+
+import copy
+import math
+from typing import Optional, Tuple, Union
+
+import torch
+from torch import nn
+from torch.nn import CrossEntropyLoss
+
+from ....activations import ACT2FN
+from ....modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
+from ....modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
+from ....modeling_utils import PreTrainedModel
+from ....utils import add_start_docstrings, logging, replace_return_docstrings
+from .configuration_speech_to_text_2 import Speech2Text2Config
+
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "Speech2Text2Config"
+_CHECKPOINT_FOR_DOC = "facebook/s2t-wav2vec2-large-en-de"
+
+
+class Speech2Text2SinusoidalPositionalEmbedding(nn.Module):
+ """This module produces sinusoidal positional embeddings of any length."""
+
+ def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None):
+ super().__init__()
+ self.offset = 2
+ self.embedding_dim = embedding_dim
+ self.padding_idx = padding_idx
+
+ def make_weights(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
+ emb_weights = self.get_embedding(num_embeddings, embedding_dim, padding_idx)
+ if hasattr(self, "weights"):
+ # in forward put the weights on the correct dtype and device of the param
+ emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device)
+
+ self.weights = nn.Parameter(emb_weights)
+ self.weights.requires_grad = False
+ self.weights.detach_()
+
+ @staticmethod
+ def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
+ """
+ Build sinusoidal embeddings. This matches the implementation in tensor2tensor, but differs slightly from the
+ description in Section 3.5 of "Attention Is All You Need".
+ """
+ half_dim = embedding_dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb)
+ emb = torch.arange(num_embeddings, dtype=torch.int64).float().unsqueeze(1) * emb.unsqueeze(0)
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
+ if embedding_dim % 2 == 1:
+ # zero pad
+ emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
+ if padding_idx is not None:
+ emb[padding_idx, :] = 0
+ return emb.to(torch.get_default_dtype())
+
+ @torch.no_grad()
+ def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
+ bsz, seq_len = input_ids.size()
+ # Create the position ids from the input token ids. Any padded tokens remain padded.
+ position_ids = self.create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length).to(
+ input_ids.device
+ )
+
+ # expand embeddings if needed
+ max_pos = self.padding_idx + 1 + seq_len
+ if max_pos > self.weights.size(0):
+ self.make_weights(max_pos + self.offset, self.embedding_dim, self.padding_idx)
+
+ return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, -1).detach()
+
+ def create_position_ids_from_input_ids(
+ self, input_ids: torch.Tensor, padding_idx: int, past_key_values_length: Optional[int] = 0
+ ):
+ """
+ Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding
+ symbols are ignored. This is modified from fairseq's `utils.make_positions`.
+
+ Args:
+ x: torch.Tensor x:
+ Returns: torch.Tensor
+ """
+ # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
+ mask = input_ids.ne(padding_idx).int()
+ incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
+ return incremental_indices.long() + padding_idx
+
+
+class Speech2Text2Attention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ dropout: float = 0.0,
+ is_decoder: bool = False,
+ bias: bool = True,
+ is_causal: bool = False,
+ config: Optional[Speech2Text2Config] = None,
+ ):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.dropout = dropout
+ self.head_dim = embed_dim // num_heads
+ self.config = config
+
+ if (self.head_dim * num_heads) != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
+ f" and `num_heads`: {num_heads})."
+ )
+ self.scaling = self.head_dim**-0.5
+ self.is_decoder = is_decoder
+ self.is_causal = is_causal
+
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ key_value_states: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+
+ bsz, tgt_len, _ = hidden_states.size()
+
+ # get query proj
+ query_states = self.q_proj(hidden_states) * self.scaling
+ # get key, value proj
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
+ # the provided `key_value_states` to support prefix tuning
+ if (
+ is_cross_attention
+ and past_key_value is not None
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
+ ):
+ # reuse k,v, cross_attentions
+ key_states = past_key_value[0]
+ value_states = past_key_value[1]
+ elif is_cross_attention:
+ # cross_attentions
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
+ elif past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+ else:
+ # self_attention
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+
+ if self.is_decoder:
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_states, value_states)
+
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
+ key_states = key_states.reshape(*proj_shape)
+ value_states = value_states.reshape(*proj_shape)
+
+ src_len = key_states.size(1)
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
+
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
+ )
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ if layer_head_mask is not None:
+ if layer_head_mask.size() != (self.num_heads,):
+ raise ValueError(
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
+ f" {layer_head_mask.size()}"
+ )
+ attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ if output_attentions:
+ # this operation is a bit awkward, but it's required to
+ # make sure that attn_weights keeps its gradient.
+ # In order to do so, attn_weights have to be reshaped
+ # twice and have to be reused in the following
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
+ else:
+ attn_weights_reshaped = None
+
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
+
+ attn_output = torch.bmm(attn_probs, value_states)
+
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
+ attn_output = attn_output.transpose(1, 2)
+
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
+ # partitioned across GPUs when using tensor-parallelism.
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
+
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, attn_weights_reshaped, past_key_value
+
+
+class Speech2Text2DecoderLayer(nn.Module):
+ def __init__(self, config: Speech2Text2Config):
+ super().__init__()
+ self.embed_dim = config.d_model
+
+ self.self_attn = Speech2Text2Attention(
+ embed_dim=self.embed_dim,
+ num_heads=config.decoder_attention_heads,
+ dropout=config.attention_dropout,
+ is_decoder=True,
+ )
+ self.dropout = config.dropout
+ self.activation_fn = ACT2FN[config.activation_function]
+ self.activation_dropout = config.activation_dropout
+
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+
+ if config.is_decoder:
+ self.encoder_attn = Speech2Text2Attention(
+ self.embed_dim,
+ config.decoder_attention_heads,
+ dropout=config.attention_dropout,
+ is_decoder=True,
+ )
+ self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+
+ self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
+ self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = True,
+ ):
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ encoder_hidden_states (`torch.FloatTensor`):
+ cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
+ encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
+ `(encoder_attention_heads,)`.
+ cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
+ size *(decoder_attention_heads,)*.
+ past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ """
+ residual = hidden_states
+
+ # Self Attention
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
+ # add present self-attn cache to positions 1,2 of present_key_value tuple
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ past_key_value=self_attn_past_key_value,
+ attention_mask=attention_mask,
+ layer_head_mask=layer_head_mask,
+ output_attentions=output_attentions,
+ )
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+
+ # Cross-Attention Block
+ cross_attn_present_key_value = None
+ cross_attn_weights = None
+ if encoder_hidden_states is not None:
+ residual = hidden_states
+
+ # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
+ hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
+ hidden_states=hidden_states,
+ key_value_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ layer_head_mask=cross_attn_layer_head_mask,
+ past_key_value=cross_attn_past_key_value,
+ output_attentions=output_attentions,
+ )
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
+
+ # add cross-attn to positions 3,4 of present_key_value tuple
+ present_key_value = present_key_value + cross_attn_present_key_value
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
+ hidden_states = self.fc2(hidden_states)
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+ hidden_states = self.final_layer_norm(hidden_states)
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights, cross_attn_weights)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+
+class Speech2Text2PreTrainedModel(PreTrainedModel):
+ config_class = Speech2Text2Config
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+
+ def _init_weights(self, module):
+ std = self.config.init_std
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, Speech2Text2SinusoidalPositionalEmbedding):
+ weight = module.get_embedding(*module.weight.shape, module.padding_idx)
+ weight = nn.Parameter(weight, requires_grad=False)
+ weight.detach_()
+ module.weight = weight
+
+
+SPEECH_TO_TEXT_2_START_DOCSTRING = r"""
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`Speech2Text2Config`]):
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
+ load the weights associated with the model, only the configuration. Check out the
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+class Speech2Text2Decoder(Speech2Text2PreTrainedModel):
+ """
+ Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`Speech2Text2DecoderLayer`]
+
+ Args:
+ config: Speech2Text2Config
+ embed_tokens (nn.Embedding): output embedding
+ """
+
+ def __init__(self, config: Speech2Text2Config):
+ super().__init__(config)
+ self.dropout = config.dropout
+ self.layerdrop = config.decoder_layerdrop
+ self.padding_idx = config.pad_token_id
+ self.max_target_positions = config.max_target_positions
+ self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
+
+ self.embed_positions = Speech2Text2SinusoidalPositionalEmbedding(
+ self.max_target_positions,
+ config.d_model,
+ self.padding_idx,
+ )
+
+ self.layers = nn.ModuleList([Speech2Text2DecoderLayer(config) for _ in range(config.decoder_layers)])
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ head_mask=None,
+ cross_attn_head_mask=None,
+ past_key_values=None,
+ inputs_embeds=None,
+ use_cache=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
+ provide it.
+
+ Indices can be obtained using [`Speech2Text2Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
+ of the decoder.
+ encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
+ Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
+ selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
+ on hidden heads. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
+ shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
+ cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
+ that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+ than the model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+
+ # past_key_values_length
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
+
+ attention_mask = _prepare_4d_causal_attention_mask(
+ attention_mask, input_shape, inputs_embeds, past_key_values_length
+ )
+
+ # expand encoder attention mask
+ if encoder_hidden_states is not None and encoder_attention_mask is not None:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _prepare_4d_attention_mask(
+ encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
+ )
+
+ # embed positions
+ positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length)
+
+ hidden_states = inputs_embeds + positions
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`..."
+ )
+ use_cache = False
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
+ next_decoder_cache = () if use_cache else None
+
+ # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
+ for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
+ if attn_mask is not None:
+ if attn_mask.size()[0] != (len(self.layers)):
+ raise ValueError(
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
+ f" {head_mask.size()[0]}."
+ )
+ for idx, decoder_layer in enumerate(self.layers):
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+ if self.training:
+ dropout_probability = torch.rand([])
+ if dropout_probability < self.layerdrop:
+ continue
+
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ decoder_layer.__call__,
+ hidden_states,
+ attention_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ head_mask[idx] if head_mask is not None else None,
+ cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
+ None,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
+ cross_attn_layer_head_mask=(
+ cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
+ ),
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ if encoder_hidden_states is not None:
+ all_cross_attentions += (layer_outputs[2],)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = next_decoder_cache if use_cache else None
+ if not return_dict:
+ return tuple(
+ v
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
+ if v is not None
+ )
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+@add_start_docstrings(
+ "The Speech2Text2 Model with a language modeling head. Can be used for summarization.",
+ SPEECH_TO_TEXT_2_START_DOCSTRING,
+)
+class Speech2Text2DecoderWrapper(Speech2Text2PreTrainedModel):
+ """
+ This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
+ used in combination with the [`EncoderDecoderModel`] framework.
+ """
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.decoder = Speech2Text2Decoder(config)
+
+ def forward(self, *args, **kwargs):
+ return self.decoder(*args, **kwargs)
+
+
+@add_start_docstrings(
+ "The Speech2Text2 Decoder with a language modeling head. Can be used as the decoder part of"
+ " [`EncoderDecoderModel`] and [`SpeechEncoderDecoder`].",
+ SPEECH_TO_TEXT_2_START_DOCSTRING,
+)
+class Speech2Text2ForCausalLM(Speech2Text2PreTrainedModel):
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config):
+ config = copy.deepcopy(config)
+ config.is_decoder = True
+ config.is_encoder_decoder = False
+ super().__init__(config)
+ self.model = Speech2Text2DecoderWrapper(config)
+
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.decoder.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.decoder.embed_tokens = value
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model.decoder = decoder
+
+ def get_decoder(self):
+ return self.model.decoder
+
+ @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.FloatTensor], CausalLMOutputWithCrossAttentions]:
+ r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
+ provide it.
+
+ Indices can be obtained using [`Speech2Text2Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
+ if the model is configured as a decoder.
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
+ in the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
+ head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
+ shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional
+ tensors are only required when the model is used as a decoder in a Sequence to Sequence model.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
+ cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
+ that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import (
+ ... SpeechEncoderDecoderModel,
+ ... Speech2Text2ForCausalLM,
+ ... Wav2Vec2Model,
+ ... Speech2Text2Config,
+ ... Wav2Vec2Config,
+ ... Wav2Vec2FeatureExtractor,
+ ... Speech2Text2Tokenizer,
+ ... )
+ >>> from datasets import load_dataset
+
+ >>> feature_extractor = Wav2Vec2FeatureExtractor()
+ >>> tokenizer = Speech2Text2Tokenizer.from_pretrained("facebook/s2t-wav2vec2-large-en-de")
+
+ >>> encoder = Wav2Vec2Model(Wav2Vec2Config())
+ >>> decoder = Speech2Text2ForCausalLM(Speech2Text2Config())
+ >>> # init random speech2text model
+
+ >>> model = SpeechEncoderDecoderModel(encoder=encoder, decoder=decoder)
+ >>> model.config.pad_token_id = tokenizer.pad_token_id
+ >>> model.config.decoder_start_token_id = tokenizer.bos_token_id
+ >>> # pre-process inputs and labels
+
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
+ >>> inputs = feature_extractor(
+ ... ds[0]["audio"]["array"], sampling_rate=ds[0]["audio"]["sampling_rate"], return_tensors="pt"
+ ... )
+ >>> input_values = inputs.input_values
+ >>> decoder_input_ids = tokenizer(ds[0]["text"], return_tensors="pt").input_ids
+ >>> # compute loss
+
+ >>> loss = model(inputs=input_values, labels=decoder_input_ids).loss
+ >>> # backprop loss
+
+ >>> loss.backward() # doctest: +IGNORE_RESULT
+ ```"""
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model.decoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ head_mask=head_mask,
+ cross_attn_head_mask=cross_attn_head_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ logits = self.lm_head(outputs[0])
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return CausalLMOutputWithCrossAttentions(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ cross_attentions=outputs.cross_attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs
+ ):
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
+ if attention_mask is None:
+ attention_mask = input_ids.new_ones(input_ids.shape)
+
+ if past_key_values:
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = input_ids.shape[1] - 1
+
+ input_ids = input_ids[:, remove_prefix_length:]
+ # first step, decoder_cached_states are empty
+ return {
+ "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
+ "attention_mask": attention_mask,
+ "past_key_values": past_key_values,
+ "use_cache": use_cache,
+ }
+
+ @staticmethod
+ def _reorder_cache(past_key_values, beam_idx):
+ reordered_past = ()
+ for layer_past in past_key_values:
+ reordered_past += (
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
+ )
+ return reordered_past
+
+
+__all__ = ["Speech2Text2ForCausalLM", "Speech2Text2PreTrainedModel"]
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/speech_to_text_2/processing_speech_to_text_2.py b/docs/transformers/build/lib/transformers/models/deprecated/speech_to_text_2/processing_speech_to_text_2.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b8edaa46d10b8ca04f68d530733fcb7cb390ee8
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/speech_to_text_2/processing_speech_to_text_2.py
@@ -0,0 +1,119 @@
+# coding=utf-8
+# Copyright 2021 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Speech processor class for Speech2Text2
+"""
+
+import warnings
+from contextlib import contextmanager
+
+from ....processing_utils import ProcessorMixin
+
+
+class Speech2Text2Processor(ProcessorMixin):
+ r"""
+ Constructs a Speech2Text2 processor which wraps a Speech2Text2 feature extractor and a Speech2Text2 tokenizer into
+ a single processor.
+
+ [`Speech2Text2Processor`] offers all the functionalities of [`AutoFeatureExtractor`] and [`Speech2Text2Tokenizer`].
+ See the [`~Speech2Text2Processor.__call__`] and [`~Speech2Text2Processor.decode`] for more information.
+
+ Args:
+ feature_extractor (`AutoFeatureExtractor`):
+ An instance of [`AutoFeatureExtractor`]. The feature extractor is a required input.
+ tokenizer (`Speech2Text2Tokenizer`):
+ An instance of [`Speech2Text2Tokenizer`]. The tokenizer is a required input.
+ """
+
+ feature_extractor_class = "AutoFeatureExtractor"
+ tokenizer_class = "Speech2Text2Tokenizer"
+
+ def __init__(self, feature_extractor, tokenizer):
+ super().__init__(feature_extractor, tokenizer)
+ self.current_processor = self.feature_extractor
+ self._in_target_context_manager = False
+
+ def __call__(self, *args, **kwargs):
+ """
+ When used in normal mode, this method forwards all its arguments to AutoFeatureExtractor's
+ [`~AutoFeatureExtractor.__call__`] and returns its output. If used in the context
+ [`~Speech2Text2Processor.as_target_processor`] this method forwards all its arguments to
+ Speech2Text2Tokenizer's [`~Speech2Text2Tokenizer.__call__`]. Please refer to the docstring of the above two
+ methods for more information.
+ """
+ # For backward compatibility
+ if self._in_target_context_manager:
+ return self.current_processor(*args, **kwargs)
+
+ if "raw_speech" in kwargs:
+ warnings.warn("Using `raw_speech` as a keyword argument is deprecated. Use `audio` instead.")
+ audio = kwargs.pop("raw_speech")
+ else:
+ audio = kwargs.pop("audio", None)
+ sampling_rate = kwargs.pop("sampling_rate", None)
+ text = kwargs.pop("text", None)
+ if len(args) > 0:
+ audio = args[0]
+ args = args[1:]
+
+ if audio is None and text is None:
+ raise ValueError("You need to specify either an `audio` or `text` input to process.")
+
+ if audio is not None:
+ inputs = self.feature_extractor(audio, *args, sampling_rate=sampling_rate, **kwargs)
+ if text is not None:
+ encodings = self.tokenizer(text, **kwargs)
+
+ if text is None:
+ return inputs
+ elif audio is None:
+ return encodings
+ else:
+ inputs["labels"] = encodings["input_ids"]
+ return inputs
+
+ def batch_decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to Speech2Text2Tokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
+ refer to the docstring of this method for more information.
+ """
+ return self.tokenizer.batch_decode(*args, **kwargs)
+
+ def decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to Speech2Text2Tokenizer's [`~PreTrainedTokenizer.decode`]. Please refer
+ to the docstring of this method for more information.
+ """
+ return self.tokenizer.decode(*args, **kwargs)
+
+ @contextmanager
+ def as_target_processor(self):
+ """
+ Temporarily sets the tokenizer for processing the input. Useful for encoding the labels when fine-tuning
+ Speech2Text2.
+ """
+ warnings.warn(
+ "`as_target_processor` is deprecated and will be removed in v5 of Transformers. You can process your "
+ "labels by using the argument `text` of the regular `__call__` method (either in the same call as "
+ "your audio inputs, or in a separate call."
+ )
+ self._in_target_context_manager = True
+ self.current_processor = self.tokenizer
+ yield
+ self.current_processor = self.feature_extractor
+ self._in_target_context_manager = False
+
+
+__all__ = ["Speech2Text2Processor"]
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/speech_to_text_2/tokenization_speech_to_text_2.py b/docs/transformers/build/lib/transformers/models/deprecated/speech_to_text_2/tokenization_speech_to_text_2.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5aa7ef8067c57d2fdae3b10c9ad3b01c3e23299
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/speech_to_text_2/tokenization_speech_to_text_2.py
@@ -0,0 +1,252 @@
+# coding=utf-8
+# Copyright 2021 The Facebook Inc. and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization class for Speech2Text2."""
+
+import json
+import os
+from typing import Dict, List, Optional, Tuple
+
+from ....tokenization_utils import PreTrainedTokenizer
+from ....utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+VOCAB_FILES_NAMES = {
+ "vocab_file": "vocab.json",
+ "tokenizer_config_file": "tokenizer_config.json",
+ "merges_file": "merges.txt",
+}
+
+
+BPE_TOKEN_MERGES = ""
+BPE_TOKEN_VOCAB = "@@ "
+
+
+def get_pairs(word):
+ """
+ Return set of symbol pairs in a word. word is represented as tuple of symbols (symbols being variable-length
+ strings)
+ """
+ pairs = set()
+ prev_char = word[0]
+ for char in word[1:]:
+ pairs.add((prev_char, char))
+ prev_char = char
+ return pairs
+
+
+# Speech2Text2 has no max input length
+
+
+class Speech2Text2Tokenizer(PreTrainedTokenizer):
+ """
+ Constructs a Speech2Text2Tokenizer.
+
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains some of the main methods. Users should refer to
+ the superclass for more information regarding such methods.
+
+ Args:
+ vocab_file (`str`):
+ File containing the vocabulary.
+ bos_token (`str`, *optional*, defaults to `""`):
+ The beginning of sentence token.
+ eos_token (`str`, *optional*, defaults to `""`):
+ The end of sentence token.
+ unk_token (`str`, *optional*, defaults to `""`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ pad_token (`str`, *optional*, defaults to `""`):
+ The token used for padding, for example when batching sequences of different lengths.
+
+ **kwargs
+ Additional keyword arguments passed along to [`PreTrainedTokenizer`]
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ model_input_names = ["input_ids", "attention_mask"]
+
+ def __init__(
+ self,
+ vocab_file,
+ bos_token="",
+ pad_token="",
+ eos_token="",
+ unk_token="",
+ do_lower_case=False,
+ merges_file=None,
+ **kwargs,
+ ):
+ self.do_lower_case = do_lower_case
+
+ with open(vocab_file, encoding="utf-8") as vocab_handle:
+ self.encoder = json.load(vocab_handle)
+ self.decoder = {v: k for k, v in self.encoder.items()}
+
+ if merges_file is None:
+ logger.info(f"No merges files provided. {self.__class__.__name__} can only be used for decoding.")
+
+ self.bpe_ranks = None
+ self.cache = None
+ else:
+ with open(merges_file, encoding="utf-8") as merges_handle:
+ merges = merges_handle.read().split("\n")[:-1]
+
+ merges = [tuple(merge.split()[:2]) for merge in merges]
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
+ self.cache = {}
+ super().__init__(
+ unk_token=unk_token,
+ bos_token=bos_token,
+ eos_token=eos_token,
+ pad_token=pad_token,
+ do_lower_case=do_lower_case,
+ **kwargs,
+ )
+
+ @property
+ def vocab_size(self) -> int:
+ return len(self.decoder)
+
+ def get_vocab(self) -> Dict:
+ return dict(self.encoder, **self.added_tokens_encoder)
+
+ def bpe(self, token):
+ word = tuple(token[:-1]) + (token[-1] + BPE_TOKEN_MERGES,)
+ if token in self.cache:
+ return self.cache[token]
+ pairs = get_pairs(word)
+
+ if not pairs:
+ return token
+
+ while True:
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
+ if bigram not in self.bpe_ranks:
+ break
+ first, second = bigram
+ new_word = []
+ i = 0
+ while i < len(word):
+ try:
+ j = word.index(first, i)
+ except ValueError:
+ new_word.extend(word[i:])
+ break
+ else:
+ new_word.extend(word[i:j])
+ i = j
+
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
+ new_word.append(first + second)
+ i += 2
+ else:
+ new_word.append(word[i])
+ i += 1
+ new_word = tuple(new_word)
+ word = new_word
+ if len(word) == 1:
+ break
+ else:
+ pairs = get_pairs(word)
+ word = " ".join(word)
+ if word == "\n " + BPE_TOKEN_MERGES:
+ word = "\n" + BPE_TOKEN_MERGES
+
+ if word.endswith(BPE_TOKEN_MERGES):
+ word = word.replace(BPE_TOKEN_MERGES, "")
+
+ word = word.replace(" ", BPE_TOKEN_VOCAB)
+ self.cache[token] = word
+ return word
+
+ def _tokenize(self, text):
+ """Tokenize a string."""
+
+ if self.bpe_ranks is None:
+ raise ValueError(
+ "This tokenizer was instantiated without a `merges.txt` file, so"
+ " that it can only be used for decoding, not for encoding. "
+ "Make sure to provide `merges.txt` file at instantiation to enable "
+ "encoding."
+ )
+
+ if self.do_lower_case:
+ text = text.lower()
+
+ text = text.split()
+
+ split_tokens = []
+ for token in text:
+ if token:
+ split_tokens.extend(list(self.bpe(token).split(" ")))
+
+ return split_tokens
+
+ def _convert_token_to_id(self, token: str) -> int:
+ """Converts a token (str) in an index (integer) using the vocab."""
+ return self.encoder.get(token, self.encoder.get(self.unk_token))
+
+ def _convert_id_to_token(self, index: int) -> str:
+ """Converts an index (integer) in a token (str) using the vocab."""
+ result = self.decoder.get(index, self.unk_token)
+ return result
+
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
+ """
+ Converts a list of output tokens into a single string.
+ """
+ # combine tokens
+ string = " ".join(tokens)
+
+ # make sure @@ tokens are concatenated
+ string = "".join(string.split(BPE_TOKEN_VOCAB))
+
+ return string
+
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
+ if not os.path.isdir(save_directory):
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
+ return
+ vocab_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
+ )
+ merges_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
+ )
+
+ with open(vocab_file, "w", encoding="utf-8") as f:
+ f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
+
+ index = 0
+ if self.bpe_ranks is None:
+ return (vocab_file,)
+
+ with open(merges_file, "w", encoding="utf-8") as writer:
+ for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
+ if index != token_index:
+ logger.warning(
+ f"Saving vocabulary to {merges_file}: BPE merge indices are not consecutive."
+ " Please check that the tokenizer is not corrupted!"
+ )
+ index = token_index
+ writer.write(" ".join(bpe_tokens) + "\n")
+ index += 1
+
+ return (vocab_file, merges_file)
+
+
+__all__ = ["Speech2Text2Tokenizer"]
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/tapex/__init__.py b/docs/transformers/build/lib/transformers/models/deprecated/tapex/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b535eb1df2c40a7680bbf8d57fc70b78e23437f4
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/tapex/__init__.py
@@ -0,0 +1,26 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ....utils import _LazyModule
+from ....utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .tokenization_tapex import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/tapex/tokenization_tapex.py b/docs/transformers/build/lib/transformers/models/deprecated/tapex/tokenization_tapex.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d554872c48df0c4f8ae9f4f63e98affca511688
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/tapex/tokenization_tapex.py
@@ -0,0 +1,1470 @@
+# coding=utf-8
+# Copyright 2022 Microsoft Research and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization classes for TAPEX."""
+
+import json
+import os
+import random
+from functools import lru_cache
+from typing import Dict, List, Optional, Tuple, Union
+
+import regex as re
+
+from ....file_utils import ExplicitEnum, PaddingStrategy, TensorType, add_end_docstrings, is_pandas_available
+from ....tokenization_utils import AddedToken, PreTrainedTokenizer
+from ....tokenization_utils_base import ENCODE_KWARGS_DOCSTRING, BatchEncoding, TextInput, TruncationStrategy
+from ....utils import logging
+
+
+if is_pandas_available():
+ import pandas as pd
+
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt"}
+
+
+class TapexTruncationStrategy(ExplicitEnum):
+ """
+ Possible values for the `truncation` argument in [`~TapasTokenizer.__call__`]. Useful for tab-completion in an IDE.
+ """
+
+ DROP_ROWS_TO_FIT = "drop_rows_to_fit"
+
+
+TAPEX_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING = r"""
+ add_special_tokens (`bool`, *optional*, defaults to `True`):
+ Whether or not to encode the sequences with the special tokens relative to their model.
+ padding (`bool`, `str` or [`~file_utils.PaddingStrategy`], *optional*, defaults to `False`):
+ Activates and controls padding. Accepts the following values:
+
+ - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
+ sequence if provided).
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
+ acceptable input length for the model if that argument is not provided.
+ - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
+ lengths).
+ truncation (`bool`, `str`, [`TapexTruncationStrategy`] or [`~tokenization_utils_base.TruncationStrategy`],
+ *optional*, defaults to `False`):
+
+ Activates and controls truncation. Accepts the following values:
+
+ - `'drop_rows_to_fit'`: Truncate to a maximum length specified with the argument `max_length` or to the
+ maximum acceptable input length for the model if that argument is not provided. This will truncate
+ row by row, removing rows from the table.
+ - `True` or `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or
+ to the maximum acceptable input length for the model if that argument is not provided. This will
+ truncate token by token, removing a token from the longest sequence in the pair if a pair of
+ sequences (or a batch of pairs) is provided.
+ - `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the
+ maximum acceptable input length for the model if that argument is not provided. This will only
+ truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
+ - `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the
+ maximum acceptable input length for the model if that argument is not provided. This will only
+ truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
+ - `False` or `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths
+ greater than the model maximum admissible input size).
+ max_length (`int`, *optional*):
+ Controls the maximum length to use by one of the truncation/padding parameters. If left unset or set to
+ `None`, this will use the predefined model maximum length if a maximum length is required by one of the
+ truncation/padding parameters. If the model has no specific maximum input length (like XLNet)
+ truncation/padding to a maximum length will be deactivated.
+ stride (`int`, *optional*, defaults to 0):
+ If set to a number along with `max_length`, the overflowing tokens returned when
+ `return_overflowing_tokens=True` will contain some tokens from the end of the truncated sequence
+ returned to provide some overlap between truncated and overflowing sequences. The value of this
+ argument defines the number of overlapping tokens.
+ pad_to_multiple_of (`int`, *optional*):
+ If set will pad the sequence to a multiple of the provided value. This is especially useful to enable
+ the use of Tensor Cores on NVIDIA hardware with compute capability `>= 7.5` (Volta).
+ return_tensors (`str` or [`~file_utils.TensorType`], *optional*):
+ If set, will return tensors instead of list of python integers. Acceptable values are:
+
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return Numpy `np.ndarray` objects.
+"""
+
+
+@lru_cache()
+def bytes_to_unicode():
+ """
+ Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
+ characters the bpe code barfs on. The reversible bpe codes work on unicode strings. This means you need a large #
+ of unicode characters in your vocab if you want to avoid UNKs. When you're at something like a 10B token dataset
+ you end up needing around 5K for decent coverage. This is a significant percentage of your normal, say, 32K bpe
+ vocab. To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
+ """
+ bs = (
+ list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
+ )
+ cs = bs[:]
+ n = 0
+ for b in range(2**8):
+ if b not in bs:
+ bs.append(b)
+ cs.append(2**8 + n)
+ n += 1
+ cs = [chr(n) for n in cs]
+ return dict(zip(bs, cs))
+
+
+def get_pairs(word):
+ """
+ Return set of symbol pairs in a word. Word is represented as tuple of symbols (symbols being variable-length
+ strings).
+ """
+ pairs = set()
+ prev_char = word[0]
+ for char in word[1:]:
+ pairs.add((prev_char, char))
+ prev_char = char
+ return pairs
+
+
+class IndexedRowTableLinearize:
+ """
+ FORMAT: col: col1 | col2 | col 3 row 1 : val1 | val2 | val3 row 2 : ...
+ """
+
+ def process_table(self, table_content: Dict):
+ """
+ Given a table, TableLinearize aims at converting it into a flatten sequence with special symbols.
+ """
+ assert "header" in table_content and "rows" in table_content, self.PROMPT_MESSAGE
+ # process header
+ table_str = self.process_header(table_content["header"]) + " "
+ # process rows
+ for i, row_example in enumerate(table_content["rows"]):
+ # NOTE: the row should start from row 1 instead of 0
+ table_str += self.process_row(row_example, row_index=i + 1) + " "
+ return table_str.strip()
+
+ def process_header(self, headers: List):
+ """
+ Given a list of headers, TableLinearize aims at converting it into a flatten sequence with special symbols.
+ """
+ return "col : " + " | ".join(headers)
+
+ def process_row(self, row: List, row_index: int):
+ """
+ Given a row, TableLinearize aims at converting it into a flatten sequence with special symbols.
+ """
+ row_str = ""
+ row_cell_values = []
+ for cell_value in row:
+ if isinstance(cell_value, int):
+ row_cell_values.append(str(cell_value))
+ else:
+ row_cell_values.append(cell_value)
+ row_str += " | ".join(row_cell_values)
+ return "row " + str(row_index) + " : " + row_str
+
+
+class TapexTokenizer(PreTrainedTokenizer):
+ r"""
+ Construct a TAPEX tokenizer. Based on byte-level Byte-Pair-Encoding (BPE).
+
+ This tokenizer can be used to flatten one or more table(s) and concatenate them with one or more related sentences
+ to be used by TAPEX models. The format that the TAPEX tokenizer creates is the following:
+
+ sentence col: col1 | col2 | col 3 row 1 : val1 | val2 | val3 row 2 : ...
+
+ The tokenizer supports a single table + single query, a single table and multiple queries (in which case the table
+ will be duplicated for every query), a single query and multiple tables (in which case the query will be duplicated
+ for every table), and multiple tables and queries. In other words, you can provide a batch of tables + questions to
+ the tokenizer for instance to prepare them for the model.
+
+ Tokenization itself is based on the BPE algorithm. It is identical to the one used by BART, RoBERTa and GPT-2.
+
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
+ this superclass for more information regarding those methods.
+
+ Args:
+ vocab_file (`str`):
+ Path to the vocabulary file.
+ merges_file (`str`):
+ Path to the merges file.
+ do_lower_case (`bool`, *optional*, defaults to `True`):
+ Whether or not to lowercase the input when tokenizing.
+ errors (`str`, *optional*, defaults to `"replace"`):
+ Paradigm to follow when decoding bytes to UTF-8. See
+ [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
+ bos_token (`str`, *optional*, defaults to `""`):
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
+
+
+
+ When building a sequence using special tokens, this is not the token that is used for the beginning of
+ sequence. The token used is the `cls_token`.
+
+
+
+ eos_token (`str`, *optional*, defaults to `""`):
+ The end of sequence token.
+
+
+
+ When building a sequence using special tokens, this is not the token that is used for the end of sequence.
+ The token used is the `sep_token`.
+
+
+
+ sep_token (`str`, *optional*, defaults to `""`):
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
+ sequence classification or for a text and a question for question answering. It is also used as the last
+ token of a sequence built with special tokens.
+ cls_token (`str`, *optional*, defaults to `""`):
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
+ unk_token (`str`, *optional*, defaults to `""`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ pad_token (`str`, *optional*, defaults to `""`):
+ The token used for padding, for example when batching sequences of different lengths.
+ mask_token (`str`, *optional*, defaults to `""`):
+ The token used for masking values. This is the token used when training this model with masked language
+ modeling. This is the token which the model will try to predict.
+ add_prefix_space (`bool`, *optional*, defaults to `False`):
+ Whether or not to add an initial space to the input. This allows to treat the leading word just as any
+ other word. (BART tokenizer detect beginning of words by the preceding space).
+ max_cell_length (`int`, *optional*, defaults to 15):
+ Maximum number of characters per cell when linearizing a table. If this number is exceeded, truncation
+ takes place.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ model_input_names = ["input_ids", "attention_mask"]
+
+ def __init__(
+ self,
+ vocab_file,
+ merges_file,
+ do_lower_case=True,
+ errors="replace",
+ bos_token="",
+ eos_token="",
+ sep_token="",
+ cls_token="",
+ unk_token="",
+ pad_token="",
+ mask_token="",
+ add_prefix_space=False,
+ max_cell_length=15,
+ **kwargs,
+ ):
+ bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
+ eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
+ sep_token = AddedToken(sep_token, lstrip=False, rstrip=False) if isinstance(sep_token, str) else sep_token
+ cls_token = AddedToken(cls_token, lstrip=False, rstrip=False) if isinstance(cls_token, str) else cls_token
+ unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
+ pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
+
+ # Mask token behave like a normal word, i.e. include the space before it
+ mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
+
+ with open(vocab_file, encoding="utf-8") as vocab_handle:
+ self.encoder = json.load(vocab_handle)
+ self.decoder = {v: k for k, v in self.encoder.items()}
+ self.errors = errors # how to handle errors in decoding
+ self.byte_encoder = bytes_to_unicode()
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
+ with open(merges_file, encoding="utf-8") as merges_handle:
+ bpe_merges = merges_handle.read().split("\n")[1:-1]
+ bpe_merges = [tuple(merge.split()) for merge in bpe_merges]
+ self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
+ self.cache = {}
+ self.add_prefix_space = add_prefix_space
+ self.do_lower_case = do_lower_case
+
+ # Should have added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
+ self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
+
+ # additional properties
+
+ super().__init__(
+ vocab_file=vocab_file,
+ merges_file=merges_file,
+ do_lower_case=do_lower_case,
+ errors=errors,
+ bos_token=bos_token,
+ eos_token=eos_token,
+ unk_token=unk_token,
+ sep_token=sep_token,
+ cls_token=cls_token,
+ pad_token=pad_token,
+ mask_token=mask_token,
+ add_prefix_space=add_prefix_space,
+ max_cell_length=max_cell_length,
+ **kwargs,
+ )
+
+ self.max_cell_length = max_cell_length
+ self.table_linearize = IndexedRowTableLinearize()
+
+ def build_inputs_with_special_tokens(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
+ adding special tokens. A TAPEX sequence has the following format:
+ - single sequence: ` X `
+ - pair of sequences: ` A B `
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs to which the special tokens will be added.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+ Returns:
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
+ """
+ if token_ids_1 is None:
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
+ cls = [self.cls_token_id]
+ sep = [self.sep_token_id]
+ return cls + token_ids_0 + sep + sep + token_ids_1 + sep
+
+ def get_special_tokens_mask(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
+ ) -> List[int]:
+ """
+ Args:
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
+ special tokens using the tokenizer `prepare_for_model` method.
+ token_ids_0 (`List[int]`):
+ List of IDs.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
+ Whether or not the token list is already formatted with special tokens for the model.
+ Returns:
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
+ """
+ if already_has_special_tokens:
+ return super().get_special_tokens_mask(
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
+ )
+
+ if token_ids_1 is None:
+ return [1] + ([0] * len(token_ids_0)) + [1]
+ return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
+
+ def create_token_type_ids_from_sequences(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Args:
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. TAPEX does not:
+ make use of token type ids, therefore a list of zeros is returned.
+ token_ids_0 (`List[int]`):
+ List of IDs.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+ Returns:
+ `List[int]`: List of zeros.
+ """
+ sep = [self.sep_token_id]
+ cls = [self.cls_token_id]
+
+ if token_ids_1 is None:
+ return len(cls + token_ids_0 + sep) * [0]
+ return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
+
+ def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs):
+ add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space)
+ if (is_split_into_words or add_prefix_space) and (len(text) > 0 and not text[0].isspace()):
+ text = " " + text
+ return (text, kwargs)
+
+ @property
+ def vocab_size(self):
+ return len(self.encoder)
+
+ def get_vocab(self):
+ return dict(self.encoder, **self.added_tokens_encoder)
+
+ def bpe(self, token):
+ if token in self.cache:
+ return self.cache[token]
+ word = tuple(token)
+ pairs = get_pairs(word)
+
+ if not pairs:
+ return token
+
+ while True:
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
+ if bigram not in self.bpe_ranks:
+ break
+ first, second = bigram
+ new_word = []
+ i = 0
+ while i < len(word):
+ try:
+ j = word.index(first, i)
+ except ValueError:
+ new_word.extend(word[i:])
+ break
+ else:
+ new_word.extend(word[i:j])
+ i = j
+
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
+ new_word.append(first + second)
+ i += 2
+ else:
+ new_word.append(word[i])
+ i += 1
+ new_word = tuple(new_word)
+ word = new_word
+ if len(word) == 1:
+ break
+ else:
+ pairs = get_pairs(word)
+ word = " ".join(word)
+ self.cache[token] = word
+ return word
+
+ def _tokenize(self, text):
+ """Tokenize a string."""
+ bpe_tokens = []
+ for token in re.findall(self.pat, text):
+ token = "".join(
+ self.byte_encoder[b] for b in token.encode("utf-8")
+ ) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case)
+ bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
+ return bpe_tokens
+
+ def _convert_token_to_id(self, token):
+ """Converts a token (str) in an id using the vocab."""
+ return self.encoder.get(token, self.encoder.get(self.unk_token))
+
+ def _convert_id_to_token(self, index):
+ """Converts an index (integer) in a token (str) using the vocab."""
+ return self.decoder.get(index)
+
+ def convert_tokens_to_string(self, tokens):
+ """Converts a sequence of tokens (string) in a single string."""
+ text = "".join(tokens)
+ text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
+ return text
+
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
+ if not os.path.isdir(save_directory):
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
+ return
+ vocab_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
+ )
+ merge_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
+ )
+
+ with open(vocab_file, "w", encoding="utf-8") as f:
+ f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
+
+ index = 0
+ with open(merge_file, "w", encoding="utf-8") as writer:
+ writer.write("#version: 0.2\n")
+ for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
+ if index != token_index:
+ logger.warning(
+ f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
+ " Please check that the tokenizer is not corrupted!"
+ )
+ index = token_index
+ writer.write(" ".join(bpe_tokens) + "\n")
+ index += 1
+
+ return vocab_file, merge_file
+
+ @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, TAPEX_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
+ def __call__(
+ self,
+ table: Union["pd.DataFrame", List["pd.DataFrame"]] = None,
+ query: Optional[Union[TextInput, List[TextInput]]] = None,
+ answer: Union[str, List[str]] = None,
+ add_special_tokens: bool = True,
+ padding: Union[bool, str, PaddingStrategy] = False,
+ truncation: Union[bool, str, TruncationStrategy] = None,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ pad_to_multiple_of: Optional[int] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ **kwargs,
+ ) -> BatchEncoding:
+ """
+ Main method to tokenize and prepare for the model one or several table-sequence pair(s).
+
+ Args:
+ table (`pd.DataFrame`, `List[pd.DataFrame]`):
+ Table(s) containing tabular data.
+ query (`str` or `List[str]`, *optional*):
+ Sentence or batch of sentences related to one or more table(s) to be encoded. Note that the number of
+ sentences must match the number of tables.
+ answer (`str` or `List[str]`, *optional*):
+ Optionally, the corresponding answer to the questions as supervision.
+ """
+
+ if table is not None:
+ return self.source_call_func(
+ table=table,
+ query=query,
+ answer=answer,
+ add_special_tokens=add_special_tokens,
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_tensors=return_tensors,
+ return_token_type_ids=return_token_type_ids,
+ return_attention_mask=return_attention_mask,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_offsets_mapping=return_offsets_mapping,
+ return_length=return_length,
+ verbose=verbose,
+ **kwargs,
+ )
+ elif answer is not None:
+ return self.target_call_func(
+ answer=answer,
+ add_special_tokens=add_special_tokens,
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_tensors=return_tensors,
+ return_token_type_ids=return_token_type_ids,
+ return_attention_mask=return_attention_mask,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_offsets_mapping=return_offsets_mapping,
+ return_length=return_length,
+ verbose=verbose,
+ **kwargs,
+ )
+ else:
+ raise ValueError("You need to provide either a `table` or an `answer`.")
+
+ def source_call_func(
+ self,
+ table: Union["pd.DataFrame", List["pd.DataFrame"]],
+ query: Optional[Union[TextInput, List[TextInput]]] = None,
+ answer: Union[str, List[str]] = None,
+ add_special_tokens: bool = True,
+ padding: Union[bool, str, PaddingStrategy] = False,
+ truncation: Union[bool, str, TruncationStrategy] = None,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ pad_to_multiple_of: Optional[int] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ **kwargs,
+ ) -> BatchEncoding:
+ # Input type checking for clearer error
+ valid_table = False
+ valid_query = False
+
+ # Check that table have a valid type
+ if isinstance(table, pd.DataFrame):
+ valid_table = True
+ elif isinstance(table, (list, tuple)) and isinstance(table[0], pd.DataFrame):
+ valid_table = True
+
+ # Check that query have a valid type
+ if query is None or isinstance(query, str):
+ valid_query = True
+ elif isinstance(query, (list, tuple)):
+ if len(query) == 0 or isinstance(query[0], str):
+ valid_query = True
+
+ if not valid_table:
+ raise ValueError(
+ "table input must of type `pd.DataFrame` (single example), `List[pd.DataFrame]` (batch of examples). "
+ )
+ if not valid_query:
+ raise ValueError("query input must of type `str` (single example), `List[str]` (batch of examples). ")
+ is_batched = isinstance(table, (list, tuple)) or isinstance(query, (list, tuple))
+
+ if is_batched:
+ return self.batch_encode_plus(
+ table=table,
+ query=query,
+ answer=answer,
+ add_special_tokens=add_special_tokens,
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_tensors=return_tensors,
+ return_token_type_ids=return_token_type_ids,
+ return_attention_mask=return_attention_mask,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_offsets_mapping=return_offsets_mapping,
+ return_length=return_length,
+ verbose=verbose,
+ **kwargs,
+ )
+ else:
+ return self.encode_plus(
+ table=table,
+ query=query,
+ answer=answer,
+ add_special_tokens=add_special_tokens,
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_tensors=return_tensors,
+ return_token_type_ids=return_token_type_ids,
+ return_attention_mask=return_attention_mask,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_offsets_mapping=return_offsets_mapping,
+ return_length=return_length,
+ verbose=verbose,
+ **kwargs,
+ )
+
+ @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, TAPEX_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
+ def batch_encode_plus(
+ self,
+ table: Union["pd.DataFrame", List["pd.DataFrame"]],
+ query: Optional[List[TextInput]] = None,
+ answer: List[str] = None,
+ add_special_tokens: bool = True,
+ padding: Union[bool, str, PaddingStrategy] = False,
+ truncation: Union[bool, str] = None,
+ max_length: Optional[int] = None,
+ pad_to_multiple_of: Optional[int] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ **kwargs,
+ ) -> BatchEncoding:
+ """
+
+
+ This method is deprecated, `__call__` should be used instead.
+
+
+ """
+ # Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
+ padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ pad_to_multiple_of=pad_to_multiple_of,
+ verbose=verbose,
+ **kwargs,
+ )
+
+ return self._batch_encode_plus(
+ table=table,
+ query=query,
+ answer=answer,
+ add_special_tokens=add_special_tokens,
+ padding_strategy=padding_strategy,
+ truncation_strategy=truncation_strategy,
+ max_length=max_length,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_tensors=return_tensors,
+ return_token_type_ids=return_token_type_ids,
+ return_attention_mask=return_attention_mask,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_offsets_mapping=return_offsets_mapping,
+ return_length=return_length,
+ verbose=verbose,
+ **kwargs,
+ )
+
+ def _batch_encode_plus(
+ self,
+ table: Union["pd.DataFrame", List["pd.DataFrame"]],
+ query: Optional[List[TextInput]] = None,
+ answer: Optional[List[str]] = None,
+ add_special_tokens: bool = True,
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+ truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ pad_to_multiple_of: Optional[int] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ **kwargs,
+ ) -> BatchEncoding:
+ if return_offsets_mapping:
+ raise NotImplementedError(
+ "return_offset_mapping is not available when using Python tokenizers. "
+ "To use this feature, change your tokenizer to one deriving from "
+ "transformers.PreTrainedTokenizerFast."
+ )
+
+ if isinstance(table, pd.DataFrame) and isinstance(query, (list, tuple)):
+ # single table, many queries case
+ # duplicate table for every query
+ table = [table] * len(query)
+ if isinstance(table, (list, tuple)) and isinstance(query, str):
+ # many tables, single query case
+ # duplicate query for every table
+ query = [query] * len(table)
+
+ batch_outputs = self._batch_prepare_for_model(
+ table=table,
+ query=query,
+ answer=answer,
+ add_special_tokens=add_special_tokens,
+ padding_strategy=padding_strategy,
+ truncation_strategy=truncation_strategy,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_attention_mask=return_attention_mask,
+ return_token_type_ids=return_token_type_ids,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_length=return_length,
+ return_tensors=return_tensors,
+ verbose=verbose,
+ )
+
+ return BatchEncoding(batch_outputs)
+
+ @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, TAPEX_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
+ def _batch_prepare_for_model(
+ self,
+ table: Union["pd.DataFrame", List["pd.DataFrame"]],
+ query: Optional[Union[TextInput, List[TextInput]]] = None,
+ answer: Optional[Union[str, List[str]]] = None,
+ add_special_tokens: bool = True,
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+ truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ pad_to_multiple_of: Optional[int] = None,
+ return_tensors: Optional[str] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ ) -> BatchEncoding:
+ """
+ This method adds special tokens, truncates sequences if overflowing while taking into account the special
+ tokens and manages a moving window (with user defined stride) for overflowing tokens.
+ """
+ batch_outputs = {}
+ if answer is None:
+ answer = [None] * len(table)
+ for _table, _query, _answer in zip(table, query, answer):
+ text = self.prepare_table_query(
+ _table, _query, _answer, truncation_strategy=truncation_strategy, max_length=max_length
+ )
+
+ if self.do_lower_case:
+ text = text.lower()
+
+ tokens = self.tokenize(text)
+ outputs = self.prepare_for_model(
+ ids=self.convert_tokens_to_ids(tokens),
+ add_special_tokens=add_special_tokens,
+ padding=PaddingStrategy.DO_NOT_PAD.value, # we pad in batch afterwards
+ truncation=truncation_strategy.value,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=None, # we pad in batch afterwards
+ return_attention_mask=False, # we pad in batch afterwards
+ return_token_type_ids=return_token_type_ids,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_length=return_length,
+ return_tensors=None, # We convert the whole batch to tensors at the end
+ prepend_batch_axis=False,
+ verbose=verbose,
+ )
+
+ for key, value in outputs.items():
+ if key not in batch_outputs:
+ batch_outputs[key] = []
+ batch_outputs[key].append(value)
+
+ batch_outputs = self.pad(
+ batch_outputs,
+ padding=padding_strategy.value,
+ max_length=max_length,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_attention_mask=return_attention_mask,
+ )
+
+ batch_outputs = BatchEncoding(batch_outputs, tensor_type=return_tensors)
+
+ return batch_outputs
+
+ @add_end_docstrings(ENCODE_KWARGS_DOCSTRING)
+ def encode(
+ self,
+ table: "pd.DataFrame",
+ query: Optional[TextInput] = None,
+ answer: Optional[str] = None,
+ add_special_tokens: bool = True,
+ padding: Union[bool, str, PaddingStrategy] = False,
+ truncation: Union[bool, str, TruncationStrategy, TapexTruncationStrategy] = None,
+ max_length: Optional[int] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ **kwargs,
+ ) -> List[int]:
+ """
+ Prepare a table, a string and possible answer for the model. This method does not return token type IDs,
+ attention masks, etc. which are necessary for the model to work correctly. Use this method if you want to build
+ your processing on your own, otherwise refer to `__call__`.
+ """
+ encoded_inputs = self.encode_plus(
+ table,
+ query=query,
+ answer=answer,
+ add_special_tokens=add_special_tokens,
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ return_tensors=return_tensors,
+ **kwargs,
+ )
+
+ return encoded_inputs["input_ids"]
+
+ @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, TAPEX_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
+ def encode_plus(
+ self,
+ table: "pd.DataFrame",
+ query: Optional[TextInput] = None,
+ answer: Optional[str] = None,
+ add_special_tokens: bool = True,
+ padding: Union[bool, str, PaddingStrategy] = False,
+ truncation: Union[bool, str] = None,
+ max_length: Optional[int] = None,
+ pad_to_multiple_of: Optional[int] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ **kwargs,
+ ) -> BatchEncoding:
+ # Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
+ padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ pad_to_multiple_of=pad_to_multiple_of,
+ verbose=verbose,
+ **kwargs,
+ )
+
+ return self._encode_plus(
+ table=table,
+ query=query,
+ answer=answer,
+ add_special_tokens=add_special_tokens,
+ padding_strategy=padding_strategy,
+ truncation_strategy=truncation_strategy,
+ max_length=max_length,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_tensors=return_tensors,
+ return_token_type_ids=return_token_type_ids,
+ return_attention_mask=return_attention_mask,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_offsets_mapping=return_offsets_mapping,
+ return_length=return_length,
+ verbose=verbose,
+ **kwargs,
+ )
+
+ def _encode_plus(
+ self,
+ table: "pd.DataFrame",
+ query: Optional[TextInput] = None,
+ answer: Optional[str] = None,
+ add_special_tokens: bool = True,
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+ truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ pad_to_multiple_of: Optional[int] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ **kwargs,
+ ) -> BatchEncoding:
+ if return_offsets_mapping:
+ raise NotImplementedError(
+ "return_offset_mapping is not available when using Python tokenizers. "
+ "To use this feature, change your tokenizer to one deriving from "
+ "transformers.PreTrainedTokenizerFast. "
+ "More information on available tokenizers at "
+ "https://github.com/huggingface/transformers/pull/2674"
+ )
+
+ text = self.prepare_table_query(
+ table, query, answer, truncation_strategy=truncation_strategy, max_length=max_length
+ )
+
+ # if necessary, perform lower case
+ if self.do_lower_case:
+ text = text.lower()
+
+ tokens = self.tokenize(text)
+
+ return self.prepare_for_model(
+ ids=self.convert_tokens_to_ids(tokens),
+ add_special_tokens=add_special_tokens,
+ padding=padding_strategy.value,
+ truncation=truncation_strategy.value,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_tensors=return_tensors,
+ prepend_batch_axis=True,
+ return_attention_mask=return_attention_mask,
+ return_token_type_ids=return_token_type_ids,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_length=return_length,
+ verbose=verbose,
+ )
+
+ def target_call_func(
+ self,
+ answer: Union[str, List[str]],
+ add_special_tokens: bool = True,
+ padding: Union[bool, str, PaddingStrategy] = False,
+ truncation: Union[bool, str, TruncationStrategy] = None,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ pad_to_multiple_of: Optional[int] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ **kwargs,
+ ) -> BatchEncoding:
+ """
+ The method tokenizes and prepares the answer label for the model.
+
+ Args:
+ answer (`str` or `List[str]`):
+ Corresponding answer supervision to the queries for training the model.
+ """
+ is_batched = isinstance(answer, (list, tuple))
+
+ if is_batched:
+ return self.target_batch_encode_plus(
+ answer=answer,
+ add_special_tokens=add_special_tokens,
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_tensors=return_tensors,
+ return_token_type_ids=return_token_type_ids,
+ return_attention_mask=return_attention_mask,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_offsets_mapping=return_offsets_mapping,
+ return_length=return_length,
+ verbose=verbose,
+ **kwargs,
+ )
+ else:
+ return self.target_encode_plus(
+ answer=answer,
+ add_special_tokens=add_special_tokens,
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_tensors=return_tensors,
+ return_token_type_ids=return_token_type_ids,
+ return_attention_mask=return_attention_mask,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_offsets_mapping=return_offsets_mapping,
+ return_length=return_length,
+ verbose=verbose,
+ **kwargs,
+ )
+
+ def target_batch_encode_plus(
+ self,
+ answer: List[str],
+ add_special_tokens: bool = True,
+ padding: Union[bool, str, PaddingStrategy] = False,
+ truncation: Union[bool, str] = None,
+ max_length: Optional[int] = None,
+ pad_to_multiple_of: Optional[int] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ **kwargs,
+ ) -> BatchEncoding:
+ """
+ Prepare answer strings for the model.
+
+ Args:
+ answer `List[str]`:
+ Corresponding answer supervision to the queries for training the model.
+ """
+ # Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
+ padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ pad_to_multiple_of=pad_to_multiple_of,
+ verbose=verbose,
+ **kwargs,
+ )
+
+ return self._target_batch_encode_plus(
+ answer=answer,
+ add_special_tokens=add_special_tokens,
+ padding_strategy=padding_strategy,
+ truncation_strategy=truncation_strategy,
+ max_length=max_length,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_tensors=return_tensors,
+ return_token_type_ids=return_token_type_ids,
+ return_attention_mask=return_attention_mask,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_offsets_mapping=return_offsets_mapping,
+ return_length=return_length,
+ verbose=verbose,
+ **kwargs,
+ )
+
+ def _target_batch_encode_plus(
+ self,
+ answer: List[str],
+ add_special_tokens: bool = True,
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+ truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ pad_to_multiple_of: Optional[int] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ **kwargs,
+ ) -> BatchEncoding:
+ batch_outputs = {}
+ for text in answer:
+ if self.do_lower_case:
+ text = text.lower()
+
+ tokens = self.tokenize(text)
+ outputs = self.prepare_for_model(
+ ids=self.convert_tokens_to_ids(tokens),
+ add_special_tokens=add_special_tokens,
+ padding=PaddingStrategy.DO_NOT_PAD.value, # we pad in batch afterwards
+ truncation=truncation_strategy.value,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=None, # we pad in batch afterwards
+ return_attention_mask=False, # we pad in batch afterwards
+ return_token_type_ids=return_token_type_ids,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_length=return_length,
+ return_tensors=None, # We convert the whole batch to tensors at the end
+ prepend_batch_axis=False,
+ verbose=verbose,
+ )
+
+ for key, value in outputs.items():
+ if key not in batch_outputs:
+ batch_outputs[key] = []
+ batch_outputs[key].append(value)
+
+ batch_outputs = self.pad(
+ batch_outputs,
+ padding=padding_strategy.value,
+ max_length=max_length,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_attention_mask=return_attention_mask,
+ )
+
+ batch_outputs = BatchEncoding(batch_outputs, tensor_type=return_tensors)
+
+ return BatchEncoding(batch_outputs)
+
+ def target_encode(
+ self,
+ answer: str,
+ add_special_tokens: bool = True,
+ padding: Union[bool, str, PaddingStrategy] = False,
+ truncation: Union[bool, str, TruncationStrategy, TapexTruncationStrategy] = None,
+ max_length: Optional[int] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ **kwargs,
+ ) -> List[int]:
+ """
+ Prepare the answer string for the model. This method does not return token type IDs, attention masks, etc.
+ which are necessary for the model to work correctly. Use this method if you want to build your processing on
+ your own, otherwise refer to `__call__`.
+
+ Args:
+ answer `str`:
+ Corresponding answer supervision to the queries for training the model
+ """
+ encoded_outputs = self.target_encode_plus(
+ answer=answer,
+ add_special_tokens=add_special_tokens,
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ return_tensors=return_tensors,
+ **kwargs,
+ )
+
+ return encoded_outputs["input_ids"]
+
+ def target_encode_plus(
+ self,
+ answer: str,
+ add_special_tokens: bool = True,
+ padding: Union[bool, str, PaddingStrategy] = False,
+ truncation: Union[bool, str] = None,
+ max_length: Optional[int] = None,
+ pad_to_multiple_of: Optional[int] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ **kwargs,
+ ) -> BatchEncoding:
+ """
+ Prepare a answer string for the model.
+
+ Args:
+ answer `str`:
+ Corresponding answer supervision to the queries for training the model.
+ """
+ # Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
+ padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ pad_to_multiple_of=pad_to_multiple_of,
+ verbose=verbose,
+ **kwargs,
+ )
+
+ return self._target_encode_plus(
+ answer=answer,
+ add_special_tokens=add_special_tokens,
+ padding_strategy=padding_strategy,
+ truncation_strategy=truncation_strategy,
+ max_length=max_length,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_tensors=return_tensors,
+ return_token_type_ids=return_token_type_ids,
+ return_attention_mask=return_attention_mask,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_offsets_mapping=return_offsets_mapping,
+ return_length=return_length,
+ verbose=verbose,
+ **kwargs,
+ )
+
+ def _target_encode_plus(
+ self,
+ answer: str,
+ add_special_tokens: bool = True,
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+ truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
+ max_length: Optional[int] = None,
+ stride: int = 0,
+ pad_to_multiple_of: Optional[int] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_token_type_ids: Optional[bool] = None,
+ return_attention_mask: Optional[bool] = None,
+ return_overflowing_tokens: bool = False,
+ return_special_tokens_mask: bool = False,
+ return_offsets_mapping: bool = False,
+ return_length: bool = False,
+ verbose: bool = True,
+ **kwargs,
+ ) -> BatchEncoding:
+ if return_offsets_mapping:
+ raise NotImplementedError(
+ "return_offset_mapping is not available when using Python tokenizers. "
+ "To use this feature, change your tokenizer to one deriving from "
+ "transformers.PreTrainedTokenizerFast. "
+ "More information on available tokenizers at "
+ "https://github.com/huggingface/transformers/pull/2674"
+ )
+
+ text = answer
+
+ # if necessary, perform lower case
+ if self.do_lower_case:
+ text = text.lower()
+
+ tokens = self.tokenize(text)
+
+ return self.prepare_for_model(
+ ids=self.convert_tokens_to_ids(tokens),
+ add_special_tokens=add_special_tokens,
+ padding=padding_strategy.value,
+ truncation=truncation_strategy.value,
+ max_length=max_length,
+ stride=stride,
+ pad_to_multiple_of=pad_to_multiple_of,
+ return_tensors=return_tensors,
+ prepend_batch_axis=True,
+ return_attention_mask=return_attention_mask,
+ return_token_type_ids=return_token_type_ids,
+ return_overflowing_tokens=return_overflowing_tokens,
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_length=return_length,
+ verbose=verbose,
+ )
+
+ def prepare_table_query(
+ self,
+ table,
+ query,
+ answer=None,
+ truncation_strategy=Union[str, TruncationStrategy, TapexTruncationStrategy],
+ max_length=None,
+ ):
+ """
+ This method can be used to linearize a table and add a corresponding query.
+
+ Optionally, it also handles truncation of the table (cells).
+
+ An answer can be provided for more precise truncation.
+ """
+ if not table.empty:
+ # step 1: create table dictionary
+ table_content = {"header": list(table.columns), "rows": [list(row.values) for i, row in table.iterrows()]}
+
+ # step 2: modify table internally
+ # always truncate table cells based on self.max_cell_length
+ # optionally truncate rows if truncation_strategy is set to it
+ self.truncate_table_cells(table_content, query, answer)
+ if truncation_strategy == TapexTruncationStrategy.DROP_ROWS_TO_FIT:
+ self.truncate_table_rows(table_content, query, answer, max_length=max_length)
+
+ # step 3: linearize table
+ linear_table = self.table_linearize.process_table(table_content)
+ else:
+ linear_table = ""
+
+ if linear_table == "":
+ logger.warning(
+ "You provide an empty table, or all cells contain much tokens (e.g., >= 1024 tokens). "
+ + f"Please carefully check the corresponding table with the query : {query}."
+ )
+ if query == "":
+ logger.warning("You provide nothing to query with respect to the table.")
+ # step 4: concatenate query with linear_table
+ separator = " " if query and linear_table else ""
+ joint_input = (query + separator + linear_table) if query else linear_table
+
+ return joint_input
+
+ def truncate_table_cells(self, table_content: Dict, question: str, answer: List):
+ # TODO (Qian): is it possible to revert the original cell if it is in the final answer?
+ cell_mapping = {}
+ for row in table_content["rows"]:
+ for i, cell in enumerate(row):
+ truncate_cell = self.truncate_cell(cell)
+ if truncate_cell is not None:
+ cell_mapping[cell] = truncate_cell
+ row[i] = truncate_cell
+
+ # modify the answer list
+ if answer is not None:
+ for i, case in enumerate(answer):
+ if case in cell_mapping.keys():
+ answer[i] = cell_mapping[case]
+
+ def truncate_cell(self, cell_value):
+ # do not process on these cases
+ if isinstance(cell_value, int) or isinstance(cell_value, float):
+ return cell_value
+ if cell_value.strip() != "":
+ try_tokens = self.tokenize(cell_value)
+ if len(try_tokens) >= self.max_cell_length:
+ retain_tokens = try_tokens[: self.max_cell_length]
+ retain_cell_value = self.convert_tokens_to_string(retain_tokens)
+ return retain_cell_value
+ else:
+ return None
+ else:
+ return cell_value
+
+ def truncate_table_rows(
+ self, table_content: Dict, question: str, answer: Optional[Union[str, List[str]]] = None, max_length=None
+ ):
+ """
+ Args:
+ table_content:
+ {"header": xxx, "rows": xxx, "id" (Optionally): xxx}
+
+ question:
+ natural language sentence
+
+ answer:
+ if for training, is the supervision; otherwise will be empty
+ """
+ delete_ratio, remain_token_len = self.estimate_delete_ratio(table_content, question, max_length)
+ # randomly delete unrelated rows
+ self.delete_unrelated_rows(table_content, question, answer, delete_ratio)
+ # guarantee the result < max_length
+ maximum_keep_rows = 0
+ for ind, row_example in enumerate(table_content["rows"]):
+ value_string = self.table_linearize.process_row(row_example, ind + 1)
+ value_token_len = len(self.tokenize(value_string))
+ # over the size limit, and take action
+ if value_token_len > remain_token_len:
+ break
+ remain_token_len -= value_token_len
+ maximum_keep_rows += 1
+ del table_content["rows"][maximum_keep_rows:]
+
+ def estimate_delete_ratio(self, table_content: Dict, question: str, max_length=None):
+ if "header" not in table_content or "rows" not in table_content:
+ raise ValueError("The table content should contain both 'header' and 'rows' keys.")
+ # calculate the tokens of header, special tokens will only be pre-prepended into question
+ question_tokens = self.tokenize(question, add_special_tokens=True)
+ # calculate the tokens of header
+ header_string = self.table_linearize.process_header(table_content["header"])
+ header_tokens = self.tokenize(header_string, add_special_tokens=False)
+ # split all cell values into tokens and see how many can be accommodated
+ used_token_len = len(question_tokens) + len(header_tokens)
+ # remaining token space for rows
+ remain_token_len = max_length - used_token_len
+
+ value_string = ""
+ for _, row_example in enumerate(table_content["rows"]):
+ # use a general index to roughly estimate the overall token len
+ value_string += self.table_linearize.process_row(row_example, 100) + " "
+ value_token_len = len(self.tokenize(value_string))
+
+ if value_token_len < remain_token_len:
+ # no row will be deleted
+ return 0.0, remain_token_len
+ else:
+ # calc a roughly delete rate
+ return 1.0 - remain_token_len / value_token_len, remain_token_len
+
+ def delete_unrelated_rows(self, table_content: Dict, question: str, answer: List, delete_ratio: float):
+ """
+ The argument answer is used only during training.
+ """
+ truncated_unrelated_indices = []
+ related_indices = []
+ if answer is None or len(answer) == 0:
+ answer_set = set()
+ else:
+ answer_set = {ans_ex.lower() for ans_ex in answer}
+ # add question key words into answer set
+ if question is not None:
+ answer_set.update(question.split())
+ question_set = set(question.strip("?!.,").split(" "))
+ row_max_len = len(table_content["rows"])
+ for _row_idx, row in enumerate(table_content["rows"]):
+ lower_row = {str(cell).lower() for cell in row}
+ if len(lower_row & answer_set) == 0 and len(lower_row & question_set) == 0:
+ truncated_unrelated_indices.append(_row_idx)
+ else:
+ # add neighbours to preserve information aggressively
+ related_indices.extend([_row_idx - 2, _row_idx - 1, _row_idx, _row_idx + 1, _row_idx + 2])
+
+ # remove the neighbours
+ truncated_unrelated_indices = [
+ _row_idx for _row_idx in truncated_unrelated_indices if _row_idx not in related_indices
+ ]
+ # select some cases to drop
+ drop_items = min(len(truncated_unrelated_indices), int(len(table_content["rows"]) * delete_ratio))
+ drop_row_indices = random.choices(truncated_unrelated_indices, k=drop_items)
+
+ for _row_idx in reversed(range(row_max_len)):
+ if _row_idx in drop_row_indices:
+ del table_content["rows"][_row_idx]
+
+ # only when the drop ratio is too large, logging for warning.
+ if "id" in table_content and len(drop_row_indices) > 0:
+ logger.warning("Delete {:.2f} rows in table {}".format(len(drop_row_indices), table_content["id"]))
+
+
+__all__ = ["TapexTokenizer"]
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/trajectory_transformer/__init__.py b/docs/transformers/build/lib/transformers/models/deprecated/trajectory_transformer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4bccdd12c9703313951e223d0ac1955f9ca1581
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/trajectory_transformer/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ....utils import _LazyModule
+from ....utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_trajectory_transformer import *
+ from .modeling_trajectory_transformer import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/trajectory_transformer/configuration_trajectory_transformer.py b/docs/transformers/build/lib/transformers/models/deprecated/trajectory_transformer/configuration_trajectory_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a267cc4c4b82c6703eaa3a732d3fca1606a60d7
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/trajectory_transformer/configuration_trajectory_transformer.py
@@ -0,0 +1,155 @@
+# coding=utf-8
+# Copyright 2022 The Trajectory Transformers paper authors and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""TrajectoryTransformer model configuration"""
+
+from ....configuration_utils import PretrainedConfig
+from ....utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class TrajectoryTransformerConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`TrajectoryTransformerModel`]. It is used to
+ instantiate an TrajectoryTransformer model according to the specified arguments, defining the model architecture.
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the
+ TrajectoryTransformer
+ [CarlCochet/trajectory-transformer-halfcheetah-medium-v2](https://huggingface.co/CarlCochet/trajectory-transformer-halfcheetah-medium-v2)
+ architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 100):
+ Vocabulary size of the TrajectoryTransformer model. Defines the number of different tokens that can be
+ represented by the `trajectories` passed when calling [`TrajectoryTransformerModel`]
+ action_weight (`int`, *optional*, defaults to 5):
+ Weight of the action in the loss function
+ reward_weight (`int`, *optional*, defaults to 1):
+ Weight of the reward in the loss function
+ value_weight (`int`, *optional*, defaults to 1):
+ Weight of the value in the loss function
+ block_size (`int`, *optional*, defaults to 249):
+ Size of the blocks in the trajectory transformer.
+ action_dim (`int`, *optional*, defaults to 6):
+ Dimension of the action space.
+ observation_dim (`int`, *optional*, defaults to 17):
+ Dimension of the observation space.
+ transition_dim (`int`, *optional*, defaults to 25):
+ Dimension of the transition space.
+ n_layer (`int`, *optional*, defaults to 4):
+ Number of hidden layers in the Transformer encoder.
+ n_head (`int`, *optional*, defaults to 4):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ n_embd (`int`, *optional*, defaults to 128):
+ Dimensionality of the embeddings and hidden states.
+ resid_pdrop (`float`, *optional*, defaults to 0.1):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ embd_pdrop (`int`, *optional*, defaults to 0.1):
+ The dropout ratio for the embeddings.
+ attn_pdrop (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for the attention.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
+ max_position_embeddings (`int`, *optional*, defaults to 512):
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
+ just in case (e.g., 512 or 1024 or 2048).
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
+ The epsilon used by the layer normalization layers.
+ kaiming_initializer_range (`float, *optional*, defaults to 1):
+ A coefficient scaling the negative slope of the kaiming initializer rectifier for EinLinear layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ Example:
+
+ ```python
+ >>> from transformers import TrajectoryTransformerConfig, TrajectoryTransformerModel
+
+ >>> # Initializing a TrajectoryTransformer CarlCochet/trajectory-transformer-halfcheetah-medium-v2 style configuration
+ >>> configuration = TrajectoryTransformerConfig()
+
+ >>> # Initializing a model (with random weights) from the CarlCochet/trajectory-transformer-halfcheetah-medium-v2 style configuration
+ >>> model = TrajectoryTransformerModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "trajectory_transformer"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ attribute_map = {
+ "hidden_size": "n_embd",
+ "num_attention_heads": "n_head",
+ "num_hidden_layers": "n_layer",
+ }
+
+ def __init__(
+ self,
+ vocab_size=100,
+ action_weight=5,
+ reward_weight=1,
+ value_weight=1,
+ block_size=249,
+ action_dim=6,
+ observation_dim=17,
+ transition_dim=25,
+ n_layer=4,
+ n_head=4,
+ n_embd=128,
+ embd_pdrop=0.1,
+ attn_pdrop=0.1,
+ resid_pdrop=0.1,
+ learning_rate=0.0006,
+ max_position_embeddings=512,
+ initializer_range=0.02,
+ layer_norm_eps=1e-12,
+ kaiming_initializer_range=1,
+ use_cache=True,
+ pad_token_id=1,
+ bos_token_id=50256,
+ eos_token_id=50256,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.action_weight = action_weight
+ self.reward_weight = reward_weight
+ self.value_weight = value_weight
+ self.max_position_embeddings = max_position_embeddings
+ self.block_size = block_size
+ self.action_dim = action_dim
+ self.observation_dim = observation_dim
+ self.transition_dim = transition_dim
+ self.learning_rate = learning_rate
+ self.n_layer = n_layer
+ self.n_head = n_head
+ self.n_embd = n_embd
+ self.embd_pdrop = embd_pdrop
+ self.attn_pdrop = attn_pdrop
+ self.resid_pdrop = resid_pdrop
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.kaiming_initializer_range = kaiming_initializer_range
+ self.use_cache = use_cache
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
+
+
+__all__ = ["TrajectoryTransformerConfig"]
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/trajectory_transformer/convert_trajectory_transformer_original_pytorch_checkpoint_to_pytorch.py b/docs/transformers/build/lib/transformers/models/deprecated/trajectory_transformer/convert_trajectory_transformer_original_pytorch_checkpoint_to_pytorch.py
new file mode 100644
index 0000000000000000000000000000000000000000..da7f7806671dbace1a10bd60d93e6782e27a5136
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/trajectory_transformer/convert_trajectory_transformer_original_pytorch_checkpoint_to_pytorch.py
@@ -0,0 +1,70 @@
+# coding=utf-8
+# Copyright 2022 The Trajectory Transformers paper authors and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""TrajectoryTransformer pytorch checkpoint conversion"""
+
+import torch
+import trajectory.utils as utils
+
+from transformers import TrajectoryTransformerModel
+
+
+class Parser(utils.Parser):
+ dataset: str = "halfcheetah-medium-expert-v2"
+ config: str = "config.offline"
+
+
+def convert_trajectory_transformer_original_pytorch_checkpoint_to_pytorch(logbase, dataset, loadpath, epoch, device):
+ """Converting Sequential blocks to ModuleList"""
+
+ gpt, gpt_epoch = utils.load_model(logbase, dataset, loadpath, epoch=epoch, device=device)
+ trajectory_transformer = TrajectoryTransformerModel(gpt.config)
+
+ trajectory_transformer.tok_emb.load_state_dict(gpt.tok_emb.state_dict())
+ trajectory_transformer.pos_emb = gpt.pos_emb
+ trajectory_transformer.drop.load_state_dict(gpt.drop.state_dict())
+ trajectory_transformer.ln_f.load_state_dict(gpt.ln_f.state_dict())
+ trajectory_transformer.head.load_state_dict(gpt.head.state_dict())
+
+ for i, block in enumerate(gpt.blocks):
+ trajectory_transformer.blocks[i].ln1.load_state_dict(gpt.blocks[i].ln1.state_dict())
+ trajectory_transformer.blocks[i].ln2.load_state_dict(gpt.blocks[i].ln2.state_dict())
+ trajectory_transformer.blocks[i].attn.load_state_dict(gpt.blocks[i].attn.state_dict())
+
+ trajectory_transformer.blocks[i].l1.load_state_dict(gpt.blocks[i].mlp[0].state_dict())
+ trajectory_transformer.blocks[i].act.load_state_dict(gpt.blocks[i].mlp[1].state_dict())
+ trajectory_transformer.blocks[i].l2.load_state_dict(gpt.blocks[i].mlp[2].state_dict())
+ trajectory_transformer.blocks[i].drop.load_state_dict(gpt.blocks[i].mlp[3].state_dict())
+
+ torch.save(trajectory_transformer.state_dict(), "pytorch_model.bin")
+
+
+if __name__ == "__main__":
+ """
+ To run this script you will need to install the original repository to run the original model. You can find it
+ here: https://github.com/jannerm/trajectory-transformer From this repository code you can also download the
+ original pytorch checkpoints.
+
+ Run with the command:
+
+ ```sh
+ >>> python convert_trajectory_transformer_original_pytorch_checkpoint_to_pytorch.py --dataset
+ ... --gpt_loadpath
+ ```
+ """
+
+ args = Parser().parse_args("plan")
+ convert_trajectory_transformer_original_pytorch_checkpoint_to_pytorch(
+ args.logbase, args.dataset, args.gpt_loadpath, args.gpt_epoch, args.device
+ )
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/trajectory_transformer/modeling_trajectory_transformer.py b/docs/transformers/build/lib/transformers/models/deprecated/trajectory_transformer/modeling_trajectory_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a9a0111d22218af6af00d7561107d0c70def2db
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/trajectory_transformer/modeling_trajectory_transformer.py
@@ -0,0 +1,610 @@
+# coding=utf-8
+# Copyright 2022 The Trajectory Transformers paper authors and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch TrajectoryTransformer model."""
+
+import math
+import os
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import functional as F
+
+from ....modeling_utils import PreTrainedModel
+from ....utils import (
+ ModelOutput,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from .configuration_trajectory_transformer import TrajectoryTransformerConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "CarlCochet/trajectory-transformer-halfcheetah-medium-v2"
+_CONFIG_FOR_DOC = "TrajectoryTransformerConfig"
+
+
+def load_tf_weights_in_trajectory_transformer(model, config, tf_checkpoint_path):
+ """Load tf checkpoints in a pytorch model."""
+ try:
+ import re
+
+ import numpy as np
+ import tensorflow as tf
+ except ImportError:
+ logger.error(
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
+ "https://www.tensorflow.org/install/ for installation instructions."
+ )
+ raise
+ tf_path = os.path.abspath(tf_checkpoint_path)
+ logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
+ # Load weights from TF model
+ init_vars = tf.train.list_variables(tf_path)
+ names = []
+ arrays = []
+ for name, shape in init_vars:
+ logger.info(f"Loading TF weight {name} with shape {shape}")
+ array = tf.train.load_variable(tf_path, name)
+ names.append(name)
+ arrays.append(array)
+
+ for name, array in zip(names, arrays):
+ name = name.split("/")
+ # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
+ # which are not required for using pretrained model
+ if any(
+ n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
+ for n in name
+ ):
+ logger.info(f"Skipping {'/'.join(name)}")
+ continue
+ pointer = model
+ for m_name in name:
+ if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
+ scope_names = re.split(r"_(\d+)", m_name)
+ else:
+ scope_names = [m_name]
+ if scope_names[0] == "kernel" or scope_names[0] == "gamma":
+ pointer = getattr(pointer, "weight")
+ elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
+ pointer = getattr(pointer, "bias")
+ elif scope_names[0] == "output_weights":
+ pointer = getattr(pointer, "weight")
+ elif scope_names[0] == "squad":
+ pointer = getattr(pointer, "classifier")
+ else:
+ try:
+ pointer = getattr(pointer, scope_names[0])
+ except AttributeError:
+ logger.info(f"Skipping {'/'.join(name)}")
+ continue
+ if len(scope_names) >= 2:
+ num = int(scope_names[1])
+ pointer = pointer[num]
+ if m_name[-11:] == "_embeddings":
+ pointer = getattr(pointer, "weight")
+ elif m_name == "kernel":
+ array = np.transpose(array)
+ try:
+ if pointer.shape != array.shape:
+ raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
+ except AssertionError as e:
+ e.args += (pointer.shape, array.shape)
+ raise
+ logger.info(f"Initialize PyTorch weight {name}")
+ pointer.data = torch.from_numpy(array)
+ return model
+
+
+@dataclass
+class TrajectoryTransformerOutput(ModelOutput):
+ """
+ Base class for model's outputs that also contains a pooling of the last hidden states.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss.
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ past_key_values (`Tuple[Tuple[torch.Tensor]]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of length `config.n_layers`, containing tuples of tensors of shape `(batch_size, num_heads,
+ sequence_length, embed_size_per_head)`). Contains pre-computed hidden-states (key and values in the
+ attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
+ shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
+ plus the initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`. GPT2Attentions weights after the attention softmax, used to compute the weighted average
+ in the self-attention heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+
+
+class TrajectoryTransformerPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = TrajectoryTransformerConfig
+ load_tf_weights = load_tf_weights_in_trajectory_transformer
+ base_model_prefix = "trajectory_transformer"
+ main_input_name = "trajectories"
+ supports_gradient_checkpointing = True
+
+ def _init_weights(self, module):
+ if isinstance(module, (nn.Linear, nn.Embedding)):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if isinstance(module, nn.Linear) and module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, EinLinear):
+ for i in range(module.n_models):
+ nn.init.kaiming_uniform_(module.weight[i], a=math.sqrt(5) / self.config.kaiming_initializer_range)
+ if module.bias is not None:
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight[i])
+ bound = (1 / math.sqrt(fan_in)) * self.config.initializer_range
+ nn.init.uniform_(module.bias[i], -bound, bound)
+
+
+TRAJECTORY_TRANSFORMER_START_DOCSTRING = r"""
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
+ it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+ behavior.
+
+ Parameters:
+ config ([`TrajectoryTransformerConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+TRAJECTORY_TRANSFORMER_INPUTS_DOCSTRING = r"""
+ Args:
+ trajectories (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Batch of trajectories, where a trajectory is a sequence of states, actions and rewards.
+ past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`, *optional*):
+ Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
+ `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
+ their past given to this model should not be passed as `input_ids` as they have already been computed.
+ targets (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Desired targets used to compute the loss.
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+class EinLinear(nn.Module):
+ def __init__(self, n_models, in_features, out_features, bias):
+ super().__init__()
+ self.n_models = n_models
+ self.out_features = out_features
+ self.in_features = in_features
+ self.weight = nn.Parameter(torch.Tensor(n_models, out_features, in_features))
+ if bias:
+ self.bias = nn.Parameter(torch.Tensor(n_models, out_features))
+ else:
+ self.register_parameter("bias", None)
+
+ def reset_parameters(self):
+ for i in range(self.n_models):
+ nn.init.kaiming_uniform_(self.weight[i], a=math.sqrt(5))
+ if self.bias is not None:
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[i])
+ bound = 1 / math.sqrt(fan_in)
+ nn.init.uniform_(self.bias[i], -bound, bound)
+
+ def forward(self, input):
+ """
+ Args:
+ input (`torch.FloatTensor` of shape `(B, n_models, input_dim)`):
+ The input to the layer.
+ """
+ # [ batch_size x n_models x output_dim ]
+ output = torch.einsum("eoi,bei->beo", self.weight, input)
+ if self.bias is not None:
+ raise RuntimeError()
+ return output
+
+
+class CausalSelfAttention(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+
+ if config.n_embd % config.n_head != 0:
+ raise ValueError(f"n_head ({config.n_head}) should be a divisor of n_embd ({config.n_embd})")
+
+ # key, query, value projections for all heads
+ self.key = nn.Linear(config.n_embd, config.n_embd)
+ self.query = nn.Linear(config.n_embd, config.n_embd)
+ self.value = nn.Linear(config.n_embd, config.n_embd)
+
+ # regularization
+ self.attn_drop = nn.Dropout(config.attn_pdrop)
+ self.resid_drop = nn.Dropout(config.resid_pdrop)
+
+ # output projection
+ self.proj = nn.Linear(config.n_embd, config.n_embd)
+
+ # causal mask to ensure that attention is only applied to the left in the input sequence
+ self.register_buffer(
+ "mask",
+ torch.tril(torch.ones(config.block_size, config.block_size)).view(
+ 1, 1, config.block_size, config.block_size
+ ),
+ persistent=False,
+ )
+
+ # mask previous value estimates
+ joined_dim = config.observation_dim + config.action_dim + 2
+ self.mask.squeeze()[:, joined_dim - 1 :: joined_dim] = 0
+
+ self.n_head = config.n_head
+
+ def forward(
+ self,
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
+ use_cache: Optional[bool] = False,
+ output_attentions: Optional[bool] = False,
+ ):
+ batch_size, sequence_length, embedding_dim = hidden_states.size()
+
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
+ # [ batch_size x n_heads x sequence_length x head_dim ]
+ key = (
+ self.key(hidden_states)
+ .view(batch_size, sequence_length, self.n_head, embedding_dim // self.n_head)
+ .transpose(1, 2)
+ )
+ query = (
+ self.query(hidden_states)
+ .view(batch_size, sequence_length, self.n_head, embedding_dim // self.n_head)
+ .transpose(1, 2)
+ )
+ value = (
+ self.value(hidden_states)
+ .view(batch_size, sequence_length, self.n_head, embedding_dim // self.n_head)
+ .transpose(1, 2)
+ )
+
+ if layer_past is not None:
+ past_key, past_value = layer_past
+ key = torch.cat((past_key, key), dim=-2)
+ value = torch.cat((past_value, value), dim=-2)
+
+ if use_cache is True:
+ present = (key, value)
+ else:
+ present = None
+
+ # causal self-attention
+ # [ batch_size x n_heads x sequence_length x sequence_length ]
+ attn_weights = (torch.matmul(query, key.transpose(-2, -1))) * (1.0 / math.sqrt(key.size(-1)))
+ attn_weights = attn_weights.masked_fill(
+ self.mask[:, :, :sequence_length, :sequence_length] == 0, torch.finfo(attn_weights.dtype).min
+ )
+ attn_weights = F.softmax(attn_weights, dim=-1)
+ self._attn_map = attn_weights.clone()
+ attn_weights = self.attn_drop(attn_weights)
+
+ output = torch.matmul(attn_weights, value)
+ # [ batch_size x sequence_length x embedding_dim ]
+ # re-assemble all head outputs side by side
+ output = output.transpose(1, 2).contiguous().view(batch_size, sequence_length, embedding_dim)
+
+ # output projection
+ output = self.resid_drop(self.proj(output))
+
+ outputs = (output, present)
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ return outputs
+
+
+class Block(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.ln1 = nn.LayerNorm(config.n_embd)
+ self.ln2 = nn.LayerNorm(config.n_embd)
+ self.attn = CausalSelfAttention(config)
+
+ # MLP
+ self.l1 = nn.Linear(config.n_embd, 4 * config.n_embd)
+ self.act = nn.GELU()
+ self.l2 = nn.Linear(4 * config.n_embd, config.n_embd)
+ self.drop = nn.Dropout(config.resid_pdrop)
+
+ def forward(
+ self,
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
+ use_cache: Optional[bool] = False,
+ output_attentions: Optional[bool] = False,
+ ):
+ residual = hidden_states
+ hidden_states = self.ln1(hidden_states)
+
+ attn_outputs = self.attn(
+ hidden_states, layer_past=layer_past, use_cache=use_cache, output_attentions=output_attentions
+ )
+ attn_output = attn_outputs[0]
+ outputs = attn_outputs[1:]
+ hidden_states = attn_output + residual
+
+ residual = hidden_states
+ hidden_states = self.ln2(hidden_states)
+ hidden_states = self.l1(hidden_states)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.l2(hidden_states)
+ hidden_states = residual + self.drop(hidden_states)
+
+ if use_cache:
+ outputs = (hidden_states,) + outputs
+ else:
+ outputs = (hidden_states,) + outputs[1:]
+
+ return outputs
+
+
+@add_start_docstrings(
+ "The bare TrajectoryTransformer Model transformer outputting raw hidden-states without any specific head on top.",
+ TRAJECTORY_TRANSFORMER_START_DOCSTRING,
+)
+class TrajectoryTransformerModel(TrajectoryTransformerPreTrainedModel):
+ """the full GPT language model, with a context size of block_size"""
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ # input embedding stem (+1 for stop token)
+ self.tok_emb = nn.Embedding(config.vocab_size * config.transition_dim + 1, config.n_embd)
+
+ self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
+ self.drop = nn.Dropout(config.embd_pdrop)
+ # transformer
+ self.blocks = nn.ModuleList([Block(config) for _ in range(config.n_layer)])
+ # decoder head
+ self.ln_f = nn.LayerNorm(config.n_embd)
+ self.head = EinLinear(config.transition_dim, config.n_embd, config.vocab_size + 1, bias=False)
+
+ self.vocab_size = config.vocab_size
+ self.stop_token = config.vocab_size * config.transition_dim
+ self.block_size = config.block_size
+
+ self.observation_dim = config.observation_dim
+ self.action_dim = config.action_dim
+ self.transition_dim = config.transition_dim
+ self.embedding_dim = config.n_embd
+
+ self.action_weight = config.action_weight
+ self.reward_weight = config.reward_weight
+ self.value_weight = config.value_weight
+
+ self.gradient_checkpointing = False
+
+ self.post_init()
+
+ def get_block_size(self):
+ return self.block_size
+
+ def offset_tokens(self, trajectories):
+ _, sequence_length = trajectories.shape
+
+ n_states = int(np.ceil(sequence_length / self.transition_dim))
+
+ offsets = torch.arange(self.transition_dim) * self.vocab_size
+ offsets = offsets.repeat(n_states).to(trajectories.device)
+
+ offset_trajectories = trajectories + offsets[:sequence_length]
+ offset_trajectories[trajectories == self.vocab_size] = self.stop_token
+ return offset_trajectories
+
+ def pad_to_full_observation(self, hidden_states):
+ batch_size, sequence_length, _ = hidden_states.shape
+
+ n_pad = (self.transition_dim - sequence_length % self.transition_dim) % self.transition_dim
+ padding = torch.zeros(batch_size, n_pad, self.embedding_dim, device=hidden_states.device)
+
+ # [ batch_size x padded_sequence_length' x embedding_dim ]
+ hidden_states_pad = torch.cat([hidden_states, padding], dim=1)
+ hidden_states_pad = hidden_states_pad.view(-1, self.transition_dim, self.embedding_dim)
+
+ return hidden_states_pad, n_pad
+
+ @add_start_docstrings_to_model_forward(
+ TRAJECTORY_TRANSFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")
+ )
+ @replace_return_docstrings(output_type=TrajectoryTransformerOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ trajectories: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
+ targets: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], TrajectoryTransformerOutput]:
+ r"""
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import TrajectoryTransformerModel
+ >>> import torch
+
+ >>> model = TrajectoryTransformerModel.from_pretrained(
+ ... "CarlCochet/trajectory-transformer-halfcheetah-medium-v2"
+ ... )
+ >>> model.to(device)
+ >>> model.eval()
+
+ >>> observations_dim, action_dim, batch_size = 17, 6, 256
+ >>> seq_length = observations_dim + action_dim + 1
+
+ >>> trajectories = torch.LongTensor([np.random.permutation(self.seq_length) for _ in range(batch_size)]).to(
+ ... device
+ ... )
+ >>> targets = torch.LongTensor([np.random.permutation(self.seq_length) for _ in range(batch_size)]).to(device)
+
+ >>> outputs = model(
+ ... trajectories,
+ ... targets=targets,
+ ... use_cache=True,
+ ... output_attentions=True,
+ ... output_hidden_states=True,
+ ... return_dict=True,
+ ... )
+ ```
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ if past_key_values is None:
+ past_key_values = tuple([None] * len(self.blocks))
+
+ batch_size, sequence_length = trajectories.size()
+
+ if sequence_length > self.block_size:
+ raise ValueError("Cannot forward, model block size is exhausted.")
+
+ offset_trajectories = self.offset_tokens(trajectories)
+ # [ batch_size x sequence_length x embedding_dim ]
+ # forward the GPT model
+ token_embeddings = self.tok_emb(offset_trajectories) # each index maps to a (learnable) vector
+ position_embeddings = self.pos_emb[:, :sequence_length, :] # each position maps to a (learnable) vector
+
+ hidden_states = self.drop(token_embeddings + position_embeddings)
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ presents = () if use_cache else None
+ all_self_attentions = () if output_attentions else None
+ all_hidden_states = () if output_hidden_states else None
+
+ for i, (block, layer_past) in enumerate(zip(self.blocks, past_key_values)):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+ outputs = self._gradient_checkpointing_func(
+ block.__call__,
+ hidden_states,
+ layer_past,
+ use_cache,
+ output_attentions,
+ )
+ else:
+ outputs = block(hidden_states, layer_past, use_cache, output_attentions)
+
+ hidden_states = outputs[0]
+ if use_cache is True:
+ presents = presents + (outputs[1],)
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
+
+ # [ batch_size x sequence_length x embedding_dim ]
+ hidden_state = self.ln_f(hidden_states)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ hidden_states_pad, n_pad = self.pad_to_full_observation(hidden_state)
+
+ logits = self.head(hidden_states_pad)
+ logits = logits.reshape(batch_size, sequence_length + n_pad, self.vocab_size + 1)
+ logits = logits[:, :sequence_length]
+
+ # if we are given some desired targets also calculate the loss
+ if targets is not None:
+ loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.view(-1), reduction="none")
+ if self.action_weight != 1 or self.reward_weight != 1 or self.value_weight != 1:
+ # make weights
+ n_states = int(np.ceil(sequence_length / self.transition_dim))
+ weights = torch.cat(
+ [
+ torch.ones(self.observation_dim, device=trajectories.device),
+ torch.ones(self.action_dim, device=trajectories.device) * self.action_weight,
+ torch.ones(1, device=trajectories.device) * self.reward_weight,
+ torch.ones(1, device=trajectories.device) * self.value_weight,
+ ]
+ )
+ weights = weights.repeat(n_states)
+ weights = weights[1:].repeat(batch_size, 1)
+ loss = loss * weights.view(-1)
+ loss = (loss * attention_mask.view(-1)).mean()
+ else:
+ loss = None
+
+ if not return_dict:
+ return tuple(v for v in [loss, logits, presents, all_hidden_states, all_self_attentions] if v is not None)
+
+ return TrajectoryTransformerOutput(
+ loss=loss,
+ logits=logits,
+ past_key_values=presents,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+
+__all__ = [
+ "TrajectoryTransformerModel",
+ "TrajectoryTransformerPreTrainedModel",
+ "load_tf_weights_in_trajectory_transformer",
+]
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/transfo_xl/__init__.py b/docs/transformers/build/lib/transformers/models/deprecated/transfo_xl/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ac9a2cbf4766bd187d90fdaf46505db3e92b68f
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/transfo_xl/__init__.py
@@ -0,0 +1,29 @@
+# Copyright 2020 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ....utils import _LazyModule
+from ....utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_transfo_xl import *
+ from .modeling_tf_transfo_xl import *
+ from .modeling_transfo_xl import *
+ from .tokenization_transfo_xl import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/transfo_xl/configuration_transfo_xl.py b/docs/transformers/build/lib/transformers/models/deprecated/transfo_xl/configuration_transfo_xl.py
new file mode 100644
index 0000000000000000000000000000000000000000..23972deae2ca9b31ce5adcf2aaf4f6e6cd4b7587
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/transfo_xl/configuration_transfo_xl.py
@@ -0,0 +1,189 @@
+# coding=utf-8
+# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Transformer XL configuration"""
+
+from ....configuration_utils import PretrainedConfig
+from ....utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class TransfoXLConfig(PretrainedConfig):
+ """
+ This is the configuration class to store the configuration of a [`TransfoXLModel`] or a [`TFTransfoXLModel`]. It is
+ used to instantiate a Transformer-XL model according to the specified arguments, defining the model architecture.
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the TransfoXL
+ [transfo-xl/transfo-xl-wt103](https://huggingface.co/transfo-xl/transfo-xl-wt103) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 267735):
+ Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`TransfoXLModel`] or [`TFTransfoXLModel`].
+ cutoffs (`List[int]`, *optional*, defaults to `[20000, 40000, 200000]`):
+ Cutoffs for the adaptive softmax.
+ d_model (`int`, *optional*, defaults to 1024):
+ Dimensionality of the model's hidden states.
+ d_embed (`int`, *optional*, defaults to 1024):
+ Dimensionality of the embeddings
+ n_head (`int`, *optional*, defaults to 16):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ d_head (`int`, *optional*, defaults to 64):
+ Dimensionality of the model's heads.
+ d_inner (`int`, *optional*, defaults to 4096):
+ Inner dimension in FF
+ div_val (`int`, *optional*, defaults to 4):
+ Divident value for adapative input and softmax
+ pre_lnorm (`boolean`, *optional*, defaults to `False`):
+ Whether or not to apply LayerNorm to the input instead of the output in the blocks.
+ n_layer (`int`, *optional*, defaults to 18):
+ Number of hidden layers in the Transformer encoder.
+ mem_len (`int`, *optional*, defaults to 1600):
+ Length of the retained previous heads.
+ clamp_len (`int`, *optional*, defaults to 1000):
+ Use the same pos embeddings after clamp_len.
+ same_length (`boolean`, *optional*, defaults to `True`):
+ Whether or not to use the same attn length for all tokens
+ proj_share_all_but_first (`boolean`, *optional*, defaults to `True`):
+ True to share all but first projs, False not to share.
+ attn_type (`int`, *optional*, defaults to 0):
+ Attention type. 0 for Transformer-XL, 1 for Shaw et al, 2 for Vaswani et al, 3 for Al Rfou et al.
+ sample_softmax (`int`, *optional*, defaults to -1):
+ Number of samples in the sampled softmax.
+ adaptive (`boolean`, *optional*, defaults to `True`):
+ Whether or not to use adaptive softmax.
+ dropout (`float`, *optional*, defaults to 0.1):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ dropatt (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ untie_r (`boolean`, *optional*, defaults to `True`):
+ Whether ot not to untie relative position biases.
+ init (`str`, *optional*, defaults to `"normal"`):
+ Parameter initializer to use.
+ init_range (`float`, *optional*, defaults to 0.01):
+ Parameters initialized by U(-init_range, init_range).
+ proj_init_std (`float`, *optional*, defaults to 0.01):
+ Parameters initialized by N(0, init_std)
+ init_std (`float`, *optional*, defaults to 0.02):
+ Parameters initialized by N(0, init_std)
+ layer_norm_epsilon (`float`, *optional*, defaults to 1e-05):
+ The epsilon to use in the layer normalization layers
+ eos_token_id (`int`, *optional*, defaults to 0):
+ End of stream token id.
+
+ Examples:
+
+ ```python
+ >>> from transformers import TransfoXLConfig, TransfoXLModel
+
+ >>> # Initializing a Transformer XL configuration
+ >>> configuration = TransfoXLConfig()
+
+ >>> # Initializing a model (with random weights) from the configuration
+ >>> model = TransfoXLModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "transfo-xl"
+ keys_to_ignore_at_inference = ["mems"]
+ attribute_map = {
+ "n_token": "vocab_size",
+ "hidden_size": "d_model",
+ "num_attention_heads": "n_head",
+ "num_hidden_layers": "n_layer",
+ }
+
+ def __init__(
+ self,
+ vocab_size=267735,
+ cutoffs=[20000, 40000, 200000],
+ d_model=1024,
+ d_embed=1024,
+ n_head=16,
+ d_head=64,
+ d_inner=4096,
+ div_val=4,
+ pre_lnorm=False,
+ n_layer=18,
+ mem_len=1600,
+ clamp_len=1000,
+ same_length=True,
+ proj_share_all_but_first=True,
+ attn_type=0,
+ sample_softmax=-1,
+ adaptive=True,
+ dropout=0.1,
+ dropatt=0.0,
+ untie_r=True,
+ init="normal",
+ init_range=0.01,
+ proj_init_std=0.01,
+ init_std=0.02,
+ layer_norm_epsilon=1e-5,
+ eos_token_id=0,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.cutoffs = []
+ self.cutoffs.extend(cutoffs)
+ if proj_share_all_but_first:
+ self.tie_projs = [False] + [True] * len(self.cutoffs)
+ else:
+ self.tie_projs = [False] + [False] * len(self.cutoffs)
+ self.d_model = d_model
+ self.d_embed = d_embed
+ self.d_head = d_head
+ self.d_inner = d_inner
+ self.div_val = div_val
+ self.pre_lnorm = pre_lnorm
+ self.n_layer = n_layer
+ self.n_head = n_head
+ self.mem_len = mem_len
+ self.same_length = same_length
+ self.attn_type = attn_type
+ self.clamp_len = clamp_len
+ self.sample_softmax = sample_softmax
+ self.adaptive = adaptive
+ self.dropout = dropout
+ self.dropatt = dropatt
+ self.untie_r = untie_r
+ self.init = init
+ self.init_range = init_range
+ self.proj_init_std = proj_init_std
+ self.init_std = init_std
+ self.layer_norm_epsilon = layer_norm_epsilon
+ super().__init__(eos_token_id=eos_token_id, **kwargs)
+
+ @property
+ def max_position_embeddings(self):
+ # Message copied from Transformer-XL documentation
+ logger.info(f"The model {self.model_type} is one of the few models that has no sequence length limit.")
+ return -1
+
+ @max_position_embeddings.setter
+ def max_position_embeddings(self, value):
+ # Message copied from Transformer-XL documentation
+ raise NotImplementedError(
+ f"The model {self.model_type} is one of the few models that has no sequence length limit."
+ )
+
+
+__all__ = ["TransfoXLConfig"]
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/transfo_xl/convert_transfo_xl_original_tf_checkpoint_to_pytorch.py b/docs/transformers/build/lib/transformers/models/deprecated/transfo_xl/convert_transfo_xl_original_tf_checkpoint_to_pytorch.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c7b687c4d98974ebf99790a201f7c3221a5498a
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/transfo_xl/convert_transfo_xl_original_tf_checkpoint_to_pytorch.py
@@ -0,0 +1,121 @@
+# coding=utf-8
+# Copyright 2018 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Convert Transformer XL checkpoint and datasets."""
+
+import argparse
+import os
+import pickle
+import sys
+
+import torch
+
+from transformers import TransfoXLConfig, TransfoXLLMHeadModel, load_tf_weights_in_transfo_xl
+from transformers.models.deprecated.transfo_xl import tokenization_transfo_xl as data_utils
+from transformers.models.deprecated.transfo_xl.tokenization_transfo_xl import CORPUS_NAME, VOCAB_FILES_NAMES
+from transformers.utils import CONFIG_NAME, WEIGHTS_NAME, logging
+
+
+logging.set_verbosity_info()
+
+# We do this to be able to load python 2 datasets pickles
+# See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918
+data_utils.Vocab = data_utils.TransfoXLTokenizer
+data_utils.Corpus = data_utils.TransfoXLCorpus
+sys.modules["data_utils"] = data_utils
+sys.modules["vocabulary"] = data_utils
+
+
+def convert_transfo_xl_checkpoint_to_pytorch(
+ tf_checkpoint_path, transfo_xl_config_file, pytorch_dump_folder_path, transfo_xl_dataset_file
+):
+ if transfo_xl_dataset_file:
+ # Convert a pre-processed corpus (see original TensorFlow repo)
+ with open(transfo_xl_dataset_file, "rb") as fp:
+ corpus = pickle.load(fp, encoding="latin1")
+ # Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term)
+ pytorch_vocab_dump_path = pytorch_dump_folder_path + "/" + VOCAB_FILES_NAMES["pretrained_vocab_file"]
+ print(f"Save vocabulary to {pytorch_vocab_dump_path}")
+ corpus_vocab_dict = corpus.vocab.__dict__
+ torch.save(corpus_vocab_dict, pytorch_vocab_dump_path)
+
+ corpus_dict_no_vocab = corpus.__dict__
+ corpus_dict_no_vocab.pop("vocab", None)
+ pytorch_dataset_dump_path = pytorch_dump_folder_path + "/" + CORPUS_NAME
+ print(f"Save dataset to {pytorch_dataset_dump_path}")
+ torch.save(corpus_dict_no_vocab, pytorch_dataset_dump_path)
+
+ if tf_checkpoint_path:
+ # Convert a pre-trained TensorFlow model
+ config_path = os.path.abspath(transfo_xl_config_file)
+ tf_path = os.path.abspath(tf_checkpoint_path)
+
+ print(f"Converting Transformer XL checkpoint from {tf_path} with config at {config_path}.")
+ # Initialise PyTorch model
+ if transfo_xl_config_file == "":
+ config = TransfoXLConfig()
+ else:
+ config = TransfoXLConfig.from_json_file(transfo_xl_config_file)
+ print(f"Building PyTorch model from configuration: {config}")
+ model = TransfoXLLMHeadModel(config)
+
+ model = load_tf_weights_in_transfo_xl(model, config, tf_path)
+ # Save pytorch-model
+ pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME)
+ pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path, CONFIG_NAME)
+ print(f"Save PyTorch model to {os.path.abspath(pytorch_weights_dump_path)}")
+ torch.save(model.state_dict(), pytorch_weights_dump_path)
+ print(f"Save configuration file to {os.path.abspath(pytorch_config_dump_path)}")
+ with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
+ f.write(config.to_json_string())
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--pytorch_dump_folder_path",
+ default=None,
+ type=str,
+ required=True,
+ help="Path to the folder to store the PyTorch model or dataset/vocab.",
+ )
+ parser.add_argument(
+ "--tf_checkpoint_path",
+ default="",
+ type=str,
+ help="An optional path to a TensorFlow checkpoint path to be converted.",
+ )
+ parser.add_argument(
+ "--transfo_xl_config_file",
+ default="",
+ type=str,
+ help=(
+ "An optional config json file corresponding to the pre-trained BERT model. \n"
+ "This specifies the model architecture."
+ ),
+ )
+ parser.add_argument(
+ "--transfo_xl_dataset_file",
+ default="",
+ type=str,
+ help="An optional dataset file to be converted in a vocabulary.\n"
+ "Given the files are in the pickle format, please be wary of passing it files you trust.",
+ )
+ args = parser.parse_args()
+ convert_transfo_xl_checkpoint_to_pytorch(
+ args.tf_checkpoint_path,
+ args.transfo_xl_config_file,
+ args.pytorch_dump_folder_path,
+ args.transfo_xl_dataset_file,
+ )
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/transfo_xl/modeling_tf_transfo_xl.py b/docs/transformers/build/lib/transformers/models/deprecated/transfo_xl/modeling_tf_transfo_xl.py
new file mode 100644
index 0000000000000000000000000000000000000000..90e2ebc3436db35a002120601ff60b66169318bb
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/transfo_xl/modeling_tf_transfo_xl.py
@@ -0,0 +1,1129 @@
+# coding=utf-8
+# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+TF 2.0 Transformer XL model.
+"""
+
+from __future__ import annotations
+
+from dataclasses import dataclass
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import tensorflow as tf
+
+from ....modeling_tf_utils import (
+ TFModelInputType,
+ TFPreTrainedModel,
+ TFSequenceClassificationLoss,
+ get_initializer,
+ keras,
+ keras_serializable,
+ unpack_inputs,
+)
+from ....tf_utils import shape_list, stable_softmax
+from ....utils import (
+ ModelOutput,
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+)
+from .configuration_transfo_xl import TransfoXLConfig
+from .modeling_tf_transfo_xl_utilities import TFAdaptiveSoftmaxMask
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "transfo-xl/transfo-xl-wt103"
+_CONFIG_FOR_DOC = "TransfoXLConfig"
+
+
+class TFPositionalEmbedding(keras.layers.Layer):
+ def __init__(self, demb, **kwargs):
+ super().__init__(**kwargs)
+
+ self.inv_freq = 1 / (10000 ** (tf.range(0, demb, 2.0) / demb))
+
+ def call(self, pos_seq, bsz=None):
+ self.inv_freq = tf.cast(self.inv_freq, dtype=pos_seq.dtype)
+ sinusoid_inp = tf.einsum("i,j->ij", pos_seq, self.inv_freq)
+ pos_emb = tf.concat([tf.sin(sinusoid_inp), tf.cos(sinusoid_inp)], -1)
+
+ if bsz is not None:
+ return tf.tile(pos_emb[:, None, :], [1, bsz, 1])
+ else:
+ return pos_emb[:, None, :]
+
+
+class TFPositionwiseFF(keras.layers.Layer):
+ def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, layer_norm_epsilon=1e-5, init_std=0.02, **kwargs):
+ super().__init__(**kwargs)
+
+ self.d_model = d_model
+ self.d_inner = d_inner
+ self.dropout = dropout
+
+ self.layer_1 = keras.layers.Dense(
+ d_inner, kernel_initializer=get_initializer(init_std), activation=tf.nn.relu, name="CoreNet_._0"
+ )
+ self.drop_1 = keras.layers.Dropout(dropout)
+ self.layer_2 = keras.layers.Dense(d_model, kernel_initializer=get_initializer(init_std), name="CoreNet_._3")
+ self.drop_2 = keras.layers.Dropout(dropout)
+
+ self.layer_norm = keras.layers.LayerNormalization(epsilon=layer_norm_epsilon, name="layer_norm")
+
+ self.pre_lnorm = pre_lnorm
+
+ def call(self, inp, training=False):
+ if self.pre_lnorm:
+ # layer normalization + positionwise feed-forward
+ core_out = self.layer_norm(inp)
+ core_out = self.layer_1(core_out)
+ core_out = self.drop_1(core_out, training=training)
+ core_out = self.layer_2(core_out)
+ core_out = self.drop_2(core_out, training=training)
+
+ # residual connection
+ output = core_out + inp
+ else:
+ # positionwise feed-forward
+ core_out = self.layer_1(inp)
+ core_out = self.drop_1(core_out, training=training)
+ core_out = self.layer_2(core_out)
+ core_out = self.drop_2(core_out, training=training)
+
+ # residual connection + layer normalization
+ output = self.layer_norm(inp + core_out)
+
+ return output
+
+
+class TFRelPartialLearnableMultiHeadAttn(keras.layers.Layer):
+ def __init__(
+ self,
+ n_head,
+ d_model,
+ d_head,
+ dropout,
+ dropatt=0.0,
+ pre_lnorm=False,
+ r_r_bias=None,
+ r_w_bias=None,
+ layer_norm_epsilon=1e-5,
+ init_std=0.02,
+ output_attentions=False,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.n_head = n_head
+ self.d_model = d_model
+ self.d_head = d_head
+ self.dropout = dropout
+ self.output_attentions = output_attentions
+
+ self.qkv_net = keras.layers.Dense(
+ 3 * n_head * d_head, kernel_initializer=get_initializer(init_std), use_bias=False, name="qkv_net"
+ )
+
+ self.drop = keras.layers.Dropout(dropout)
+ self.dropatt = keras.layers.Dropout(dropatt)
+ self.o_net = keras.layers.Dense(
+ d_model, kernel_initializer=get_initializer(init_std), use_bias=False, name="o_net"
+ )
+
+ self.layer_norm = keras.layers.LayerNormalization(epsilon=layer_norm_epsilon, name="layer_norm")
+
+ self.scale = 1 / (d_head**0.5)
+
+ self.pre_lnorm = pre_lnorm
+
+ if r_r_bias is not None and r_w_bias is not None: # Biases are shared
+ self.r_r_bias = r_r_bias
+ self.r_w_bias = r_w_bias
+ else:
+ self.r_r_bias = None
+ self.r_w_bias = None
+
+ self.r_net = keras.layers.Dense(
+ self.n_head * self.d_head, kernel_initializer=get_initializer(init_std), use_bias=False, name="r_net"
+ )
+
+ def build(self, input_shape):
+ if self.r_r_bias is None or self.r_w_bias is None: # Biases are not shared
+ self.r_r_bias = self.add_weight(
+ shape=(self.n_head, self.d_head), initializer="zeros", trainable=True, name="r_r_bias"
+ )
+ self.r_w_bias = self.add_weight(
+ shape=(self.n_head, self.d_head), initializer="zeros", trainable=True, name="r_w_bias"
+ )
+ super().build(input_shape)
+
+ def _rel_shift(self, x):
+ x_size = shape_list(x)
+
+ x = tf.pad(x, [[0, 0], [1, 0], [0, 0], [0, 0]])
+ x = tf.reshape(x, [x_size[1] + 1, x_size[0], x_size[2], x_size[3]])
+ x = tf.slice(x, [1, 0, 0, 0], [-1, -1, -1, -1])
+ x = tf.reshape(x, x_size)
+
+ return x
+
+ def call(self, w, r, attn_mask, mems, head_mask, output_attentions, training=False):
+ qlen, rlen, bsz = shape_list(w)[0], shape_list(r)[0], shape_list(w)[1]
+
+ if mems is not None:
+ mems = tf.cast(mems, dtype=w.dtype)
+ cat = tf.concat([mems, w], 0)
+ if self.pre_lnorm:
+ w_heads = self.qkv_net(self.layer_norm(cat))
+ else:
+ w_heads = self.qkv_net(cat)
+ r_head_k = self.r_net(r)
+
+ w_head_q, w_head_k, w_head_v = tf.split(w_heads, 3, axis=-1)
+ w_head_q = w_head_q[-qlen:]
+ else:
+ if self.pre_lnorm:
+ w_heads = self.qkv_net(self.layer_norm(w))
+ else:
+ w_heads = self.qkv_net(w)
+ r_head_k = self.r_net(r)
+
+ w_head_q, w_head_k, w_head_v = tf.split(w_heads, 3, axis=-1)
+
+ klen = shape_list(w_head_k)[0]
+
+ w_head_q = tf.reshape(w_head_q, (qlen, bsz, self.n_head, self.d_head)) # qlen x bsz x n_head x d_head
+ w_head_k = tf.reshape(w_head_k, (klen, bsz, self.n_head, self.d_head)) # qlen x bsz x n_head x d_head
+ w_head_v = tf.reshape(w_head_v, (klen, bsz, self.n_head, self.d_head)) # qlen x bsz x n_head x d_head
+
+ r_head_k = tf.reshape(r_head_k, (rlen, self.n_head, self.d_head)) # qlen x n_head x d_head
+
+ # compute attention score
+ rw_head_q = w_head_q + self.r_w_bias # qlen x bsz x n_head x d_head
+ AC = tf.einsum("ibnd,jbnd->ijbn", rw_head_q, w_head_k) # qlen x klen x bsz x n_head
+
+ rr_head_q = w_head_q + self.r_r_bias
+ BD = tf.einsum("ibnd,jnd->ijbn", rr_head_q, r_head_k) # qlen x klen x bsz x n_head
+ BD = self._rel_shift(BD)
+
+ # [qlen x klen x bsz x n_head]
+ attn_score = AC + BD
+ attn_score = attn_score * self.scale
+
+ # compute attention probability
+ if attn_mask is not None:
+ attn_mask_t = attn_mask[:, :, None, None]
+ attn_mask_t = tf.cast(attn_mask_t, dtype=attn_score.dtype)
+ attn_score = attn_score * (1.0 - attn_mask_t) - 1e30 * attn_mask_t
+
+ # [qlen x klen x bsz x n_head]
+ attn_prob = stable_softmax(attn_score, axis=1)
+ attn_prob = self.dropatt(attn_prob, training=training)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attn_prob = attn_prob * head_mask
+
+ # compute attention vector
+ attn_vec = tf.einsum("ijbn,jbnd->ibnd", attn_prob, w_head_v)
+
+ # [qlen x bsz x n_head x d_head]
+ attn_vec_sizes = shape_list(attn_vec)
+ attn_vec = tf.reshape(attn_vec, (attn_vec_sizes[0], attn_vec_sizes[1], self.n_head * self.d_head))
+
+ # linear projection
+ attn_out = self.o_net(attn_vec)
+ attn_out = self.drop(attn_out, training=training)
+
+ if self.pre_lnorm:
+ # residual connection
+ outputs = [w + attn_out]
+ else:
+ # residual connection + layer normalization
+ outputs = [self.layer_norm(w + attn_out)]
+
+ if output_attentions:
+ outputs.append(attn_prob)
+
+ return outputs
+
+
+class TFRelPartialLearnableDecoderLayer(keras.layers.Layer):
+ def __init__(
+ self,
+ n_head,
+ d_model,
+ d_head,
+ d_inner,
+ dropout,
+ dropatt=0.0,
+ pre_lnorm=False,
+ r_w_bias=None,
+ r_r_bias=None,
+ layer_norm_epsilon=1e-5,
+ init_std=0.02,
+ output_attentions=False,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.dec_attn = TFRelPartialLearnableMultiHeadAttn(
+ n_head,
+ d_model,
+ d_head,
+ dropout,
+ dropatt=dropatt,
+ pre_lnorm=pre_lnorm,
+ r_w_bias=r_w_bias,
+ r_r_bias=r_r_bias,
+ init_std=init_std,
+ layer_norm_epsilon=layer_norm_epsilon,
+ output_attentions=output_attentions,
+ name="dec_attn",
+ )
+ self.pos_ff = TFPositionwiseFF(
+ d_model,
+ d_inner,
+ dropout,
+ pre_lnorm=pre_lnorm,
+ init_std=init_std,
+ layer_norm_epsilon=layer_norm_epsilon,
+ name="pos_ff",
+ )
+
+ def call(self, dec_inp, r, dec_attn_mask, mems, head_mask, output_attentions, training=False):
+ attn_outputs = self.dec_attn(dec_inp, r, dec_attn_mask, mems, head_mask, output_attentions, training=training)
+ ff_output = self.pos_ff(attn_outputs[0], training=training)
+
+ outputs = [ff_output] + attn_outputs[1:]
+
+ return outputs
+
+
+class TFTransfoEmbeddings(keras.layers.Layer):
+ def __init__(self, vocab_size, emb_size, init_std, **kwargs):
+ super().__init__(**kwargs)
+
+ self.vocab_size = vocab_size
+ self.emb_size = emb_size
+ self.init_std = init_std
+
+ def build(self, input_shape):
+ self.weight = self.add_weight(
+ shape=(self.vocab_size, self.emb_size),
+ initializer=get_initializer(self.init_std),
+ name="embeddings",
+ )
+
+ super().build(input_shape)
+
+ def call(self, inputs):
+ return tf.gather(self.weight, inputs)
+
+
+class TFAdaptiveEmbedding(keras.layers.Layer):
+ def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, init_std=0.02, sample_softmax=False, **kwargs):
+ super().__init__(**kwargs)
+
+ self.n_token = n_token
+ self.d_embed = d_embed
+ self.init_std = init_std
+
+ self.cutoffs = cutoffs + [n_token]
+ self.div_val = div_val
+ self.d_proj = d_proj
+
+ self.emb_scale = d_proj**0.5
+
+ self.cutoff_ends = [0] + self.cutoffs
+
+ self.emb_layers = []
+ self.emb_projs = []
+
+ if div_val == 1:
+ raise NotImplementedError # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint
+ else:
+ for i in range(len(self.cutoffs)):
+ l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
+ d_emb_i = d_embed // (div_val**i)
+ self.emb_layers.append(
+ TFTransfoEmbeddings(
+ r_idx - l_idx,
+ d_emb_i,
+ init_std,
+ name=f"emb_layers_._{i}",
+ )
+ )
+
+ def build(self, input_shape):
+ for i in range(len(self.cutoffs)):
+ d_emb_i = self.d_embed // (self.div_val**i)
+ self.emb_projs.append(
+ self.add_weight(
+ shape=(d_emb_i, self.d_proj),
+ initializer=get_initializer(self.init_std),
+ trainable=True,
+ name=f"emb_projs_._{i}",
+ )
+ )
+
+ super().build(input_shape)
+
+ def call(self, inp):
+ if self.div_val == 1:
+ raise NotImplementedError # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint
+ else:
+ inp_flat = tf.reshape(inp, (-1,))
+ emb_flat = tf.zeros([shape_list(inp_flat)[0], self.d_proj])
+ for i in range(len(self.cutoffs)):
+ l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
+
+ mask_i = (inp_flat >= l_idx) & (inp_flat < r_idx)
+
+ inp_i = tf.boolean_mask(inp_flat, mask_i) - l_idx
+ emb_i = self.emb_layers[i](inp_i)
+ emb_i = tf.einsum("id,de->ie", emb_i, self.emb_projs[i])
+
+ mask_idx = tf.where(mask_i)
+ scatter = tf.scatter_nd(mask_idx, emb_i, shape_list(emb_flat))
+ emb_flat = tf.cast(emb_flat, dtype=scatter.dtype)
+ emb_flat += scatter
+
+ embed_shape = shape_list(inp) + [self.d_proj]
+ embed = tf.reshape(emb_flat, embed_shape)
+
+ embed *= self.emb_scale
+
+ return embed
+
+
+@keras_serializable
+class TFTransfoXLMainLayer(keras.layers.Layer):
+ config_class = TransfoXLConfig
+
+ def __init__(self, config, **kwargs):
+ super().__init__(**kwargs)
+
+ self.config = config
+ self.output_hidden_states = config.output_hidden_states
+ self.output_attentions = config.output_attentions
+ self.return_dict = config.use_return_dict
+
+ self.n_token = config.vocab_size
+
+ self.d_embed = config.d_embed
+ self.d_model = config.d_model
+ self.n_head = config.n_head
+ self.d_head = config.d_head
+ self.untie_r = config.untie_r
+
+ self.word_emb = TFAdaptiveEmbedding(
+ config.vocab_size,
+ config.d_embed,
+ config.d_model,
+ config.cutoffs,
+ div_val=config.div_val,
+ init_std=config.init_std,
+ name="word_emb",
+ )
+
+ self.drop = keras.layers.Dropout(config.dropout)
+
+ self.n_layer = config.n_layer
+ self.mem_len = config.mem_len
+ self.attn_type = config.attn_type
+
+ self.layers = []
+ if config.attn_type == 0: # the default attention
+ for i in range(config.n_layer):
+ self.layers.append(
+ TFRelPartialLearnableDecoderLayer(
+ config.n_head,
+ config.d_model,
+ config.d_head,
+ config.d_inner,
+ config.dropout,
+ dropatt=config.dropatt,
+ pre_lnorm=config.pre_lnorm,
+ r_w_bias=None if self.untie_r else self.r_w_bias,
+ r_r_bias=None if self.untie_r else self.r_r_bias,
+ layer_norm_epsilon=config.layer_norm_epsilon,
+ init_std=config.init_std,
+ output_attentions=self.output_attentions,
+ name=f"layers_._{i}",
+ )
+ )
+ else: # learnable embeddings and absolute embeddings
+ raise NotImplementedError # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint
+
+ self.same_length = config.same_length
+ self.clamp_len = config.clamp_len
+
+ if self.attn_type == 0: # default attention
+ self.pos_emb = TFPositionalEmbedding(self.d_model, name="pos_emb")
+ else: # learnable embeddings and absolute embeddings
+ raise NotImplementedError # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint
+
+ def build(self, input_shape):
+ if not self.untie_r:
+ self.r_w_bias = self.add_weight(
+ shape=(self.n_head, self.d_head), initializer="zeros", trainable=True, name="r_w_bias"
+ )
+ self.r_r_bias = self.add_weight(
+ shape=(self.n_head, self.d_head), initializer="zeros", trainable=True, name="r_r_bias"
+ )
+ super().build(input_shape)
+
+ def get_input_embeddings(self):
+ return self.word_emb
+
+ def set_input_embeddings(self, value):
+ raise NotImplementedError
+
+ def backward_compatible(self):
+ self.sample_softmax = -1
+
+ def reset_memory_length(self, mem_len):
+ self.mem_len = mem_len
+
+ def _prune_heads(self, heads):
+ raise NotImplementedError
+
+ def init_mems(self, bsz):
+ if self.mem_len > 0:
+ mems = []
+ for i in range(self.n_layer):
+ empty = tf.zeros([self.mem_len, bsz, self.d_model])
+ mems.append(empty)
+
+ return mems
+ else:
+ return None
+
+ def _update_mems(self, hids, mems, mlen, qlen):
+ # does not deal with None
+ if mems is None:
+ return None
+
+ # mems is not None
+ assert len(hids) == len(mems), "len(hids) != len(mems)"
+
+ # There are `mlen + qlen` steps that can be cached into mems
+ new_mems = []
+ end_idx = mlen + tf.math.maximum(0, qlen)
+ beg_idx = tf.math.maximum(0, end_idx - tf.convert_to_tensor(self.mem_len))
+ for i in range(len(hids)):
+ mems[i] = tf.cast(mems[i], dtype=hids[i].dtype)
+ cat = tf.concat([mems[i], hids[i]], axis=0)
+ tf.stop_gradient(cat)
+ new_mems.append(cat[beg_idx:end_idx])
+
+ return new_mems
+
+ @unpack_inputs
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ mems: List[tf.Tensor] | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ labels: np.ndarray | tf.Tensor | None = None,
+ training: bool = False,
+ ):
+ # the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library
+ # so we transpose here from shape [bsz, len] to shape [len, bsz]
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_ids = tf.transpose(input_ids, perm=(1, 0))
+ qlen, bsz = shape_list(input_ids)
+ elif inputs_embeds is not None:
+ inputs_embeds = tf.transpose(inputs_embeds, perm=(1, 0, 2))
+ qlen, bsz = shape_list(inputs_embeds)[:2]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ if mems is None:
+ mems = self.init_mems(bsz)
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] (a head_mask for each layer)
+ # and head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head]
+ if head_mask is not None:
+ raise NotImplementedError
+ else:
+ head_mask = [None] * self.n_layer
+
+ if inputs_embeds is not None:
+ word_emb = inputs_embeds
+ else:
+ word_emb = self.word_emb(input_ids)
+
+ mlen = shape_list(mems[0])[0] if mems is not None else 0
+ klen = mlen + qlen
+
+ # Compute decoder attention mask
+ all_ones = tf.ones([qlen, klen], dtype=tf.int32)
+ upper_mask = 1 - tf.linalg.band_part(tf.ones([qlen, klen], dtype=tf.int32), -1, mlen)
+ if self.same_length:
+ mask_len = klen - self.mem_len
+ mask_shift_len = qlen - tf.nn.relu(mask_len) # Lazy clamping of negatives to zero
+
+ # Use an indicator variable instead of a conditional to keep the compiler happy
+ lower_mask = tf.linalg.band_part(all_ones, -1, 0) - (
+ tf.linalg.band_part(all_ones, mask_shift_len - 1, 0) * tf.cast(mask_shift_len != 0, tf.int32)
+ )
+ dec_attn_mask = upper_mask + lower_mask
+ else:
+ dec_attn_mask = upper_mask
+
+ hids = []
+ attentions = [] if output_attentions else None
+ if self.attn_type == 0: # default
+ pos_seq = tf.range(klen - 1, -1, -1.0)
+ if self.clamp_len > 0:
+ pos_seq = tf.minimum(pos_seq, self.clamp_len)
+ pos_emb = self.pos_emb(pos_seq)
+
+ core_out = self.drop(word_emb, training=training)
+ pos_emb = self.drop(pos_emb, training=training)
+
+ for i, layer in enumerate(self.layers):
+ hids.append(core_out)
+ mems_i = None if mems is None else mems[i]
+ layer_outputs = layer(
+ core_out,
+ pos_emb,
+ dec_attn_mask,
+ mems_i,
+ head_mask[i],
+ output_attentions,
+ training=training,
+ )
+ core_out = layer_outputs[0]
+ if output_attentions:
+ attentions.append(layer_outputs[1])
+ else: # learnable embeddings and absolute embeddings
+ raise NotImplementedError # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint
+
+ core_out = self.drop(core_out, training=training)
+
+ new_mems = self._update_mems(hids, mems, mlen, qlen)
+
+ # We transpose back here to shape [bsz, len, hidden_dim]
+ core_out = tf.transpose(core_out, perm=(1, 0, 2))
+
+ if output_hidden_states:
+ # Transpose to library standard shape [bsz, len, hidden_dim] and add last layer
+ hids = tuple(tf.transpose(t, perm=(1, 0, 2)) for t in hids)
+ hids = hids + (core_out,)
+ else:
+ hids = None
+ if output_attentions:
+ # Transpose to library standard shape [bsz, n_heads, query_seq_len, key_seq_len]
+ attentions = tuple(tf.transpose(t, perm=(2, 3, 0, 1)) for t in attentions)
+
+ if not return_dict:
+ return tuple(v for v in [core_out, new_mems, hids, attentions] if v is not None)
+
+ return TFTransfoXLModelOutput(
+ last_hidden_state=core_out,
+ mems=new_mems,
+ hidden_states=hids,
+ attentions=attentions,
+ )
+
+
+class TFTransfoXLPreTrainedModel(TFPreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = TransfoXLConfig
+ base_model_prefix = "transformer"
+
+
+@dataclass
+class TFTransfoXLModelOutput(ModelOutput):
+ """
+ Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
+
+ Args:
+ last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ mems (`List[tf.Tensor]` of length `config.n_layers`):
+ Contains pre-computed hidden-states (key and values in the attention blocks). Can be used (see `mems`
+ input) to speed up sequential decoding. The token ids which have their past given to this model should not
+ be passed as input ids as they have already been computed.
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ last_hidden_state: Optional[tf.Tensor] = None
+ mems: List[tf.Tensor] = None
+ hidden_states: Tuple[tf.Tensor] | None = None
+ attentions: Tuple[tf.Tensor] | None = None
+
+
+@dataclass
+class TFTransfoXLLMHeadModelOutput(ModelOutput):
+ """
+ Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
+
+ Args:
+ losses (`tf.Tensor` of shape *(batch_size, sequence_length-1)*, *optional*, returned when `labels` is provided):
+ Language modeling losses (not reduced).
+ prediction_scores (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token after SoftMax).
+ mems (`List[tf.Tensor]` of length `config.n_layers`):
+ Contains pre-computed hidden-states (key and values in the attention blocks). Can be used (see `mems`
+ input) to speed up sequential decoding. The token ids which have their past given to this model should not
+ be passed as input ids as they have already been computed.
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ prediction_scores: Optional[tf.Tensor] = None
+ mems: List[tf.Tensor] = None
+ hidden_states: Tuple[tf.Tensor] | None = None
+ attentions: Tuple[tf.Tensor] | None = None
+
+
+@dataclass
+class TFTransfoXLSequenceClassifierOutputWithPast(ModelOutput):
+ """
+ Base class for outputs of sentence classification models.
+
+ Args:
+ loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Classification (or regression if config.num_labels==1) loss.
+ logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`):
+ Classification (or regression if config.num_labels==1) scores (before SoftMax).
+ mems (`List[tf.Tensor]` of length `config.n_layers`):
+ Contains pre-computed hidden-states (key and values in the attention blocks). Can be used (see `mems`
+ input) to speed up sequential decoding. The token ids which have their past given to this model should not
+ be passed as input ids as they have already been computed.
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ loss: tf.Tensor | None = None
+ logits: Optional[tf.Tensor] = None
+ mems: List[tf.Tensor] = None
+ hidden_states: Tuple[tf.Tensor] | None = None
+ attentions: Tuple[tf.Tensor] | None = None
+
+
+TRANSFO_XL_START_DOCSTRING = r"""
+
+ This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
+ as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
+ behavior.
+
+
+
+ TensorFlow models and layers in `transformers` accept two formats as input:
+
+ - having all inputs as keyword arguments (like PyTorch models), or
+ - having all inputs as a list, tuple or dict in the first positional argument.
+
+ The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
+ and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
+ pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
+ format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
+ the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
+ positional argument:
+
+ - a single Tensor with `input_ids` only and nothing else: `model(input_ids)`
+ - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
+ `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
+ - a dictionary with one or several input Tensors associated to the input names given in the docstring:
+ `model({"input_ids": input_ids, "token_type_ids": token_type_ids})`
+
+ Note that when creating models and layers with
+ [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
+ about any of this, as you can just pass inputs like you would to any other Python function!
+
+
+
+ Parameters:
+ config ([`TransfoXLConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+TRANSFO_XL_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and
+ [`PreTrainedTokenizer.encode`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ mems (`List[tf.Tensor]` of length `config.n_layers`):
+ Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model (see
+ `mems` output below). Can be used to speed up sequential decoding. The token ids which have their mems
+ given to this model should not be passed as `input_ids` as they have already been computed.
+ head_mask (`tf.Tensor` or `Numpy array` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+ inputs_embeds (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
+ config will be used instead.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
+ used instead.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
+ eager mode, in graph mode the value will always be set to True.
+ training (`bool`, *optional*, defaults to `False`):
+ Whether or not to use the model in training mode (some modules like dropout modules have different
+ behaviors between training and evaluation).
+"""
+
+
+@add_start_docstrings(
+ "The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
+ TRANSFO_XL_START_DOCSTRING,
+)
+class TFTransfoXLModel(TFTransfoXLPreTrainedModel):
+ def __init__(self, config, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+ self.transformer = TFTransfoXLMainLayer(config, name="transformer")
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(TRANSFO_XL_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TFTransfoXLModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ mems: List[tf.Tensor] | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ training: bool = False,
+ ) -> TFTransfoXLModelOutput | Tuple[tf.Tensor]:
+ outputs = self.transformer(
+ input_ids=input_ids,
+ mems=mems,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ return outputs
+
+
+@add_start_docstrings(
+ """
+ The Transformer-XL Model with a language modeling head on top (adaptive softmax with weights tied to the adaptive
+ input embeddings)
+ """,
+ TRANSFO_XL_START_DOCSTRING,
+)
+class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.transformer = TFTransfoXLMainLayer(config, name="transformer")
+ self.sample_softmax = config.sample_softmax
+ assert self.sample_softmax <= 0, (
+ "Sampling from the softmax is not implemented yet. Please look at issue: #3310:"
+ " https://github.com/huggingface/transformers/issues/3310"
+ )
+
+ self.crit = TFAdaptiveSoftmaxMask(
+ config.vocab_size, config.d_embed, config.d_model, config.cutoffs, div_val=config.div_val, name="crit"
+ )
+
+ def _resize_token_embeddings(self, new_num_tokens):
+ raise NotImplementedError()
+
+ def get_output_embeddings(self):
+ """Double-check if you are using adaptive softmax."""
+ if len(self.crit.out_layers) > 0:
+ return self.crit.out_layers[-1]
+ return None
+
+ def reset_memory_length(self, mem_len):
+ self.transformer.reset_memory_length(mem_len)
+
+ def init_mems(self, bsz):
+ return self.transformer.init_mems(bsz)
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(TRANSFO_XL_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TFTransfoXLLMHeadModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ mems: List[tf.Tensor] | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ labels: np.ndarray | tf.Tensor | None = None,
+ training: bool = False,
+ ) -> TFTransfoXLLMHeadModelOutput | Tuple[tf.Tensor]:
+ if input_ids is not None:
+ bsz, tgt_len = shape_list(input_ids)[:2]
+ else:
+ bsz, tgt_len = shape_list(inputs_embeds)[:2]
+
+ transformer_outputs = self.transformer(
+ input_ids,
+ mems,
+ head_mask,
+ inputs_embeds,
+ output_attentions,
+ output_hidden_states,
+ return_dict,
+ training=training,
+ )
+
+ last_hidden = transformer_outputs[0]
+ pred_hid = last_hidden[:, -tgt_len:]
+
+ softmax_output = self.crit(pred_hid, labels, training=training)
+ prediction_scores = softmax_output if labels is None else ()
+
+ if not return_dict:
+ return (prediction_scores,) + transformer_outputs[1:]
+
+ return TFTransfoXLLMHeadModelOutput(
+ prediction_scores=prediction_scores,
+ mems=transformer_outputs.mems,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **model_kwargs):
+ inputs = {}
+
+ # if past is defined in model kwargs then use it for faster decoding
+ if past_key_values:
+ input_ids = tf.expand_dims(input_ids[:, -1], axis=-1)
+ else:
+ input_ids = input_ids
+
+ return inputs
+
+ # Adapted from the torch tie_weights function
+ def tf_to_pt_weight_rename(self, tf_weight):
+ if self.config.tie_word_embeddings and "crit.out_layers" in tf_weight:
+ return tf_weight, tf_weight.replace("crit.out_layers", "transformer.word_emb.emb_layers")
+ elif self.config.tie_projs and "crit.out_projs" in tf_weight:
+ for i, tie_proj in enumerate(self.config.tie_projs):
+ if tie_proj and self.config.div_val == 1 and self.config.d_model != self.config.d_embed:
+ # self.crit.out_projs[i] = self.transformer.word_emb.emb_projs[0]
+ return tf_weight, tf_weight.replace(f"crit.out_projs.{i}", "transformer.word_emb.emb_projs.0")
+ elif tie_proj and self.config.div_val != 1:
+ # self.crit.out_projs[i] = self.transformer.word_emb.emb_projs[i]
+ return tf_weight, tf_weight.replace("crit.out_projs", "transformer.word_emb.emb_projs")
+ else:
+ return (tf_weight,)
+
+
+@add_start_docstrings(
+ """
+ The Transfo XL Model transformer with a sequence classification head on top (linear layer).
+
+ [`TFTransfoXLForSequenceClassification`] uses the last token in order to do the classification, as other causal
+ models (e.g. GPT-1,GPT-2) do.
+
+ Since it does classification on the last token, it requires to know the position of the last token. If a
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
+ each row of the batch).
+ """,
+ TRANSFO_XL_START_DOCSTRING,
+)
+class TFTransfoXLForSequenceClassification(TFTransfoXLPreTrainedModel, TFSequenceClassificationLoss):
+ def __init__(self, config, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+ self.num_labels = config.num_labels
+ self.score = keras.layers.Dense(
+ config.num_labels,
+ kernel_initializer=get_initializer(config.init_range),
+ name="score",
+ use_bias=False,
+ )
+ self.transformer = TFTransfoXLMainLayer(config, name="transformer")
+
+ def get_output_embeddings(self):
+ # Remove after transformers v4.32. Fix this model's `test_model_common_attributes` test too.
+ logger.warning(
+ "Sequence classification models do not have output embeddings. `.get_output_embeddings` will be removed "
+ "in transformers v4.32."
+ )
+ return self.transformer.word_emb
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(TRANSFO_XL_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TFTransfoXLSequenceClassifierOutputWithPast,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ mems: List[tf.Tensor] | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ labels: np.ndarray | tf.Tensor | None = None,
+ training: Optional[bool] = False,
+ ) -> Union[Tuple, TFTransfoXLSequenceClassifierOutputWithPast]:
+ r"""
+ labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the cross entropy classification loss. Indices should be in `[0, ...,
+ config.vocab_size - 1]`.
+ """
+ transformer_outputs = self.transformer(
+ input_ids=input_ids,
+ mems=mems,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ hidden_states = transformer_outputs[0]
+ logits = self.score(hidden_states)
+ in_logits = None
+ if self.config.pad_token_id is None:
+ sequence_lengths = -1
+ else:
+ if input_ids is not None:
+ sequence_lengths = (
+ tf.argmax(tf.cast(tf.math.equal(input_ids, self.config.pad_token_id), input_ids.dtype), axis=-1)
+ - 1
+ )
+ sequence_lengths = tf.where(sequence_lengths >= 0, sequence_lengths, input_ids.shape[-1] - 1)
+ in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1)
+ else:
+ sequence_lengths = -1
+ logger.warning_once(
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
+ )
+ loss = None
+
+ if labels is not None:
+ if input_ids is not None:
+ batch_size, sequence_length = shape_list(input_ids)[:2]
+ else:
+ batch_size, sequence_length = shape_list(inputs_embeds)[:2]
+ assert self.config.pad_token_id is not None or batch_size == 1, (
+ "Cannot handle batch sizes > 1 if no padding token is defined."
+ )
+
+ if not tf.is_tensor(sequence_lengths):
+ in_logits = logits[0:batch_size, sequence_lengths]
+
+ loss = self.hf_compute_loss(tf.reshape(labels, [-1, 1]), tf.reshape(in_logits, [-1, self.num_labels]))
+
+ pooled_logits = in_logits if in_logits is not None else logits
+
+ if not return_dict:
+ output = (pooled_logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TFTransfoXLSequenceClassifierOutputWithPast(
+ loss=loss,
+ logits=pooled_logits,
+ mems=transformer_outputs.mems,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+
+__all__ = [
+ "TFAdaptiveEmbedding",
+ "TFTransfoXLForSequenceClassification",
+ "TFTransfoXLLMHeadModel",
+ "TFTransfoXLMainLayer",
+ "TFTransfoXLModel",
+ "TFTransfoXLPreTrainedModel",
+]
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/transfo_xl/modeling_tf_transfo_xl_utilities.py b/docs/transformers/build/lib/transformers/models/deprecated/transfo_xl/modeling_tf_transfo_xl_utilities.py
new file mode 100644
index 0000000000000000000000000000000000000000..48205e06fb20a473959544db4971dff0d3e58cbf
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/transfo_xl/modeling_tf_transfo_xl_utilities.py
@@ -0,0 +1,178 @@
+# coding=utf-8
+# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+A TF 2.0 Adaptive Softmax for Transformer XL model.
+"""
+
+import tensorflow as tf
+
+from ....modeling_tf_utils import keras
+from ....tf_utils import shape_list
+
+
+class TFAdaptiveSoftmaxMask(keras.layers.Layer):
+ def __init__(self, vocab_size, d_embed, d_proj, cutoffs, div_val=1, keep_order=False, **kwargs):
+ super().__init__(**kwargs)
+
+ self.vocab_size = vocab_size
+ self.d_embed = d_embed
+ self.d_proj = d_proj
+
+ self.cutoffs = cutoffs + [vocab_size]
+ self.cutoff_ends = [0] + self.cutoffs
+ self.div_val = div_val
+
+ self.shortlist_size = self.cutoffs[0]
+ self.n_clusters = len(self.cutoffs) - 1
+ self.head_size = self.shortlist_size + self.n_clusters
+ self.keep_order = keep_order
+
+ self.out_layers = []
+ self.out_projs = []
+
+ def build(self, input_shape):
+ if self.n_clusters > 0:
+ self.cluster_weight = self.add_weight(
+ shape=(self.n_clusters, self.d_embed), initializer="zeros", trainable=True, name="cluster_weight"
+ )
+ self.cluster_bias = self.add_weight(
+ shape=(self.n_clusters,), initializer="zeros", trainable=True, name="cluster_bias"
+ )
+
+ if self.div_val == 1:
+ for i in range(len(self.cutoffs)):
+ if self.d_proj != self.d_embed:
+ weight = self.add_weight(
+ shape=(self.d_embed, self.d_proj),
+ initializer="zeros",
+ trainable=True,
+ name=f"out_projs_._{i}",
+ )
+ self.out_projs.append(weight)
+ else:
+ self.out_projs.append(None)
+ weight = self.add_weight(
+ shape=(self.vocab_size, self.d_embed),
+ initializer="zeros",
+ trainable=True,
+ name=f"out_layers_._{i}_._weight",
+ )
+ bias = self.add_weight(
+ shape=(self.vocab_size,),
+ initializer="zeros",
+ trainable=True,
+ name=f"out_layers_._{i}_._bias",
+ )
+ self.out_layers.append((weight, bias))
+ else:
+ for i in range(len(self.cutoffs)):
+ l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
+ d_emb_i = self.d_embed // (self.div_val**i)
+
+ weight = self.add_weight(
+ shape=(d_emb_i, self.d_proj), initializer="zeros", trainable=True, name=f"out_projs_._{i}"
+ )
+ self.out_projs.append(weight)
+ weight = self.add_weight(
+ shape=(r_idx - l_idx, d_emb_i),
+ initializer="zeros",
+ trainable=True,
+ name=f"out_layers_._{i}_._weight",
+ )
+ bias = self.add_weight(
+ shape=(r_idx - l_idx,),
+ initializer="zeros",
+ trainable=True,
+ name=f"out_layers_._{i}_._bias",
+ )
+ self.out_layers.append((weight, bias))
+ super().build(input_shape)
+
+ @staticmethod
+ def _logit(x, W, b, proj=None):
+ y = x
+ if proj is not None:
+ y = tf.einsum("ibd,ed->ibe", y, proj)
+ return tf.einsum("ibd,nd->ibn", y, W) + b
+
+ @staticmethod
+ def _gather_logprob(logprob, target):
+ lp_size = shape_list(logprob)
+ r = tf.range(lp_size[0], dtype=target.dtype)
+ idx = tf.stack([r, target], 1)
+ return tf.gather_nd(logprob, idx)
+
+ def call(self, hidden, target, return_mean=True, training=False):
+ head_logprob = 0
+ if self.n_clusters == 0:
+ output = self._logit(hidden, self.out_layers[0][0], self.out_layers[0][1], self.out_projs[0])
+ if target is not None:
+ loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=target, logits=output)
+ out = tf.nn.log_softmax(output, axis=-1)
+ else:
+ hidden_sizes = shape_list(hidden)
+ out = []
+ loss = tf.zeros(hidden_sizes[:2])
+ for i in range(len(self.cutoffs)):
+ l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
+ if target is not None:
+ mask = (target >= l_idx) & (target < r_idx)
+ mask_idx = tf.where(mask)
+ cur_target = tf.boolean_mask(target, mask) - l_idx
+
+ if self.div_val == 1:
+ cur_W = self.out_layers[0][0][l_idx:r_idx]
+ cur_b = self.out_layers[0][1][l_idx:r_idx]
+ else:
+ cur_W = self.out_layers[i][0]
+ cur_b = self.out_layers[i][1]
+
+ if i == 0:
+ cur_W = tf.concat([cur_W, self.cluster_weight], 0)
+ cur_b = tf.concat([cur_b, self.cluster_bias], 0)
+
+ head_logit = self._logit(hidden, cur_W, cur_b, self.out_projs[0])
+ head_logprob = tf.nn.log_softmax(head_logit)
+ out.append(head_logprob[..., : self.cutoffs[0]])
+ if target is not None:
+ cur_head_logprob = tf.boolean_mask(head_logprob, mask)
+ cur_logprob = self._gather_logprob(cur_head_logprob, cur_target)
+ else:
+ tail_logit = self._logit(hidden, cur_W, cur_b, self.out_projs[i])
+ tail_logprob = tf.nn.log_softmax(tail_logit)
+ cluster_prob_idx = self.cutoffs[0] + i - 1 # No probability for the head cluster
+ logprob_i = head_logprob[..., cluster_prob_idx, None] + tail_logprob
+ out.append(logprob_i)
+ if target is not None:
+ cur_head_logprob = tf.boolean_mask(head_logprob, mask)
+ cur_tail_logprob = tf.boolean_mask(tail_logprob, mask)
+ cur_logprob = self._gather_logprob(cur_tail_logprob, cur_target)
+ cur_logprob += cur_head_logprob[:, self.cutoff_ends[1] + i - 1]
+ if target is not None:
+ loss += tf.scatter_nd(mask_idx, -cur_logprob, shape_list(loss))
+ out = tf.concat(out, axis=-1)
+
+ if target is not None:
+ if return_mean:
+ loss = tf.reduce_mean(loss)
+ # Add the training-time loss value to the layer using `self.add_loss()`.
+ self.add_loss(loss)
+
+ # Log the loss as a metric (we could log arbitrary metrics,
+ # including different metrics for training and inference.
+ self.add_metric(loss, name=self.name, aggregation="mean" if return_mean else "")
+
+ return out
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/transfo_xl/modeling_transfo_xl.py b/docs/transformers/build/lib/transformers/models/deprecated/transfo_xl/modeling_transfo_xl.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf843850c05c411cec507971aeb8557a3038591b
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/transfo_xl/modeling_transfo_xl.py
@@ -0,0 +1,1303 @@
+# coding=utf-8
+# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+PyTorch Transformer XL model. Adapted from https://github.com/kimiyoung/transformer-xl. In particular
+https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/mem_transformer.py
+"""
+
+import warnings
+from dataclasses import dataclass
+from typing import List, Optional, Tuple, Union
+
+import torch
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from ....modeling_utils import PreTrainedModel
+from ....utils import (
+ ModelOutput,
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+)
+from .configuration_transfo_xl import TransfoXLConfig
+from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "transfo-xl/transfo-xl-wt103"
+_CONFIG_FOR_DOC = "TransfoXLConfig"
+
+
+def build_tf_to_pytorch_map(model, config):
+ """
+ A map of modules from TF to PyTorch. This time I use a map to keep the PyTorch model as identical to the original
+ PyTorch model as possible.
+ """
+ tf_to_pt_map = {}
+
+ if hasattr(model, "transformer"):
+ # We are loading in a TransfoXLLMHeadModel => we will load also the Adaptive Softmax
+ tf_to_pt_map.update(
+ {
+ "transformer/adaptive_softmax/cutoff_0/cluster_W": model.crit.cluster_weight,
+ "transformer/adaptive_softmax/cutoff_0/cluster_b": model.crit.cluster_bias,
+ }
+ )
+ for i, (out_l, proj_l, tie_proj) in enumerate(
+ zip(model.crit.out_layers, model.crit.out_projs, config.tie_projs)
+ ):
+ layer_str = f"transformer/adaptive_softmax/cutoff_{i}/"
+ if config.tie_word_embeddings:
+ tf_to_pt_map.update({layer_str + "b": out_l.bias})
+ else:
+ raise NotImplementedError
+ # I don't think this is implemented in the TF code
+ tf_to_pt_map.update({layer_str + "lookup_table": out_l.weight, layer_str + "b": out_l.bias})
+ if not tie_proj:
+ tf_to_pt_map.update({layer_str + "proj": proj_l})
+ # Now load the rest of the transformer
+ model = model.transformer
+
+ # Embeddings
+ for i, (embed_l, proj_l) in enumerate(zip(model.word_emb.emb_layers, model.word_emb.emb_projs)):
+ layer_str = f"transformer/adaptive_embed/cutoff_{i}/"
+ tf_to_pt_map.update({layer_str + "lookup_table": embed_l.weight, layer_str + "proj_W": proj_l})
+
+ # Transformer blocks
+ for i, b in enumerate(model.layers):
+ layer_str = f"transformer/layer_{i}/"
+ tf_to_pt_map.update(
+ {
+ layer_str + "rel_attn/LayerNorm/gamma": b.dec_attn.layer_norm.weight,
+ layer_str + "rel_attn/LayerNorm/beta": b.dec_attn.layer_norm.bias,
+ layer_str + "rel_attn/o/kernel": b.dec_attn.o_net.weight,
+ layer_str + "rel_attn/qkv/kernel": b.dec_attn.qkv_net.weight,
+ layer_str + "rel_attn/r/kernel": b.dec_attn.r_net.weight,
+ layer_str + "ff/LayerNorm/gamma": b.pos_ff.layer_norm.weight,
+ layer_str + "ff/LayerNorm/beta": b.pos_ff.layer_norm.bias,
+ layer_str + "ff/layer_1/kernel": b.pos_ff.CoreNet[0].weight,
+ layer_str + "ff/layer_1/bias": b.pos_ff.CoreNet[0].bias,
+ layer_str + "ff/layer_2/kernel": b.pos_ff.CoreNet[3].weight,
+ layer_str + "ff/layer_2/bias": b.pos_ff.CoreNet[3].bias,
+ }
+ )
+
+ # Relative positioning biases
+ if config.untie_r:
+ r_r_list = []
+ r_w_list = []
+ for b in model.layers:
+ r_r_list.append(b.dec_attn.r_r_bias)
+ r_w_list.append(b.dec_attn.r_w_bias)
+ else:
+ r_r_list = [model.r_r_bias]
+ r_w_list = [model.r_w_bias]
+ tf_to_pt_map.update({"transformer/r_r_bias": r_r_list, "transformer/r_w_bias": r_w_list})
+ return tf_to_pt_map
+
+
+def load_tf_weights_in_transfo_xl(model, config, tf_path):
+ """Load tf checkpoints in a pytorch model"""
+ try:
+ import numpy as np
+ import tensorflow as tf
+ except ImportError:
+ logger.error(
+ "Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
+ "https://www.tensorflow.org/install/ for installation instructions."
+ )
+ raise
+ # Build TF to PyTorch weights loading map
+ tf_to_pt_map = build_tf_to_pytorch_map(model, config)
+
+ # Load weights from TF model
+ init_vars = tf.train.list_variables(tf_path)
+ tf_weights = {}
+ for name, shape in init_vars:
+ logger.info(f"Loading TF weight {name} with shape {shape}")
+ array = tf.train.load_variable(tf_path, name)
+ tf_weights[name] = array
+
+ for name, pointer in tf_to_pt_map.items():
+ assert name in tf_weights
+ array = tf_weights[name]
+ # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
+ # which are not required for using pretrained model
+ if "kernel" in name or "proj" in name:
+ array = np.transpose(array)
+ if ("r_r_bias" in name or "r_w_bias" in name) and len(pointer) > 1:
+ # Here we will split the TF weights
+ assert len(pointer) == array.shape[0]
+ for i, p_i in enumerate(pointer):
+ arr_i = array[i, ...]
+ try:
+ assert p_i.shape == arr_i.shape
+ except AssertionError as e:
+ e.args += (p_i.shape, arr_i.shape)
+ raise
+ logger.info(f"Initialize PyTorch weight {name} for layer {i}")
+ p_i.data = torch.from_numpy(arr_i)
+ else:
+ try:
+ assert pointer.shape == array.shape, (
+ f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
+ )
+ except AssertionError as e:
+ e.args += (pointer.shape, array.shape)
+ raise
+ logger.info(f"Initialize PyTorch weight {name}")
+ pointer.data = torch.from_numpy(array)
+ tf_weights.pop(name, None)
+ tf_weights.pop(name + "/Adam", None)
+ tf_weights.pop(name + "/Adam_1", None)
+
+ logger.info(f"Weights not copied to PyTorch model: {', '.join(tf_weights.keys())}")
+ return model
+
+
+class PositionalEmbedding(nn.Module):
+ def __init__(self, demb):
+ super().__init__()
+
+ self.demb = demb
+
+ inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb))
+ self.register_buffer("inv_freq", inv_freq)
+
+ def forward(self, pos_seq, bsz=None):
+ sinusoid_inp = torch.outer(pos_seq, self.inv_freq)
+ pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)
+
+ if bsz is not None:
+ return pos_emb[:, None, :].expand(-1, bsz, -1)
+ else:
+ return pos_emb[:, None, :]
+
+
+class PositionwiseFF(nn.Module):
+ def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, layer_norm_epsilon=1e-5):
+ super().__init__()
+
+ self.d_model = d_model
+ self.d_inner = d_inner
+ self.dropout = dropout
+
+ self.CoreNet = nn.Sequential(
+ nn.Linear(d_model, d_inner),
+ nn.ReLU(inplace=True),
+ nn.Dropout(dropout),
+ nn.Linear(d_inner, d_model),
+ nn.Dropout(dropout),
+ )
+
+ self.layer_norm = nn.LayerNorm(d_model, eps=layer_norm_epsilon)
+
+ self.pre_lnorm = pre_lnorm
+
+ def forward(self, inp):
+ if self.pre_lnorm:
+ # layer normalization + positionwise feed-forward
+ core_out = self.CoreNet(self.layer_norm(inp))
+
+ # residual connection
+ output = core_out + inp
+ else:
+ # positionwise feed-forward
+ core_out = self.CoreNet(inp)
+
+ # residual connection + layer normalization
+ output = self.layer_norm(inp + core_out)
+
+ return output
+
+
+class RelPartialLearnableMultiHeadAttn(nn.Module):
+ def __init__(
+ self,
+ n_head,
+ d_model,
+ d_head,
+ dropout,
+ dropatt=0,
+ pre_lnorm=False,
+ r_r_bias=None,
+ r_w_bias=None,
+ layer_norm_epsilon=1e-5,
+ ):
+ super().__init__()
+
+ self.n_head = n_head
+ self.d_model = d_model
+ self.d_head = d_head
+ self.dropout = dropout
+
+ self.qkv_net = nn.Linear(d_model, 3 * n_head * d_head, bias=False)
+
+ self.drop = nn.Dropout(dropout)
+ self.dropatt = nn.Dropout(dropatt)
+ self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)
+
+ self.layer_norm = nn.LayerNorm(d_model, eps=layer_norm_epsilon)
+
+ self.scale = 1 / (d_head**0.5)
+
+ self.pre_lnorm = pre_lnorm
+
+ if r_r_bias is None or r_w_bias is None: # Biases are not shared
+ self.r_r_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head))
+ self.r_w_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head))
+ else:
+ self.r_r_bias = r_r_bias
+ self.r_w_bias = r_w_bias
+
+ self.r_net = nn.Linear(self.d_model, self.n_head * self.d_head, bias=False)
+
+ def _rel_shift(self, x):
+ zero_pad_shape = (x.size(0), 1) + x.size()[2:]
+ zero_pad = torch.zeros(zero_pad_shape, device=x.device, dtype=x.dtype)
+ x_padded = torch.cat([zero_pad, x], dim=1)
+
+ x_padded_shape = (x.size(1) + 1, x.size(0)) + x.size()[2:]
+ x_padded = x_padded.view(*x_padded_shape)
+
+ x = x_padded[1:].view_as(x)
+
+ return x
+
+ def forward(self, w, r, attn_mask=None, mems=None, head_mask=None, output_attentions=False):
+ qlen, rlen, bsz = w.size(0), r.size(0), w.size(1)
+
+ if mems is not None:
+ cat = torch.cat([mems, w], 0)
+ if self.pre_lnorm:
+ w_heads = self.qkv_net(self.layer_norm(cat))
+ else:
+ w_heads = self.qkv_net(cat)
+ r_head_k = self.r_net(r)
+
+ w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
+ w_head_q = w_head_q[-qlen:]
+ else:
+ if self.pre_lnorm:
+ w_heads = self.qkv_net(self.layer_norm(w))
+ else:
+ w_heads = self.qkv_net(w)
+ r_head_k = self.r_net(r)
+
+ w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
+
+ klen = w_head_k.size(0)
+
+ w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head
+ w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head
+ w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head
+
+ r_head_k = r_head_k.view(rlen, self.n_head, self.d_head) # qlen x n_head x d_head
+
+ # compute attention score
+ rw_head_q = w_head_q + self.r_w_bias # qlen x bsz x n_head x d_head
+ AC = torch.einsum("ibnd,jbnd->ijbn", (rw_head_q, w_head_k)) # qlen x klen x bsz x n_head
+
+ rr_head_q = w_head_q + self.r_r_bias
+ BD = torch.einsum("ibnd,jnd->ijbn", (rr_head_q, r_head_k)) # qlen x klen x bsz x n_head
+ BD = self._rel_shift(BD)
+
+ # [qlen x klen x bsz x n_head]
+ attn_score = AC + BD
+ attn_score.mul_(self.scale)
+
+ mask_value = torch.finfo(attn_score.dtype).min
+
+ # compute attention probability
+ if attn_mask is not None and torch.sum(attn_mask).item():
+ attn_mask = attn_mask == 1 # Switch to bool
+ if attn_mask.dim() == 2:
+ attn_score = (
+ attn_score.float().masked_fill(attn_mask[None, :, :, None], mask_value).type_as(attn_score)
+ )
+ elif attn_mask.dim() == 3:
+ attn_score = attn_score.float().masked_fill(attn_mask[:, :, :, None], mask_value).type_as(attn_score)
+
+ # [qlen x klen x bsz x n_head]
+ attn_prob = nn.functional.softmax(attn_score, dim=1)
+ attn_prob = self.dropatt(attn_prob)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attn_prob = attn_prob * head_mask
+
+ # compute attention vector
+ attn_vec = torch.einsum("ijbn,jbnd->ibnd", (attn_prob, w_head_v))
+
+ # [qlen x bsz x n_head x d_head]
+ attn_vec = attn_vec.contiguous().view(attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
+
+ # linear projection
+ attn_out = self.o_net(attn_vec)
+ attn_out = self.drop(attn_out)
+
+ if self.pre_lnorm:
+ # residual connection
+ outputs = [w + attn_out]
+ else:
+ # residual connection + layer normalization
+ outputs = [self.layer_norm(w + attn_out)]
+
+ if output_attentions:
+ outputs.append(attn_prob)
+
+ return outputs
+
+
+class RelPartialLearnableDecoderLayer(nn.Module):
+ def __init__(self, n_head, d_model, d_head, d_inner, dropout, layer_norm_epsilon=1e-5, **kwargs):
+ super().__init__()
+
+ self.dec_attn = RelPartialLearnableMultiHeadAttn(
+ n_head, d_model, d_head, dropout, layer_norm_epsilon=layer_norm_epsilon, **kwargs
+ )
+ self.pos_ff = PositionwiseFF(
+ d_model, d_inner, dropout, pre_lnorm=kwargs.get("pre_lnorm"), layer_norm_epsilon=layer_norm_epsilon
+ )
+
+ def forward(self, dec_inp, r, dec_attn_mask=None, mems=None, head_mask=None, output_attentions=False):
+ attn_outputs = self.dec_attn(
+ dec_inp,
+ r,
+ attn_mask=dec_attn_mask,
+ mems=mems,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ )
+ ff_output = self.pos_ff(attn_outputs[0])
+
+ outputs = [ff_output] + attn_outputs[1:]
+
+ return outputs
+
+
+class AdaptiveEmbedding(nn.Module):
+ def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, sample_softmax=False):
+ super().__init__()
+
+ self.n_token = n_token
+ self.d_embed = d_embed
+
+ self.cutoffs = cutoffs + [n_token]
+ self.div_val = div_val
+ self.d_proj = d_proj
+
+ self.emb_scale = d_proj**0.5
+
+ self.cutoff_ends = [0] + self.cutoffs
+
+ self.emb_layers = nn.ModuleList()
+ self.emb_projs = nn.ParameterList()
+ if div_val == 1:
+ self.emb_layers.append(nn.Embedding(n_token, d_embed, sparse=sample_softmax > 0))
+ if d_proj != d_embed:
+ self.emb_projs.append(nn.Parameter(torch.FloatTensor(d_proj, d_embed)))
+ else:
+ for i in range(len(self.cutoffs)):
+ l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
+ d_emb_i = d_embed // (div_val**i)
+ self.emb_layers.append(nn.Embedding(r_idx - l_idx, d_emb_i))
+ self.emb_projs.append(nn.Parameter(torch.FloatTensor(d_proj, d_emb_i)))
+
+ def forward(self, inp):
+ if self.div_val == 1:
+ embed = self.emb_layers[0](inp)
+ if self.d_proj != self.d_embed:
+ embed = nn.functional.linear(embed, self.emb_projs[0])
+ else:
+ param = next(self.parameters())
+ inp_flat = inp.view(-1)
+ emb_flat = torch.zeros([inp_flat.size(0), self.d_proj], dtype=param.dtype, device=param.device)
+ for i in range(len(self.cutoffs)):
+ l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
+
+ mask_i = (inp_flat >= l_idx) & (inp_flat < r_idx)
+ indices_i = mask_i.nonzero().squeeze()
+
+ if indices_i.numel() == 0:
+ continue
+
+ inp_i = inp_flat.index_select(0, indices_i) - l_idx
+ emb_i = self.emb_layers[i](inp_i)
+ emb_i = nn.functional.linear(emb_i, self.emb_projs[i])
+
+ emb_flat.index_copy_(0, indices_i, emb_i)
+
+ embed_shape = inp.size() + (self.d_proj,)
+ embed = emb_flat.view(embed_shape)
+
+ embed.mul_(self.emb_scale)
+
+ return embed
+
+
+class TransfoXLPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = TransfoXLConfig
+ load_tf_weights = load_tf_weights_in_transfo_xl
+ base_model_prefix = "transformer"
+
+ def _init_weight(self, weight):
+ if self.config.init == "uniform":
+ nn.init.uniform_(weight, -self.config.init_range, self.config.init_range)
+ elif self.config.init == "normal":
+ nn.init.normal_(weight, 0.0, self.config.init_std)
+
+ def _init_bias(self, bias):
+ nn.init.constant_(bias, 0.0)
+
+ def _init_weights(self, m):
+ """Initialize the weights."""
+ classname = m.__class__.__name__
+ if classname.find("Linear") != -1:
+ if hasattr(m, "weight") and m.weight is not None:
+ self._init_weight(m.weight)
+ if hasattr(m, "bias") and m.bias is not None:
+ self._init_bias(m.bias)
+ elif classname.find("AdaptiveEmbedding") != -1:
+ if hasattr(m, "emb_projs"):
+ for i in range(len(m.emb_projs)):
+ if m.emb_projs[i] is not None:
+ nn.init.normal_(m.emb_projs[i], 0.0, self.config.proj_init_std)
+ elif classname.find("Embedding") != -1:
+ if hasattr(m, "weight"):
+ self._init_weight(m.weight)
+ elif classname.find("ProjectedAdaptiveLogSoftmax") != -1:
+ if hasattr(m, "cluster_weight") and m.cluster_weight is not None:
+ self._init_weight(m.cluster_weight)
+ if hasattr(m, "cluster_bias") and m.cluster_bias is not None:
+ self._init_bias(m.cluster_bias)
+ if hasattr(m, "out_projs"):
+ for i in range(len(m.out_projs)):
+ if m.out_projs[i] is not None:
+ nn.init.normal_(m.out_projs[i], 0.0, self.config.proj_init_std)
+ elif classname.find("LayerNorm") != -1:
+ if hasattr(m, "weight"):
+ nn.init.normal_(m.weight, 1.0, self.config.init_std)
+ if hasattr(m, "bias") and m.bias is not None:
+ self._init_bias(m.bias)
+ else:
+ if hasattr(m, "r_emb"):
+ self._init_weight(m.r_emb)
+ if hasattr(m, "r_w_bias"):
+ self._init_weight(m.r_w_bias)
+ if hasattr(m, "r_r_bias"):
+ self._init_weight(m.r_r_bias)
+ if hasattr(m, "r_bias"):
+ self._init_bias(m.r_bias)
+
+ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, layer: Optional[int] = -1):
+ """
+ Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size. Take care of tying
+ weights embeddings afterwards if the model class has a *tie_weights()* method.
+
+ Arguments:
+ new_num_tokens: (*optional*) int:
+ New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at
+ the end. Reducing the size will remove vectors from the end. If not provided or None: does nothing and
+ just returns a pointer to the input tokens `torch.nn.Embeddings` Module of the model.
+ layer: (*optional*) int:
+ Layer of the *AdaptiveEmbedding* where the resizing should be done. Per default the last layer will be
+ resized. Be aware that when resizing other than the last layer, you have to ensure that the new
+ token(s) in the tokenizer are at the corresponding position.
+
+ Return: `torch.nn.Embeddings` Pointer to the input tokens Embeddings Module of the model
+ """
+ base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
+
+ if new_num_tokens is None:
+ return self.get_input_embeddings()
+
+ new_num_tokens_layer, layer = self._get_new_num_tokens_layer(new_num_tokens, layer)
+ assert new_num_tokens_layer > 0, "The size of the new embedding layer cannot be 0 or less"
+ model_embeds = base_model._resize_token_embeddings(new_num_tokens_layer, layer)
+
+ # Update base model and current model config
+ self.config.vocab_size = new_num_tokens
+ base_model.vocab_size = new_num_tokens
+ base_model.n_token = new_num_tokens
+
+ new_embedding_shapes = self._get_embedding_shapes()
+ self._resize_cutoffs(new_num_tokens, new_num_tokens_layer, new_embedding_shapes, layer)
+
+ # Tie weights again if needed
+ self.tie_weights()
+
+ return model_embeds
+
+ def _get_new_num_tokens_layer(self, new_num_tokens, layer):
+ embeddings = self.get_input_embeddings()
+ if layer == -1:
+ layer = len(embeddings.emb_layers) - 1
+ assert 0 <= layer <= len(embeddings.emb_layers) - 1
+
+ new_num_tokens_layer = (
+ new_num_tokens
+ - sum([emb.weight.shape[0] for emb in embeddings.emb_layers[:layer]])
+ - sum([emb.weight.shape[0] for emb in embeddings.emb_layers[layer + 1 :]])
+ )
+ return new_num_tokens_layer, layer
+
+ def _get_embedding_shapes(self):
+ embeddings = self.get_input_embeddings()
+ return [emb.weight.shape[0] for emb in embeddings.emb_layers]
+
+ def _resize_token_embeddings(self, new_num_tokens, layer=-1):
+ embeddings = self.get_input_embeddings()
+ if new_num_tokens is None:
+ return embeddings
+ new_embeddings_layer = self._get_resized_embeddings(embeddings.emb_layers[layer], new_num_tokens)
+ embeddings.emb_layers[layer] = new_embeddings_layer
+
+ self.set_input_embeddings(embeddings)
+
+ return self.get_input_embeddings()
+
+ def _resize_cutoffs(self, new_num_tokens, new_emb_size, new_embedding_shapes, layer):
+ embeddings = self.get_input_embeddings()
+
+ for i in range(layer, len(embeddings.cutoffs)):
+ embeddings.cutoffs[i] = sum(new_embedding_shapes[: i + 1])
+
+ embeddings.cutoff_ends = [0] + embeddings.cutoffs
+ embeddings.n_token = new_num_tokens
+
+ self.config.cutoffs = embeddings.cutoffs[:-1]
+
+ return embeddings.cutoffs
+
+
+@dataclass
+class TransfoXLModelOutput(ModelOutput):
+ """
+ Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ mems (`List[torch.FloatTensor]` of length `config.n_layers`):
+ Contains pre-computed hidden-states (key and values in the attention blocks). Can be used (see `mems`
+ input) to speed up sequential decoding. The token ids which have their past given to this model should not
+ be passed as input ids as they have already been computed.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
+ shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ last_hidden_state: torch.FloatTensor
+ mems: List[torch.FloatTensor] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+
+
+@dataclass
+class TransfoXLSequenceClassifierOutputWithPast(ModelOutput):
+ """
+ Base class for outputs of sentence classification models.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Classification (or regression if config.num_labels==1) loss.
+ logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
+ Classification (or regression if config.num_labels==1) scores (before SoftMax).
+ mems (`List[torch.FloatTensor]` of length `config.n_layers`):
+ Contains pre-computed hidden-states (key and values in the attention blocks). Can be used (see `mems`
+ input) to speed up sequential decoding. The token ids which have their past given to this model should not
+ be passed as input ids as they have already been computed.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
+ shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ mems: List[torch.FloatTensor] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+
+
+@dataclass
+class TransfoXLLMHeadModelOutput(ModelOutput):
+ """
+ Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
+
+ Args:
+ losses (`torch.FloatTensor` of shape *(batch_size, sequence_length-1)*, *optional*, returned when `labels` is provided):
+ Language modeling losses (not reduced).
+ prediction_scores (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token after SoftMax).
+ mems (`List[torch.FloatTensor]` of length `config.n_layers`):
+ Contains pre-computed hidden-states (key and values in the attention blocks). Can be used (see `mems`
+ input) to speed up sequential decoding. The token ids which have their past given to this model should not
+ be passed as input ids as they have already been computed.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
+ shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ loss (`torch.FloatTensor` of shape `()`, *optional*, returned when `labels` is provided)
+ Reduced language modeling loss.
+ """
+
+ losses: Optional[torch.FloatTensor] = None
+ prediction_scores: Optional[torch.FloatTensor] = None
+ mems: List[torch.FloatTensor] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+ loss: Optional[torch.FloatTensor] = None
+
+ @property
+ def logits(self):
+ # prediction scores are the output of the adaptive softmax, see
+ # the file `modeling_transfo_xl_utilities`. Since the adaptive
+ # softmax returns the log softmax value, `self.prediction_scores`
+ # are strictly speaking not exactly `logits`, but behave the same
+ # way logits do.
+ return self.prediction_scores
+
+
+TRANSFO_XL_START_DOCSTRING = r"""
+
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`TransfoXLConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+TRANSFO_XL_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ mems (`List[torch.FloatTensor]` of length `config.n_layers`):
+ Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model (see
+ `mems` output below). Can be used to speed up sequential decoding. The token ids which have their mems
+ given to this model should not be passed as `input_ids` as they have already been computed.
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
+ TRANSFO_XL_START_DOCSTRING,
+)
+class TransfoXLModel(TransfoXLPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.n_token = config.vocab_size
+
+ self.d_embed = config.d_embed
+ self.d_model = config.d_model
+ self.n_head = config.n_head
+ self.d_head = config.d_head
+
+ self.word_emb = AdaptiveEmbedding(
+ config.vocab_size, config.d_embed, config.d_model, config.cutoffs, div_val=config.div_val
+ )
+
+ self.drop = nn.Dropout(config.dropout)
+
+ self.n_layer = config.n_layer
+ self.mem_len = config.mem_len
+ self.attn_type = config.attn_type
+
+ if not config.untie_r:
+ self.r_w_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head))
+ self.r_r_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head))
+
+ self.layers = nn.ModuleList()
+ if config.attn_type == 0: # the default attention
+ for i in range(config.n_layer):
+ self.layers.append(
+ RelPartialLearnableDecoderLayer(
+ config.n_head,
+ config.d_model,
+ config.d_head,
+ config.d_inner,
+ config.dropout,
+ dropatt=config.dropatt,
+ pre_lnorm=config.pre_lnorm,
+ r_w_bias=None if config.untie_r else self.r_w_bias,
+ r_r_bias=None if config.untie_r else self.r_r_bias,
+ layer_norm_epsilon=config.layer_norm_epsilon,
+ )
+ )
+ else: # learnable embeddings and absolute embeddings are not used in our pretrained checkpoints
+ raise NotImplementedError # Removed them to avoid maintaining dead code
+
+ self.same_length = config.same_length
+ self.clamp_len = config.clamp_len
+
+ if self.attn_type == 0: # default attention
+ self.pos_emb = PositionalEmbedding(self.d_model)
+ else: # learnable embeddings and absolute embeddings
+ raise NotImplementedError # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.word_emb
+
+ def set_input_embeddings(self, new_embeddings):
+ self.word_emb = new_embeddings
+
+ def backward_compatible(self):
+ self.sample_softmax = -1
+
+ def reset_memory_length(self, mem_len):
+ self.mem_len = mem_len
+
+ def _prune_heads(self, heads):
+ logger.info("Head pruning is not implemented for Transformer-XL model")
+ pass
+
+ def init_mems(self, bsz):
+ if self.mem_len > 0:
+ mems = []
+ param = next(self.parameters())
+ for i in range(self.n_layer):
+ empty = torch.zeros(self.mem_len, bsz, self.config.d_model, dtype=param.dtype, device=param.device)
+ mems.append(empty)
+
+ return mems
+ else:
+ return None
+
+ def _update_mems(self, hids, mems, mlen, qlen):
+ # does not deal with None
+ if mems is None:
+ return None
+
+ # mems is not None
+ assert len(hids) == len(mems), "len(hids) != len(mems)"
+
+ # There are `mlen + qlen` steps that can be cached into mems
+ with torch.no_grad():
+ new_mems = []
+ end_idx = mlen + max(0, qlen)
+ beg_idx = max(0, end_idx - self.mem_len)
+ for i in range(len(hids)):
+ cat = torch.cat([mems[i], hids[i]], dim=0)
+ new_mems.append(cat[beg_idx:end_idx].detach())
+
+ return new_mems
+
+ @add_start_docstrings_to_model_forward(TRANSFO_XL_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TransfoXLModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ mems: Optional[List[torch.FloatTensor]] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, TransfoXLModelOutput]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library
+ # so we transpose here from shape [bsz, len] to shape [len, bsz]
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_ids = input_ids.transpose(0, 1).contiguous()
+ qlen, bsz = input_ids.size()
+ elif inputs_embeds is not None:
+ inputs_embeds = inputs_embeds.transpose(0, 1).contiguous()
+ qlen, bsz = inputs_embeds.shape[0], inputs_embeds.shape[1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ if mems is None:
+ mems = self.init_mems(bsz)
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] (a head_mask for each layer)
+ # and head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head]
+ if head_mask is not None:
+ if head_mask.dim() == 1:
+ head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0).unsqueeze(0)
+ head_mask = head_mask.expand(self.n_layer, -1, -1, -1, -1)
+ elif head_mask.dim() == 2:
+ head_mask = head_mask.unsqueeze(1).unsqueeze(1).unsqueeze(1)
+ head_mask = head_mask.to(
+ dtype=next(self.parameters()).dtype
+ ) # switch to float if need + fp16 compatibility
+ else:
+ head_mask = [None] * self.n_layer
+
+ if inputs_embeds is not None:
+ word_emb = inputs_embeds
+ else:
+ word_emb = self.word_emb(input_ids)
+
+ mlen = mems[0].size(0) if mems is not None else 0
+ klen = mlen + qlen
+ if self.same_length:
+ all_ones = word_emb.new_ones((qlen, klen), dtype=torch.bool)
+ mask_len = klen - self.mem_len
+ if mask_len > 0:
+ mask_shift_len = qlen - mask_len
+ else:
+ mask_shift_len = qlen
+ dec_attn_mask = (torch.triu(all_ones, 1 + mlen) + torch.tril(all_ones, -mask_shift_len))[:, :, None] # -1
+ else:
+ dec_attn_mask = torch.triu(word_emb.new_ones((qlen, klen), dtype=torch.bool), diagonal=1 + mlen)[
+ :, :, None
+ ]
+
+ hids = []
+ attentions = [] if output_attentions else None
+ if self.attn_type == 0: # default
+ pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.device, dtype=torch.int64).type_as(
+ dtype=word_emb.dtype
+ )
+ if self.clamp_len > 0:
+ pos_seq.clamp_(max=self.clamp_len)
+ pos_emb = self.pos_emb(pos_seq)
+
+ core_out = self.drop(word_emb)
+ pos_emb = self.drop(pos_emb)
+
+ for i, layer in enumerate(self.layers):
+ hids.append(core_out)
+ mems_i = None if mems is None else mems[i]
+ layer_outputs = layer(
+ core_out,
+ pos_emb,
+ dec_attn_mask=dec_attn_mask,
+ mems=mems_i,
+ head_mask=head_mask[i],
+ output_attentions=output_attentions,
+ )
+ core_out = layer_outputs[0]
+ if output_attentions:
+ attentions.append(layer_outputs[1])
+ else: # learnable embeddings and absolute embeddings
+ raise NotImplementedError # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint
+
+ core_out = self.drop(core_out)
+
+ new_mems = self._update_mems(hids, mems, mlen, qlen)
+
+ if output_hidden_states:
+ # Add last layer and transpose to library standard shape [bsz, len, hidden_dim]
+ hids.append(core_out)
+ hids = tuple(t.transpose(0, 1).contiguous() for t in hids)
+ else:
+ hids = None
+ if output_attentions:
+ # Transpose to library standard shape [bsz, n_heads, query_seq_len, key_seq_len]
+ attentions = tuple(t.permute(2, 3, 0, 1).contiguous() for t in attentions)
+ # We transpose back here to shape [bsz, len, hidden_dim]
+ core_out = core_out.transpose(0, 1).contiguous()
+
+ if not return_dict:
+ return tuple(v for v in [core_out, new_mems, hids, attentions] if v is not None)
+
+ return TransfoXLModelOutput(
+ last_hidden_state=core_out,
+ mems=new_mems,
+ hidden_states=hids,
+ attentions=attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ The Transformer-XL Model with a language modeling head on top (adaptive softmax with weights tied to the adaptive
+ input embeddings)
+ """,
+ TRANSFO_XL_START_DOCSTRING,
+)
+class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
+ _tied_weights_keys = [r"crit\.out_projs\.\d+", r"crit\.out_layers\.\d+\.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.transformer = TransfoXLModel(config)
+ self.sample_softmax = config.sample_softmax
+ self.trainer_compatible = getattr(config, "trainer_compatible", False)
+
+ if not self.trainer_compatible:
+ warnings.warn(
+ "The output of TransfoXL will be updated in v5 to support a single loss as first argument. In order "
+ "to use that updated output, please specify `trainer_compatible=True` as your configuration"
+ " attribute.",
+ DeprecationWarning,
+ )
+
+ assert self.sample_softmax <= 0, (
+ "Sampling from the softmax is not implemented yet. Please look at issue: #3310:"
+ " https://github.com/huggingface/transformers/issues/3310"
+ )
+
+ self.crit = ProjectedAdaptiveLogSoftmax(
+ config.vocab_size, config.d_embed, config.d_model, config.cutoffs, div_val=config.div_val
+ )
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def tie_weights(self):
+ """
+ Run this to be sure output and input (adaptive) softmax weights are tied
+ """
+
+ if self.config.tie_word_embeddings:
+ for i in range(len(self.crit.out_layers)):
+ self._tie_or_clone_weights(self.crit.out_layers[i], self.transformer.word_emb.emb_layers[i])
+ if self.config.tie_projs:
+ for i, tie_proj in enumerate(self.config.tie_projs):
+ if tie_proj and self.config.div_val == 1 and self.config.d_model != self.config.d_embed:
+ if self.config.torchscript:
+ self.crit.out_projs[i] = nn.Parameter(self.transformer.word_emb.emb_projs[0].clone())
+ else:
+ self.crit.out_projs[i] = self.transformer.word_emb.emb_projs[0]
+ elif tie_proj and self.config.div_val != 1:
+ if self.config.torchscript:
+ self.crit.out_projs[i] = nn.Parameter(self.transformer.word_emb.emb_projs[i].clone())
+ else:
+ self.crit.out_projs[i] = self.transformer.word_emb.emb_projs[i]
+
+ def reset_memory_length(self, mem_len):
+ self.transformer.reset_memory_length(mem_len)
+
+ def init_mems(self, bsz):
+ return self.transformer.init_mems(bsz)
+
+ @add_start_docstrings_to_model_forward(TRANSFO_XL_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TransfoXLLMHeadModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ mems: Optional[List[torch.FloatTensor]] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, TransfoXLLMHeadModelOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ if input_ids is not None:
+ bsz, tgt_len = input_ids.size(0), input_ids.size(1)
+ elif inputs_embeds is not None:
+ bsz, tgt_len = inputs_embeds.size(0), inputs_embeds.size(1)
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ transformer_outputs = self.transformer(
+ input_ids,
+ mems=mems,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ last_hidden = transformer_outputs[0]
+ pred_hid = last_hidden[:, -tgt_len:]
+
+ if labels is not None:
+ # Prevents all labels being -100 and throwing an error
+ # when backwarding the loss
+ miss_valid_label = labels[0, 1:].sum() == (labels.size(1) - 1) * -100
+ if miss_valid_label:
+ # Sets an token, just to prevent loss from being NaN
+ labels[0, 1] = self.config.eos_token_id
+
+ softmax_output = self.crit(pred_hid, labels)
+ prediction_scores = softmax_output.view(bsz, tgt_len, -1) if labels is None else ()
+
+ if labels is not None:
+ losses = softmax_output.view(bsz, tgt_len - 1)
+ # Avoids from incorporating padding (-100) tokens into loss value
+ loss = losses[losses != 0].mean()
+ else:
+ losses, loss = None, None
+
+ if not return_dict:
+ if self.trainer_compatible:
+ output = (prediction_scores, losses) if losses is not None else (prediction_scores,)
+ output += transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+ else:
+ output = (prediction_scores, *transformer_outputs[1:])
+ output = ((losses,) + output) if losses is not None else output
+ return (output + (loss,)) if loss is not None else output
+
+ return TransfoXLLMHeadModelOutput(
+ loss=loss,
+ prediction_scores=prediction_scores,
+ losses=losses,
+ mems=transformer_outputs.mems,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+ def get_output_embeddings(self):
+ """Double-check if you are using adaptive softmax."""
+ if self.sample_softmax > 0:
+ return self.out_layer
+ else:
+ return self.crit.out_layers[-1]
+
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **model_kwargs):
+ inputs = {}
+
+ # if past is defined in model kwargs then use it for faster decoding
+ if past_key_values:
+ inputs["mems"] = past_key_values
+ inputs["input_ids"] = input_ids[:, -1].unsqueeze(-1)
+ else:
+ inputs["input_ids"] = input_ids
+
+ return inputs
+
+ def _resize_cutoffs(self, new_num_tokens, new_emb_size, new_embedding_shapes, layer):
+ new_cutoffs = super()._resize_cutoffs(new_num_tokens, new_emb_size, new_embedding_shapes, layer)
+
+ self.crit.cutoffs = new_cutoffs
+ self.crit.cutoff_ends = [0] + new_cutoffs
+ self.crit.n_token = new_num_tokens
+
+ @staticmethod
+ def _reorder_cache(mems: List[torch.Tensor], beam_idx: torch.Tensor) -> List[torch.Tensor]:
+ """
+ This function is used to re-order the `mems` cache if [`~PreTrainedModel.beam_search`] or
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `mems` with the correct beam_idx at every
+ generation step.
+ """
+ return [layer_past.index_select(1, beam_idx.to(layer_past.device)) for layer_past in mems]
+
+
+@add_start_docstrings(
+ """
+ The Transformer-XL Model transformer with a sequence classification head on top (linear layer).
+
+ [`TransfoXLForSequenceClassification`] uses the last token in order to do the classification, as other causal
+ models (e.g. GPT-1) do.
+
+ Since it does classification on the last token, it requires to know the position of the last token. If a
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
+ each row of the batch).
+ """,
+ TRANSFO_XL_START_DOCSTRING,
+)
+class TransfoXLForSequenceClassification(TransfoXLPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.transformer = TransfoXLModel(config)
+ self.score = nn.Linear(config.d_embed, self.num_labels, bias=False)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(TRANSFO_XL_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TransfoXLSequenceClassifierOutputWithPast,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ mems: Optional[List[torch.FloatTensor]] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, TransfoXLSequenceClassifierOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = self.transformer(
+ input_ids,
+ mems=mems,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ hidden_states = transformer_outputs[0]
+ logits = self.score(hidden_states)
+
+ if input_ids is not None:
+ batch_size, sequence_length = input_ids.shape[:2]
+ else:
+ batch_size, sequence_length = inputs_embeds.shape[:2]
+
+ assert self.config.pad_token_id is not None or batch_size == 1, (
+ "Cannot handle batch sizes > 1 if no padding token is defined."
+ )
+ if self.config.pad_token_id is None:
+ sequence_lengths = -1
+ else:
+ if input_ids is not None:
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
+ sequence_lengths = sequence_lengths.to(logits.device)
+ else:
+ sequence_lengths = -1
+ logger.warning_once(
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
+ )
+
+ pooled_logits = logits[range(batch_size), sequence_lengths]
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(pooled_logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(pooled_logits, labels)
+ if not return_dict:
+ output = (pooled_logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TransfoXLSequenceClassifierOutputWithPast(
+ loss=loss,
+ logits=pooled_logits,
+ mems=transformer_outputs.mems,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+
+__all__ = [
+ "AdaptiveEmbedding",
+ "TransfoXLForSequenceClassification",
+ "TransfoXLLMHeadModel",
+ "TransfoXLModel",
+ "TransfoXLPreTrainedModel",
+ "load_tf_weights_in_transfo_xl",
+]
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/transfo_xl/modeling_transfo_xl_utilities.py b/docs/transformers/build/lib/transformers/models/deprecated/transfo_xl/modeling_transfo_xl_utilities.py
new file mode 100644
index 0000000000000000000000000000000000000000..f76f3ccc6259fcb033b44eb43dd98be23482221c
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/transfo_xl/modeling_transfo_xl_utilities.py
@@ -0,0 +1,251 @@
+# coding=utf-8
+# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Utilities for PyTorch Transformer XL model. Directly adapted from https://github.com/kimiyoung/transformer-xl.
+"""
+
+import torch
+from torch import nn
+
+
+# CUDA_MAJOR = int(torch.version.cuda.split('.')[0])
+# CUDA_MINOR = int(torch.version.cuda.split('.')[1])
+
+
+class ProjectedAdaptiveLogSoftmax(nn.Module):
+ def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, keep_order=False):
+ super().__init__()
+
+ self.n_token = n_token
+ self.d_embed = d_embed
+ self.d_proj = d_proj
+
+ self.cutoffs = cutoffs + [n_token]
+ self.cutoff_ends = [0] + self.cutoffs
+ self.div_val = div_val
+
+ self.shortlist_size = self.cutoffs[0]
+ self.n_clusters = len(self.cutoffs) - 1
+ self.head_size = self.shortlist_size + self.n_clusters
+
+ if self.n_clusters > 0:
+ self.cluster_weight = nn.Parameter(torch.zeros(self.n_clusters, self.d_embed))
+ self.cluster_bias = nn.Parameter(torch.zeros(self.n_clusters))
+
+ self.out_layers = nn.ModuleList()
+ self.out_projs = nn.ParameterList()
+
+ if div_val == 1:
+ for i in range(len(self.cutoffs)):
+ if d_proj != d_embed:
+ self.out_projs.append(nn.Parameter(torch.FloatTensor(d_proj, d_embed)))
+ else:
+ self.out_projs.append(None)
+
+ self.out_layers.append(nn.Linear(d_embed, n_token))
+ else:
+ for i in range(len(self.cutoffs)):
+ l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
+ d_emb_i = d_embed // (div_val**i)
+
+ self.out_projs.append(nn.Parameter(torch.FloatTensor(d_proj, d_emb_i)))
+
+ self.out_layers.append(nn.Linear(d_emb_i, r_idx - l_idx))
+
+ self.keep_order = keep_order
+
+ def _compute_logit(self, hidden, weight, bias, proj):
+ if proj is None:
+ logit = nn.functional.linear(hidden, weight, bias=bias)
+ else:
+ # if CUDA_MAJOR <= 9 and CUDA_MINOR <= 1:
+ proj_hid = nn.functional.linear(hidden, proj.t().contiguous())
+ logit = nn.functional.linear(proj_hid, weight, bias=bias)
+ # else:
+ # logit = torch.einsum('bd,de,ev->bv', (hidden, proj, weight.t()))
+ # if bias is not None:
+ # logit = logit + bias
+
+ return logit
+
+ def forward(self, hidden, labels=None, keep_order=False):
+ """
+ Params:
+ hidden :: [len*bsz x d_proj]
+ labels :: [len*bsz]
+
+ Return:
+ if labels is None: out :: [len*bsz x n_tokens] log probabilities of tokens over the vocabulary else: out ::
+ [(len-1)*bsz] Negative log likelihood. We could replace this implementation by the native PyTorch one if
+ theirs had an option to set bias on all clusters in the native one. here:
+ https://github.com/pytorch/pytorch/blob/dbe6a7a9ff1a364a8706bf5df58a1ca96d2fd9da/torch/nn/modules/adaptive.py#L138
+ """
+
+ if labels is not None:
+ # Shift so that tokens < n predict n
+ hidden = hidden[..., :-1, :].contiguous()
+ labels = labels[..., 1:].contiguous()
+ hidden = hidden.view(-1, hidden.size(-1))
+ labels = labels.view(-1)
+ if hidden.size(0) != labels.size(0):
+ raise RuntimeError("Input and labels should have the same size in the batch dimension.")
+ else:
+ hidden = hidden.view(-1, hidden.size(-1))
+
+ if self.n_clusters == 0:
+ logit = self._compute_logit(hidden, self.out_layers[0].weight, self.out_layers[0].bias, self.out_projs[0])
+ if labels is not None:
+ mask = labels != -100
+ out = torch.zeros_like(labels, dtype=hidden.dtype, device=hidden.device)
+ out[mask] = (
+ -nn.functional.log_softmax(logit, dim=-1)[mask].gather(1, labels[mask].unsqueeze(1)).squeeze(1)
+ )
+ else:
+ out = nn.functional.log_softmax(logit, dim=-1)
+ else:
+ # construct weights and biases
+ weights, biases = [], []
+ for i in range(len(self.cutoffs)):
+ if self.div_val == 1:
+ l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
+ weight_i = self.out_layers[0].weight[l_idx:r_idx]
+ bias_i = self.out_layers[0].bias[l_idx:r_idx]
+ else:
+ weight_i = self.out_layers[i].weight
+ bias_i = self.out_layers[i].bias
+
+ if i == 0:
+ weight_i = torch.cat([weight_i, self.cluster_weight], dim=0)
+ bias_i = torch.cat([bias_i, self.cluster_bias], dim=0)
+
+ weights.append(weight_i)
+ biases.append(bias_i)
+
+ head_weight, head_bias, head_proj = weights[0], biases[0], self.out_projs[0]
+
+ head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj)
+ head_logprob = nn.functional.log_softmax(head_logit, dim=1)
+
+ if labels is None:
+ out = hidden.new_empty((head_logit.size(0), self.n_token))
+ else:
+ out = torch.zeros_like(labels, dtype=hidden.dtype, device=hidden.device)
+
+ offset = 0
+ cutoff_values = [0] + self.cutoffs
+ for i in range(len(cutoff_values) - 1):
+ l_idx, r_idx = cutoff_values[i], cutoff_values[i + 1]
+
+ if labels is not None:
+ mask_i = (labels >= l_idx) & (labels < r_idx)
+ indices_i = mask_i.nonzero().squeeze()
+
+ if indices_i.numel() == 0:
+ continue
+
+ target_i = labels.index_select(0, indices_i) - l_idx
+ head_logprob_i = head_logprob.index_select(0, indices_i)
+ hidden_i = hidden.index_select(0, indices_i)
+ else:
+ hidden_i = hidden
+
+ if i == 0:
+ if labels is not None:
+ logprob_i = head_logprob_i.gather(1, target_i[:, None]).squeeze(1)
+ else:
+ out[:, : self.cutoffs[0]] = head_logprob[:, : self.cutoffs[0]]
+ else:
+ weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i]
+
+ tail_logit_i = self._compute_logit(hidden_i, weight_i, bias_i, proj_i)
+ tail_logprob_i = nn.functional.log_softmax(tail_logit_i, dim=1)
+ cluster_prob_idx = self.cutoffs[0] + i - 1 # No probability for the head cluster
+ if labels is not None:
+ logprob_i = head_logprob_i[:, cluster_prob_idx] + tail_logprob_i.gather(
+ 1, target_i[:, None]
+ ).squeeze(1)
+ else:
+ logprob_i = head_logprob[:, cluster_prob_idx, None] + tail_logprob_i
+ out[:, l_idx:r_idx] = logprob_i
+
+ if labels is not None:
+ if (hasattr(self, "keep_order") and self.keep_order) or keep_order:
+ out.index_copy_(0, indices_i, -logprob_i)
+ else:
+ out[offset : offset + logprob_i.size(0)].copy_(-logprob_i)
+ offset += logprob_i.size(0)
+
+ return out
+
+ def log_prob(self, hidden):
+ r"""
+ Computes log probabilities for all \\(n\_classes\\) From:
+ https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/adaptive.p
+
+ Args:
+ hidden (Tensor): a minibatch of example
+
+ Returns:
+ log-probabilities of for each class \\(c\\) in range \\(0 <= c <= n\_classes\\), where \\(n\_classes\\) is
+ a parameter passed to `AdaptiveLogSoftmaxWithLoss` constructor. Shape:
+
+ - Input: \\((N, in\_features)\\)
+ - Output: \\((N, n\_classes)\\)
+ """
+ if self.n_clusters == 0:
+ logit = self._compute_logit(hidden, self.out_layers[0].weight, self.out_layers[0].bias, self.out_projs[0])
+ return nn.functional.log_softmax(logit, dim=-1)
+ else:
+ # construct weights and biases
+ weights, biases = [], []
+ for i in range(len(self.cutoffs)):
+ if self.div_val == 1:
+ l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
+ weight_i = self.out_layers[0].weight[l_idx:r_idx]
+ bias_i = self.out_layers[0].bias[l_idx:r_idx]
+ else:
+ weight_i = self.out_layers[i].weight
+ bias_i = self.out_layers[i].bias
+
+ if i == 0:
+ weight_i = torch.cat([weight_i, self.cluster_weight], dim=0)
+ bias_i = torch.cat([bias_i, self.cluster_bias], dim=0)
+
+ weights.append(weight_i)
+ biases.append(bias_i)
+
+ head_weight, head_bias, head_proj = weights[0], biases[0], self.out_projs[0]
+ head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj)
+
+ out = hidden.new_empty((head_logit.size(0), self.n_token))
+ head_logprob = nn.functional.log_softmax(head_logit, dim=1)
+
+ cutoff_values = [0] + self.cutoffs
+ for i in range(len(cutoff_values) - 1):
+ start_idx, stop_idx = cutoff_values[i], cutoff_values[i + 1]
+
+ if i == 0:
+ out[:, : self.cutoffs[0]] = head_logprob[:, : self.cutoffs[0]]
+ else:
+ weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i]
+
+ tail_logit_i = self._compute_logit(hidden, weight_i, bias_i, proj_i)
+ tail_logprob_i = nn.functional.log_softmax(tail_logit_i, dim=1)
+
+ logprob_i = head_logprob[:, -i] + tail_logprob_i
+ out[:, start_idx, stop_idx] = logprob_i
+
+ return out
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/transfo_xl/tokenization_transfo_xl.py b/docs/transformers/build/lib/transformers/models/deprecated/transfo_xl/tokenization_transfo_xl.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9a5fb7b3466b0091717dfe2a264c16b0a1aeeeb
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/transfo_xl/tokenization_transfo_xl.py
@@ -0,0 +1,821 @@
+# coding=utf-8
+# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Tokenization classes for Transformer XL model. Adapted from https://github.com/kimiyoung/transformer-xl.
+"""
+
+import glob
+import os
+import pickle
+import re
+from collections import Counter, OrderedDict
+from typing import List, Optional, Tuple
+
+import numpy as np
+
+from ....tokenization_utils import PreTrainedTokenizer
+from ....utils import (
+ cached_file,
+ is_sacremoses_available,
+ is_torch_available,
+ logging,
+ requires_backends,
+ strtobool,
+ torch_only_method,
+)
+
+
+if is_sacremoses_available():
+ import sacremoses as sm
+
+
+if is_torch_available():
+ import torch
+
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {
+ "pretrained_vocab_file": "vocab.pkl",
+ "pretrained_vocab_file_torch": "vocab.bin",
+ "vocab_file": "vocab.txt",
+}
+
+
+PRETRAINED_CORPUS_ARCHIVE_MAP = {
+ "transfo-xl/transfo-xl-wt103": "https://huggingface.co/transfo-xl/transfo-xl-wt103/resolve/main/corpus.bin",
+}
+CORPUS_NAME = "corpus.bin"
+
+MATCH_NUMBERS = r"(?<=\d)[,.](?=\d)", r" @\g<0>@ "
+DETOKENIZE_NUMBERS = [(r" @\,@ ", r","), (r" @\.@ ", r".")]
+
+
+def tokenize_numbers(text_array: List[str]) -> List[str]:
+ """
+ Splits large comma-separated numbers and floating point values. This is done by replacing commas with ' @,@ ' and
+ dots with ' @.@ '.
+
+ Args:
+ text_array: An already tokenized text as list.
+
+ Returns:
+ A list of strings with tokenized numbers.
+
+ Example:
+
+ ```python
+ >>> tokenize_numbers(["$", "5,000", "1.73", "m"])
+ ['$', '5', '@,@', '000', '1', '@.@', '73', 'm']
+ ```"""
+ tokenized = []
+ for i in range(len(text_array)):
+ reg, sub = MATCH_NUMBERS
+ replaced = re.sub(reg, sub, text_array[i]).split()
+ tokenized.extend(replaced)
+
+ return tokenized
+
+
+def detokenize_numbers(text: str) -> str:
+ """
+ Inverts the operation of *tokenize_numbers*. This is replacing ' @,@ ' and ' @.@' by ',' and '.'.
+
+ Args:
+ text: A string where the number should be detokenized.
+
+ Returns:
+ A detokenized string.
+
+ Example:
+
+ ```python
+ >>> detokenize_numbers("$ 5 @,@ 000 1 @.@ 73 m")
+ '$ 5,000 1.73 m'
+ ```"""
+ for reg, sub in DETOKENIZE_NUMBERS:
+ text = re.sub(reg, sub, text)
+ return text
+
+
+class TransfoXLTokenizer(PreTrainedTokenizer):
+ """
+ Construct a Transformer-XL tokenizer adapted from Vocab class in [the original
+ code](https://github.com/kimiyoung/transformer-xl). The Transformer-XL tokenizer is a word-level tokenizer (no
+ sub-word tokenization).
+
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
+ this superclass for more information regarding those methods.
+
+ Args:
+ special (`List[str]`, *optional*):
+ A list of special tokens (to be treated by the original implementation of this tokenizer).
+ min_freq (`int`, *optional*, defaults to 0):
+ The minimum number of times a token has to be present in order to be kept in the vocabulary (otherwise it
+ will be mapped to `unk_token`).
+ max_size (`int`, *optional*):
+ The maximum size of the vocabulary. If left unset, it will default to the size of the vocabulary found
+ after excluding the tokens according to the `min_freq` rule.
+ lower_case (`bool`, *optional*, defaults to `False`):
+ Whether or not to lowercase the input when tokenizing.
+ delimiter (`str`, *optional*):
+ The delimiter used between tokens.
+ vocab_file (`str`, *optional*):
+ File containing the vocabulary (from the original implementation).
+ pretrained_vocab_file (`str`, *optional*):
+ File containing the vocabulary as saved with the `save_pretrained()` method.
+ never_split (`List[str]`, *optional*):
+ List of tokens that should never be split. If no list is specified, will simply use the existing special
+ tokens.
+ unk_token (`str`, *optional*, defaults to `""`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ eos_token (`str`, *optional*, defaults to `""`):
+ The end of sequence token.
+ additional_special_tokens (`List[str]`, *optional*, defaults to `['']`):
+ A list of additional special tokens (for the HuggingFace functionality).
+ language (`str`, *optional*, defaults to `"en"`):
+ The language of this tokenizer (used for mose preprocessing).
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ model_input_names = ["input_ids"]
+
+ def __init__(
+ self,
+ special=None,
+ min_freq=0,
+ max_size=None,
+ lower_case=False,
+ delimiter=None,
+ vocab_file=None,
+ pretrained_vocab_file: Optional[str] = None,
+ never_split=None,
+ unk_token="",
+ eos_token="",
+ additional_special_tokens=[""],
+ language="en",
+ **kwargs,
+ ):
+ logger.error(
+ "`TransfoXL` was deprecated due to security issues linked to `pickle.load` in `TransfoXLTokenizer`. "
+ "See more details on this model's documentation page: "
+ "`https://github.com/huggingface/transformers/blob/main/docs/source/en/model_doc/transfo-xl.md`."
+ )
+
+ requires_backends(self, "sacremoses")
+ if special is None:
+ special = []
+ self.counter = Counter()
+ self.special = special
+ self.min_freq = min_freq
+ self.max_size = max_size
+ self.lower_case = lower_case
+ self.delimiter = delimiter
+ self.vocab_file = vocab_file
+ self.punctuation_symbols = '!"#$%&()*+,-./\\:;<=>?@[\\]^_`{|}~'
+ self.punction_without_space_before_pattern = re.compile(rf"[^\s][{self.punctuation_symbols}]")
+ self.punctuation_with_space_around_pattern = self._compile_space_around_punctuation_pattern()
+ self.language = language
+ self.moses_punct_normalizer = sm.MosesPunctNormalizer(language)
+ self.moses_tokenizer = sm.MosesTokenizer(language)
+ self.moses_detokenizer = sm.MosesDetokenizer(language)
+ self.idx2sym = []
+ self.sym2idx = OrderedDict()
+ # This try... catch... is not beautiful but honestly this tokenizer was not made to be used
+ # in a library like ours, at all.
+ try:
+ vocab_dict = None
+ if pretrained_vocab_file is not None:
+ # Priority on pickle files (support PyTorch and TF)
+ if not strtobool(os.environ.get("TRUST_REMOTE_CODE", "False")):
+ raise ValueError(
+ "This part uses `pickle.load` which is insecure and will execute arbitrary code that is "
+ "potentially malicious. It's recommended to never unpickle data that could have come from an "
+ "untrusted source, or that could have been tampered with. If you already verified the pickle "
+ "data and decided to use it, you can set the environment variable "
+ "`TRUST_REMOTE_CODE` to `True` to allow it."
+ )
+ with open(pretrained_vocab_file, "rb") as f:
+ vocab_dict = pickle.load(f)
+
+ # Loading a torch-saved transfo-xl vocab dict with pickle results in an integer
+ # Entering this if statement means that we tried to load a torch-saved file with pickle, and we failed.
+ # We therefore load it with torch, if it's available.
+ if isinstance(vocab_dict, int):
+ if not is_torch_available():
+ raise ImportError(
+ "Not trying to load dict with PyTorch as you need to install pytorch to load "
+ "from a PyTorch pretrained vocabulary, "
+ "or activate it with environment variables USE_TORCH=1 and USE_TF=0."
+ )
+ vocab_dict = torch.load(pretrained_vocab_file, weights_only=True)
+
+ if vocab_dict is not None:
+ for key, value in vocab_dict.items():
+ if key not in self.__dict__ or key in ["sym2idx", "idx2sym"]:
+ self.__dict__[key] = value
+ elif vocab_file is not None:
+ self.build_vocab()
+
+ except Exception as e:
+ raise ValueError(
+ f"Unable to parse file {pretrained_vocab_file}. Unknown format. "
+ "If you tried to load a model saved through TransfoXLTokenizerFast, "
+ "please note they are not compatible."
+ ) from e
+
+ if vocab_file is not None:
+ self.build_vocab()
+
+ super().__init__(
+ special=special,
+ min_freq=min_freq,
+ max_size=max_size,
+ lower_case=lower_case,
+ delimiter=delimiter,
+ vocab_file=vocab_file,
+ pretrained_vocab_file=pretrained_vocab_file,
+ never_split=never_split,
+ unk_token=unk_token,
+ eos_token=eos_token,
+ additional_special_tokens=additional_special_tokens,
+ language=language,
+ **kwargs,
+ )
+
+ # these are not required to initialize the parent class as only used when tokenizing.
+ if never_split is None:
+ never_split = self.all_special_tokens
+ self.never_split = never_split
+
+ @property
+ def do_lower_case(self):
+ return self.lower_case
+
+ def _compile_space_around_punctuation_pattern(self):
+ look_ahead_for_special_token = f"(?=[{self.punctuation_symbols}])"
+ look_ahead_to_match_all_except_space = r"(?=[^\s])"
+ return re.compile(r"" + look_ahead_for_special_token + look_ahead_to_match_all_except_space)
+
+ def count_file(self, path, verbose=False, add_eos=False):
+ if verbose:
+ logger.info(f"counting file {path} ...")
+ assert os.path.exists(path), f"Input file {path} not found"
+
+ sents = []
+ with open(path, "r", encoding="utf-8") as f:
+ for idx, line in enumerate(f):
+ if verbose and idx > 0 and idx % 500000 == 0:
+ logger.info(f" line {idx}")
+ symbols = self.tokenize(line, add_eos=add_eos)
+ self.counter.update(symbols)
+ sents.append(symbols)
+
+ return sents
+
+ def count_sents(self, sents, verbose=False):
+ """
+ sents : a list of sentences, each a list of tokenized symbols
+ """
+ if verbose:
+ logger.info(f"counting {len(sents)} sents ...")
+ for idx, symbols in enumerate(sents):
+ if verbose and idx > 0 and idx % 500000 == 0:
+ logger.info(f" line {idx}")
+ self.counter.update(symbols)
+
+ def _build_from_file(self, vocab_file):
+ self.idx2sym = []
+ self.sym2idx = OrderedDict()
+
+ with open(vocab_file, "r", encoding="utf-8") as f:
+ for line in f:
+ symb = line.strip().split()[0]
+ self.add_symbol(symb)
+ if "" in self.sym2idx:
+ self.unk_idx = self.sym2idx[""]
+ elif "" in self.sym2idx:
+ self.unk_idx = self.sym2idx[""]
+ else:
+ raise ValueError("Token not in vocabulary and no token in vocabulary for replacement.")
+
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
+ if os.path.isdir(save_directory):
+ vocab_file = os.path.join(
+ save_directory,
+ (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["pretrained_vocab_file"],
+ )
+ else:
+ vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory
+ with open(vocab_file, "wb") as f:
+ pickle.dump(self.__dict__, f)
+ return (vocab_file,)
+
+ def build_vocab(self):
+ if self.vocab_file:
+ logger.info(f"building vocab from {self.vocab_file}")
+ self._build_from_file(self.vocab_file)
+ logger.info(f"Final vocab size {len(self.sym2idx)}")
+ else:
+ logger.info(f"building vocab with min_freq={self.min_freq}, max_size={self.max_size}")
+ self.idx2sym = []
+ self.sym2idx = OrderedDict()
+
+ for sym in self.special:
+ self.add_special(sym)
+
+ for sym, cnt in self.counter.most_common(self.max_size):
+ if cnt < self.min_freq:
+ break
+ self.add_symbol(sym)
+
+ logger.info(f"Final vocab size {len(self.sym2idx)} from {len(self.counter)} unique tokens")
+
+ @torch_only_method
+ def encode_file(self, path, ordered=False, verbose=False, add_eos=True, add_double_eos=False):
+ if verbose:
+ logger.info(f"encoding file {path} ...")
+ assert os.path.exists(path), f"Output file {path} not found"
+ encoded = []
+ with open(path, "r", encoding="utf-8") as f:
+ for idx, line in enumerate(f):
+ if verbose and idx > 0 and idx % 500000 == 0:
+ logger.info(f" line {idx}")
+ symbols = self.tokenize(line, add_eos=add_eos, add_double_eos=add_double_eos)
+ encoded.append(self.convert_to_tensor(symbols))
+
+ if ordered:
+ encoded = torch.cat(encoded)
+
+ return encoded
+
+ @torch_only_method
+ def encode_sents(self, sents, ordered=False, verbose=False):
+ if verbose:
+ logger.info(f"encoding {len(sents)} sents ...")
+ encoded = []
+ for idx, symbols in enumerate(sents):
+ if verbose and idx > 0 and idx % 500000 == 0:
+ logger.info(f" line {idx}")
+ encoded.append(self.convert_to_tensor(symbols))
+
+ if ordered:
+ encoded = torch.cat(encoded)
+
+ return encoded
+
+ def add_special(self, sym):
+ if sym not in self.sym2idx:
+ self.idx2sym.append(sym)
+ self.sym2idx[sym] = len(self.idx2sym) - 1
+ setattr(self, f"{sym.strip('<>')}_idx", self.sym2idx[sym])
+
+ def add_symbol(self, sym):
+ if sym not in self.sym2idx:
+ self.idx2sym.append(sym)
+ self.sym2idx[sym] = len(self.idx2sym) - 1
+
+ def move_added_token(self, token: str, target_idx: int):
+ """
+ Moves an added token to a specific position in the vocab. This method should be used when resizing an embedding
+ layer other than the last one in the `AdaptiveEmbedding` in order to move the token in the tokenizer from the
+ default position (at the very end) to the desired one.
+
+ Args:
+ token: The token to move to a specific position in the vocab.
+ target_idx: The position where the token should be moved to.
+ """
+ assert token in self.added_tokens_encoder, "Token which should be moved has to be an added token"
+ assert token not in self.idx2sym, "Token which should be moved is already in vocab"
+
+ # Insert sym into vocab
+ self.idx2sym.insert(target_idx, token)
+ self.sym2idx[token] = target_idx
+
+ # Shift following indices in sym2idx
+ for idx in range(target_idx + 1, len(self.idx2sym)):
+ current_sym = self.idx2sym[idx]
+ self.sym2idx[current_sym] = idx
+
+ # Delete token from added_tokens
+ old_index = self._added_tokens_encoder.pop(token)
+ self._added_tokens_decoder.pop(old_index)
+
+ def moses_punct_norm(self, text):
+ return self.moses_punct_normalizer.normalize(text)
+
+ def moses_tokenize(self, text):
+ return self.moses_tokenizer.tokenize(
+ text, aggressive_dash_splits=True, return_str=False, escape=False, protected_patterns=self.never_split
+ )
+
+ def moses_pipeline(self, text: str) -> List[str]:
+ """
+ Does basic tokenization using [`sacremoses.MosesPunctNormalizer`] and [`sacremoses.MosesTokenizer`] with
+ *aggressive_dash_splits=True* (see [`sacremoses.tokenize.MosesTokenizer.tokenize`]). Additionally, large
+ comma-separated numbers and floating point values are split. E.g. "23,000 people are 1.80m tall" -> "23 @,@ 000
+ people are 1 @.@ 80m tall"
+
+ Args:
+ text: Text to be tokenize
+
+ Returns:
+ A list of tokenized string
+
+ Example:
+
+ ```python
+ >>> tokenizer = TransfoXLTokenizer.from_pretrained("transfo-xl/transfo-xl-wt103")
+ >>> tokenizer.moses_pipeline("23,000 people are 1.80 m tall")
+ ['23', '@,@', '000', 'people', 'are', '1', '@.@', '80', 'm', 'tall']
+ ```"""
+ text = self.moses_punct_norm(text)
+ text = self.moses_tokenize(text)
+ text = tokenize_numbers(text)
+ return text
+
+ def _convert_id_to_token(self, idx):
+ """Converts an id in a token (BPE) using the vocab."""
+ assert 0 <= idx < len(self), f"Index {idx} out of vocabulary range"
+ return self.idx2sym[idx]
+
+ def _convert_token_to_id(self, sym):
+ """Converts a token (str) in an id using the vocab."""
+ if sym in self.sym2idx:
+ return self.sym2idx[sym]
+ else:
+ # logger.info(f'encounter unk {sym}')
+ # assert '' not in sym
+ if hasattr(self, "unk_idx"):
+ return self.sym2idx.get(sym, self.unk_idx)
+ # Backward compatibility with pre-trained models
+ elif "" in self.sym2idx:
+ return self.sym2idx[""]
+ elif "" in self.sym2idx:
+ return self.sym2idx[""]
+ else:
+ raise ValueError("Token not in vocabulary and no token in vocabulary for replacement.")
+
+ def convert_tokens_to_string(self, tokens):
+ """
+ Converts a sequence of tokens (string) in a single string. Additionally, the split numbers are converted back
+ into it's original form.
+ """
+ out_string = self.moses_detokenizer.detokenize(tokens)
+ return detokenize_numbers(out_string).strip()
+
+ @torch_only_method
+ def convert_to_tensor(self, symbols):
+ return torch.LongTensor(self.convert_tokens_to_ids(symbols))
+
+ @property
+ def vocab_size(self):
+ return len(self.idx2sym)
+
+ def get_vocab(self):
+ vocab = self.sym2idx.copy()
+ vocab.update(self.added_tokens_encoder)
+ return vocab
+
+ def _tokenize(self, line, add_eos=False, add_double_eos=False):
+ line = line.strip()
+ # convert to lower case
+ if self.lower_case:
+ line = line.lower()
+
+ # empty delimiter '' will evaluate False
+ if self.delimiter == "":
+ symbols = line
+ else:
+ symbols = self.moses_pipeline(line)
+
+ if add_double_eos: # lm1b
+ return [""] + symbols + [""]
+ elif add_eos:
+ return symbols + [""]
+ else:
+ return symbols
+
+
+class LMOrderedIterator:
+ def __init__(self, data, bsz, bptt, device="cpu", ext_len=None):
+ """
+ data -- LongTensor -- the LongTensor is strictly ordered
+ """
+ self.bsz = bsz
+ self.bptt = bptt
+ self.ext_len = ext_len if ext_len is not None else 0
+
+ self.device = device
+
+ # Work out how cleanly we can divide the dataset into bsz parts.
+ self.n_step = data.size(0) // bsz
+
+ # Trim off any extra elements that wouldn't cleanly fit (remainders).
+ data = data.narrow(0, 0, self.n_step * bsz)
+
+ # Evenly divide the data across the bsz batches.
+ self.data = data.view(bsz, -1).t().contiguous().to(device)
+
+ # Number of mini-batches
+ self.n_batch = (self.n_step + self.bptt - 1) // self.bptt
+
+ def get_batch(self, i, bptt=None):
+ if bptt is None:
+ bptt = self.bptt
+ seq_len = min(bptt, self.data.size(0) - 1 - i)
+
+ end_idx = i + seq_len
+ beg_idx = max(0, i - self.ext_len)
+
+ data = self.data[beg_idx:end_idx]
+ target = self.data[i + 1 : i + 1 + seq_len]
+
+ data_out = data.transpose(0, 1).contiguous().to(self.device)
+ target_out = target.transpose(0, 1).contiguous().to(self.device)
+
+ return data_out, target_out, seq_len
+
+ def get_fixlen_iter(self, start=0):
+ for i in range(start, self.data.size(0) - 1, self.bptt):
+ yield self.get_batch(i)
+
+ def get_varlen_iter(self, start=0, std=5, min_len=5, max_deviation=3):
+ max_len = self.bptt + max_deviation * std
+ i = start
+ while True:
+ bptt = self.bptt if np.random.random() < 0.95 else self.bptt / 2.0
+ bptt = min(max_len, max(min_len, int(np.random.normal(bptt, std))))
+ data, target, seq_len = self.get_batch(i, bptt)
+ i += seq_len
+ yield data, target, seq_len
+ if i >= self.data.size(0) - 2:
+ break
+
+ def __iter__(self):
+ return self.get_fixlen_iter()
+
+
+class LMShuffledIterator:
+ def __init__(self, data, bsz, bptt, device="cpu", ext_len=None, shuffle=False):
+ """
+ data -- list[LongTensor] -- there is no order among the LongTensors
+ """
+ self.data = data
+
+ self.bsz = bsz
+ self.bptt = bptt
+ self.ext_len = ext_len if ext_len is not None else 0
+
+ self.device = device
+ self.shuffle = shuffle
+
+ def get_sent_stream(self):
+ # index iterator
+ epoch_indices = np.random.permutation(len(self.data)) if self.shuffle else np.array(range(len(self.data)))
+
+ # sentence iterator
+ for idx in epoch_indices:
+ yield self.data[idx]
+
+ @torch_only_method
+ def stream_iterator(self, sent_stream):
+ # streams for each data in the batch
+ streams = [None] * self.bsz
+
+ data = torch.LongTensor(self.bptt, self.bsz)
+ target = torch.LongTensor(self.bptt, self.bsz)
+
+ n_retain = 0
+
+ while True:
+ # data : [n_retain+bptt x bsz]
+ # target : [bptt x bsz]
+ data[n_retain:].fill_(-1)
+ target.fill_(-1)
+
+ valid_batch = True
+
+ for i in range(self.bsz):
+ n_filled = 0
+ try:
+ while n_filled < self.bptt:
+ if streams[i] is None or len(streams[i]) <= 1:
+ streams[i] = next(sent_stream)
+ # number of new tokens to fill in
+ n_new = min(len(streams[i]) - 1, self.bptt - n_filled)
+ # first n_retain tokens are retained from last batch
+ data[n_retain + n_filled : n_retain + n_filled + n_new, i] = streams[i][:n_new]
+ target[n_filled : n_filled + n_new, i] = streams[i][1 : n_new + 1]
+ streams[i] = streams[i][n_new:]
+ n_filled += n_new
+ except StopIteration:
+ valid_batch = False
+ break
+
+ if not valid_batch:
+ return
+
+ data_out = data.transpose(0, 1).contiguous().to(self.device)
+ target_out = target.transpose(0, 1).contiguous().to(self.device)
+
+ yield data_out, target_out, self.bptt
+
+ n_retain = min(data.size(0), self.ext_len)
+ if n_retain > 0:
+ data[:n_retain] = data[-n_retain:]
+ data.resize_(n_retain + self.bptt, data.size(1))
+
+ def __iter__(self):
+ # sent_stream is an iterator
+ sent_stream = self.get_sent_stream()
+
+ for batch in self.stream_iterator(sent_stream):
+ yield batch
+
+
+class LMMultiFileIterator(LMShuffledIterator):
+ def __init__(self, paths, vocab, bsz, bptt, device="cpu", ext_len=None, shuffle=False):
+ self.paths = paths
+ self.vocab = vocab
+
+ self.bsz = bsz
+ self.bptt = bptt
+ self.ext_len = ext_len if ext_len is not None else 0
+
+ self.device = device
+ self.shuffle = shuffle
+
+ def get_sent_stream(self, path):
+ sents = self.vocab.encode_file(path, add_double_eos=True)
+ if self.shuffle:
+ np.random.shuffle(sents)
+ sent_stream = iter(sents)
+
+ return sent_stream
+
+ def __iter__(self):
+ if self.shuffle:
+ np.random.shuffle(self.paths)
+
+ for path in self.paths:
+ # sent_stream is an iterator
+ sent_stream = self.get_sent_stream(path)
+ for batch in self.stream_iterator(sent_stream):
+ yield batch
+
+
+class TransfoXLCorpus:
+ @classmethod
+ @torch_only_method
+ def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
+ """
+ Instantiate a pre-processed corpus.
+ """
+ vocab = TransfoXLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
+ is_local = os.path.isdir(pretrained_model_name_or_path)
+ # redirect to the cache, if necessary
+ try:
+ resolved_corpus_file = cached_file(pretrained_model_name_or_path, CORPUS_NAME, cache_dir=cache_dir)
+ except EnvironmentError:
+ logger.error(
+ f"Corpus '{pretrained_model_name_or_path}' was not found in corpus list"
+ f" ({', '.join(PRETRAINED_CORPUS_ARCHIVE_MAP.keys())}. We assumed '{pretrained_model_name_or_path}'"
+ f" was a path or url but couldn't find files {CORPUS_NAME} at this path or url."
+ )
+ return None
+ if is_local:
+ logger.info(f"loading corpus file {resolved_corpus_file}")
+ else:
+ logger.info(f"loading corpus file {CORPUS_NAME} from cache at {resolved_corpus_file}")
+
+ # Instantiate tokenizer.
+ corpus = cls(*inputs, **kwargs)
+ corpus_dict = torch.load(resolved_corpus_file, weights_only=True)
+ for key, value in corpus_dict.items():
+ corpus.__dict__[key] = value
+ corpus.vocab = vocab
+ if corpus.train is not None:
+ corpus.train = torch.tensor(corpus.train, dtype=torch.long)
+ if corpus.valid is not None:
+ corpus.valid = torch.tensor(corpus.valid, dtype=torch.long)
+ if corpus.test is not None:
+ corpus.test = torch.tensor(corpus.test, dtype=torch.long)
+ return corpus
+
+ def __init__(self, *args, **kwargs):
+ self.vocab = TransfoXLTokenizer(*args, **kwargs)
+ self.dataset = None
+ self.train = None
+ self.valid = None
+ self.test = None
+
+ def build_corpus(self, path, dataset):
+ self.dataset = dataset
+
+ if self.dataset in ["ptb", "wt2", "enwik8", "text8"]:
+ self.vocab.count_file(os.path.join(path, "train.txt"))
+ self.vocab.count_file(os.path.join(path, "valid.txt"))
+ self.vocab.count_file(os.path.join(path, "test.txt"))
+ elif self.dataset == "wt103":
+ self.vocab.count_file(os.path.join(path, "train.txt"))
+ elif self.dataset == "lm1b":
+ train_path_pattern = os.path.join(
+ path,
+ "1-billion-word-language-modeling-benchmark-r13output",
+ "training-monolingual.tokenized.shuffled",
+ "news.en-*",
+ )
+ train_paths = glob.glob(train_path_pattern)
+ # the vocab will load from file when build_vocab() is called
+
+ self.vocab.build_vocab()
+
+ if self.dataset in ["ptb", "wt2", "wt103"]:
+ self.train = self.vocab.encode_file(os.path.join(path, "train.txt"), ordered=True)
+ self.valid = self.vocab.encode_file(os.path.join(path, "valid.txt"), ordered=True)
+ self.test = self.vocab.encode_file(os.path.join(path, "test.txt"), ordered=True)
+ elif self.dataset in ["enwik8", "text8"]:
+ self.train = self.vocab.encode_file(os.path.join(path, "train.txt"), ordered=True, add_eos=False)
+ self.valid = self.vocab.encode_file(os.path.join(path, "valid.txt"), ordered=True, add_eos=False)
+ self.test = self.vocab.encode_file(os.path.join(path, "test.txt"), ordered=True, add_eos=False)
+ elif self.dataset == "lm1b":
+ self.train = train_paths
+ self.valid = self.vocab.encode_file(os.path.join(path, "valid.txt"), ordered=False, add_double_eos=True)
+ self.test = self.vocab.encode_file(os.path.join(path, "test.txt"), ordered=False, add_double_eos=True)
+
+ def get_iterator(self, split, *args, **kwargs):
+ if split == "train":
+ if self.dataset in ["ptb", "wt2", "wt103", "enwik8", "text8"]:
+ data_iter = LMOrderedIterator(self.train, *args, **kwargs)
+ elif self.dataset == "lm1b":
+ kwargs["shuffle"] = True
+ data_iter = LMMultiFileIterator(self.train, self.vocab, *args, **kwargs)
+ elif split in ["valid", "test"]:
+ data = self.valid if split == "valid" else self.test
+ if self.dataset in ["ptb", "wt2", "wt103", "enwik8", "text8"]:
+ data_iter = LMOrderedIterator(data, *args, **kwargs)
+ elif self.dataset == "lm1b":
+ data_iter = LMShuffledIterator(data, *args, **kwargs)
+ else:
+ data_iter = None
+ raise ValueError(f"Split not recognized: {split}")
+
+ return data_iter
+
+
+@torch_only_method
+def get_lm_corpus(datadir, dataset):
+ fn = os.path.join(datadir, "cache.pt")
+ fn_pickle = os.path.join(datadir, "cache.pkl")
+ if os.path.exists(fn):
+ logger.info("Loading cached dataset...")
+ corpus = torch.load(fn_pickle, weights_only=True)
+ elif os.path.exists(fn):
+ logger.info("Loading cached dataset from pickle...")
+ if not strtobool(os.environ.get("TRUST_REMOTE_CODE", "False")):
+ raise ValueError(
+ "This part uses `pickle.load` which is insecure and will execute arbitrary code that is potentially "
+ "malicious. It's recommended to never unpickle data that could have come from an untrusted source, or "
+ "that could have been tampered with. If you already verified the pickle data and decided to use it, "
+ "you can set the environment variable `TRUST_REMOTE_CODE` to `True` to allow it."
+ )
+ with open(fn, "rb") as fp:
+ corpus = pickle.load(fp)
+ else:
+ logger.info(f"Producing dataset {dataset}...")
+ kwargs = {}
+ if dataset in ["wt103", "wt2"]:
+ kwargs["special"] = [""]
+ kwargs["lower_case"] = False
+ elif dataset == "ptb":
+ kwargs["special"] = [""]
+ kwargs["lower_case"] = True
+ elif dataset == "lm1b":
+ kwargs["special"] = []
+ kwargs["lower_case"] = False
+ kwargs["vocab_file"] = os.path.join(datadir, "1b_word_vocab.txt")
+ elif dataset in ["enwik8", "text8"]:
+ pass
+
+ corpus = TransfoXLCorpus(datadir, dataset, **kwargs)
+ torch.save(corpus, fn)
+
+ return corpus
+
+
+__all__ = ["TransfoXLCorpus", "TransfoXLTokenizer"]
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/tvlt/__init__.py b/docs/transformers/build/lib/transformers/models/deprecated/tvlt/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..941db2f6ac5fefe66e06d598e1adbe0db1cf1769
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/tvlt/__init__.py
@@ -0,0 +1,20 @@
+# flake8: noqa
+# There's no way to ignore "F401 '...' imported but unused" warnings in this
+# module, but to preserve other warnings. So, don't check this module at all.
+from typing import TYPE_CHECKING
+
+from ....utils import _LazyModule
+from ....utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_tvlt import *
+ from .feature_extraction_tvlt import *
+ from .processing_tvlt import *
+ from .modeling_tvlt import *
+ from .image_processing_tvlt import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/tvlt/configuration_tvlt.py b/docs/transformers/build/lib/transformers/models/deprecated/tvlt/configuration_tvlt.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf159fa7e0b7754778ebda57b46dea4479fdc3c9
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/tvlt/configuration_tvlt.py
@@ -0,0 +1,187 @@
+# coding=utf-8
+# Copyright 2023 MURGe-Lab and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""TVLT model configuration"""
+
+from ....configuration_utils import PretrainedConfig
+from ....utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class TvltConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`TvltModel`]. It is used to instantiate a TVLT
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the TVLT
+ [ZinengTang/tvlt-base](https://huggingface.co/ZinengTang/tvlt-base) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ image_size (`int`, *optional*, defaults to 224):
+ The size (resolution) of each image.
+ spectrogram_length (`int`, *optional*, defaults to 2048):
+ The time length of each audio spectrogram.
+ frequency_length (`int`, *optional*, defaults to 128):
+ The frequency length of audio spectrogram.
+ image_patch_size (`List[int]`, *optional*, defaults to `[16, 16]`):
+ The size (resolution) of each image patch.
+ audio_patch_size (`List[int]`, *optional*, defaults to `[16, 16]`):
+ The size (resolution) of each audio patch.
+ num_image_channels (`int`, *optional*, defaults to 3):
+ The number of input image channels.
+ num_audio_channels (`int`, *optional*, defaults to 1):
+ The number of input audio channels.
+ num_frames (`int`, *optional*, defaults to 8):
+ The maximum number of frames for an input video.
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ intermediate_size (`int`, *optional*, defaults to 3072):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the layer normalization layers.
+ qkv_bias (`bool`, *optional*, defaults to `True`):
+ Whether to add a bias to the queries, keys and values.
+ use_mean_pooling (`bool`, *optional*, defaults to `False`):
+ Whether to mean pool the final hidden states instead of using the final hidden state of the [CLS] token.
+ decoder_num_attention_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads for each attention layer in the decoder.
+ decoder_hidden_size (`int`, *optional*, defaults to 512):
+ Dimensionality of the decoder.
+ decoder_num_hidden_layers (`int`, *optional*, defaults to 8):
+ Number of hidden layers in the decoder.
+ decoder_intermediate_size (`int`, *optional*, defaults to 2048):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the decoder.
+ pixel_mask_ratio (`float`, *optional*, defaults to 0.75):
+ Image patch masking ratio.
+ audio_mask_ratio (`float`, *optional*, defaults to 0.15):
+ Audio patch masking ratio.
+ audio_mask_type (`str`, *optional*, defaults to `"frame-level"`):
+ Audio patch masking type, choose between "frame-level" and "patch-level".
+ task_matching (`bool`, *optional*, defaults to `True`):
+ Whether to use vision audio matching task in pretraining.
+ task_mae (`bool`, *optional*, defaults to `True`):
+ Whether to use the masked auto-encoder (MAE) in pretraining.
+ loss_type (`str`, *optional*, defaults to `"classification"`):
+ Loss types including regression and classification.
+
+ Example:
+
+ ```python
+ >>> from transformers import TvltConfig, TvltModel
+
+ >>> # # Initializing a TVLT ZinengTang/tvlt-base style configuration
+ >>> configuration = TvltConfig()
+
+ >>> # # Initializing a model (with random weights) from the ZinengTang/tvlt-base style configuration
+ >>> model = TvltModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "tvlt"
+
+ def __init__(
+ self,
+ image_size=224,
+ spectrogram_length=2048,
+ frequency_length=128,
+ image_patch_size=[16, 16],
+ audio_patch_size=[16, 16],
+ num_image_channels=3,
+ num_audio_channels=1,
+ num_frames=8,
+ hidden_size=768,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ intermediate_size=3072,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.0,
+ attention_probs_dropout_prob=0.0,
+ initializer_range=0.02,
+ layer_norm_eps=1e-6,
+ qkv_bias=True,
+ use_mean_pooling=False,
+ decoder_num_attention_heads=16,
+ decoder_hidden_size=512,
+ decoder_num_hidden_layers=8,
+ decoder_intermediate_size=2048,
+ pixel_mask_ratio=0.75,
+ audio_mask_ratio=0.15,
+ audio_mask_type="frame-level",
+ task_matching=True,
+ task_mae=True,
+ loss_type="classification",
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ if audio_mask_type not in ("frame-level", "patch_level"):
+ raise ValueError(
+ "audio_mask_type must be one of two acceptable strategies - {'frame_level', 'patch-level') "
+ f"got {audio_mask_type}"
+ )
+
+ self.image_size = image_size
+ self.spectrogram_length = spectrogram_length
+ self.frequency_length = frequency_length
+ self.image_patch_size = image_patch_size
+ self.audio_patch_size = audio_patch_size
+ self.num_image_channels = num_image_channels
+ self.num_audio_channels = num_audio_channels
+ self.num_frames = num_frames
+
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.qkv_bias = qkv_bias
+ self.use_mean_pooling = use_mean_pooling
+
+ self.decoder_num_attention_heads = decoder_num_attention_heads
+ self.decoder_hidden_size = decoder_hidden_size
+ self.decoder_num_hidden_layers = decoder_num_hidden_layers
+ self.decoder_intermediate_size = decoder_intermediate_size
+ self.pixel_mask_ratio = pixel_mask_ratio
+ self.audio_mask_ratio = audio_mask_ratio
+ self.audio_mask_type = audio_mask_type
+
+ self.task_matching = task_matching
+ self.task_mae = task_mae
+ self.loss_type = loss_type
+
+
+__all__ = ["TvltConfig"]
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/tvlt/feature_extraction_tvlt.py b/docs/transformers/build/lib/transformers/models/deprecated/tvlt/feature_extraction_tvlt.py
new file mode 100644
index 0000000000000000000000000000000000000000..bbbfac9031b9be499bd168c74681f32e32357c5b
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/tvlt/feature_extraction_tvlt.py
@@ -0,0 +1,233 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Feature extractor class for TVLT."""
+
+from math import ceil
+from typing import List, Optional, Union
+
+import numpy as np
+
+from ....audio_utils import mel_filter_bank, spectrogram, window_function
+from ....feature_extraction_sequence_utils import BatchFeature, SequenceFeatureExtractor
+from ....utils import TensorType, logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class TvltFeatureExtractor(SequenceFeatureExtractor):
+ r"""
+ Constructs a TVLT audio feature extractor. This feature extractor can be used to prepare audios for the model.
+
+ This feature extractor inherits from [`FeatureExtractionMixin`] which contains most of the main methods. Users
+ should refer to this superclass for more information regarding those methods.
+
+ Args:
+ spectrogram_length (`Dict[str, int]` *optional*, defaults to 2048):
+ The time length of each audio spectrogram.
+ num_channels (`int` *optional*, defaults to 1):
+ Number of audio channels.
+ patch_size (`List[int]` *optional*, defaults to `[16, 16]`):
+ The patch size of audio patch embedding.
+ feature_size (`int`, *optional*, defaults to 128):
+ The frequency length of audio spectrogram.
+ sampling_rate (`int`, *optional*, defaults to 44100):
+ The sampling rate at which the audio files should be digitalized expressed in Hertz (Hz).
+ hop_length_to_sampling_rate (`int`, *optional*, defaults to 86):
+ Hop length is length of the overlaping windows for the STFT used to obtain the Mel Frequency coefficients.
+ For example, with sampling rate 44100, the hop length is 512, with 44100 / 512 = 86
+ n_fft (`int`, *optional*, defaults to 2048):
+ Size of the Fourier transform.
+ padding_value (`float`, *optional*, defaults to 0.0):
+ Padding value used to pad the audio. Should correspond to silences.
+ """
+
+ model_input_names = ["audio_values", "audio_mask"]
+
+ def __init__(
+ self,
+ spectrogram_length=2048,
+ num_channels=1,
+ patch_size=[16, 16],
+ feature_size=128,
+ sampling_rate=44100,
+ hop_length_to_sampling_rate=86,
+ n_fft=2048,
+ padding_value=0.0,
+ **kwargs,
+ ):
+ super().__init__(
+ feature_size=feature_size,
+ sampling_rate=sampling_rate,
+ padding_value=padding_value,
+ **kwargs,
+ )
+
+ self.spectrogram_length = spectrogram_length
+ self.num_channels = num_channels
+ self.patch_size = patch_size
+ self.freq_len = feature_size // self.patch_size[1]
+ self.n_fft = n_fft
+ self.hop_length = sampling_rate // hop_length_to_sampling_rate
+ self.sampling_rate = sampling_rate
+ self.padding_value = padding_value
+ self.mel_filters = mel_filter_bank(
+ num_frequency_bins=1 + n_fft // 2,
+ num_mel_filters=feature_size,
+ min_frequency=0.0,
+ max_frequency=22050.0,
+ sampling_rate=sampling_rate,
+ norm="slaney",
+ mel_scale="slaney",
+ ).T
+
+ def _np_extract_fbank_features(self, waveform: np.array) -> np.ndarray:
+ """
+ Compute the log-mel spectrogram of the provided audio, gives similar results to Whisper's original torch
+ implementation with 1e-5 tolerance.
+ """
+ log_spec = spectrogram(
+ waveform,
+ window_function(self.n_fft, "hann"),
+ frame_length=self.n_fft,
+ hop_length=self.hop_length,
+ power=2.0,
+ mel_filters=self.mel_filters.T,
+ log_mel="dB",
+ db_range=80.0,
+ )
+ log_spec = log_spec[:, :-1]
+ log_spec = log_spec - 20.0
+ log_spec = np.clip(log_spec / 40.0, -2.0, 0.0) + 1.0
+ return log_spec
+
+ def __call__(
+ self,
+ raw_speech: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]],
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_attention_mask: Optional[bool] = True,
+ sampling_rate: Optional[int] = None,
+ resample: bool = False,
+ mask_audio: bool = False,
+ **kwargs,
+ ) -> BatchFeature:
+ """
+ Main method to prepare one or several audio(s) for the model.
+
+ Args:
+ raw_speech (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`):
+ The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float
+ values, a list of numpy arrays or a list of list of float values. Must be mono channel audio, not
+ stereo, i.e. single float per timestep.
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
+ If set, will return tensors instead of list of python integers. Acceptable values are:
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return Numpy `np.ndarray` objects.
+ return_attention_mask (`bool`, *optional*, default to `True`):
+ Whether to return the attention mask. If left to the default, will return the attention mask according
+ to the specific feature_extractor's default. [What are attention masks?](../glossary#attention-mask)
+
+
+
+ For TvltTransformer models, `attention_mask` should alwys be passed for batched inference, to avoid
+ subtle bugs.
+
+
+
+ sampling_rate (`int`, *optional*):
+ The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass
+ `sampling_rate` at the forward call to prevent silent errors and allow automatic speech recognition
+ pipeline. Current model supports sampling rate 16000 and 44100.
+ resample (`bool`, *optional*, defaults to `False`):
+ If the sampling rate is not matched, resample the input audio to match.
+ mask_audio (`bool`, *optional*, defaults to `False`):
+ Whether or not to mask input audio for MAE task.
+
+ Returns:
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
+
+ - **audio_values** -- Audio values to be fed to a model, of shape (batch_size, num_channels, height,
+ width).
+
+ - **audio_mask** -- Audio masks to be fed to a model, of shape (batch_size, num_audio_patches).
+ """
+
+ if sampling_rate is not None:
+ if sampling_rate != self.sampling_rate:
+ raise ValueError(
+ "This feature extractor is set to support sampling rate"
+ f" of {self.sampling_rate}. Please make sure that the provided `raw_speech` input was sampled"
+ f" with {self.sampling_rate} and not {sampling_rate}."
+ )
+ else:
+ logger.warning(
+ "It is strongly recommended to pass the `sampling_rate` argument to this function. "
+ "Failing to do so can result in silent errors that might be hard to debug."
+ )
+
+ is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1
+ if is_batched_numpy and len(raw_speech.shape) > 2:
+ raise ValueError(f"Only mono-channel audio is supported for input to {self}")
+ is_batched = is_batched_numpy or (
+ isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list)))
+ )
+ if is_batched:
+ raw_speech = [np.asarray([speech], dtype=np.float32).T for speech in raw_speech]
+ elif not is_batched and not isinstance(raw_speech, np.ndarray):
+ raw_speech = np.asarray(raw_speech, dtype=np.float32)
+ elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(np.float64):
+ raw_speech = raw_speech.astype(np.float32)
+ # always return batch
+ if not is_batched:
+ raw_speech = [np.asarray([raw_speech]).T]
+
+ # Convert audio signals to log mel spectrograms, truncate by time axis
+ audio_features = [
+ self._np_extract_fbank_features(waveform.squeeze()).T[: self.spectrogram_length] for waveform in raw_speech
+ ]
+ if isinstance(audio_features[0], List):
+ audio_features = [np.asarray(feature, dtype=np.float32) for feature in audio_features]
+
+ # Create audio attention mask
+ max_patch_len = max(
+ [ceil(feature.shape[0] / self.patch_size[0]) * self.freq_len for feature in audio_features]
+ ) # The maximum number of audio patches in a batch
+ if return_attention_mask:
+ audio_mask = [
+ (ceil(feature.shape[0] / self.patch_size[0]) * self.freq_len) * [1]
+ + (max_patch_len - ceil(feature.shape[0] / self.patch_size[0]) * self.freq_len) * [0]
+ for feature in audio_features
+ ]
+ audio_mask = np.array(audio_mask).astype(np.float32)
+
+ # convert into correct format for padding
+ max_time_len = max_patch_len // self.freq_len * self.patch_size[0] # The maximum audio size in a batch
+ padded_audio_features = np.ones([len(audio_features), 1, max_time_len, self.feature_size]).astype(np.float32)
+ padded_audio_features = padded_audio_features * self.padding_value
+ for i in range(len(audio_features)):
+ feature = audio_features[i]
+ padded_audio_features[i, :, : feature.shape[0], :] = feature
+
+ # return as BatchFeature
+ if return_attention_mask:
+ data = {"audio_values": padded_audio_features, "audio_mask": audio_mask}
+ else:
+ data = {"audio_values": padded_audio_features}
+
+ encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
+ return encoded_inputs
+
+
+__all__ = ["TvltFeatureExtractor"]
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/tvlt/image_processing_tvlt.py b/docs/transformers/build/lib/transformers/models/deprecated/tvlt/image_processing_tvlt.py
new file mode 100644
index 0000000000000000000000000000000000000000..db87d6e8d5686e9d5f5b04ab9020d0e4bed7aa19
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/tvlt/image_processing_tvlt.py
@@ -0,0 +1,438 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Image processor class for TVLT."""
+
+from typing import Dict, List, Optional, Union
+
+import numpy as np
+
+from ....image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
+from ....image_transforms import (
+ get_resize_output_image_size,
+ resize,
+ to_channel_dimension_format,
+)
+from ....image_utils import (
+ IMAGENET_STANDARD_MEAN,
+ IMAGENET_STANDARD_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ infer_channel_dimension_format,
+ is_scaled_image,
+ is_valid_image,
+ to_numpy_array,
+ valid_images,
+ validate_kwargs,
+ validate_preprocess_arguments,
+)
+from ....utils import TensorType, logging
+
+
+logger = logging.get_logger(__name__)
+
+
+def make_batched(videos) -> List[List[ImageInput]]:
+ if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)):
+ return videos
+
+ elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]):
+ videos_dim = np.array(videos[0]).ndim
+ if videos_dim == 3:
+ return [videos]
+ elif videos_dim == 4:
+ return videos
+
+ elif is_valid_image(videos):
+ videos_dim = np.array(videos).ndim
+ if videos_dim == 3:
+ return [[videos]]
+ elif videos_dim == 4:
+ return [videos]
+ elif videos_dim == 5:
+ return videos
+
+ raise ValueError(f"Could not make batched video from {videos}")
+
+
+class TvltImageProcessor(BaseImageProcessor):
+ r"""
+ Constructs a TVLT image processor.
+
+ This processor can be used to prepare either videos or images for the model by converting images to 1-frame videos.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
+ `do_resize` parameter in the `preprocess` method.
+ size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 224}`):
+ Size of the output image after resizing. The shortest edge of the image will be resized to
+ `size["shortest_edge"]` while maintaining the aspect ratio of the original image. Can be overriden by
+ `size` in the `preprocess` method.
+ patch_size (`List[int]` *optional*, defaults to [16,16]):
+ The patch size of image patch embedding.
+ num_frames (`int` *optional*, defaults to 8):
+ The maximum number of video frames.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
+ Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
+ `preprocess` method.
+ do_center_crop (`bool`, *optional*, defaults to `True`):
+ Whether to center crop the image to the specified `crop_size`. Can be overridden by the `do_center_crop`
+ parameter in the `preprocess` method.
+ crop_size (`Dict[str, int]`, *optional*, defaults to `{"height": 224, "width": 224}`):
+ Size of the image after applying the center crop. Can be overridden by the `crop_size` parameter in the
+ `preprocess` method.
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
+ parameter in the `preprocess` method.
+ rescale_factor (`int` or `float`, *optional*, defaults to 1/255):
+ Defines the scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter
+ in the `preprocess` method.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
+ method.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
+ image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
+ """
+
+ model_input_names = [
+ "pixel_values",
+ "pixel_mask",
+ "pixel_values_mixed",
+ "pixel_mask_mixed",
+ ]
+
+ def __init__(
+ self,
+ do_resize: bool = True,
+ size: Dict[str, int] = None,
+ patch_size: List[int] = [16, 16],
+ num_frames: int = 8,
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
+ do_center_crop: bool = True,
+ crop_size: Dict[str, int] = None,
+ do_rescale: bool = True,
+ rescale_factor: Union[int, float] = 1 / 255,
+ do_normalize: bool = True,
+ image_mean: Optional[Union[float, List[float]]] = IMAGENET_STANDARD_MEAN,
+ image_std: Optional[Union[float, List[float]]] = IMAGENET_STANDARD_STD,
+ init_mask_generator=False,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ size = size if size is not None else {"shortest_edge": 224}
+ size = get_size_dict(size, default_to_square=False)
+ crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
+ crop_size = get_size_dict(crop_size, param_name="crop_size")
+
+ self.do_resize = do_resize
+ self.size = size
+ self.patch_size = patch_size
+ self.num_frames = num_frames
+ self.do_center_crop = do_center_crop
+ self.crop_size = crop_size
+ self.resample = resample
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean
+ self.image_std = image_std
+ self._valid_processor_keys = [
+ "videos",
+ "do_resize",
+ "size",
+ "patch_size",
+ "num_frames",
+ "resample",
+ "do_center_crop",
+ "crop_size",
+ "do_rescale",
+ "rescale_factor",
+ "do_normalize",
+ "image_mean",
+ "image_std",
+ "is_mixed",
+ "return_tensors",
+ "data_format",
+ "input_data_format",
+ ]
+
+ def resize(
+ self,
+ image: np.ndarray,
+ size: Dict[str, int],
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ) -> np.ndarray:
+ """
+ Resize an image.
+
+ Args:
+ image (`np.ndarray`):
+ Image to resize.
+ size (`Dict[str, int]`):
+ Size of the output image. If `size` is of the form `{"height": h, "width": w}`, the output image will
+ have the size `(h, w)`. If `size` is of the form `{"shortest_edge": s}`, the output image will have its
+ shortest edge of length `s` while keeping the aspect ratio of the original image.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
+ Resampling filter to use when resiizing the image.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ input_data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred.
+ """
+ size = get_size_dict(size, default_to_square=False)
+ if "shortest_edge" in size:
+ output_size = get_resize_output_image_size(
+ image, size["shortest_edge"], default_to_square=False, input_data_format=input_data_format
+ )
+ elif "height" in size and "width" in size:
+ output_size = (size["height"], size["width"])
+ else:
+ raise ValueError(f"Size must have 'height' and 'width' or 'shortest_edge' as keys. Got {size.keys()}")
+ return resize(
+ image,
+ size=output_size,
+ resample=resample,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ **kwargs,
+ )
+
+ def _preprocess_image(
+ self,
+ image: ImageInput,
+ do_resize: Optional[bool] = None,
+ size: Dict[str, int] = None,
+ resample: PILImageResampling = None,
+ do_center_crop: Optional[bool] = None,
+ crop_size: Dict[str, int] = None,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> np.ndarray:
+ """Preprocesses a single image."""
+
+ validate_preprocess_arguments(
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ do_center_crop=do_center_crop,
+ crop_size=crop_size,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ )
+
+ # All transformations expect numpy arrays.
+ image = to_numpy_array(image)
+
+ if do_rescale and is_scaled_image(image):
+ logger.warning_once(
+ "It looks like you are trying to rescale already rescaled images. If the input"
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
+ )
+
+ if input_data_format is None:
+ input_data_format = infer_channel_dimension_format(image)
+
+ if do_resize:
+ image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
+
+ if do_center_crop:
+ image = self.center_crop(image, size=crop_size, input_data_format=input_data_format)
+
+ if do_rescale:
+ image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
+
+ if do_normalize:
+ image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
+ image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
+ return image
+
+ def preprocess(
+ self,
+ videos: ImageInput,
+ do_resize: Optional[bool] = None,
+ size: Dict[str, int] = None,
+ patch_size: List[int] = None,
+ num_frames: Optional[int] = None,
+ resample: PILImageResampling = None,
+ do_center_crop: Optional[bool] = None,
+ crop_size: Dict[str, int] = None,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ is_mixed: bool = False,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: ChannelDimension = ChannelDimension.FIRST,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ) -> BatchFeature:
+ """
+ Preprocess an videos or image or batch of videos or images.
+
+ Args:
+ videos (`ImageInput`):
+ Images or videos to preprocess. Expects a single or batch of frames with pixel values ranging from 0 to
+ 255. If passing in frames with pixel values between 0 and 1, set `do_rescale=False`.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
+ Size of the image after applying resize.
+ patch_size (`List[int]` *optional*, defaults to self.patch_size):
+ The patch size of image patch embedding.
+ num_frames (`int` *optional*, defaults to self.num_frames):
+ The maximum number of video frames.
+ resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only
+ has an effect if `do_resize` is set to `True`.
+ do_center_crop (`bool`, *optional*, defaults to `self.do_centre_crop`):
+ Whether to centre crop the image.
+ crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`):
+ Size of the image after applying the centre crop.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image values between [0 - 1].
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
+ Image mean.
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
+ Image standard deviation.
+ is_mixed (`bool`, *optional*):
+ If the input video has negative samples.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - Unset: Use the inferred channel dimension format of the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+
+ Returns:
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
+
+ - **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height,
+ width).
+
+ - **pixel_mask** -- Pixel masks to be fed to a model, of shape (batch_size, num_pixel_patches).
+
+ - **pixel_values_mixed** -- Pixel values with both postive or negative to be fed to a model, of shape
+ (batch_size, num_channels, height, width).
+
+ - **pixel_mask_mixed** -- Pixel masks with both postive or negative to be fed to a model, of shape
+ (batch_size, num_pixel_patches).
+ """
+ do_resize = do_resize if do_resize is not None else self.do_resize
+ resample = resample if resample is not None else self.resample
+ do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+
+ size = size if size is not None else self.size
+ size = get_size_dict(size, default_to_square=False)
+ crop_size = crop_size if crop_size is not None else self.crop_size
+ crop_size = get_size_dict(crop_size, param_name="crop_size")
+ patch_size = patch_size if patch_size is not None else self.patch_size
+ num_frames = num_frames if patch_size is not None else self.num_frames
+
+ validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys)
+
+ if not valid_images(videos):
+ raise ValueError(
+ "Invalid image or video type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+
+ videos = make_batched(videos)
+
+ # Check number of frames is fewer than maximum frames
+ for video in videos:
+ if len(video) > self.num_frames:
+ raise ValueError(
+ f"number of frames must not be greater than the maximum frames of the model {self.num_frames}."
+ )
+
+ max_num_frames = max([len(video) for video in videos])
+ num_patches_per_image = (size["shortest_edge"] // patch_size[0]) ** 2
+ video_masks = np.array(
+ [
+ len(video) * num_patches_per_image * [1] + (max_num_frames - len(video)) * num_patches_per_image * [0]
+ for video in videos
+ ]
+ )
+
+ videos = [
+ [
+ self._preprocess_image(
+ image=img,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ do_center_crop=do_center_crop,
+ crop_size=crop_size,
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ )
+ for img in video
+ ]
+ for video in videos
+ ]
+
+ # If videos contain both positive/negative, use mixed key for video-audio matching task
+ if is_mixed:
+ data = {"pixel_values_mixed": videos, "pixel_mask_mixed": video_masks}
+ else:
+ data = {"pixel_values": videos, "pixel_mask": video_masks}
+
+ return BatchFeature(data=data, tensor_type=return_tensors)
+
+
+__all__ = ["TvltImageProcessor"]
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/tvlt/modeling_tvlt.py b/docs/transformers/build/lib/transformers/models/deprecated/tvlt/modeling_tvlt.py
new file mode 100644
index 0000000000000000000000000000000000000000..279224ac4d24a2c4b3a7694dbd20a77390a0a793
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/tvlt/modeling_tvlt.py
@@ -0,0 +1,1291 @@
+# coding=utf-8
+# Copyright 2023 MURGe-Lab and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch TVLT model."""
+
+import collections.abc
+import math
+from copy import deepcopy
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from ....activations import ACT2FN
+from ....modeling_outputs import BaseModelOutput, SequenceClassifierOutput
+from ....modeling_utils import PreTrainedModel
+from ....pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
+from ....utils import (
+ ModelOutput,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from .configuration_tvlt import TvltConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "TvltConfig"
+_CHECKPOINT_FOR_DOC = "ZinengTang/tvlt-base"
+
+
+@dataclass
+class TvltModelOutput(ModelOutput):
+ """
+ Class for TvltModel's outputs, with potential hidden states and attentions.
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ last_pixel_hidden_state (`torch.FloatTensor` of shape `(batch_size, pixel_sequence_length, hidden_size)`):
+ Pixel sequence of hidden-states at the output of the last layer of the model.
+ last_audio_hidden_state (`torch.FloatTensor` of shape `(batch_size, audio_sequence_length, hidden_size)`):
+ Audio sequence of hidden-states at the output of the last layer of the model.
+ pixel_label_masks (`torch.FloatTensor` of shape `(batch_size, pixel_patch_length)`):
+ Tensor indicating which pixel patches are masked (1) and which are not (0).
+ audio_label_masks (`torch.FloatTensor` of shape `(batch_size, audio_patch_length)`):
+ Tensor indicating which audio patches are masked (1) and which are not (0).
+ pixel_ids_restore (`torch.LongTensor` of shape `(batch_size, pixel_patch_length)`):
+ Tensor containing the ids permutation of pixel masking.
+ audio_ids_restore (`torch.LongTensor` of shape `(batch_size, audio_patch_length)`):
+ Tensor containing the ids permutation of audio masking.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings and one for the output of each layer) of
+ shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
+ plus the initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
+ the self-attention heads.
+ """
+
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ last_pixel_hidden_state: Optional[torch.FloatTensor] = None
+ last_audio_hidden_state: Optional[torch.FloatTensor] = None
+ pixel_label_masks: Optional[torch.LongTensor] = None
+ audio_label_masks: Optional[torch.LongTensor] = None
+ pixel_ids_restore: Optional[torch.LongTensor] = None
+ audio_ids_restore: Optional[torch.LongTensor] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class TvltDecoderOutput(ModelOutput):
+ """
+ Class for TvltDecoder's outputs, with potential hidden states and attentions.
+
+ Args:
+ logits (`torch.FloatTensor` of shape `(batch_size, patch_size ** 2 * num_channels)`):
+ Pixel reconstruction logits.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings and one for the output of each layer) of
+ shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
+ plus the initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
+ the self-attention heads.
+ """
+
+ logits: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class TvltForPreTrainingOutput(ModelOutput):
+ """
+ Class for TvltForPreTraining's outputs, with potential hidden states and attentions.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`):
+ Pixel reconstruction loss.
+ matching_logits (`torch.FloatTensor` of shape `(batch_size, 1)`):
+ Matching objective logits.
+ pixel_logits (`torch.FloatTensor` of shape
+ `(batch_size, pixel_patch_length, image_patch_size ** 3 * pixel_num_channels)`): Pixel reconstruction
+ logits.
+ audio_logits (`torch.FloatTensor` of shape
+ `(batch_size, audio_patch_length, image_patch_size[0] * image_patch_size[1])`): Audio reconstruction
+ logits.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings and one for the output of each layer) of
+ shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
+ plus the initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
+ the self-attention heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ matching_logits: Optional[torch.FloatTensor] = None
+ pixel_logits: Optional[torch.FloatTensor] = None
+ audio_logits: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+
+
+def generate_pixel_mask_noise(pixel_values, pixel_mask=None, mask_ratio=0.75):
+ """Generate noise for audio masking."""
+
+ batch_size, seq_len = pixel_values.shape[:2]
+ noise = torch.rand((batch_size, seq_len), device=pixel_values.device) # noise in [0, 1]
+ len_keep = int(seq_len * (1 - mask_ratio))
+ return noise, len_keep
+
+
+def generate_audio_mask_noise(audio_values, audio_mask=None, mask_ratio=0.75, mask_type="patch-level", freq_len=8):
+ """Generate noise for audio masking."""
+
+ batch_size, seq_len = audio_values.shape[:2]
+ if mask_type == "frame-level":
+ num_time_patches = seq_len // freq_len
+ noise = (
+ torch.rand(batch_size, num_time_patches, device=audio_values.device)
+ .unsqueeze(-1)
+ .repeat(1, 1, freq_len)
+ .view(batch_size, seq_len)
+ ) # noise in [0, 1]
+ elif mask_type == "patch-level":
+ noise = torch.rand(batch_size, seq_len, device=audio_values.device) # noise in [0, 1]
+ len_keep = int(seq_len * (1 - mask_ratio))
+ return noise, len_keep
+
+
+def random_masking(sequence, noise, len_keep, attention_masks=None):
+ """
+ Perform random masking by per-sample shuffling on frame-level. Per-sample shuffling is done by argsort random
+ noise. sequence: [batch_size, seq_len, hidden_dim], sequence
+ """
+
+ batch_size, seq_len, hidden_dim = sequence.shape
+
+ # sort noise for each sample
+ ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
+ ids_restore = torch.argsort(ids_shuffle, dim=1)
+
+ # keep the first subset
+ ids_keep = ids_shuffle[:, :len_keep]
+ sequence_masked = torch.gather(sequence, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, hidden_dim))
+
+ # generate the binary mask: 0 is keep, 1 is remove
+ label_masks = torch.ones([batch_size, seq_len], device=sequence.device)
+ label_masks[:, :len_keep] = 0
+ # unshuffle to get the binary mask
+ label_masks = torch.gather(label_masks, dim=1, index=ids_restore)
+
+ if attention_masks is not None:
+ label_masks *= attention_masks
+ attention_masks = torch.gather(attention_masks, dim=1, index=ids_keep)
+
+ return sequence_masked, attention_masks, label_masks, ids_restore
+
+
+class TvltPixelEmbeddings(nn.Module):
+ """Construct the patch and position embeddings."""
+
+ def __init__(self, config):
+ super().__init__()
+
+ self.patch_embeddings = TvltPixelPatchEmbeddings(config)
+ self.num_patches_per_image = self.patch_embeddings.num_patches_per_image
+
+ self.type_embed_v = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
+ self.temporal_embed = nn.Parameter(torch.zeros(1, config.num_frames, config.hidden_size))
+ self.pos_embed_v = nn.Parameter(torch.zeros(1, self.num_patches_per_image, config.hidden_size))
+
+ self.config = config
+
+ def forward(self, pixel_values, attention_masks=None):
+ # create patch embeddings
+ batch_size, num_frames, num_channels, height, width = pixel_values.shape
+
+ embeddings = self.patch_embeddings(pixel_values)
+ embeddings += self.pos_embed_v.repeat(1, num_frames, 1)
+ embeddings += torch.repeat_interleave(self.temporal_embed[:, :num_frames], self.num_patches_per_image, dim=1)
+ embeddings += self.type_embed_v
+
+ return embeddings, attention_masks
+
+
+class TvltAudioEmbeddings(nn.Module):
+ """Construct the patch and position embeddings."""
+
+ def __init__(self, config):
+ super().__init__()
+
+ self.patch_embeddings = TvltAudioPatchEmbeddings(config)
+ self.num_patches = self.patch_embeddings.num_patches
+
+ self.type_embed_a = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
+ self.num_freq_patches = config.frequency_length // config.audio_patch_size[1]
+ self.pos_embed_a = nn.Parameter(torch.zeros(1, self.num_patches // self.num_freq_patches, config.hidden_size))
+ self.freq_embed = nn.Parameter(torch.zeros(1, self.num_freq_patches, config.hidden_size))
+
+ self.num_freq_patches = config.frequency_length // config.audio_patch_size[1]
+ self.config = config
+
+ def forward(self, audio_values, attention_masks=None):
+ # create patch embeddings
+ embeddings = self.patch_embeddings(audio_values)
+
+ num_time_patches = embeddings.size(1) // self.num_freq_patches
+ embeddings += self.freq_embed.repeat(1, num_time_patches, 1)
+ embeddings += torch.repeat_interleave(self.pos_embed_a[:, :num_time_patches], self.num_freq_patches, dim=1)
+ embeddings += self.type_embed_a
+
+ return embeddings, attention_masks
+
+
+class TvltPixelPatchEmbeddings(nn.Module):
+ """
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
+ Transformer.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ image_size, patch_size = config.image_size, config.image_patch_size
+ num_channels, hidden_size = config.num_image_channels, config.hidden_size
+
+ image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
+ num_patches_per_image = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.num_patches_per_image = num_patches_per_image
+ self.hidden_size = hidden_size
+
+ self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
+
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
+ batch_size, num_frames, num_channels, height, width = pixel_values.shape
+ if num_channels != self.num_channels:
+ raise ValueError(
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
+ )
+ if height != self.image_size[0] or width != self.image_size[1]:
+ raise ValueError(
+ f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
+ )
+
+ pixel_values = pixel_values.reshape(batch_size * num_frames, num_channels, height, width)
+ embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
+ embeddings = embeddings.reshape(batch_size, num_frames * self.num_patches_per_image, self.hidden_size)
+
+ return embeddings
+
+
+class TvltAudioPatchEmbeddings(nn.Module):
+ """
+ This class turns `audio_values` of shape `(batch_size, num_channels, height, width)` into the initial
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
+ Transformer.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ spectrogram_length, frequency_length, patch_size = (
+ config.spectrogram_length,
+ config.frequency_length,
+ config.audio_patch_size,
+ )
+ num_channels, hidden_size = config.num_audio_channels, config.hidden_size
+
+ spectrogram_size = (spectrogram_length, frequency_length)
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
+ num_patches = (spectrogram_size[1] // patch_size[1]) * (spectrogram_size[0] // patch_size[0])
+ patch_shape = (spectrogram_size[0] // patch_size[0], spectrogram_size[1] // patch_size[1])
+ self.spectrogram_size = spectrogram_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.num_patches = num_patches
+ self.patch_shape = patch_shape
+
+ self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
+
+ def forward(self, audio_values: torch.Tensor) -> torch.Tensor:
+ batch_size, num_channels, height, width = audio_values.shape
+ if num_channels != self.num_channels:
+ raise ValueError(
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
+ )
+ if height > self.spectrogram_size[0] or width != self.spectrogram_size[1]:
+ raise ValueError(
+ f"Input audio size ({height}*{width}) doesn't match model"
+ f" ({self.spectrogram_size[0]}*{self.spectrogram_size[1]})."
+ )
+ embeddings = self.projection(audio_values).flatten(2).transpose(1, 2)
+
+ return embeddings
+
+
+class TvltSelfAttention(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+ raise ValueError(
+ f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
+ f"heads {config.num_attention_heads}."
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+ self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+
+ def transpose_for_scores(self, x):
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
+ x = x.view(*new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False):
+ mixed_query_layer = self.query(hidden_states)
+
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+ if attention_mask is not None:
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
+ attention_scores = attention_scores + attention_mask
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs = attention_probs * head_mask
+
+ context_layer = torch.matmul(attention_probs, value_layer)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(*new_context_layer_shape)
+
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+ return outputs
+
+
+class TvltSelfOutput(nn.Module):
+ """
+ The residual connection is defined in TvltLayer instead of here (as is the case with other models), due to the
+ layernorm applied before each block.
+ """
+
+ def __init__(self, config: TvltConfig) -> None:
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+
+ return hidden_states
+
+
+class TvltAttention(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.attention = TvltSelfAttention(config)
+ self.output = TvltSelfOutput(config)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
+ )
+
+ # Prune linear layers
+ self.attention.query = prune_linear_layer(self.attention.query, index)
+ self.attention.key = prune_linear_layer(self.attention.key, index)
+ self.attention.value = prune_linear_layer(self.attention.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
+ self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False):
+ self_outputs = self.attention(hidden_states, attention_mask, head_mask, output_attentions)
+
+ attention_output = self.output(self_outputs[0], hidden_states)
+
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+ return outputs
+
+
+class TvltIntermediate(nn.Module):
+ def __init__(self, config: TvltConfig) -> None:
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+
+ return hidden_states
+
+
+class TvltOutput(nn.Module):
+ def __init__(self, config: TvltConfig) -> None:
+ super().__init__()
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+
+ hidden_states = hidden_states + input_tensor
+
+ return hidden_states
+
+
+class TvltLayer(nn.Module):
+ """This corresponds to the Block class in the timm implementation."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.seq_len_dim = 1
+ self.attention = TvltAttention(config)
+ self.intermediate = TvltIntermediate(config)
+ self.output = TvltOutput(config)
+ self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ def forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False):
+ self_attention_outputs = self.attention(
+ self.layernorm_before(hidden_states), # in ViLT, layernorm is applied before self-attention
+ attention_mask,
+ head_mask,
+ output_attentions=output_attentions,
+ )
+ attention_output = self_attention_outputs[0]
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
+
+ # first residual connection
+ hidden_states = attention_output + hidden_states.to(attention_output.device)
+
+ # in ViLT, layernorm is also applied after self-attention
+ layer_output = self.layernorm_after(hidden_states)
+ layer_output = self.intermediate(layer_output)
+
+ # second residual connection is done here
+ layer_output = self.output(layer_output, hidden_states)
+
+ outputs = (layer_output,) + outputs
+
+ return outputs
+
+
+class TvltEncoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.layer = nn.ModuleList([TvltLayer(config) for _ in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ output_attentions=False,
+ output_hidden_states=False,
+ return_dict=True,
+ ):
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+
+ for i, layer_module in enumerate(self.layer):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ layer_module.__call__,
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ output_attentions,
+ )
+ else:
+ layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions)
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+
+class TvltPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = TvltConfig
+ base_model_prefix = "tvlt"
+ main_input_name = "pixel_values"
+ supports_gradient_checkpointing = True
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+
+TVLT_START_DOCSTRING = r"""
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+ behavior.
+
+ Parameters:
+ config ([`TvltConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+TVLT_INPUTS_DOCSTRING = r"""
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
+ Pixel values. Pixel values can be obtained using [`TvltProcessor`]. See [`TvltProcessor.__call__`] for
+ details.
+
+ audio_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Audio values. Audio values can be obtained using [`TvltProcessor`]. See [`TvltProcessor.__call__`] for
+ details.
+
+ pixel_mask (`torch.FloatTensor` of shape `(batch_size, num_pixel_patches)`):
+ Pixel masks. Pixel masks can be obtained using [`TvltProcessor`]. See [`TvltProcessor.__call__`] for
+ details.
+
+ audio_mask (`torch.FloatTensor` of shape `(batch_size, num_audio_patches)`):
+ Audio masks. Audio masks can be obtained using [`TvltProcessor`]. See [`TvltProcessor.__call__`] for
+ details.
+
+ pixel_values_mixed (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
+ Pixel values that mix positive and negative samples in Tvlt vision-audio matching. Pixel values mixed can
+ be obtained using [`TvltProcessor`]. See [`TvltProcessor.__call__`] for details.
+
+ pixel_mask_mixed (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel masks of pixel_values_mixed. Pixel masks mixed can be obtained using [`TvltProcessor`]. See
+ [`TvltProcessor.__call__`] for details.
+
+ mask_pixel (`bool`, *optional*):
+ Whether to mask pixel for MAE tasks. Only set to True in TvltForPreTraining.
+
+ mask_audio (`bool`, *optional*):
+ Whether to mask audio for MAE tasks. Only set to True in TvltForPreTraining.
+
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare TVLT Model transformer outputting raw hidden-states without any specific head on top.",
+ TVLT_START_DOCSTRING,
+)
+class TvltModel(TvltPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.config = config
+
+ self.pixel_embeddings = TvltPixelEmbeddings(config)
+ self.audio_embeddings = TvltAudioEmbeddings(config)
+ self.encoder = TvltEncoder(config)
+
+ self.cls_embedding = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
+
+ if config.use_mean_pooling:
+ self.layernorm = None
+ else:
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.pixel_embeddings.patch_embeddings, self.audio_embeddings.patch_embeddings
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ @add_start_docstrings_to_model_forward(TVLT_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=TvltModelOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor,
+ audio_values: torch.FloatTensor,
+ pixel_mask: Optional[torch.FloatTensor] = None,
+ audio_mask: Optional[torch.FloatTensor] = None,
+ mask_pixel: bool = False,
+ mask_audio: bool = False,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.FloatTensor], TvltModelOutput]:
+ r"""
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import TvltProcessor, TvltModel
+ >>> import numpy as np
+ >>> import torch
+
+ >>> num_frames = 8
+ >>> images = list(np.random.randn(num_frames, 3, 224, 224))
+ >>> audio = list(np.random.randn(10000))
+
+ >>> processor = TvltProcessor.from_pretrained("ZinengTang/tvlt-base")
+ >>> model = TvltModel.from_pretrained("ZinengTang/tvlt-base")
+
+ >>> input_dict = processor(images, audio, sampling_rate=44100, return_tensors="pt")
+
+ >>> outputs = model(**input_dict)
+ >>> loss = outputs.loss
+ ```"""
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ pixel_embedding_output, pixel_mask = self.pixel_embeddings(pixel_values, pixel_mask)
+
+ audio_embedding_output, audio_mask = self.audio_embeddings(audio_values, audio_mask)
+
+ # Mask pixel if mask_pixel is True
+ pixel_label_masks = None
+ pixel_ids_restore = None
+ if mask_pixel:
+ pixel_mask_noise, pixel_len_keep = generate_pixel_mask_noise(
+ pixel_embedding_output, pixel_mask=pixel_mask, mask_ratio=self.config.pixel_mask_ratio
+ )
+ pixel_embedding_output, pixel_mask, pixel_label_masks, pixel_ids_restore = random_masking(
+ pixel_embedding_output,
+ pixel_mask_noise,
+ pixel_len_keep,
+ attention_masks=pixel_mask,
+ )
+
+ # Mask audio if mask_audio is True
+ audio_label_masks = None
+ audio_ids_restore = None
+ if mask_audio:
+ num_freq_patches = self.config.frequency_length // self.config.audio_patch_size[1]
+ audio_mask_noise, audio_len_keep = generate_audio_mask_noise(
+ audio_embedding_output,
+ audio_mask=audio_mask,
+ mask_ratio=self.config.audio_mask_ratio,
+ mask_type=self.config.audio_mask_type,
+ freq_len=num_freq_patches,
+ )
+ audio_embedding_output, audio_mask, audio_label_masks, audio_ids_restore = random_masking(
+ audio_embedding_output,
+ audio_mask_noise,
+ audio_len_keep,
+ attention_masks=audio_mask,
+ )
+
+ # Prepare for encoder inputs and attention masks
+ batch_size = pixel_values.size(0)
+ embedding_output = torch.cat(
+ [self.cls_embedding.repeat(batch_size, 1, 1), pixel_embedding_output, audio_embedding_output], 1
+ )
+ masked_pixel_len = pixel_embedding_output.size(1)
+
+ attention_mask = None
+ if pixel_mask is not None and audio_mask is not None:
+ attention_mask = torch.cat([pixel_mask[:, :1], pixel_mask, audio_mask], 1)
+
+ input_shape = embedding_output.size()
+ extended_attention_mask = None
+ if attention_mask is not None:
+ extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ attention_mask=extended_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = encoder_outputs[0]
+ if self.layernorm is not None:
+ sequence_output = self.layernorm(sequence_output)
+
+ pixel_sequence_output = sequence_output[:, 1 : 1 + masked_pixel_len]
+ audio_sequence_output = sequence_output[:, 1 + masked_pixel_len :]
+ if not return_dict:
+ return (
+ sequence_output,
+ pixel_sequence_output,
+ audio_sequence_output,
+ pixel_label_masks,
+ audio_label_masks,
+ pixel_ids_restore,
+ audio_ids_restore,
+ ) + encoder_outputs[1:]
+
+ return TvltModelOutput(
+ last_hidden_state=sequence_output,
+ last_pixel_hidden_state=pixel_sequence_output,
+ last_audio_hidden_state=audio_sequence_output,
+ pixel_label_masks=pixel_label_masks,
+ audio_label_masks=audio_label_masks,
+ pixel_ids_restore=pixel_ids_restore,
+ audio_ids_restore=audio_ids_restore,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+class TvltDecoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+
+ decoder_config = deepcopy(config)
+ decoder_config.hidden_size = config.decoder_hidden_size
+ decoder_config.num_hidden_layers = config.decoder_num_hidden_layers
+ decoder_config.num_attention_heads = config.decoder_num_attention_heads
+ decoder_config.intermediate_size = config.decoder_intermediate_size
+ self.decoder_layers = nn.ModuleList(
+ [TvltLayer(decoder_config) for _ in range(config.decoder_num_hidden_layers)]
+ )
+
+ self.layernorm = nn.LayerNorm(config.decoder_hidden_size, eps=config.layer_norm_eps)
+
+ self.gradient_checkpointing = False
+ self.config = config
+
+ def forward(
+ self,
+ hidden_states,
+ output_attentions=False,
+ output_hidden_states=False,
+ return_dict=True,
+ ):
+ # apply Transformer layers (blocks)
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+ for i, layer_module in enumerate(self.decoder_layers):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ layer_module.__call__,
+ hidden_states,
+ None,
+ output_attentions,
+ )
+ else:
+ layer_outputs = layer_module(hidden_states, output_attentions=output_attentions)
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ # predictor projection
+ logits = self.layernorm(hidden_states)
+
+ if not return_dict:
+ return tuple(v for v in [logits, all_hidden_states, all_self_attentions] if v is not None)
+ return TvltDecoderOutput(logits=logits, hidden_states=all_hidden_states, attentions=all_self_attentions)
+
+
+@add_start_docstrings(
+ "The TVLT Model transformer with the decoder on top for self-supervised pre-training.",
+ TVLT_START_DOCSTRING,
+)
+class TvltForPreTraining(TvltPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.config = config
+
+ self.task_matching = config.task_matching
+ self.task_mae = config.task_mae
+ if not (self.task_matching or self.task_mae):
+ raise ValueError("Must set at least one of matching task and MAE task to true")
+
+ self.tvlt = TvltModel(config)
+
+ if self.task_matching:
+ self.matching_head = TvltMatchingHead(config)
+
+ if self.task_mae:
+ self.encoder_to_decoder = nn.Linear(config.hidden_size, config.decoder_hidden_size, bias=True)
+
+ self.pixel_mask_token = nn.Parameter(torch.zeros(1, 1, config.decoder_hidden_size))
+ self.audio_mask_token = nn.Parameter(torch.zeros(1, 1, config.decoder_hidden_size))
+
+ self.decoder = TvltDecoder(config)
+
+ decoder_hidden_size = config.decoder_hidden_size
+
+ num_frames = config.num_frames
+ num_patches_per_image = self.tvlt.pixel_embeddings.num_patches_per_image
+ self.decoder_pixel_pos_embed = nn.Parameter(torch.zeros(1, num_patches_per_image, decoder_hidden_size))
+ self.decoder_temporal_embed = nn.Parameter(torch.zeros(1, config.num_frames, decoder_hidden_size))
+ self.decoder_pixel_type_embed = nn.Parameter(torch.zeros(1, 1, decoder_hidden_size))
+
+ num_audio_patches = self.tvlt.audio_embeddings.num_patches
+ num_freq_patches = config.frequency_length // config.audio_patch_size[1]
+ self.decoder_audio_pos_embed = nn.Parameter(
+ torch.zeros(1, num_audio_patches // num_freq_patches, decoder_hidden_size)
+ )
+ self.decoder_freq_embed = nn.Parameter(torch.zeros(1, num_freq_patches, decoder_hidden_size))
+ self.decoder_audio_type_embed = nn.Parameter(torch.zeros(1, 1, decoder_hidden_size))
+
+ pixel_mae_output_dim = self.config.image_patch_size[0] ** 2 * self.config.num_image_channels
+ self.pixel_mae_head = TvltMAEHead(config, pixel_mae_output_dim)
+ audio_mae_output_dim = (
+ self.config.audio_patch_size[0] * self.config.audio_patch_size[1] * self.config.num_audio_channels
+ )
+ self.audio_mae_head = TvltMAEHead(config, audio_mae_output_dim)
+
+ self.num_frames = num_frames
+ self.num_patches_per_image = num_patches_per_image
+ self.num_freq_patches = num_freq_patches
+ self.image_patch_size = config.image_patch_size
+ self.audio_patch_size = config.audio_patch_size
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def patchify_pixel(self, pixel_values):
+ """
+ pixel_values: [batch_size, num_frames, 3, height, width]
+ """
+ batch_size, num_frames, num_channels, height, width = pixel_values.shape
+ num_patches_height = pixel_values.shape[3] // self.image_patch_size[0]
+ num_patches_width = pixel_values.shape[4] // self.image_patch_size[1]
+ patchified_pixel_values = pixel_values.reshape(
+ shape=(
+ batch_size,
+ num_frames,
+ num_channels,
+ num_patches_height,
+ self.image_patch_size[0],
+ num_patches_width,
+ self.image_patch_size[1],
+ )
+ )
+ patchified_pixel_values = torch.einsum("ntchpwq->nthwpqc", patchified_pixel_values)
+ patchified_pixel_values = patchified_pixel_values.reshape(
+ shape=(
+ batch_size,
+ num_patches_height * num_patches_width * num_frames,
+ self.image_patch_size[0] * self.image_patch_size[1] * num_channels,
+ )
+ )
+ return patchified_pixel_values
+
+ def patchify_audio(self, audio_values):
+ """
+ audio_values: [batch_size, 1, height, width]
+ """
+ batch_size, num_channels, height, width = audio_values.shape
+ num_patches_height = height // self.audio_patch_size[0]
+ num_patches_width = width // self.audio_patch_size[1]
+ patchified_audio_values = audio_values.reshape(
+ shape=(
+ batch_size,
+ num_channels,
+ num_patches_height,
+ self.audio_patch_size[0],
+ num_patches_width,
+ self.audio_patch_size[1],
+ )
+ )
+ patchified_audio_values = torch.einsum("nchpwq->nhwpqc", patchified_audio_values)
+ patchified_audio_values = patchified_audio_values.reshape(
+ shape=(
+ batch_size,
+ num_patches_height * num_patches_width,
+ self.audio_patch_size[0] * self.audio_patch_size[1] * num_channels,
+ )
+ )
+ return patchified_audio_values
+
+ def pixel_mae_loss(self, pixel_values, pixel_predictions, mask):
+ patchified_pixel_values = self.patchify_pixel(pixel_values)
+ loss = (pixel_predictions - patchified_pixel_values) ** 2
+ loss = loss.mean(dim=-1) # [batch_size, pixel_pixel_length], mean loss per patch
+ loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
+ return loss
+
+ def audio_mae_loss(self, audio_values, audio_predictions, mask):
+ patchified_audio_values = self.patchify_audio(audio_values)
+ loss = (audio_predictions - patchified_audio_values) ** 2
+ loss = loss.mean(dim=-1) # [batch_size, audio_pixel_length], mean loss per patch
+ loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
+ return loss
+
+ def concatenate_mask(self, mask_token, sequence, ids_restore):
+ batch_size, seq_length, dim = sequence.shape
+ mask_tokens = mask_token.repeat(batch_size, ids_restore.shape[1] - seq_length, 1)
+ padded_sequence = torch.cat([sequence, mask_tokens], dim=1)
+ padded_sequence = torch.gather(
+ padded_sequence, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, dim)
+ ) # unshuffle
+ return padded_sequence
+
+ @add_start_docstrings_to_model_forward(TVLT_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=TvltForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor,
+ audio_values: torch.FloatTensor,
+ pixel_mask: Optional[torch.FloatTensor] = None,
+ audio_mask: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ pixel_values_mixed: Optional[torch.FloatTensor] = None,
+ pixel_mask_mixed: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.FloatTensor], TvltForPreTrainingOutput]:
+ r"""
+ pixel_values_mixed (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
+ Pixel values that mix positive and negative samples in Tvlt vision-audio matching. Audio values can be
+ obtained using [`TvltProcessor`]. See [`TvltProcessor.__call__`] for details.
+
+ pixel_mask_mixed (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel masks of pixel_values_mixed. Pixel values mixed can be obtained using [`TvltProcessor`]. See
+ [`TvltProcessor.__call__`] for details.
+
+ labels (`torch.LongTensor` of shape `(batch_size, num_labels)`, *optional*):
+ Labels for computing the vision audio matching loss. Indices should be in `[0, 1]`. num_labels has to be 1.
+
+ Return:
+
+ Examples:
+
+ ```python
+ >>> from transformers import TvltProcessor, TvltForPreTraining
+ >>> import numpy as np
+ >>> import torch
+
+ >>> num_frames = 8
+ >>> images = list(np.random.randn(num_frames, 3, 224, 224))
+ >>> images_mixed = list(np.random.randn(num_frames, 3, 224, 224))
+ >>> audio = list(np.random.randn(10000))
+ >>> processor = TvltProcessor.from_pretrained("ZinengTang/tvlt-base")
+ >>> model = TvltForPreTraining.from_pretrained("ZinengTang/tvlt-base")
+ >>> input_dict = processor(
+ ... images, audio, images_mixed, sampling_rate=44100, mask_pixel=True, mask_audio=True, return_tensors="pt"
+ ... )
+
+ >>> outputs = model(**input_dict)
+ >>> loss = outputs.loss
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ total_loss = 0.0
+
+ if self.task_matching:
+ if labels is None:
+ raise ValueError("Matching task requires labels")
+ if pixel_values_mixed is None:
+ raise ValueError("Matching task requires pixel_values_mixed")
+
+ outputs = self.tvlt(
+ pixel_values_mixed,
+ audio_values,
+ pixel_mask=pixel_mask_mixed,
+ audio_mask=audio_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+ matching_logits = self.matching_head(sequence_output)
+
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(matching_logits.view(-1), labels.view(-1))
+ total_loss += loss
+
+ pixel_logits = None
+ audio_logits = None
+ if self.task_mae and self.training:
+ outputs = self.tvlt(
+ pixel_values,
+ audio_values,
+ pixel_mask=pixel_mask,
+ audio_mask=audio_mask,
+ mask_pixel=True,
+ mask_audio=True,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ pixel_sequence_output = outputs.last_pixel_hidden_state if return_dict else outputs[1]
+ audio_sequence_output = outputs.last_audio_hidden_state if return_dict else outputs[2]
+ pixel_label_masks = outputs.pixel_label_masks if return_dict else outputs[3]
+ audio_label_masks = outputs.audio_label_masks if return_dict else outputs[4]
+ pixel_ids_restore = outputs.pixel_ids_restore if return_dict else outputs[5]
+ audio_ids_restore = outputs.audio_ids_restore if return_dict else outputs[6]
+
+ pixel_decoder_input = self.encoder_to_decoder(
+ pixel_sequence_output
+ ) # [batch_size, num_masked_pixel_patches, decoder_hidden_size]
+ audio_decoder_input = self.encoder_to_decoder(
+ audio_sequence_output
+ ) # [batch_size, num_masked_audio_patches, decoder_hidden_size]
+ num_frames = pixel_values.size(1)
+ pixel_decoder_input = self.concatenate_mask(self.pixel_mask_token, pixel_decoder_input, pixel_ids_restore)
+ pixel_decoder_input = pixel_decoder_input + self.decoder_pixel_pos_embed.repeat(1, num_frames, 1)
+ pixel_decoder_input = pixel_decoder_input + torch.repeat_interleave(
+ self.decoder_temporal_embed[:, :num_frames], self.num_patches_per_image, dim=1
+ )
+ pixel_decoder_input = pixel_decoder_input + self.decoder_pixel_type_embed
+ pixel_decoder_outputs = self.decoder(pixel_decoder_input)
+ pixel_logits = self.pixel_mae_head(pixel_decoder_outputs.logits)
+
+ audio_decoder_input = self.concatenate_mask(self.audio_mask_token, audio_decoder_input, audio_ids_restore)
+ num_time_patches = audio_decoder_input.size(1) // self.num_freq_patches
+ audio_decoder_input = audio_decoder_input + self.decoder_freq_embed.repeat(1, num_time_patches, 1)
+ audio_decoder_input = audio_decoder_input + torch.repeat_interleave(
+ self.decoder_audio_pos_embed[:, :num_time_patches], self.num_freq_patches, dim=1
+ )
+ audio_decoder_input = audio_decoder_input + self.decoder_audio_type_embed
+ audio_decoder_outputs = self.decoder(audio_decoder_input)
+ audio_logits = self.audio_mae_head(audio_decoder_outputs.logits)
+
+ loss = self.pixel_mae_loss(pixel_values, pixel_logits, pixel_label_masks) + self.audio_mae_loss(
+ audio_values, audio_logits, audio_label_masks
+ )
+ total_loss += loss
+
+ if not return_dict:
+ output = (matching_logits, pixel_logits, audio_logits) + outputs[7:]
+ return ((total_loss,) + output) if loss is not None else output
+
+ return TvltForPreTrainingOutput(
+ loss=total_loss,
+ matching_logits=matching_logits,
+ pixel_logits=pixel_logits,
+ audio_logits=audio_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+class TvltPooler(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.activation = nn.Tanh()
+
+ def forward(self, hidden_states):
+ first_token_tensor = hidden_states[:, 0]
+ pooled_output = self.dense(first_token_tensor)
+ pooled_output = self.activation(pooled_output)
+ return pooled_output
+
+
+class TvltMatchingHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.pooler = TvltPooler(config)
+ self.fc = nn.Linear(config.hidden_size, 1)
+
+ def forward(self, hidden_states):
+ hidden_states = self.fc(self.pooler(hidden_states))
+ return hidden_states
+
+
+class TvltMAEHead(nn.Module):
+ def __init__(self, config, output_dim=None):
+ super().__init__()
+ self.config = config
+ self.decoder = nn.Linear(config.decoder_hidden_size, output_dim)
+
+ def forward(self, hidden_states):
+ hidden_states = self.decoder(hidden_states)
+ return hidden_states
+
+
+@add_start_docstrings(
+ """
+ Tvlt Model transformer with a classifier head on top (an MLP on top of the final hidden state of the [CLS] token)
+ for audiovisual classification tasks, e.g. CMU-MOSEI Sentiment Analysis and Audio to Video Retrieval.
+ """,
+ TVLT_START_DOCSTRING,
+)
+class TvltForAudioVisualClassification(TvltPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.tvlt = TvltModel(config)
+
+ # Classifier head
+ self.classifier = nn.Sequential(
+ nn.Linear(config.hidden_size, config.hidden_size * 2),
+ nn.LayerNorm(config.hidden_size * 2, eps=config.layer_norm_eps),
+ nn.GELU(),
+ nn.Linear(config.hidden_size * 2, config.num_labels),
+ )
+ self.config = config
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(TVLT_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor,
+ audio_values: torch.FloatTensor,
+ pixel_mask: Optional[torch.FloatTensor] = None,
+ audio_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ labels: Optional[torch.LongTensor] = None,
+ ) -> Union[Tuple[torch.FloatTensor], SequenceClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, num_labels)`, *optional*):
+ Labels for computing the audiovisual loss. Indices should be in `[0, ..., num_classes-1]` where num_classes
+ refers to the number of classes in audiovisual tasks.
+
+ Return:
+
+ Examples:
+ ```python
+ >>> from transformers import TvltProcessor, TvltForAudioVisualClassification
+ >>> import numpy as np
+ >>> import torch
+
+ >>> num_frames = 8
+ >>> images = list(np.random.randn(num_frames, 3, 224, 224))
+ >>> audio = list(np.random.randn(10000))
+ >>> processor = TvltProcessor.from_pretrained("ZinengTang/tvlt-base")
+ >>> model = TvltForAudioVisualClassification.from_pretrained("ZinengTang/tvlt-base")
+ >>> input_dict = processor(images, audio, sampling_rate=44100, return_tensors="pt")
+
+ >>> outputs = model(**input_dict)
+ >>> loss = outputs.loss
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.tvlt(
+ pixel_values,
+ audio_values,
+ pixel_mask=pixel_mask,
+ audio_mask=audio_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = outputs[0][:, 0]
+ logits = self.classifier(sequence_output) # rank value
+
+ loss = None
+ if labels is not None:
+ if self.config.loss_type == "regression":
+ loss_fct = MSELoss()
+ loss = loss_fct(logits, labels)
+ elif self.config.loss_type == "classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits, labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[4:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+__all__ = ["TvltModel", "TvltForPreTraining", "TvltForAudioVisualClassification", "TvltPreTrainedModel"]
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/tvlt/processing_tvlt.py b/docs/transformers/build/lib/transformers/models/deprecated/tvlt/processing_tvlt.py
new file mode 100644
index 0000000000000000000000000000000000000000..d9f8e0978d8a2c4158094ebdf3bc2279700e7ba5
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/tvlt/processing_tvlt.py
@@ -0,0 +1,92 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Processor class for TVLT.
+"""
+
+from ....processing_utils import ProcessorMixin
+
+
+class TvltProcessor(ProcessorMixin):
+ r"""
+ Constructs a TVLT processor which wraps a TVLT image processor and TVLT feature extractor into a single processor.
+
+ [`TvltProcessor`] offers all the functionalities of [`TvltImageProcessor`] and [`TvltFeatureExtractor`]. See the
+ docstring of [`~TvltProcessor.__call__`] for more information.
+
+ Args:
+ image_processor (`TvltImageProcessor`):
+ An instance of [`TvltImageProcessor`]. The image processor is a required input.
+ feature_extractor (`TvltFeatureExtractor`):
+ An instance of [`TvltFeatureExtractor`]. The feature extractor is a required input.
+ """
+
+ attributes = ["image_processor", "feature_extractor"]
+ image_processor_class = "TvltImageProcessor"
+ feature_extractor_class = "TvltFeatureExtractor"
+
+ def __init__(self, image_processor, feature_extractor):
+ super().__init__(image_processor=image_processor, feature_extractor=feature_extractor)
+
+ self.image_processor = image_processor
+ self.feature_extractor = feature_extractor
+
+ def __call__(
+ self,
+ images=None,
+ audio=None,
+ images_mixed=None,
+ sampling_rate=None,
+ mask_audio=False,
+ mask_pixel=False,
+ *args,
+ **kwargs,
+ ):
+ """
+ Forwards the `images` argument to TvltImageProcessor's [`~TvltImageProcessor.preprocess`] and the `audio`
+ argument to TvltFeatureExtractor's [`~TvltFeatureExtractor.__call__`]. Please refer to the docstring of the
+ above two methods for more information.
+ """
+
+ if images is None and audio is None:
+ raise ValueError("You need to specify either an `images` or `audio` input to process.")
+
+ images_mixed_dict = None
+ if images is not None:
+ images_dict = self.image_processor(images, mask_pixel=mask_pixel, *args, **kwargs)
+ if images_mixed is not None:
+ images_mixed_dict = self.image_processor(images_mixed, is_mixed=True, *args, **kwargs)
+ if audio is not None:
+ audio_dict = self.feature_extractor(
+ audio, *args, sampling_rate=sampling_rate, mask_audio=mask_audio, **kwargs
+ )
+
+ output_dict = {}
+ if audio is not None:
+ output_dict.update(audio_dict)
+ if images is not None:
+ output_dict.update(images_dict)
+ if images_mixed_dict is not None:
+ output_dict.update(images_mixed_dict)
+ return output_dict
+
+ @property
+ def model_input_names(self):
+ image_processor_input_names = self.image_processor.model_input_names
+ feature_extractor_input_names = self.feature_extractor.model_input_names
+ return list(dict.fromkeys(image_processor_input_names + feature_extractor_input_names))
+
+
+__all__ = ["TvltProcessor"]
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/van/__init__.py b/docs/transformers/build/lib/transformers/models/deprecated/van/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9552c827365d3913539be3eafaf822146ada5829
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/van/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ....utils import _LazyModule
+from ....utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_van import *
+ from .modeling_van import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/van/configuration_van.py b/docs/transformers/build/lib/transformers/models/deprecated/van/configuration_van.py
new file mode 100644
index 0000000000000000000000000000000000000000..08f1db7a4b48e7fb32b986cdc557b7aa3d328ed0
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/van/configuration_van.py
@@ -0,0 +1,110 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""VAN model configuration"""
+
+from ....configuration_utils import PretrainedConfig
+from ....utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class VanConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`VanModel`]. It is used to instantiate a VAN model
+ according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the VAN
+ [Visual-Attention-Network/van-base](https://huggingface.co/Visual-Attention-Network/van-base) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ image_size (`int`, *optional*, defaults to 224):
+ The size (resolution) of each image.
+ num_channels (`int`, *optional*, defaults to 3):
+ The number of input channels.
+ patch_sizes (`List[int]`, *optional*, defaults to `[7, 3, 3, 3]`):
+ Patch size to use in each stage's embedding layer.
+ strides (`List[int]`, *optional*, defaults to `[4, 2, 2, 2]`):
+ Stride size to use in each stage's embedding layer to downsample the input.
+ hidden_sizes (`List[int]`, *optional*, defaults to `[64, 128, 320, 512]`):
+ Dimensionality (hidden size) at each stage.
+ depths (`List[int]`, *optional*, defaults to `[3, 3, 12, 3]`):
+ Depth (number of layers) for each stage.
+ mlp_ratios (`List[int]`, *optional*, defaults to `[8, 8, 4, 4]`):
+ The expansion ratio for mlp layer at each stage.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in each layer. If string, `"gelu"`, `"relu"`,
+ `"selu"` and `"gelu_new"` are supported.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the layer normalization layers.
+ layer_scale_init_value (`float`, *optional*, defaults to 0.01):
+ The initial value for layer scaling.
+ drop_path_rate (`float`, *optional*, defaults to 0.0):
+ The dropout probability for stochastic depth.
+ dropout_rate (`float`, *optional*, defaults to 0.0):
+ The dropout probability for dropout.
+
+ Example:
+ ```python
+ >>> from transformers import VanModel, VanConfig
+
+ >>> # Initializing a VAN van-base style configuration
+ >>> configuration = VanConfig()
+ >>> # Initializing a model from the van-base style configuration
+ >>> model = VanModel(configuration)
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "van"
+
+ def __init__(
+ self,
+ image_size=224,
+ num_channels=3,
+ patch_sizes=[7, 3, 3, 3],
+ strides=[4, 2, 2, 2],
+ hidden_sizes=[64, 128, 320, 512],
+ depths=[3, 3, 12, 3],
+ mlp_ratios=[8, 8, 4, 4],
+ hidden_act="gelu",
+ initializer_range=0.02,
+ layer_norm_eps=1e-6,
+ layer_scale_init_value=1e-2,
+ drop_path_rate=0.0,
+ dropout_rate=0.0,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.image_size = image_size
+ self.num_channels = num_channels
+ self.patch_sizes = patch_sizes
+ self.strides = strides
+ self.hidden_sizes = hidden_sizes
+ self.depths = depths
+ self.mlp_ratios = mlp_ratios
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.layer_scale_init_value = layer_scale_init_value
+ self.drop_path_rate = drop_path_rate
+ self.dropout_rate = dropout_rate
+
+
+__all__ = ["VanConfig"]
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/van/convert_van_to_pytorch.py b/docs/transformers/build/lib/transformers/models/deprecated/van/convert_van_to_pytorch.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd87217f051a04a6a339e93e6d77cfc85c701832
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/van/convert_van_to_pytorch.py
@@ -0,0 +1,290 @@
+# coding=utf-8
+# Copyright 2022 BNRist (Tsinghua University), TKLNDST (Nankai University) and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Convert VAN checkpoints from the original repository.
+
+URL: https://github.com/Visual-Attention-Network/VAN-Classification"""
+
+import argparse
+import json
+import sys
+from dataclasses import dataclass, field
+from functools import partial
+from pathlib import Path
+from typing import List, Optional
+
+import torch
+import torch.nn as nn
+from huggingface_hub import cached_download, hf_hub_download
+from torch import Tensor
+
+from transformers import AutoImageProcessor, VanConfig, VanForImageClassification
+from transformers.models.deprecated.van.modeling_van import VanLayerScaling
+from transformers.utils import logging
+
+
+logging.set_verbosity_info()
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+class Tracker:
+ module: nn.Module
+ traced: List[nn.Module] = field(default_factory=list)
+ handles: list = field(default_factory=list)
+
+ def _forward_hook(self, m, inputs: Tensor, outputs: Tensor):
+ has_not_submodules = len(list(m.modules())) == 1 or isinstance(m, nn.Conv2d) or isinstance(m, nn.BatchNorm2d)
+ if has_not_submodules:
+ if not isinstance(m, VanLayerScaling):
+ self.traced.append(m)
+
+ def __call__(self, x: Tensor):
+ for m in self.module.modules():
+ self.handles.append(m.register_forward_hook(self._forward_hook))
+ self.module(x)
+ [x.remove() for x in self.handles]
+ return self
+
+ @property
+ def parametrized(self):
+ # check the len of the state_dict keys to see if we have learnable params
+ return list(filter(lambda x: len(list(x.state_dict().keys())) > 0, self.traced))
+
+
+@dataclass
+class ModuleTransfer:
+ src: nn.Module
+ dest: nn.Module
+ verbose: int = 0
+ src_skip: List = field(default_factory=list)
+ dest_skip: List = field(default_factory=list)
+
+ def __call__(self, x: Tensor):
+ """
+ Transfer the weights of `self.src` to `self.dest` by performing a forward pass using `x` as input. Under the
+ hood we tracked all the operations in both modules.
+ """
+ dest_traced = Tracker(self.dest)(x).parametrized
+ src_traced = Tracker(self.src)(x).parametrized
+
+ src_traced = list(filter(lambda x: type(x) not in self.src_skip, src_traced))
+ dest_traced = list(filter(lambda x: type(x) not in self.dest_skip, dest_traced))
+
+ if len(dest_traced) != len(src_traced):
+ raise Exception(
+ f"Numbers of operations are different. Source module has {len(src_traced)} operations while"
+ f" destination module has {len(dest_traced)}."
+ )
+
+ for dest_m, src_m in zip(dest_traced, src_traced):
+ dest_m.load_state_dict(src_m.state_dict())
+ if self.verbose == 1:
+ print(f"Transfered from={src_m} to={dest_m}")
+
+
+def copy_parameters(from_model: nn.Module, our_model: nn.Module) -> nn.Module:
+ # nn.Parameter cannot be tracked by the Tracker, thus we need to manually convert them
+ from_state_dict = from_model.state_dict()
+ our_state_dict = our_model.state_dict()
+ config = our_model.config
+ all_keys = []
+ for stage_idx in range(len(config.hidden_sizes)):
+ for block_id in range(config.depths[stage_idx]):
+ from_key = f"block{stage_idx + 1}.{block_id}.layer_scale_1"
+ to_key = f"van.encoder.stages.{stage_idx}.layers.{block_id}.attention_scaling.weight"
+
+ all_keys.append((from_key, to_key))
+ from_key = f"block{stage_idx + 1}.{block_id}.layer_scale_2"
+ to_key = f"van.encoder.stages.{stage_idx}.layers.{block_id}.mlp_scaling.weight"
+
+ all_keys.append((from_key, to_key))
+
+ for from_key, to_key in all_keys:
+ our_state_dict[to_key] = from_state_dict.pop(from_key)
+
+ our_model.load_state_dict(our_state_dict)
+ return our_model
+
+
+def convert_weight_and_push(
+ name: str,
+ config: VanConfig,
+ checkpoint: str,
+ from_model: nn.Module,
+ save_directory: Path,
+ push_to_hub: bool = True,
+):
+ print(f"Downloading weights for {name}...")
+ checkpoint_path = cached_download(checkpoint)
+ print(f"Converting {name}...")
+ from_state_dict = torch.load(checkpoint_path, weights_only=True)["state_dict"]
+ from_model.load_state_dict(from_state_dict)
+ from_model.eval()
+ with torch.no_grad():
+ our_model = VanForImageClassification(config).eval()
+ module_transfer = ModuleTransfer(src=from_model, dest=our_model)
+ x = torch.randn((1, 3, 224, 224))
+ module_transfer(x)
+ our_model = copy_parameters(from_model, our_model)
+
+ if not torch.allclose(from_model(x), our_model(x).logits):
+ raise ValueError("The model logits don't match the original one.")
+
+ checkpoint_name = name
+ print(checkpoint_name)
+
+ if push_to_hub:
+ our_model.push_to_hub(
+ repo_path_or_name=save_directory / checkpoint_name,
+ commit_message="Add model",
+ use_temp_dir=True,
+ )
+
+ # we can use the convnext one
+ image_processor = AutoImageProcessor.from_pretrained("facebook/convnext-base-224-22k-1k")
+ image_processor.push_to_hub(
+ repo_path_or_name=save_directory / checkpoint_name,
+ commit_message="Add image processor",
+ use_temp_dir=True,
+ )
+
+ print(f"Pushed {checkpoint_name}")
+
+
+def convert_weights_and_push(save_directory: Path, model_name: Optional[str] = None, push_to_hub: bool = True):
+ filename = "imagenet-1k-id2label.json"
+ num_labels = 1000
+
+ repo_id = "huggingface/label-files"
+ num_labels = num_labels
+ id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
+ id2label = {int(k): v for k, v in id2label.items()}
+
+ id2label = id2label
+ label2id = {v: k for k, v in id2label.items()}
+
+ ImageNetPreTrainedConfig = partial(VanConfig, num_labels=num_labels, id2label=id2label, label2id=label2id)
+
+ names_to_config = {
+ "van-tiny": ImageNetPreTrainedConfig(
+ hidden_sizes=[32, 64, 160, 256],
+ depths=[3, 3, 5, 2],
+ mlp_ratios=[8, 8, 4, 4],
+ ),
+ "van-small": ImageNetPreTrainedConfig(
+ hidden_sizes=[64, 128, 320, 512],
+ depths=[2, 2, 4, 2],
+ mlp_ratios=[8, 8, 4, 4],
+ ),
+ "van-base": ImageNetPreTrainedConfig(
+ hidden_sizes=[64, 128, 320, 512],
+ depths=[3, 3, 12, 3],
+ mlp_ratios=[8, 8, 4, 4],
+ ),
+ "van-large": ImageNetPreTrainedConfig(
+ hidden_sizes=[64, 128, 320, 512],
+ depths=[3, 5, 27, 3],
+ mlp_ratios=[8, 8, 4, 4],
+ ),
+ }
+
+ names_to_original_models = {
+ "van-tiny": van_tiny,
+ "van-small": van_small,
+ "van-base": van_base,
+ "van-large": van_large,
+ }
+
+ names_to_original_checkpoints = {
+ "van-tiny": (
+ "https://huggingface.co/Visual-Attention-Network/VAN-Tiny-original/resolve/main/van_tiny_754.pth.tar"
+ ),
+ "van-small": (
+ "https://huggingface.co/Visual-Attention-Network/VAN-Small-original/resolve/main/van_small_811.pth.tar"
+ ),
+ "van-base": (
+ "https://huggingface.co/Visual-Attention-Network/VAN-Base-original/resolve/main/van_base_828.pth.tar"
+ ),
+ "van-large": (
+ "https://huggingface.co/Visual-Attention-Network/VAN-Large-original/resolve/main/van_large_839.pth.tar"
+ ),
+ }
+
+ if model_name:
+ convert_weight_and_push(
+ model_name,
+ names_to_config[model_name],
+ checkpoint=names_to_original_checkpoints[model_name],
+ from_model=names_to_original_models[model_name](),
+ save_directory=save_directory,
+ push_to_hub=push_to_hub,
+ )
+ else:
+ for model_name, config in names_to_config.items():
+ convert_weight_and_push(
+ model_name,
+ config,
+ checkpoint=names_to_original_checkpoints[model_name],
+ from_model=names_to_original_models[model_name](),
+ save_directory=save_directory,
+ push_to_hub=push_to_hub,
+ )
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ # Required parameters
+ parser.add_argument(
+ "--model-name",
+ default=None,
+ type=str,
+ help=(
+ "The name of the model you wish to convert, it must be one of the supported resnet* architecture,"
+ " currently: van-tiny/small/base/large. If `None`, all of them will the converted."
+ ),
+ )
+ parser.add_argument(
+ "--pytorch_dump_folder_path",
+ default=None,
+ type=Path,
+ required=True,
+ help="Path to the output PyTorch model directory.",
+ )
+ parser.add_argument(
+ "--van_dir",
+ required=True,
+ type=Path,
+ help=(
+ "A path to VAN's original implementation directory. You can download from here:"
+ " https://github.com/Visual-Attention-Network/VAN-Classification"
+ ),
+ )
+ parser.add_argument(
+ "--push_to_hub",
+ default=True,
+ type=bool,
+ required=False,
+ help="If True, push model and image processor to the hub.",
+ )
+
+ args = parser.parse_args()
+ pytorch_dump_folder_path: Path = args.pytorch_dump_folder_path
+ pytorch_dump_folder_path.mkdir(exist_ok=True, parents=True)
+ van_dir = args.van_dir
+ # append the path to the parents to maskformer dir
+ sys.path.append(str(van_dir.parent))
+ from van.models.van import van_base, van_large, van_small, van_tiny
+
+ convert_weights_and_push(pytorch_dump_folder_path, args.model_name, args.push_to_hub)
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/van/modeling_van.py b/docs/transformers/build/lib/transformers/models/deprecated/van/modeling_van.py
new file mode 100644
index 0000000000000000000000000000000000000000..1da03cb544d467d9fdbe9b5258fabaccdedd7eff
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/van/modeling_van.py
@@ -0,0 +1,541 @@
+# coding=utf-8
+# Copyright 2022 BNRist (Tsinghua University), TKLNDST (Nankai University) and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Visual Attention Network (VAN) model."""
+
+import math
+from collections import OrderedDict
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from ....activations import ACT2FN
+from ....modeling_outputs import (
+ BaseModelOutputWithNoAttention,
+ BaseModelOutputWithPoolingAndNoAttention,
+ ImageClassifierOutputWithNoAttention,
+)
+from ....modeling_utils import PreTrainedModel
+from ....utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
+from .configuration_van import VanConfig
+
+
+logger = logging.get_logger(__name__)
+
+# General docstring
+_CONFIG_FOR_DOC = "VanConfig"
+
+# Base docstring
+_CHECKPOINT_FOR_DOC = "Visual-Attention-Network/van-base"
+_EXPECTED_OUTPUT_SHAPE = [1, 512, 7, 7]
+
+# Image classification docstring
+_IMAGE_CLASS_CHECKPOINT = "Visual-Attention-Network/van-base"
+_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
+
+
+def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
+ """
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+ Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
+ however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
+ layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
+ argument.
+ """
+ if drop_prob == 0.0 or not training:
+ return input
+ keep_prob = 1 - drop_prob
+ shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
+ random_tensor.floor_() # binarize
+ output = input.div(keep_prob) * random_tensor
+ return output
+
+
+class VanDropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob: Optional[float] = None) -> None:
+ super().__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ return drop_path(hidden_states, self.drop_prob, self.training)
+
+ def extra_repr(self) -> str:
+ return "p={}".format(self.drop_prob)
+
+
+class VanOverlappingPatchEmbedder(nn.Module):
+ """
+ Downsamples the input using a patchify operation with a `stride` of 4 by default making adjacent windows overlap by
+ half of the area. From [PVTv2: Improved Baselines with Pyramid Vision
+ Transformer](https://arxiv.org/abs/2106.13797).
+ """
+
+ def __init__(self, in_channels: int, hidden_size: int, patch_size: int = 7, stride: int = 4):
+ super().__init__()
+ self.convolution = nn.Conv2d(
+ in_channels, hidden_size, kernel_size=patch_size, stride=stride, padding=patch_size // 2
+ )
+ self.normalization = nn.BatchNorm2d(hidden_size)
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ hidden_state = self.convolution(input)
+ hidden_state = self.normalization(hidden_state)
+ return hidden_state
+
+
+class VanMlpLayer(nn.Module):
+ """
+ MLP with depth-wise convolution, from [PVTv2: Improved Baselines with Pyramid Vision
+ Transformer](https://arxiv.org/abs/2106.13797).
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ hidden_size: int,
+ out_channels: int,
+ hidden_act: str = "gelu",
+ dropout_rate: float = 0.5,
+ ):
+ super().__init__()
+ self.in_dense = nn.Conv2d(in_channels, hidden_size, kernel_size=1)
+ self.depth_wise = nn.Conv2d(hidden_size, hidden_size, kernel_size=3, padding=1, groups=hidden_size)
+ self.activation = ACT2FN[hidden_act]
+ self.dropout1 = nn.Dropout(dropout_rate)
+ self.out_dense = nn.Conv2d(hidden_size, out_channels, kernel_size=1)
+ self.dropout2 = nn.Dropout(dropout_rate)
+
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
+ hidden_state = self.in_dense(hidden_state)
+ hidden_state = self.depth_wise(hidden_state)
+ hidden_state = self.activation(hidden_state)
+ hidden_state = self.dropout1(hidden_state)
+ hidden_state = self.out_dense(hidden_state)
+ hidden_state = self.dropout2(hidden_state)
+ return hidden_state
+
+
+class VanLargeKernelAttention(nn.Module):
+ """
+ Basic Large Kernel Attention (LKA).
+ """
+
+ def __init__(self, hidden_size: int):
+ super().__init__()
+ self.depth_wise = nn.Conv2d(hidden_size, hidden_size, kernel_size=5, padding=2, groups=hidden_size)
+ self.depth_wise_dilated = nn.Conv2d(
+ hidden_size, hidden_size, kernel_size=7, dilation=3, padding=9, groups=hidden_size
+ )
+ self.point_wise = nn.Conv2d(hidden_size, hidden_size, kernel_size=1)
+
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
+ hidden_state = self.depth_wise(hidden_state)
+ hidden_state = self.depth_wise_dilated(hidden_state)
+ hidden_state = self.point_wise(hidden_state)
+ return hidden_state
+
+
+class VanLargeKernelAttentionLayer(nn.Module):
+ """
+ Computes attention using Large Kernel Attention (LKA) and attends the input.
+ """
+
+ def __init__(self, hidden_size: int):
+ super().__init__()
+ self.attention = VanLargeKernelAttention(hidden_size)
+
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
+ attention = self.attention(hidden_state)
+ attended = hidden_state * attention
+ return attended
+
+
+class VanSpatialAttentionLayer(nn.Module):
+ """
+ Van spatial attention layer composed by projection (via conv) -> act -> Large Kernel Attention (LKA) attention ->
+ projection (via conv) + residual connection.
+ """
+
+ def __init__(self, hidden_size: int, hidden_act: str = "gelu"):
+ super().__init__()
+ self.pre_projection = nn.Sequential(
+ OrderedDict(
+ [
+ ("conv", nn.Conv2d(hidden_size, hidden_size, kernel_size=1)),
+ ("act", ACT2FN[hidden_act]),
+ ]
+ )
+ )
+ self.attention_layer = VanLargeKernelAttentionLayer(hidden_size)
+ self.post_projection = nn.Conv2d(hidden_size, hidden_size, kernel_size=1)
+
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
+ residual = hidden_state
+ hidden_state = self.pre_projection(hidden_state)
+ hidden_state = self.attention_layer(hidden_state)
+ hidden_state = self.post_projection(hidden_state)
+ hidden_state = hidden_state + residual
+ return hidden_state
+
+
+class VanLayerScaling(nn.Module):
+ """
+ Scales the inputs by a learnable parameter initialized by `initial_value`.
+ """
+
+ def __init__(self, hidden_size: int, initial_value: float = 1e-2):
+ super().__init__()
+ self.weight = nn.Parameter(initial_value * torch.ones((hidden_size)), requires_grad=True)
+
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
+ # unsqueezing for broadcasting
+ hidden_state = self.weight.unsqueeze(-1).unsqueeze(-1) * hidden_state
+ return hidden_state
+
+
+class VanLayer(nn.Module):
+ """
+ Van layer composed by normalization layers, large kernel attention (LKA) and a multi layer perceptron (MLP).
+ """
+
+ def __init__(
+ self,
+ config: VanConfig,
+ hidden_size: int,
+ mlp_ratio: int = 4,
+ drop_path_rate: float = 0.5,
+ ):
+ super().__init__()
+ self.drop_path = VanDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
+ self.pre_normomalization = nn.BatchNorm2d(hidden_size)
+ self.attention = VanSpatialAttentionLayer(hidden_size, config.hidden_act)
+ self.attention_scaling = VanLayerScaling(hidden_size, config.layer_scale_init_value)
+ self.post_normalization = nn.BatchNorm2d(hidden_size)
+ self.mlp = VanMlpLayer(
+ hidden_size, hidden_size * mlp_ratio, hidden_size, config.hidden_act, config.dropout_rate
+ )
+ self.mlp_scaling = VanLayerScaling(hidden_size, config.layer_scale_init_value)
+
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
+ residual = hidden_state
+ # attention
+ hidden_state = self.pre_normomalization(hidden_state)
+ hidden_state = self.attention(hidden_state)
+ hidden_state = self.attention_scaling(hidden_state)
+ hidden_state = self.drop_path(hidden_state)
+ # residual connection
+ hidden_state = residual + hidden_state
+ residual = hidden_state
+ # mlp
+ hidden_state = self.post_normalization(hidden_state)
+ hidden_state = self.mlp(hidden_state)
+ hidden_state = self.mlp_scaling(hidden_state)
+ hidden_state = self.drop_path(hidden_state)
+ # residual connection
+ hidden_state = residual + hidden_state
+ return hidden_state
+
+
+class VanStage(nn.Module):
+ """
+ VanStage, consisting of multiple layers.
+ """
+
+ def __init__(
+ self,
+ config: VanConfig,
+ in_channels: int,
+ hidden_size: int,
+ patch_size: int,
+ stride: int,
+ depth: int,
+ mlp_ratio: int = 4,
+ drop_path_rate: float = 0.0,
+ ):
+ super().__init__()
+ self.embeddings = VanOverlappingPatchEmbedder(in_channels, hidden_size, patch_size, stride)
+ self.layers = nn.Sequential(
+ *[
+ VanLayer(
+ config,
+ hidden_size,
+ mlp_ratio=mlp_ratio,
+ drop_path_rate=drop_path_rate,
+ )
+ for _ in range(depth)
+ ]
+ )
+ self.normalization = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
+
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
+ hidden_state = self.embeddings(hidden_state)
+ hidden_state = self.layers(hidden_state)
+ # rearrange b c h w -> b (h w) c
+ batch_size, hidden_size, height, width = hidden_state.shape
+ hidden_state = hidden_state.flatten(2).transpose(1, 2)
+ hidden_state = self.normalization(hidden_state)
+ # rearrange b (h w) c- > b c h w
+ hidden_state = hidden_state.view(batch_size, height, width, hidden_size).permute(0, 3, 1, 2)
+ return hidden_state
+
+
+class VanEncoder(nn.Module):
+ """
+ VanEncoder, consisting of multiple stages.
+ """
+
+ def __init__(self, config: VanConfig):
+ super().__init__()
+ self.stages = nn.ModuleList([])
+ patch_sizes = config.patch_sizes
+ strides = config.strides
+ hidden_sizes = config.hidden_sizes
+ depths = config.depths
+ mlp_ratios = config.mlp_ratios
+ drop_path_rates = [
+ x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")
+ ]
+
+ for num_stage, (patch_size, stride, hidden_size, depth, mlp_expantion, drop_path_rate) in enumerate(
+ zip(patch_sizes, strides, hidden_sizes, depths, mlp_ratios, drop_path_rates)
+ ):
+ is_first_stage = num_stage == 0
+ in_channels = hidden_sizes[num_stage - 1]
+ if is_first_stage:
+ in_channels = config.num_channels
+ self.stages.append(
+ VanStage(
+ config,
+ in_channels,
+ hidden_size,
+ patch_size=patch_size,
+ stride=stride,
+ depth=depth,
+ mlp_ratio=mlp_expantion,
+ drop_path_rate=drop_path_rate,
+ )
+ )
+
+ def forward(
+ self,
+ hidden_state: torch.Tensor,
+ output_hidden_states: Optional[bool] = False,
+ return_dict: Optional[bool] = True,
+ ) -> Union[Tuple, BaseModelOutputWithNoAttention]:
+ all_hidden_states = () if output_hidden_states else None
+
+ for _, stage_module in enumerate(self.stages):
+ hidden_state = stage_module(hidden_state)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_state,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_state, all_hidden_states] if v is not None)
+
+ return BaseModelOutputWithNoAttention(last_hidden_state=hidden_state, hidden_states=all_hidden_states)
+
+
+class VanPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = VanConfig
+ base_model_prefix = "van"
+ main_input_name = "pixel_values"
+ supports_gradient_checkpointing = True
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, nn.Linear):
+ nn.init.trunc_normal_(module.weight, std=self.config.initializer_range)
+ if isinstance(module, nn.Linear) and module.bias is not None:
+ nn.init.constant_(module.bias, 0)
+ elif isinstance(module, nn.LayerNorm):
+ nn.init.constant_(module.bias, 0)
+ nn.init.constant_(module.weight, 1.0)
+ elif isinstance(module, nn.Conv2d):
+ fan_out = module.kernel_size[0] * module.kernel_size[1] * module.out_channels
+ fan_out //= module.groups
+ module.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
+ if module.bias is not None:
+ module.bias.data.zero_()
+
+
+VAN_START_DOCSTRING = r"""
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+ behavior.
+
+ Parameters:
+ config ([`VanConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+VAN_INPUTS_DOCSTRING = r"""
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
+ [`ConvNextImageProcessor.__call__`] for details.
+
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all stages. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare VAN model outputting raw features without any specific head on top. Note, VAN does not have an embedding"
+ " layer.",
+ VAN_START_DOCSTRING,
+)
+class VanModel(VanPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.config = config
+ self.encoder = VanEncoder(config)
+ # final layernorm layer
+ self.layernorm = nn.LayerNorm(config.hidden_sizes[-1], eps=config.layer_norm_eps)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(VAN_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=BaseModelOutputWithPoolingAndNoAttention,
+ config_class=_CONFIG_FOR_DOC,
+ modality="vision",
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
+ )
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor],
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPoolingAndNoAttention]:
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ encoder_outputs = self.encoder(
+ pixel_values,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ last_hidden_state = encoder_outputs[0]
+ # global average pooling, n c w h -> n c
+ pooled_output = last_hidden_state.mean(dim=[-2, -1])
+
+ if not return_dict:
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPoolingAndNoAttention(
+ last_hidden_state=last_hidden_state,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ )
+
+
+@add_start_docstrings(
+ """
+ VAN Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
+ ImageNet.
+ """,
+ VAN_START_DOCSTRING,
+)
+class VanForImageClassification(VanPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.van = VanModel(config)
+ # Classifier head
+ self.classifier = (
+ nn.Linear(config.hidden_sizes[-1], config.num_labels) if config.num_labels > 0 else nn.Identity()
+ )
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(VAN_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_IMAGE_CLASS_CHECKPOINT,
+ output_type=ImageClassifierOutputWithNoAttention,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
+ )
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, ImageClassifierOutputWithNoAttention]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.van(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict)
+
+ pooled_output = outputs.pooler_output if return_dict else outputs[1]
+
+ logits = self.classifier(pooled_output)
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.config.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.config.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return ImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states)
+
+
+__all__ = ["VanForImageClassification", "VanModel", "VanPreTrainedModel"]
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/vit_hybrid/__init__.py b/docs/transformers/build/lib/transformers/models/deprecated/vit_hybrid/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5bd93aa4dabea9b2adbc54eeeb3664de589f43a
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/vit_hybrid/__init__.py
@@ -0,0 +1,28 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ....utils import _LazyModule
+from ....utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_vit_hybrid import *
+ from .image_processing_vit_hybrid import *
+ from .modeling_vit_hybrid import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/vit_hybrid/configuration_vit_hybrid.py b/docs/transformers/build/lib/transformers/models/deprecated/vit_hybrid/configuration_vit_hybrid.py
new file mode 100644
index 0000000000000000000000000000000000000000..65b6a3e5ef515218afb72e8b871c4ffb2465666f
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/vit_hybrid/configuration_vit_hybrid.py
@@ -0,0 +1,172 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""ViT Hybrid model configuration"""
+
+from ....configuration_utils import PretrainedConfig
+from ....utils import logging
+from ...auto.configuration_auto import CONFIG_MAPPING
+from ...bit import BitConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+class ViTHybridConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`ViTHybridModel`]. It is used to instantiate a ViT
+ Hybrid model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the ViT Hybrid
+ [google/vit-hybrid-base-bit-384](https://huggingface.co/google/vit-hybrid-base-bit-384) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ backbone_config (`Union[Dict[str, Any], PretrainedConfig]`, *optional*):
+ The configuration of the backbone in a dictionary or the config object of the backbone.
+ backbone (`str`, *optional*):
+ Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
+ will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
+ is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
+ use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
+ Whether to use pretrained weights for the backbone.
+ use_timm_backbone (`bool`, *optional*, defaults to `False`):
+ Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
+ library.
+ backbone_kwargs (`dict`, *optional*):
+ Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
+ e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ intermediate_size (`int`, *optional*, defaults to 3072):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
+ The epsilon used by the layer normalization layers.
+ image_size (`int`, *optional*, defaults to 224):
+ The size (resolution) of each image.
+ patch_size (`int`, *optional*, defaults to 1):
+ The size (resolution) of each patch.
+ num_channels (`int`, *optional*, defaults to 3):
+ The number of input channels.
+ backbone_featmap_shape (`List[int]`, *optional*, defaults to `[1, 1024, 24, 24]`):
+ Used only for the `hybrid` embedding type. The shape of the feature maps of the backbone.
+ qkv_bias (`bool`, *optional*, defaults to `True`):
+ Whether to add a bias to the queries, keys and values.
+
+ Example:
+
+ ```python
+ >>> from transformers import ViTHybridConfig, ViTHybridModel
+
+ >>> # Initializing a ViT Hybrid vit-hybrid-base-bit-384 style configuration
+ >>> configuration = ViTHybridConfig()
+
+ >>> # Initializing a model (with random weights) from the vit-hybrid-base-bit-384 style configuration
+ >>> model = ViTHybridModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "vit-hybrid"
+
+ def __init__(
+ self,
+ backbone_config=None,
+ backbone=None,
+ use_pretrained_backbone=False,
+ use_timm_backbone=False,
+ backbone_kwargs=None,
+ hidden_size=768,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ intermediate_size=3072,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.0,
+ attention_probs_dropout_prob=0.0,
+ initializer_range=0.02,
+ layer_norm_eps=1e-12,
+ image_size=224,
+ patch_size=1,
+ num_channels=3,
+ backbone_featmap_shape=[1, 1024, 24, 24],
+ qkv_bias=True,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ if use_pretrained_backbone:
+ raise ValueError("Pretrained backbones are not supported yet.")
+
+ if backbone_config is not None and backbone is not None:
+ raise ValueError("You can't specify both `backbone` and `backbone_config`.")
+
+ if backbone_config is None and backbone is None:
+ logger.info("`backbone_config` is `None`. Initializing the config with a `BiT` backbone.")
+ backbone_config = {
+ "global_padding": "same",
+ "layer_type": "bottleneck",
+ "depths": [3, 4, 9],
+ "out_features": ["stage3"],
+ "embedding_dynamic_padding": True,
+ }
+
+ if backbone_kwargs is not None and backbone_kwargs and backbone_config is not None:
+ raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.")
+
+ if isinstance(backbone_config, dict):
+ if "model_type" in backbone_config:
+ backbone_config_class = CONFIG_MAPPING[backbone_config["model_type"]]
+ else:
+ logger.info(
+ "`model_type` is not found in `backbone_config`. Use `Bit` as the backbone configuration class."
+ )
+ backbone_config_class = BitConfig
+ backbone_config = backbone_config_class(**backbone_config)
+
+ self.backbone_featmap_shape = backbone_featmap_shape
+ self.backbone_config = backbone_config
+ self.backbone = backbone
+ self.use_pretrained_backbone = use_pretrained_backbone
+ self.use_timm_backbone = use_timm_backbone
+ self.backbone_kwargs = backbone_kwargs
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.qkv_bias = qkv_bias
+
+
+__all__ = ["ViTHybridConfig"]
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/vit_hybrid/convert_vit_hybrid_timm_to_pytorch.py b/docs/transformers/build/lib/transformers/models/deprecated/vit_hybrid/convert_vit_hybrid_timm_to_pytorch.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d717d74c961e509697adab7623d2bc3fe64a1cf
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/vit_hybrid/convert_vit_hybrid_timm_to_pytorch.py
@@ -0,0 +1,282 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Convert ViT hybrid checkpoints from the timm library."""
+
+import argparse
+import json
+from pathlib import Path
+
+import requests
+import timm
+import torch
+from huggingface_hub import hf_hub_download
+from PIL import Image
+from timm.data import resolve_data_config
+from timm.data.transforms_factory import create_transform
+
+from transformers import (
+ BitConfig,
+ ViTHybridConfig,
+ ViTHybridForImageClassification,
+ ViTHybridImageProcessor,
+ ViTHybridModel,
+)
+from transformers.image_utils import PILImageResampling
+from transformers.utils import logging
+
+
+logging.set_verbosity_info()
+logger = logging.get_logger(__name__)
+
+
+# here we list all keys to be renamed (original name on the left, our name on the right)
+def create_rename_keys(config, base_model=False):
+ rename_keys = []
+
+ # fmt: off
+ # stem:
+ rename_keys.append(("cls_token", "vit.embeddings.cls_token"))
+ rename_keys.append(("pos_embed", "vit.embeddings.position_embeddings"))
+
+ rename_keys.append(("patch_embed.proj.weight", "vit.embeddings.patch_embeddings.projection.weight"))
+ rename_keys.append(("patch_embed.proj.bias", "vit.embeddings.patch_embeddings.projection.bias"))
+
+ # backbone
+ rename_keys.append(("patch_embed.backbone.stem.conv.weight", "vit.embeddings.patch_embeddings.backbone.bit.embedder.convolution.weight"))
+ rename_keys.append(("patch_embed.backbone.stem.norm.weight", "vit.embeddings.patch_embeddings.backbone.bit.embedder.norm.weight"))
+ rename_keys.append(("patch_embed.backbone.stem.norm.bias", "vit.embeddings.patch_embeddings.backbone.bit.embedder.norm.bias"))
+
+ for stage_idx in range(len(config.backbone_config.depths)):
+ for layer_idx in range(config.backbone_config.depths[stage_idx]):
+ rename_keys.append((f"patch_embed.backbone.stages.{stage_idx}.blocks.{layer_idx}.conv1.weight", f"vit.embeddings.patch_embeddings.backbone.bit.encoder.stages.{stage_idx}.layers.{layer_idx}.conv1.weight"))
+ rename_keys.append((f"patch_embed.backbone.stages.{stage_idx}.blocks.{layer_idx}.norm1.weight", f"vit.embeddings.patch_embeddings.backbone.bit.encoder.stages.{stage_idx}.layers.{layer_idx}.norm1.weight"))
+ rename_keys.append((f"patch_embed.backbone.stages.{stage_idx}.blocks.{layer_idx}.norm1.bias", f"vit.embeddings.patch_embeddings.backbone.bit.encoder.stages.{stage_idx}.layers.{layer_idx}.norm1.bias"))
+ rename_keys.append((f"patch_embed.backbone.stages.{stage_idx}.blocks.{layer_idx}.conv2.weight", f"vit.embeddings.patch_embeddings.backbone.bit.encoder.stages.{stage_idx}.layers.{layer_idx}.conv2.weight"))
+ rename_keys.append((f"patch_embed.backbone.stages.{stage_idx}.blocks.{layer_idx}.norm2.weight", f"vit.embeddings.patch_embeddings.backbone.bit.encoder.stages.{stage_idx}.layers.{layer_idx}.norm2.weight"))
+ rename_keys.append((f"patch_embed.backbone.stages.{stage_idx}.blocks.{layer_idx}.norm2.bias", f"vit.embeddings.patch_embeddings.backbone.bit.encoder.stages.{stage_idx}.layers.{layer_idx}.norm2.bias"))
+ rename_keys.append((f"patch_embed.backbone.stages.{stage_idx}.blocks.{layer_idx}.conv3.weight", f"vit.embeddings.patch_embeddings.backbone.bit.encoder.stages.{stage_idx}.layers.{layer_idx}.conv3.weight"))
+ rename_keys.append((f"patch_embed.backbone.stages.{stage_idx}.blocks.{layer_idx}.norm3.weight", f"vit.embeddings.patch_embeddings.backbone.bit.encoder.stages.{stage_idx}.layers.{layer_idx}.norm3.weight"))
+ rename_keys.append((f"patch_embed.backbone.stages.{stage_idx}.blocks.{layer_idx}.norm3.bias", f"vit.embeddings.patch_embeddings.backbone.bit.encoder.stages.{stage_idx}.layers.{layer_idx}.norm3.bias"))
+
+ rename_keys.append((f"patch_embed.backbone.stages.{stage_idx}.blocks.0.downsample.conv.weight", f"vit.embeddings.patch_embeddings.backbone.bit.encoder.stages.{stage_idx}.layers.0.downsample.conv.weight"))
+ rename_keys.append((f"patch_embed.backbone.stages.{stage_idx}.blocks.0.downsample.norm.weight", f"vit.embeddings.patch_embeddings.backbone.bit.encoder.stages.{stage_idx}.layers.0.downsample.norm.weight"))
+ rename_keys.append((f"patch_embed.backbone.stages.{stage_idx}.blocks.0.downsample.norm.bias", f"vit.embeddings.patch_embeddings.backbone.bit.encoder.stages.{stage_idx}.layers.0.downsample.norm.bias"))
+
+ # transformer encoder
+ for i in range(config.num_hidden_layers):
+ # encoder layers: output projection, 2 feedforward neural networks and 2 layernorms
+ rename_keys.append((f"blocks.{i}.norm1.weight", f"vit.encoder.layer.{i}.layernorm_before.weight"))
+ rename_keys.append((f"blocks.{i}.norm1.bias", f"vit.encoder.layer.{i}.layernorm_before.bias"))
+ rename_keys.append((f"blocks.{i}.attn.proj.weight", f"vit.encoder.layer.{i}.attention.output.dense.weight"))
+ rename_keys.append((f"blocks.{i}.attn.proj.bias", f"vit.encoder.layer.{i}.attention.output.dense.bias"))
+ rename_keys.append((f"blocks.{i}.norm2.weight", f"vit.encoder.layer.{i}.layernorm_after.weight"))
+ rename_keys.append((f"blocks.{i}.norm2.bias", f"vit.encoder.layer.{i}.layernorm_after.bias"))
+ rename_keys.append((f"blocks.{i}.mlp.fc1.weight", f"vit.encoder.layer.{i}.intermediate.dense.weight"))
+ rename_keys.append((f"blocks.{i}.mlp.fc1.bias", f"vit.encoder.layer.{i}.intermediate.dense.bias"))
+ rename_keys.append((f"blocks.{i}.mlp.fc2.weight", f"vit.encoder.layer.{i}.output.dense.weight"))
+ rename_keys.append((f"blocks.{i}.mlp.fc2.bias", f"vit.encoder.layer.{i}.output.dense.bias"))
+
+ if base_model:
+ # layernorm + pooler
+ rename_keys.extend(
+ [
+ ("norm.weight", "layernorm.weight"),
+ ("norm.bias", "layernorm.bias"),
+ ("pre_logits.fc.weight", "pooler.dense.weight"),
+ ("pre_logits.fc.bias", "pooler.dense.bias"),
+ ]
+ )
+
+ # if just the base model, we should remove "vit" from all keys that start with "vit"
+ rename_keys = [(pair[0], pair[1][4:]) if pair[1].startswith("vit") else pair for pair in rename_keys]
+ else:
+ # layernorm + classification head
+ rename_keys.extend(
+ [
+ ("norm.weight", "vit.layernorm.weight"),
+ ("norm.bias", "vit.layernorm.bias"),
+ ("head.weight", "classifier.weight"),
+ ("head.bias", "classifier.bias"),
+ ]
+ )
+ # fmt: on
+
+ return rename_keys
+
+
+# we split up the matrix of each encoder layer into queries, keys and values
+def read_in_q_k_v(state_dict, config, base_model=False):
+ for i in range(config.num_hidden_layers):
+ if base_model:
+ prefix = ""
+ else:
+ prefix = "vit."
+ # read in weights + bias of input projection layer (in timm, this is a single matrix + bias)
+ in_proj_weight = state_dict.pop(f"blocks.{i}.attn.qkv.weight")
+ in_proj_bias = state_dict.pop(f"blocks.{i}.attn.qkv.bias")
+ # next, add query, keys and values (in that order) to the state dict
+ state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[
+ : config.hidden_size, :
+ ]
+ state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.bias"] = in_proj_bias[: config.hidden_size]
+ state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[
+ config.hidden_size : config.hidden_size * 2, :
+ ]
+ state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.bias"] = in_proj_bias[
+ config.hidden_size : config.hidden_size * 2
+ ]
+ state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[
+ -config.hidden_size :, :
+ ]
+ state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.bias"] = in_proj_bias[-config.hidden_size :]
+
+
+def remove_classification_head_(state_dict):
+ ignore_keys = ["head.weight", "head.bias"]
+ for k in ignore_keys:
+ state_dict.pop(k, None)
+
+
+def rename_key(dct, old, new):
+ val = dct.pop(old)
+ dct[new] = val
+
+
+# We will verify our results on an image of cute cats
+def prepare_img():
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ im = Image.open(requests.get(url, stream=True).raw)
+ return im
+
+
+@torch.no_grad()
+def convert_vit_checkpoint(vit_name, pytorch_dump_folder_path, push_to_hub=False):
+ """
+ Copy/paste/tweak model's weights to our ViT structure.
+ """
+
+ # define default ViT hybrid configuration
+ backbone_config = BitConfig(
+ global_padding="same",
+ layer_type="bottleneck",
+ depths=(3, 4, 9),
+ out_features=["stage3"],
+ embedding_dynamic_padding=True,
+ )
+ config = ViTHybridConfig(backbone_config=backbone_config, image_size=384, num_labels=1000)
+ base_model = False
+
+ # load original model from timm
+ timm_model = timm.create_model(vit_name, pretrained=True)
+ timm_model.eval()
+
+ # load state_dict of original model, remove and rename some keys
+ state_dict = timm_model.state_dict()
+ if base_model:
+ remove_classification_head_(state_dict)
+ rename_keys = create_rename_keys(config, base_model)
+ for src, dest in rename_keys:
+ rename_key(state_dict, src, dest)
+ read_in_q_k_v(state_dict, config, base_model)
+
+ repo_id = "huggingface/label-files"
+ filename = "imagenet-1k-id2label.json"
+ id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
+ id2label = {int(k): v for k, v in id2label.items()}
+ config.id2label = id2label
+ config.label2id = {v: k for k, v in id2label.items()}
+
+ # load HuggingFace model
+ if vit_name[-5:] == "in21k":
+ model = ViTHybridModel(config).eval()
+ else:
+ model = ViTHybridForImageClassification(config).eval()
+ model.load_state_dict(state_dict)
+
+ # create image processor
+ transform = create_transform(**resolve_data_config({}, model=timm_model))
+ timm_transforms = transform.transforms
+
+ pillow_resamplings = {
+ "bilinear": PILImageResampling.BILINEAR,
+ "bicubic": PILImageResampling.BICUBIC,
+ "nearest": PILImageResampling.NEAREST,
+ }
+
+ processor = ViTHybridImageProcessor(
+ do_resize=True,
+ size={"shortest_edge": timm_transforms[0].size},
+ resample=pillow_resamplings[timm_transforms[0].interpolation.value],
+ do_center_crop=True,
+ crop_size={"height": timm_transforms[1].size[0], "width": timm_transforms[1].size[1]},
+ do_normalize=True,
+ image_mean=timm_transforms[-1].mean.tolist(),
+ image_std=timm_transforms[-1].std.tolist(),
+ )
+
+ image = prepare_img()
+ timm_pixel_values = transform(image).unsqueeze(0)
+ pixel_values = processor(image, return_tensors="pt").pixel_values
+
+ # verify pixel values
+ assert torch.allclose(timm_pixel_values, pixel_values)
+
+ # verify logits
+ with torch.no_grad():
+ outputs = model(pixel_values)
+ logits = outputs.logits
+
+ print("Predicted class:", logits.argmax(-1).item())
+ if base_model:
+ timm_pooled_output = timm_model.forward_features(pixel_values)
+ assert timm_pooled_output.shape == outputs.pooler_output.shape
+ assert torch.allclose(timm_pooled_output, outputs.pooler_output, atol=1e-3)
+ else:
+ timm_logits = timm_model(pixel_values)
+ assert timm_logits.shape == outputs.logits.shape
+ assert torch.allclose(timm_logits, outputs.logits, atol=1e-3)
+ print("Looks ok!")
+
+ if pytorch_dump_folder_path is not None:
+ Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
+ print(f"Saving model {vit_name} to {pytorch_dump_folder_path}")
+ model.save_pretrained(pytorch_dump_folder_path)
+ print(f"Saving processor to {pytorch_dump_folder_path}")
+ processor.save_pretrained(pytorch_dump_folder_path)
+
+ if push_to_hub:
+ print(f"Pushing model and processor to the hub {vit_name}")
+ model.push_to_hub(f"ybelkada/{vit_name}")
+ processor.push_to_hub(f"ybelkada/{vit_name}")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ # Required parameters
+ parser.add_argument(
+ "--vit_name",
+ default="vit_base_r50_s16_384",
+ type=str,
+ help="Name of the hybrid ViT timm model you'd like to convert.",
+ )
+ parser.add_argument(
+ "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
+ )
+ parser.add_argument(
+ "--push_to_hub", action="store_true", help="Whether to upload the model to the HuggingFace hub."
+ )
+
+ args = parser.parse_args()
+ convert_vit_checkpoint(args.vit_name, args.pytorch_dump_folder_path, args.push_to_hub)
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/vit_hybrid/image_processing_vit_hybrid.py b/docs/transformers/build/lib/transformers/models/deprecated/vit_hybrid/image_processing_vit_hybrid.py
new file mode 100644
index 0000000000000000000000000000000000000000..72410878933d7bd9726794588dc241c08b405a25
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/vit_hybrid/image_processing_vit_hybrid.py
@@ -0,0 +1,341 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Image processor class for ViT hybrid."""
+
+from typing import Dict, List, Optional, Union
+
+import numpy as np
+
+from ....image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
+from ....image_transforms import (
+ convert_to_rgb,
+ get_resize_output_image_size,
+ resize,
+ to_channel_dimension_format,
+)
+from ....image_utils import (
+ OPENAI_CLIP_MEAN,
+ OPENAI_CLIP_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ infer_channel_dimension_format,
+ is_scaled_image,
+ make_list_of_images,
+ to_numpy_array,
+ valid_images,
+ validate_kwargs,
+ validate_preprocess_arguments,
+)
+from ....utils import TensorType, is_vision_available, logging
+
+
+logger = logging.get_logger(__name__)
+
+
+if is_vision_available():
+ import PIL
+
+
+class ViTHybridImageProcessor(BaseImageProcessor):
+ r"""
+ Constructs a ViT Hybrid image processor.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
+ `do_resize` in the `preprocess` method.
+ size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 224}`):
+ Size of the image after resizing. The shortest edge of the image is resized to size["shortest_edge"], with
+ the longest edge resized to keep the input aspect ratio. Can be overridden by `size` in the `preprocess`
+ method.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
+ Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
+ do_center_crop (`bool`, *optional*, defaults to `True`):
+ Whether to center crop the image to the specified `crop_size`. Can be overridden by `do_center_crop` in the
+ `preprocess` method.
+ crop_size (`Dict[str, int]` *optional*, defaults to 224):
+ Size of the output image after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess`
+ method.
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
+ the `preprocess` method.
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
+ Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
+ method.
+ do_normalize:
+ Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
+ image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
+ Can be overridden by the `image_std` parameter in the `preprocess` method.
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
+ Whether to convert the image to RGB.
+ """
+
+ model_input_names = ["pixel_values"]
+
+ def __init__(
+ self,
+ do_resize: bool = True,
+ size: Dict[str, int] = None,
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ do_center_crop: bool = True,
+ crop_size: Dict[str, int] = None,
+ do_rescale: bool = True,
+ rescale_factor: Union[int, float] = 1 / 255,
+ do_normalize: bool = True,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ do_convert_rgb: bool = True,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ size = size if size is not None else {"shortest_edge": 224}
+ size = get_size_dict(size, default_to_square=False)
+ crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
+ crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size")
+
+ self.do_resize = do_resize
+ self.size = size
+ self.resample = resample
+ self.do_center_crop = do_center_crop
+ self.crop_size = crop_size
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
+ self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
+ self.do_convert_rgb = do_convert_rgb
+ self._valid_processor_keys = [
+ "images",
+ "do_resize",
+ "size",
+ "resample",
+ "do_center_crop",
+ "crop_size",
+ "do_rescale",
+ "rescale_factor",
+ "do_normalize",
+ "image_mean",
+ "image_std",
+ "do_convert_rgb",
+ "return_tensors",
+ "data_format",
+ "input_data_format",
+ ]
+
+ def resize(
+ self,
+ image: np.ndarray,
+ size: Dict[str, int],
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ) -> np.ndarray:
+ """
+ Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge
+ resized to keep the input aspect ratio.
+
+ Args:
+ image (`np.ndarray`):
+ Image to resize.
+ size (`Dict[str, int]`):
+ Size of the output image.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
+ Resampling filter to use when resiizing the image.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred.
+ """
+ default_to_square = True
+ if "shortest_edge" in size:
+ size = size["shortest_edge"]
+ default_to_square = False
+ elif "height" in size and "width" in size:
+ size = (size["height"], size["width"])
+ else:
+ raise ValueError("Size must contain either 'shortest_edge' or 'height' and 'width'.")
+
+ output_size = get_resize_output_image_size(
+ image,
+ size=size,
+ default_to_square=default_to_square,
+ input_data_format=input_data_format,
+ )
+ return resize(
+ image,
+ size=output_size,
+ resample=resample,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ **kwargs,
+ )
+
+ def preprocess(
+ self,
+ images: ImageInput,
+ do_resize: Optional[bool] = None,
+ size: Dict[str, int] = None,
+ resample: PILImageResampling = None,
+ do_center_crop: Optional[bool] = None,
+ crop_size: Optional[int] = None,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ do_convert_rgb: Optional[bool] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ) -> PIL.Image.Image:
+ """
+ Preprocess an image or batch of images.
+
+ Args:
+ images (`ImageInput`):
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
+ Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
+ the longest edge resized to keep the input aspect ratio.
+ resample (`int`, *optional*, defaults to `self.resample`):
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
+ has an effect if `do_resize` is set to `True`.
+ do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
+ Whether to center crop the image.
+ crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`):
+ Size of the center crop. Only has an effect if `do_center_crop` is set to `True`.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image.
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
+ `True`.
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
+ Whether to convert the image to RGB.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - Unset: defaults to the channel dimension format of the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ """
+ do_resize = do_resize if do_resize is not None else self.do_resize
+ size = size if size is not None else self.size
+ size = get_size_dict(size, param_name="size", default_to_square=False)
+ resample = resample if resample is not None else self.resample
+ do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
+ crop_size = crop_size if crop_size is not None else self.crop_size
+ crop_size = get_size_dict(crop_size, param_name="crop_size", default_to_square=True)
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
+
+ images = make_list_of_images(images)
+
+ validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys)
+
+ if not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+ validate_preprocess_arguments(
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ do_center_crop=do_center_crop,
+ crop_size=crop_size,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ )
+
+ # PIL RGBA images are converted to RGB
+ if do_convert_rgb:
+ images = [convert_to_rgb(image) for image in images]
+
+ # All transformations expect numpy arrays.
+ images = [to_numpy_array(image) for image in images]
+
+ if do_rescale and is_scaled_image(images[0]):
+ logger.warning_once(
+ "It looks like you are trying to rescale already rescaled images. If the input"
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
+ )
+
+ if input_data_format is None:
+ # We assume that all images have the same channel dimension format.
+ input_data_format = infer_channel_dimension_format(images[0])
+
+ all_images = []
+ for image in images:
+ if do_resize:
+ image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
+
+ if do_center_crop:
+ image = self.center_crop(image=image, size=crop_size, input_data_format=input_data_format)
+
+ if do_rescale:
+ image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
+
+ if do_normalize:
+ image = self.normalize(
+ image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
+ )
+
+ all_images.append(image)
+ images = [
+ to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
+ for image in all_images
+ ]
+
+ data = {"pixel_values": images}
+ return BatchFeature(data=data, tensor_type=return_tensors)
+
+
+__all__ = ["ViTHybridImageProcessor"]
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/vit_hybrid/modeling_vit_hybrid.py b/docs/transformers/build/lib/transformers/models/deprecated/vit_hybrid/modeling_vit_hybrid.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4d2511019395f97ad790f9ab548667a4cac3b5b
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/vit_hybrid/modeling_vit_hybrid.py
@@ -0,0 +1,770 @@
+# coding=utf-8
+# Copyright 2022 Google AI, Ross Wightman, The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch ViT Hybrid model."""
+
+import collections.abc
+import math
+from typing import Dict, List, Optional, Set, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from ....activations import ACT2FN
+from ....modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
+from ....modeling_utils import PreTrainedModel
+from ....pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
+from ....utils import (
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ torch_int,
+)
+from ....utils.backbone_utils import load_backbone
+from .configuration_vit_hybrid import ViTHybridConfig
+
+
+logger = logging.get_logger(__name__)
+
+# General docstring
+_CONFIG_FOR_DOC = "ViTHybridConfig"
+
+# Base docstring
+_CHECKPOINT_FOR_DOC = "google/vit-hybrid-base-bit-384"
+_EXPECTED_OUTPUT_SHAPE = [1, 197, 768]
+
+# Image classification docstring
+_IMAGE_CLASS_CHECKPOINT = "google/vit-hybrid-base-bit-384"
+_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
+
+
+class ViTHybridEmbeddings(nn.Module):
+ """
+ Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
+ """
+
+ def __init__(self, config: ViTHybridConfig, use_mask_token: bool = False) -> None:
+ super().__init__()
+
+ self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
+ self.patch_embeddings = ViTHybridPatchEmbeddings(config)
+ num_patches = self.patch_embeddings.num_patches
+ self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size))
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ self.patch_size = config.patch_size
+ self.config = config
+
+ # Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
+ """
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
+ images. This method is also adapted to support torch.jit tracing.
+
+ Adapted from:
+ - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
+ - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
+ """
+
+ num_patches = embeddings.shape[1] - 1
+ num_positions = self.position_embeddings.shape[1] - 1
+
+ # always interpolate when tracing to ensure the exported model works for dynamic input shapes
+ if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
+ return self.position_embeddings
+
+ class_pos_embed = self.position_embeddings[:, :1]
+ patch_pos_embed = self.position_embeddings[:, 1:]
+
+ dim = embeddings.shape[-1]
+
+ new_height = height // self.patch_size
+ new_width = width // self.patch_size
+
+ sqrt_num_positions = torch_int(num_positions**0.5)
+ patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
+
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed,
+ size=(new_height, new_width),
+ mode="bicubic",
+ align_corners=False,
+ )
+
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+
+ return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
+
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
+ interpolate_pos_encoding: bool = False,
+ ) -> torch.Tensor:
+ batch_size, num_channels, height, width = pixel_values.shape
+ embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
+
+ if bool_masked_pos is not None:
+ seq_length = embeddings.shape[1]
+ mask_tokens = self.mask_token.expand(batch_size, seq_length, -1)
+ # replace the masked visual tokens by mask_tokens
+ mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
+ embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
+
+ # add the [CLS] token to the embedded patch tokens
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
+ embeddings = torch.cat((cls_tokens, embeddings), dim=1)
+
+ # add positional encoding to each token
+ if interpolate_pos_encoding:
+ embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
+ else:
+ embeddings = embeddings + self.position_embeddings
+
+ embeddings = self.dropout(embeddings)
+
+ return embeddings
+
+
+class ViTHybridPatchEmbeddings(nn.Module):
+ """
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
+ Transformer.
+ """
+
+ def __init__(self, config, feature_size=None):
+ super().__init__()
+ image_size, patch_size = config.image_size, config.patch_size
+ num_channels, hidden_size = config.num_channels, config.hidden_size
+
+ image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
+
+ self.backbone = load_backbone(config)
+ if self.backbone.config.model_type != "bit":
+ raise ValueError(f"Backbone model type {self.backbone.model_type} is not supported.")
+ feature_dim = self.backbone.channels[-1]
+
+ if feature_size is None:
+ feature_map = config.backbone_featmap_shape
+
+ feature_size = feature_map[-2:]
+ feature_dim = feature_map[1]
+ else:
+ feature_size = (
+ feature_size if isinstance(feature_size, collections.abc.Iterable) else (feature_size, feature_size)
+ )
+ feature_dim = self.backbone.channels[-1]
+
+ self.grid_size = (feature_size[0] // patch_size[0], feature_size[1] // patch_size[1])
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+
+ self.projection = nn.Conv2d(feature_dim, hidden_size, kernel_size=patch_size, stride=patch_size)
+
+ def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
+ _, num_channels, height, width = pixel_values.shape
+ if num_channels != self.num_channels:
+ raise ValueError(
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
+ )
+ if not interpolate_pos_encoding:
+ if height != self.image_size[0] or width != self.image_size[1]:
+ raise ValueError(
+ f"Input image size ({height}*{width}) doesn't match model"
+ f" ({self.image_size[0]}*{self.image_size[1]})."
+ )
+
+ features = self.backbone(pixel_values).feature_maps[-1]
+ embeddings = self.projection(features).flatten(2).transpose(1, 2)
+
+ return embeddings
+
+
+class ViTHybridSelfAttention(nn.Module):
+ def __init__(self, config: ViTHybridConfig) -> None:
+ super().__init__()
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+ raise ValueError(
+ f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
+ f"heads {config.num_attention_heads}."
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+ self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
+ x = x.view(new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward(
+ self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
+ mixed_query_layer = self.query(hidden_states)
+
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs = attention_probs * head_mask
+
+ context_layer = torch.matmul(attention_probs, value_layer)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(new_context_layer_shape)
+
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+ return outputs
+
+
+class ViTHybridSdpaSelfAttention(ViTHybridSelfAttention):
+ def __init__(self, config: ViTHybridConfig) -> None:
+ super().__init__(config)
+ self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
+
+ def forward(
+ self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
+ mixed_query_layer = self.query(hidden_states)
+
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+
+ context_layer = torch.nn.functional.scaled_dot_product_attention(
+ query_layer,
+ key_layer,
+ value_layer,
+ head_mask,
+ self.attention_probs_dropout_prob if self.training else 0.0,
+ is_causal=False,
+ scale=None,
+ )
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(new_context_layer_shape)
+
+ return context_layer, None
+
+
+class ViTHybridSelfOutput(nn.Module):
+ """
+ The residual connection is defined in ViTHybridLayer instead of here (as is the case with other models), due to the
+ layernorm applied before each block.
+ """
+
+ def __init__(self, config: ViTHybridConfig) -> None:
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+
+ return hidden_states
+
+
+class ViTHybridAttention(nn.Module):
+ def __init__(self, config: ViTHybridConfig) -> None:
+ super().__init__()
+ self.attention = ViTHybridSelfAttention(config)
+ self.output = ViTHybridSelfOutput(config)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads: Set[int]) -> None:
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
+ )
+
+ # Prune linear layers
+ self.attention.query = prune_linear_layer(self.attention.query, index)
+ self.attention.key = prune_linear_layer(self.attention.key, index)
+ self.attention.value = prune_linear_layer(self.attention.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
+ self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
+ self_outputs = self.attention(hidden_states, head_mask, output_attentions)
+
+ attention_output = self.output(self_outputs[0], hidden_states)
+
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+ return outputs
+
+
+class ViTHybridSdpaAttention(ViTHybridAttention):
+ def __init__(self, config: ViTHybridConfig) -> None:
+ super().__init__(config)
+ self.attention = ViTHybridSdpaSelfAttention(config)
+
+
+class ViTHybridIntermediate(nn.Module):
+ def __init__(self, config: ViTHybridConfig) -> None:
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+
+ return hidden_states
+
+
+class ViTHybridOutput(nn.Module):
+ def __init__(self, config: ViTHybridConfig) -> None:
+ super().__init__()
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+
+ hidden_states = hidden_states + input_tensor
+
+ return hidden_states
+
+
+VIT_HYBRID_ATTENTION_CLASSES = {
+ "eager": ViTHybridAttention,
+ "sdpa": ViTHybridSdpaAttention,
+}
+
+
+class ViTHybridLayer(nn.Module):
+ """This corresponds to the Block class in the timm implementation."""
+
+ def __init__(self, config: ViTHybridConfig) -> None:
+ super().__init__()
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.seq_len_dim = 1
+ self.attention = VIT_HYBRID_ATTENTION_CLASSES[config._attn_implementation](config)
+ self.intermediate = ViTHybridIntermediate(config)
+ self.output = ViTHybridOutput(config)
+ self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
+ self_attention_outputs = self.attention(
+ self.layernorm_before(hidden_states), # in ViTHybrid, layernorm is applied before self-attention
+ head_mask,
+ output_attentions=output_attentions,
+ )
+ attention_output = self_attention_outputs[0]
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
+
+ # first residual connection
+ # We assign to correct device for `accelerate`, check: https://github.com/huggingface/transformers/pull/20705/
+ hidden_states = attention_output + hidden_states.to(attention_output.device)
+
+ # in ViTHybrid, layernorm is also applied after self-attention
+ layer_output = self.layernorm_after(hidden_states)
+ layer_output = self.intermediate(layer_output)
+
+ # second residual connection is done here
+ layer_output = self.output(layer_output, hidden_states)
+
+ outputs = (layer_output,) + outputs
+
+ return outputs
+
+
+class ViTHybridEncoder(nn.Module):
+ def __init__(self, config: ViTHybridConfig) -> None:
+ super().__init__()
+ self.config = config
+ self.layer = nn.ModuleList([ViTHybridLayer(config) for _ in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ) -> Union[tuple, BaseModelOutput]:
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+
+ for i, layer_module in enumerate(self.layer):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ layer_module.__call__,
+ hidden_states,
+ layer_head_mask,
+ output_attentions,
+ )
+ else:
+ layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+
+class ViTHybridPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = ViTHybridConfig
+ base_model_prefix = "vit"
+ main_input_name = "pixel_values"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["ViTHybridEmbeddings", "ViTHybridLayer"]
+ _supports_sdpa = True
+
+ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
+ """Initialize the weights"""
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
+ # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
+ # `trunc_normal_cpu` not implemented in `half` issues
+ module.weight.data = nn.init.trunc_normal_(
+ module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
+ ).to(module.weight.dtype)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, ViTHybridEmbeddings):
+ module.position_embeddings.data = nn.init.trunc_normal_(
+ module.position_embeddings.data.to(torch.float32),
+ mean=0.0,
+ std=self.config.initializer_range,
+ ).to(module.position_embeddings.dtype)
+ module.cls_token.data = nn.init.trunc_normal_(
+ module.cls_token.data.to(torch.float32),
+ mean=0.0,
+ std=self.config.initializer_range,
+ ).to(module.cls_token.dtype)
+ module.mask_token.data.zero_()
+
+
+VIT_START_DOCSTRING = r"""
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+ behavior.
+
+ Parameters:
+ config ([`ViTHybridConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+VIT_INPUTS_DOCSTRING = r"""
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
+ [`ViTHybridImageProcessor.__call__`] for details.
+
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare ViT Hybrid Model transformer outputting raw hidden-states without any specific head on top.",
+ VIT_START_DOCSTRING,
+)
+class ViTHybridModel(ViTHybridPreTrainedModel):
+ def __init__(self, config: ViTHybridConfig, add_pooling_layer: bool = True, use_mask_token: bool = False):
+ super().__init__(config)
+ self.config = config
+
+ self.embeddings = ViTHybridEmbeddings(config, use_mask_token=use_mask_token)
+ self.encoder = ViTHybridEncoder(config)
+
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.pooler = ViTHybridPooler(config) if add_pooling_layer else None
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self) -> ViTHybridPatchEmbeddings:
+ return self.embeddings.patch_embeddings
+
+ def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ @add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=BaseModelOutputWithPooling,
+ config_class=_CONFIG_FOR_DOC,
+ modality="vision",
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
+ )
+ def forward(
+ self,
+ pixel_values: Optional[torch.Tensor] = None,
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ interpolate_pos_encoding: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
+ r"""
+ bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
+ Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ # TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?)
+ expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype
+ if pixel_values.dtype != expected_dtype:
+ pixel_values = pixel_values.to(expected_dtype)
+
+ embedding_output = self.embeddings(
+ pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
+ )
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = encoder_outputs[0]
+ sequence_output = self.layernorm(sequence_output)
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
+
+ if not return_dict:
+ head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
+ return head_outputs + encoder_outputs[1:]
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+class ViTHybridPooler(nn.Module):
+ def __init__(self, config: ViTHybridConfig):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.activation = nn.Tanh()
+
+ def forward(self, hidden_states):
+ # We "pool" the model by simply taking the hidden state corresponding
+ # to the first token.
+ first_token_tensor = hidden_states[:, 0]
+ pooled_output = self.dense(first_token_tensor)
+ pooled_output = self.activation(pooled_output)
+ return pooled_output
+
+
+@add_start_docstrings(
+ """
+ ViT Hybrid Model transformer with an image classification head on top (a linear layer on top of the final hidden
+ state of the [CLS] token) e.g. for ImageNet.
+ """,
+ VIT_START_DOCSTRING,
+)
+class ViTHybridForImageClassification(ViTHybridPreTrainedModel):
+ def __init__(self, config: ViTHybridConfig) -> None:
+ super().__init__(config)
+
+ self.num_labels = config.num_labels
+ self.vit = ViTHybridModel(config, add_pooling_layer=False)
+
+ # Classifier head
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_IMAGE_CLASS_CHECKPOINT,
+ output_type=ImageClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
+ )
+ def forward(
+ self,
+ pixel_values: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ interpolate_pos_encoding: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, ImageClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.vit(
+ pixel_values,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ logits = self.classifier(sequence_output[:, 0, :])
+
+ loss = None
+ if labels is not None:
+ # move labels to correct device to enable model parallelism
+ labels = labels.to(logits.device)
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return ImageClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+__all__ = ["ViTHybridForImageClassification", "ViTHybridModel", "ViTHybridPreTrainedModel"]
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/xlm_prophetnet/__init__.py b/docs/transformers/build/lib/transformers/models/deprecated/xlm_prophetnet/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c13c67012fa157a94bae21734a1f59b26c78588d
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/xlm_prophetnet/__init__.py
@@ -0,0 +1,28 @@
+# Copyright 2020 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ....utils import _LazyModule
+from ....utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_xlm_prophetnet import *
+ from .modeling_xlm_prophetnet import *
+ from .tokenization_xlm_prophetnet import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/xlm_prophetnet/configuration_xlm_prophetnet.py b/docs/transformers/build/lib/transformers/models/deprecated/xlm_prophetnet/configuration_xlm_prophetnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d7751d9541ec1a3b39799915d742955d863cd24
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/xlm_prophetnet/configuration_xlm_prophetnet.py
@@ -0,0 +1,181 @@
+# coding=utf-8
+# Copyright 2020 The Microsoft Authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""XLM-ProphetNet model configuration"""
+
+from typing import Callable, Optional, Union
+
+from ....configuration_utils import PretrainedConfig
+from ....utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class XLMProphetNetConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`XLMProphetNetModel`]. It is used to instantiate a
+ XLMProphetNet model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the XLMProphetNet
+ [microsoft/xprophetnet-large-wiki100-cased](https://huggingface.co/microsoft/xprophetnet-large-wiki100-cased)
+ architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ activation_dropout (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for activations inside the fully connected layer.
+ activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
+ vocab_size (`int`, *optional*, defaults to 30522):
+ Vocabulary size of the ProphetNET model. Defines the number of different tokens that can be represented by
+ the `inputs_ids` passed when calling [`XLMProphetNetModel`].
+ hidden_size (`int`, *optional*, defaults to 1024):
+ Dimensionality of the layers and the pooler layer.
+ encoder_ffn_dim (`int`, *optional*, defaults to 4096):
+ Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
+ num_encoder_layers (`int`, *optional*, defaults to 12):
+ Number of encoder layers.
+ num_encoder_attention_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ decoder_ffn_dim (`int`, *optional*, defaults to 4096):
+ Dimensionality of the `intermediate` (often named feed-forward) layer in decoder.
+ num_decoder_layers (`int`, *optional*, defaults to 12):
+ Number of decoder layers.
+ num_decoder_attention_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ attention_dropout (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for the attention probabilities.
+ dropout (`float`, *optional*, defaults to 0.1):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ max_position_embeddings (`int`, *optional*, defaults to 512):
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
+ just in case (e.g., 512 or 1024 or 2048).
+ init_std (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ add_cross_attention (`bool`, *optional*, defaults to `True`):
+ Whether cross-attention layers should be added to the model.
+ is_encoder_decoder (`bool`, *optional*, defaults to `True`):
+ Whether this is an encoder/decoder model.
+ pad_token_id (`int`, *optional*, defaults to 1)
+ Padding token id.
+ bos_token_id (`int`, *optional*, defaults to 0)
+ Beginning of stream token id.
+ eos_token_id (`int`, *optional*, defaults to 2)
+ End of stream token id.
+ ngram (`int`, *optional*, defaults to 2)
+ Number of future tokens to predict. Set to 1 to be same as traditional Language model to predict next first
+ token.
+ num_buckets (`int`, *optional*, defaults to 32)
+ The number of buckets to use for each attention layer. This is for relative position calculation. See the
+ [T5 paper](see https://arxiv.org/abs/1910.10683) for more details.
+ relative_max_distance (`int`, *optional*, defaults to 128)
+ Relative distances greater than this number will be put into the last same bucket. This is for relative
+ position calculation. See the [T5 paper](see https://arxiv.org/abs/1910.10683) for more details.
+ disable_ngram_loss (`bool`, *optional*, defaults to `False`):
+ Whether be trained predicting only the next first token.
+ eps (`float`, *optional*, defaults to 0.0):
+ Controls the `epsilon` parameter value for label smoothing in the loss calculation. If set to 0, no label
+ smoothing is performed.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models).
+ """
+
+ model_type = "xlm-prophetnet"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ attribute_map = {
+ "num_attention_heads": "num_encoder_attention_heads",
+ }
+
+ def __init__(
+ self,
+ activation_dropout: Optional[float] = 0.1,
+ activation_function: Optional[Union[str, Callable]] = "gelu",
+ vocab_size: Optional[int] = 30522,
+ hidden_size: Optional[int] = 1024,
+ encoder_ffn_dim: Optional[int] = 4096,
+ num_encoder_layers: Optional[int] = 12,
+ num_encoder_attention_heads: Optional[int] = 16,
+ decoder_ffn_dim: Optional[int] = 4096,
+ num_decoder_layers: Optional[int] = 12,
+ num_decoder_attention_heads: Optional[int] = 16,
+ attention_dropout: Optional[float] = 0.1,
+ dropout: Optional[float] = 0.1,
+ max_position_embeddings: Optional[int] = 512,
+ init_std: Optional[float] = 0.02,
+ is_encoder_decoder: Optional[bool] = True,
+ add_cross_attention: Optional[bool] = True,
+ decoder_start_token_id: Optional[int] = 0,
+ ngram: Optional[int] = 2,
+ num_buckets: Optional[int] = 32,
+ relative_max_distance: Optional[int] = 128,
+ disable_ngram_loss: Optional[bool] = False,
+ eps: Optional[float] = 0.0,
+ use_cache: Optional[bool] = True,
+ pad_token_id: Optional[int] = 0,
+ bos_token_id: Optional[int] = 1,
+ eos_token_id: Optional[int] = 2,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.encoder_ffn_dim = encoder_ffn_dim
+ self.num_encoder_layers = num_encoder_layers
+ self.num_encoder_attention_heads = num_encoder_attention_heads
+ self.decoder_ffn_dim = decoder_ffn_dim
+ self.num_decoder_layers = num_decoder_layers
+ self.num_decoder_attention_heads = num_decoder_attention_heads
+ self.max_position_embeddings = max_position_embeddings
+ self.init_std = init_std # Normal(0, this parameter)
+ self.activation_function = activation_function
+
+ # parameters for xlmprophetnet
+ self.ngram = ngram
+ self.num_buckets = num_buckets
+ self.relative_max_distance = relative_max_distance
+ self.disable_ngram_loss = disable_ngram_loss
+ self.eps = eps
+
+ # 3 Types of Dropout
+ self.attention_dropout = attention_dropout
+ self.activation_dropout = activation_dropout
+ self.dropout = dropout
+
+ self.use_cache = use_cache
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ is_encoder_decoder=is_encoder_decoder,
+ add_cross_attention=add_cross_attention,
+ decoder_start_token_id=decoder_start_token_id,
+ **kwargs,
+ )
+
+ @property
+ def num_hidden_layers(self) -> int:
+ return self.num_encoder_layers + self.num_decoder_layers
+
+ @num_hidden_layers.setter
+ def num_hidden_layers(self, value):
+ raise NotImplementedError(
+ "This model does not support the setting of `num_hidden_layers`. Please set `num_encoder_layers` and"
+ " `num_decoder_layers`."
+ )
+
+
+__all__ = ["XLMProphetNetConfig"]
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/xlm_prophetnet/modeling_xlm_prophetnet.py b/docs/transformers/build/lib/transformers/models/deprecated/xlm_prophetnet/modeling_xlm_prophetnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..17bc9ffada60028e6531c90b3939cf22f846cdf3
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/xlm_prophetnet/modeling_xlm_prophetnet.py
@@ -0,0 +1,2346 @@
+# coding=utf-8
+# Copyright 2020 The Microsoft Authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch XLM-ProphetNet model."""
+
+import copy
+import math
+import warnings
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import Tensor, nn
+from torch.nn import LayerNorm
+
+from ....activations import ACT2FN
+from ....modeling_outputs import BaseModelOutput
+from ....modeling_utils import PreTrainedModel
+from ....utils import (
+ ModelOutput,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from .configuration_xlm_prophetnet import XLMProphetNetConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+_CONFIG_FOR_DOC = "XLMProphetNetConfig"
+
+
+XLM_PROPHETNET_START_DOCSTRING = r"""
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ Original ProphetNet code can be found [here](https://github.com/microsoft/ProphetNet). Checkpoints were converted
+ from original Fairseq checkpoints. For more information on the checkpoint conversion, please take a look at the
+ file `convert_prophetnet_original_pytorch_checkpoint_to_pytorch.py`.
+
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
+ it as a regular PyTorch Module and refer to the PyTorch documentation for all matters related to general usage and
+ behavior.
+
+ Parameters:
+ config ([`XLMProphetNetConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+XLM_PROPHETNET_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Indices of decoder input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
+
+ XLMProphetNet uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If
+ `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
+ `past_key_values`).
+
+ decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
+ be used by default.
+ head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
+ Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
+ `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
+ hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+XLM_PROPHETNET_STANDALONE_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+def softmax(hidden_state, dim, onnx_trace=False):
+ if onnx_trace:
+ return nn.functional.softmax(hidden_state.float(), dim=dim)
+ else:
+ return nn.functional.softmax(hidden_state, dim=dim, dtype=torch.float32)
+
+
+def ngram_attention_bias(sequence_length, ngram, device, dtype):
+ """
+ This function computes the bias for the predict stream
+ """
+ left_block = (
+ torch.ones((ngram, sequence_length, sequence_length), device=device, dtype=dtype) * torch.finfo(dtype).min
+ )
+ right_block = left_block.detach().clone()
+ # create bias
+ for stream_idx in range(ngram):
+ right_block[stream_idx].fill_diagonal_(0, wrap=False)
+ left_block[stream_idx].triu_(-stream_idx + 1)
+
+ left_block[:, :, 0] = 0
+ return torch.cat([left_block, right_block], dim=2)
+
+
+def compute_relative_buckets(num_buckets, max_distance, relative_positions, is_bidirectional=False):
+ """
+ This function computes individual parts of the relative position buckets. For more detail, see paper.
+ """
+ inv_relative_positions = -relative_positions
+ rel_positions_bucket = 0
+
+ if is_bidirectional:
+ num_buckets = num_buckets // 2
+ rel_positions_bucket = (
+ rel_positions_bucket
+ + torch.lt(inv_relative_positions, torch.zeros_like(inv_relative_positions)).int() * num_buckets
+ )
+ inv_relative_positions = torch.abs(inv_relative_positions)
+ else:
+ inv_relative_positions = torch.max(inv_relative_positions, torch.zeros_like(inv_relative_positions))
+
+ max_exact = num_buckets // 2
+ is_small = torch.lt(inv_relative_positions, max_exact)
+ val_if_large = max_exact + torch.log(inv_relative_positions.float() / max_exact) / math.log(
+ max_distance / max_exact
+ ) * (num_buckets - max_exact)
+ val_if_large = torch.min(val_if_large, torch.ones_like(val_if_large) * (num_buckets - 1)).int()
+ rel_positions_bucket = rel_positions_bucket + torch.where(is_small, inv_relative_positions.int(), val_if_large)
+ return rel_positions_bucket
+
+
+def compute_all_stream_relative_buckets(num_buckets, max_distance, position_ids):
+ """
+ This function computes both main and predict relative position buckets. For more detail, see paper.
+ """
+ # main stream
+ main_stream_relative_positions = position_ids.unsqueeze(1).repeat(1, position_ids.size(-1), 1)
+ main_stream_relative_positions = main_stream_relative_positions - position_ids.unsqueeze(-1)
+
+ # predicting stream
+ predicting_stream_relative_positions = torch.cat((position_ids - 1, position_ids), dim=-1).unsqueeze(1)
+ predicting_stream_relative_positions = predicting_stream_relative_positions.repeat(1, position_ids.size(-1), 1)
+ predicting_stream_relative_positions = predicting_stream_relative_positions - position_ids.unsqueeze(-1)
+
+ # get both position buckets
+ main_relative_position_buckets = compute_relative_buckets(
+ num_buckets, max_distance, main_stream_relative_positions, is_bidirectional=False
+ )
+ predict_relative_position_buckets = compute_relative_buckets(
+ num_buckets, max_distance, predicting_stream_relative_positions, is_bidirectional=False
+ )
+ return main_relative_position_buckets, predict_relative_position_buckets
+
+
+@dataclass
+class XLMProphetNetSeq2SeqLMOutput(ModelOutput):
+ """
+ Base class for sequence-to-sequence language models outputs.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss.
+ logits (`torch.FloatTensor` of shape `(batch_size, decoder_sequence_length, config.vocab_size)`):
+ Prediction scores of the main stream language modeling head (scores for each vocabulary token before
+ SoftMax).
+ logits_ngram (`torch.FloatTensor` of shape `(batch_size, ngram * decoder_sequence_length, config.vocab_size)`):
+ Prediction scores of the predict stream language modeling head (scores for each vocabulary token before
+ SoftMax).
+ past_key_values (`List[torch.FloatTensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ List of `torch.FloatTensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size,
+ num_attn_heads, decoder_sequence_length, embed_size_per_head)`).
+
+ Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
+ used (see `past_key_values` input) to speed up sequential decoding.
+ decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
+ shape `(batch_size, decoder_sequence_length, hidden_size)`.
+
+ Hidden-states of main stream of the decoder at the output of each layer plus the initial embedding outputs.
+ decoder_ngram_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
+ shape `(batch_size, ngram * decoder_sequence_length, hidden_size)`.
+
+ Hidden-states of the predict stream of the decoder at the output of each layer plus the initial embedding
+ outputs.
+ decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads,
+ decoder_sequence_length, decoder_sequence_length)`.
+
+ Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
+ self-attention heads.
+ decoder_ngram_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads,
+ decoder_sequence_length, decoder_sequence_length)`.
+
+ Attentions weights of the predict stream of the decoder, after the attention softmax, used to compute the
+ weighted average in the self-attention heads.
+ cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads,
+ encoder_sequence_length, decoder_sequence_length)`.
+
+ Attentions weights of the cross-attention layer of the decoder, after the attention softmax, used to
+ compute the weighted average in the
+ encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder of the model.
+ encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
+ shape `(batch_size, encoder_sequence_length, hidden_size)`.
+
+ Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
+ encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads,
+ encoder_sequence_length, encoder_sequence_length)`. Attentions weights of the encoder, after the attention
+ softmax, used to compute the weighted average in the self-attention heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ logits_ngram: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[Tuple[torch.FloatTensor]] = None
+ decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ decoder_ngram_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
+ decoder_ngram_attentions: Optional[Tuple[torch.FloatTensor]] = None
+ cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
+ encoder_last_hidden_state: Optional[torch.FloatTensor] = None
+ encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
+
+ @property
+ def decoder_cross_attentions(self):
+ warnings.warn(
+ "`decoder_cross_attentions` is deprecated and will be removed soon. Please use `cross_attentions`"
+ " instead.",
+ FutureWarning,
+ )
+ return self.cross_attentions
+
+
+@dataclass
+class XLMProphetNetSeq2SeqModelOutput(ModelOutput):
+ """
+ Base class for model encoder's outputs that also contains : pre-computed hidden states that can speed up sequential
+ decoding.
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, decoder_sequence_length, hidden_size)`):
+ Sequence of main stream hidden-states at the output of the last layer of the decoder of the model.
+
+ If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
+ hidden_size)` is output.
+ last_hidden_state_ngram (`torch.FloatTensor` of shape `(batch_size,ngram * decoder_sequence_length, config.vocab_size)`, *optional*):
+ Sequence of predict stream hidden-states at the output of the last layer of the decoder of the model.
+ past_key_values (`List[torch.FloatTensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ List of `torch.FloatTensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size,
+ num_attn_heads, decoder_sequence_length, embed_size_per_head)`).
+
+ Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
+ used (see `past_key_values` input) to speed up sequential decoding.
+ decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
+ shape `(batch_size, decoder_sequence_length, hidden_size)`.
+
+ Hidden-states of main stream of the decoder at the output of each layer plus the initial embedding outputs.
+ decoder_ngram_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
+ shape `(batch_size, ngram * decoder_sequence_length, hidden_size)`.
+
+ Hidden-states of the predict stream of the decoder at the output of each layer plus the initial embedding
+ outputs.
+ decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads,
+ decoder_sequence_length, decoder_sequence_length)`.
+
+ Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
+ self-attention heads.
+ decoder_ngram_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads,
+ decoder_sequence_length, decoder_sequence_length)`.
+
+ Attentions weights of the predict stream of the decoder, after the attention softmax, used to compute the
+ weighted average in the
+ cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads,
+ encoder_sequence_length, decoder_sequence_length)`.
+
+ Attentions weights of the cross-attention layer of the decoder, after the attention softmax, used to
+ compute the weighted average in the
+ encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder of the model.
+ encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
+ shape `(batch_size, encoder_sequence_length, hidden_size)`.
+
+ Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
+ encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads,
+ encoder_sequence_length, encoder_sequence_length)`.
+
+ Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
+ self-attention heads.
+ """
+
+ last_hidden_state: torch.FloatTensor
+ last_hidden_state_ngram: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[Tuple[torch.FloatTensor]] = None
+ decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ decoder_ngram_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
+ decoder_ngram_attentions: Optional[Tuple[torch.FloatTensor]] = None
+ cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
+ encoder_last_hidden_state: Optional[torch.FloatTensor] = None
+ encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
+
+ @property
+ def decoder_cross_attentions(self):
+ warnings.warn(
+ "`decoder_cross_attentions` is deprecated and will be removed soon. Please use `cross_attentions`"
+ " instead.",
+ FutureWarning,
+ )
+ return self.cross_attentions
+
+
+@dataclass
+class XLMProphetNetDecoderModelOutput(ModelOutput):
+ """
+ Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, decoder_sequence_length, hidden_size)`):
+ Sequence of main stream hidden-states at the output of the last layer of the decoder of the model.
+
+ If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
+ hidden_size)` is output.
+ last_hidden_state_ngram (`torch.FloatTensor` of shape `(batch_size, ngram * decoder_sequence_length, config.vocab_size)`):
+ Sequence of predict stream hidden-states at the output of the last layer of the decoder of the model.
+ past_key_values (`List[torch.FloatTensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ List of `torch.FloatTensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size,
+ num_attn_heads, decoder_sequence_length, embed_size_per_head)`).
+
+ Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
+ used (see `past_key_values` input) to speed up sequential decoding.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
+ shape `(batch_size, decoder_sequence_length, hidden_size)`.
+
+ Hidden-states of main stream of the decoder at the output of each layer plus the initial embedding outputs.
+ ngram_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
+ shape `(batch_size, ngram * decoder_sequence_length, hidden_size)`.
+
+ Hidden-states of the predict stream of the decoder at the output of each layer plus the initial embedding
+ outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads,
+ decoder_sequence_length, decoder_sequence_length)`.
+
+ Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
+ self-attention heads.
+ ngram_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads,
+ decoder_sequence_length, decoder_sequence_length)`.
+
+ Attentions weights of the predict stream of the decoder, after the attention softmax, used to compute the
+ weighted average in the
+ cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads,
+ encoder_sequence_length, decoder_sequence_length)`.
+
+ Attentions weights of the cross-attention layer of the decoder, after the attention softmax, used to
+ compute the weighted average in the
+ """
+
+ last_hidden_state: torch.FloatTensor
+ last_hidden_state_ngram: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[Tuple[torch.FloatTensor]] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ hidden_states_ngram: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+ ngram_attentions: Optional[Tuple[torch.FloatTensor]] = None
+ cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
+
+
+@dataclass
+class XLMProphetNetDecoderLMOutput(ModelOutput):
+ """
+ Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss.
+ logits (`torch.FloatTensor` of shape `(batch_size, decoder_sequence_length, config.vocab_size)`):
+ Prediction scores of the main stream language modeling head (scores for each vocabulary token before
+ SoftMax).
+ logits_ngram (`torch.FloatTensor` of shape `(batch_size, ngram * decoder_sequence_length, config.vocab_size)`):
+ Prediction scores of the predict stream language modeling head (scores for each vocabulary token before
+ SoftMax).
+ past_key_values (`List[torch.FloatTensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ List of `torch.FloatTensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size,
+ num_attn_heads, decoder_sequence_length, embed_size_per_head)`).
+
+ Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
+ used (see `past_key_values` input) to speed up sequential decoding.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
+ shape `(batch_size, decoder_sequence_length, hidden_size)`.
+
+ Hidden-states of main stream of the decoder at the output of each layer plus the initial embedding outputs.
+ ngram_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
+ shape `(batch_size, ngram * decoder_sequence_length, hidden_size)`.
+
+ Hidden-states of the predict stream of the decoder at the output of each layer plus the initial embedding
+ outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads,
+ decoder_sequence_length, decoder_sequence_length)`.
+
+ Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
+ self-attention heads.
+ ngram_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads,
+ decoder_sequence_length, decoder_sequence_length)`.
+
+ Attentions weights of the predict stream of the decoder, after the attention softmax, used to compute the
+ weighted average in the
+ cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads,
+ encoder_sequence_length, decoder_sequence_length)`.
+
+ Attentions weights of the cross-attention layer of the decoder, after the attention softmax, used to
+ compute the weighted average in the
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ logits_ngram: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[Tuple[torch.FloatTensor]] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ hidden_states_ngram: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+ ngram_attentions: Optional[Tuple[torch.FloatTensor]] = None
+ cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
+
+
+class XLMProphetNetPreTrainedModel(PreTrainedModel):
+ config_class = XLMProphetNetConfig
+ base_model_prefix = "prophetnet"
+ supports_gradient_checkpointing = True
+
+ def _init_weights(self, module):
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=self.config.init_std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=self.config.init_std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+ def _shift_right(self, input_ids):
+ decoder_start_token_id = self.config.decoder_start_token_id
+ pad_token_id = self.config.pad_token_id
+
+ assert decoder_start_token_id is not None, (
+ "self.model.config.decoder_start_token_id has to be defined. In XLMProphetNet it is usually set to the"
+ " pad_token_id. See XLMProphetNet docs for more information"
+ )
+
+ # shift inputs to the right
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
+ shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
+ shifted_input_ids[..., 0] = decoder_start_token_id
+
+ assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
+ # replace possible -100 values in labels by `pad_token_id`
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
+
+ assert torch.all(shifted_input_ids >= 0).item(), "Verify that `shifted_input_ids` has only positive values"
+
+ return shifted_input_ids
+
+
+class XLMProphetNetPositionalEmbeddings(nn.Embedding):
+ """
+ This module learns positional embeddings up to a fixed maximum size. Padding ids are ignored by either offsetting
+ based on padding_idx or by setting padding_idx to None and ensuring that the appropriate position ids are passed to
+ the forward function.
+ """
+
+ def __init__(self, config: XLMProphetNetConfig) -> None:
+ self.max_length = config.max_position_embeddings
+ super().__init__(config.max_position_embeddings, config.hidden_size, config.pad_token_id)
+
+ def forward(self, inputs_shape, device, attention_mask=None, past_key_values=None, position_ids=None):
+ assert (position_ids is None) or (self.padding_idx is None), (
+ "If position_ids is pre-computed then padding_idx should not be set."
+ )
+
+ if position_ids is None:
+ if past_key_values is not None:
+ # position_ids is the same for every token when decoding a single step
+ # Without the int() cast, it doesn't work in some cases when exporting to ONNX
+ prev_num_input_ids = past_key_values[0][0].shape[2]
+ num_input_ids = inputs_shape[1] + prev_num_input_ids
+ position_ids = torch.ones((1, 1), dtype=torch.long, device=device) * (
+ int(self.padding_idx + num_input_ids)
+ )
+ else:
+ if attention_mask is None:
+ attention_mask = torch.ones(inputs_shape, dtype=torch.long, device=device)
+
+ # retrieve position_ids from input_ids / attention_mask
+ position_ids = (
+ torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask
+ ).long() + self.padding_idx
+
+ # make sure position_ids are not bigger then max_length
+ position_ids = position_ids.clamp(0, self.max_length - 1)
+
+ return super().forward(position_ids), position_ids
+
+ def _forward(self, position_ids):
+ return super().forward(position_ids)
+
+
+class XLMProphetNetAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(
+ self,
+ config: XLMProphetNetConfig,
+ num_attn_heads: int,
+ ):
+ super().__init__()
+ hidden_size = config.hidden_size
+
+ self.attention_dropout = config.attention_dropout
+ self.dropout = config.dropout
+ self.num_attn_heads = num_attn_heads
+ self.head_dim = hidden_size // num_attn_heads
+
+ assert self.head_dim * num_attn_heads == hidden_size, (
+ "`config.hidden_size` must be divisible by `config.num_encoder_attention_heads` and"
+ " `config.num_decoder_attention_heads`"
+ )
+
+ self.key_proj = nn.Linear(hidden_size, hidden_size)
+ self.value_proj = nn.Linear(hidden_size, hidden_size)
+ self.query_proj = nn.Linear(hidden_size, hidden_size)
+
+ self.out_proj = nn.Linear(hidden_size, hidden_size)
+
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return tensor.view(bsz, seq_len, self.num_attn_heads, self.head_dim).transpose(1, 2).contiguous()
+
+ def forward(
+ self,
+ hidden_states,
+ key_value_states: Optional[Tensor] = None,
+ attention_mask: Optional[Tensor] = None,
+ layer_head_mask: Optional[Tensor] = None,
+ past_key_value: Optional[Tuple[Tensor]] = None,
+ output_attentions: bool = False,
+ ) -> Tuple[Tensor, Optional[Tensor]]:
+ batch_size, tgt_len, hidden_size = hidden_states.size()
+
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+ assert list(hidden_states.size()) == [
+ batch_size,
+ tgt_len,
+ hidden_size,
+ ], f"Size of hidden states should be {batch_size, tgt_len, hidden_size}, but is {hidden_states.size()}"
+
+ # previous time steps are cached - no need to recompute key and value if they are static
+ query_states = self.query_proj(hidden_states) / (self.head_dim**0.5)
+
+ if is_cross_attention and past_key_value is not None:
+ # reuse k,v, cross_attentions
+ key_states = past_key_value[0]
+ value_states = past_key_value[1]
+ elif is_cross_attention:
+ # cross_attentions
+ key_states = self._shape(self.key_proj(key_value_states), -1, batch_size)
+ value_states = self._shape(self.value_proj(key_value_states), -1, batch_size)
+ else:
+ # self_attention
+ key_states = self._shape(self.key_proj(hidden_states), -1, batch_size)
+ value_states = self._shape(self.value_proj(hidden_states), -1, batch_size)
+
+ if is_cross_attention:
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_states, value_states)
+
+ # project states into the correct shape
+ proj_shape = (batch_size, self.num_attn_heads, -1, self.head_dim)
+ query_states = self._shape(query_states, tgt_len, batch_size).view(*proj_shape)
+ key_states = key_states.view(*proj_shape)
+ value_states = value_states.view(*proj_shape)
+ src_len = key_states.size(2)
+ attn_weights = torch.einsum("bsij,bsjk->bsik", query_states, key_states.transpose(2, 3))
+ expected_shape = (batch_size, self.num_attn_heads, tgt_len, src_len)
+ if attn_weights.size() != expected_shape:
+ raise ValueError(f"Attention weights should have size {expected_shape}, but is {attn_weights.size()}")
+
+ # This is part of a workaround to get around fork/join parallelism not supporting Optional types.
+ if attention_mask is not None and attention_mask.dim() == 0:
+ attention_mask = None
+
+ expected_shape = (batch_size, self.num_attn_heads, 1, src_len)
+ if attention_mask is not None and attention_mask.size() != expected_shape:
+ raise ValueError(f"Attention mask should have size {expected_shape}, but is {attention_mask.size()}")
+ if attention_mask is not None: # don't attend to padding symbols
+ attn_weights = attn_weights + attention_mask
+ if output_attentions:
+ attn_weights_reshaped = attn_weights
+ else:
+ attn_weights_reshaped = None
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ if layer_head_mask is not None:
+ assert layer_head_mask.size() == (self.num_attn_heads,), (
+ f"Head mask for a single layer should be of size {(self.num_attn_heads,)}, but is"
+ f" {layer_head_mask.size()}"
+ )
+ attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(
+ batch_size, self.num_attn_heads, tgt_len, src_len
+ )
+
+ # apply head_mask also on attn_weights_reshaped which is used for n-gram attention inside the model
+ attn_weights_reshaped = layer_head_mask.view(1, -1, 1, 1) * attn_weights_reshaped
+
+ attn_probs = nn.functional.dropout(
+ attn_weights,
+ p=self.attention_dropout,
+ training=self.training,
+ )
+ attn_output = torch.einsum("bsij,bsjk->bsik", attn_probs, value_states)
+ expected_shape = (batch_size, self.num_attn_heads, tgt_len, self.head_dim)
+ if attn_output.size() != expected_shape:
+ raise ValueError(f"`attn_output` should have shape {expected_shape}, but is of shape {attn_output.size()}")
+
+ attn_output = attn_output.transpose(1, 2).reshape(batch_size, tgt_len, hidden_size)
+ attn_output = self.out_proj(attn_output)
+
+ attn_output = nn.functional.dropout(attn_output, p=self.dropout, training=self.training)
+ return attn_output, attn_weights_reshaped, past_key_value
+
+
+class XLMProphetNetFeedForward(nn.Module):
+ """
+ This is the residual two feed-forward layer block based on the original Transformer implementation.
+ """
+
+ def __init__(self, config: XLMProphetNetConfig, ffn_dim: int):
+ super().__init__()
+ self.activation_fn = ACT2FN[config.activation_function]
+ self.intermediate = nn.Linear(config.hidden_size, ffn_dim)
+ self.output = nn.Linear(ffn_dim, config.hidden_size)
+ self.activation_dropout = config.activation_dropout
+ self.dropout = config.dropout
+
+ def forward(self, hidden_states):
+ hidden_states = self.intermediate(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
+ hidden_states = self.output(hidden_states)
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ return hidden_states
+
+
+class XLMProphetNetNgramSelfAttention(nn.Module):
+ def __init__(self, config: XLMProphetNetConfig):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+
+ self.num_buckets = config.num_buckets
+ self.relative_max_distance = config.relative_max_distance
+ self.num_attn_heads = config.num_decoder_attention_heads
+ self.dropout = config.dropout
+ self.attention_dropout = config.attention_dropout
+ self.head_dim = config.hidden_size // self.num_attn_heads
+ self.ngram = config.ngram
+
+ assert self.head_dim * self.num_attn_heads == config.hidden_size, (
+ "config.hidden_size must be divisible by num_attn_heads"
+ )
+ # key, value, query projection
+ self.key_proj = nn.Linear(config.hidden_size, config.hidden_size)
+ self.value_proj = nn.Linear(config.hidden_size, config.hidden_size)
+ self.query_proj = nn.Linear(config.hidden_size, config.hidden_size)
+
+ # out projection
+ self.out_proj = nn.Linear(config.hidden_size, config.hidden_size)
+
+ # rel position embeddings
+ self.relative_pos_embeddings = nn.Linear(config.hidden_size, self.num_buckets * self.num_attn_heads)
+
+ # for onnx runtime
+ self.onnx_trace = False
+
+ def _shape(self, tensor, seq_len, batch_size):
+ return tensor.view(batch_size, seq_len, self.num_attn_heads, self.head_dim).transpose(1, 2).contiguous()
+
+ def prepare_for_onnx_export_(self):
+ self.onnx_trace = True
+
+ def forward(
+ self,
+ hidden_states,
+ past_key_value: Optional[Tuple[Tensor]] = None,
+ attention_mask=None,
+ layer_head_mask=None,
+ extended_predict_attention_mask=None,
+ main_relative_position_buckets=None,
+ predict_relative_position_buckets=None,
+ position_ids=None,
+ ):
+ batch_size, ngram_sequence_length, hidden_size = hidden_states.size()
+ assert list(hidden_states.size()) == [batch_size, ngram_sequence_length, hidden_size], (
+ f"`hidden_states` should be of shape {batch_size, ngram_sequence_length, hidden_size}, but is of shape"
+ f" {hidden_states.shape}"
+ )
+
+ # project
+ query_states = self.query_proj(hidden_states)
+ key_states = self.key_proj(hidden_states)
+ value_states = self.value_proj(hidden_states)
+
+ # normalize
+ query_states = query_states / (self.head_dim**0.5)
+
+ # reshape
+ query_states = self._shape(query_states, ngram_sequence_length, batch_size)
+ key_states = self._shape(key_states, -1, batch_size)
+ value_states = self._shape(value_states, -1, batch_size)
+ proj_shape = (batch_size, self.num_attn_heads, -1, self.head_dim)
+
+ query_states = query_states.view(*proj_shape)
+ key_states = key_states.view(*proj_shape)
+ value_states = value_states.view(*proj_shape)
+
+ # chunk into main stream and predict stream
+ hidden_states_list = hidden_states.chunk(1 + self.ngram, dim=1)
+ query_states_list = query_states.chunk(1 + self.ngram, dim=2)
+ key_states_list = key_states.chunk(1 + self.ngram, dim=2)
+ value_states_list = value_states.chunk(1 + self.ngram, dim=2)
+
+ main_hidden_states, hidden_states_predict_list = hidden_states_list[0], hidden_states_list[1:]
+ main_query_states, predict_query_states_list = query_states_list[0], query_states_list[1:]
+ main_key_states, predict_key_states_list = key_states_list[0], key_states_list[1:]
+ main_value_states, predict_value_states_list = value_states_list[0], value_states_list[1:]
+
+ # saved states are stored with shape (batch_size, num_attn_heads, seq_len, head_dim)
+ if past_key_value is not None:
+ prev_main_key_states = past_key_value[0]
+ main_key_states = torch.cat((prev_main_key_states, main_key_states), dim=2)
+ prev_main_value_states = past_key_value[1]
+ main_value_states = torch.cat((prev_main_value_states, main_value_states), dim=2)
+
+ # Update cache
+ past_key_value = (main_key_states, main_value_states)
+
+ # get seq_length of main stream only
+ sequence_length = ngram_sequence_length // (1 + self.ngram)
+
+ # MAIN-STREAM
+ # main attn weights
+ # [batch_size, number_heads, sequence_length, head_dimesion]
+ # x [batch_size, number_heads, head_dimesion, sequence_length]
+ # -> [batch_size, number_heads, sequence_length, sequence_length]
+ main_attn_weights = torch.einsum("bntc,bncs->bnts", main_query_states, main_key_states.transpose(2, 3))
+
+ # retrieve relative position embeddings for each layer -> see paper for more details
+ main_relative_pos_embeddings = self.get_main_relative_pos_embeddings(
+ main_hidden_states, main_attn_weights, position_ids, main_relative_position_buckets
+ )
+
+ main_attn_weights = main_attn_weights + main_relative_pos_embeddings
+
+ if attention_mask is not None:
+ main_attn_weights = main_attn_weights + attention_mask
+
+ main_attn_probs = softmax(
+ main_attn_weights,
+ dim=-1,
+ onnx_trace=self.onnx_trace,
+ ).type_as(main_attn_weights)
+
+ if layer_head_mask is not None:
+ assert layer_head_mask.size() == (self.num_attn_heads,), (
+ f"Head mask for a single layer should be of size {(self.num_attn_heads,)}, but is"
+ f" {layer_head_mask.size()}"
+ )
+ main_attn_probs = layer_head_mask.view(1, -1, 1, 1) * main_attn_probs.view(
+ batch_size, self.num_attn_heads, -1, sequence_length
+ )
+
+ main_attn_probs = nn.functional.dropout(main_attn_probs, p=self.attention_dropout, training=self.training)
+ # project to attn_output
+ # [batch_size, number_heads, sequence_length, sequence_length]
+ # x [batch_size, number_heads, sequence_length, head_dimesion]
+ # -> [batch_size, number_heads, sequence_length, head_dimesion]
+ main_attn_output = torch.einsum("bntc,bncs->bnts", main_attn_probs, main_value_states)
+ # reshape so that num_heads dim is merged into last `head_dim` axis
+ main_attn_output = main_attn_output.transpose(1, 2).reshape(batch_size, 1, sequence_length, hidden_size)
+ main_attn_output = self.out_proj(main_attn_output)
+
+ # PREDICT-STREAM
+ # [batch_size, ngram, number_heads, sequence_length, head_dimesion]
+ predict_query_states = torch.stack(predict_query_states_list, 1).view(
+ batch_size, self.ngram, self.num_attn_heads, sequence_length, self.head_dim
+ )
+
+ # [batch_size, ngram, number_heads, 2*sequence_length, head_dimesion]
+ predict_key_states = torch.stack([torch.cat([main_key_states, key], 2) for key in predict_key_states_list], 1)
+
+ # [batch_size, sequence_length, ngram, hidden_size]
+ predict_hidden_states = torch.stack(hidden_states_predict_list, dim=2)
+
+ # [batch_size, number_heads, ngram, 2*sequence_length, head_dimesion]
+ predict_value_states = torch.cat(
+ [torch.cat([main_value_states, v_p], 2).unsqueeze(2) for v_p in predict_value_states_list], 2
+ )
+
+ # [batch_size, ngram, number_heads, sequence_length, head_dimesion]
+ # x [batch_size, ngram, number_heads, 2*sequence_length, head_dimesion]
+ # -> [batch_size, ngram, number_heads, sequence_length, 2*sequence_length]
+ predict_attn_weights = torch.einsum("bnhtc,bnhsc->bnhts", (predict_query_states, predict_key_states))
+
+ # retrieve relative position embeddings for each layer -> see paper for more details
+ # [batch_size, ngram, number_heads, sequence_length, predict_relative_pos_embeddings]
+ predict_relative_pos_embeddings = self.get_predict_relative_pos_embeddings(
+ predict_hidden_states, predict_attn_weights, position_ids, predict_relative_position_buckets
+ )
+
+ # [batch_size, ngram, number_heads, sequence_length, 2*sequence_length]
+ predict_attn_weights = predict_attn_weights + predict_relative_pos_embeddings
+
+ if extended_predict_attention_mask is not None:
+ # Permuting Predict attention mask to [batch_size, ngram, number_heads, sequence_length, 2*sequence_length]
+ extended_predict_attention_mask = extended_predict_attention_mask.permute(0, 2, 1, 3, 4)
+ extended_predict_attention_mask = extended_predict_attention_mask.to(predict_attn_weights.dtype)
+ predict_attn_weights = predict_attn_weights + extended_predict_attention_mask
+
+ predict_attn_probs = softmax(
+ predict_attn_weights,
+ dim=-1,
+ onnx_trace=self.onnx_trace,
+ ).type_as(predict_attn_weights)
+
+ if layer_head_mask is not None:
+ assert layer_head_mask.size() == (self.num_attn_heads,), (
+ f"Head mask for a single layer should be of size {(self.num_attn_heads,)}, but is"
+ f" {layer_head_mask.size()}"
+ )
+ predict_attn_probs = layer_head_mask.view(1, 1, -1, 1, 1) * predict_attn_probs
+
+ predict_attn_probs = nn.functional.dropout(
+ predict_attn_probs, p=self.attention_dropout, training=self.training
+ )
+ # project to attention output
+ # [batch_size, ngram, number_heads, sequence_length, 2*sequence_length]
+ # x [batch_size, ngram, number_heads, 2*sequence_length, head_dimesion]
+ # -> [batch_size, ngram, number_heads, sequence_length, head_dimesion]
+ predict_attn_output = torch.einsum(
+ "bnhts,bnhsc->bnhtc", (predict_attn_probs, predict_value_states.transpose(1, 2))
+ )
+
+ # reshape so that num_heads dim is merged into last `head_dim` axis
+ # [batch_size, ngram, number_heads, sequence_length, head_dimesion] -> [batch_size, ngram, sequence_length, hidden_size]
+ predict_attn_output = predict_attn_output.transpose(2, 3)
+ predict_attn_output = predict_attn_output.reshape(batch_size, self.ngram, sequence_length, hidden_size)
+ predict_attn_output = self.out_proj(predict_attn_output)
+
+ # concat to single attn output
+ # [batch_size, (1+ngram)*sequence_length, hidden_size]
+ attn_output = torch.cat([main_attn_output, predict_attn_output], 1).view(batch_size, -1, hidden_size)
+ # reshape into better form for `config.output_attentions`
+ main_attn_probs = main_attn_probs.view(batch_size, self.num_attn_heads, sequence_length, -1)
+
+ attn_output = nn.functional.dropout(attn_output, p=self.dropout, training=self.training)
+
+ return attn_output, main_attn_probs, predict_attn_probs, past_key_value
+
+ def get_main_relative_pos_embeddings(
+ self, hidden_states, attn_weights, position_ids, main_relative_position_buckets
+ ):
+ # input hidden_states [batch_size, sequence_length, hidden_size]
+ # input attn_weights [batch_size, num_heads, sequence_length, sequence_length]
+ # input position_ids [batch_size, sequence_length] or [1,1]
+ batch_size, num_attn_heads, tgt_len, src_len = attn_weights.shape
+ attn_weights = attn_weights.view(batch_size, num_attn_heads, tgt_len, src_len)
+ if main_relative_position_buckets is None:
+ batch_size, sequence_length = hidden_states.shape[:2]
+ relative_positions = (
+ torch.arange(1, attn_weights.shape[-1] + 1)
+ .unsqueeze(0)
+ .unsqueeze(0)
+ .repeat(batch_size, sequence_length, 1)
+ .to(position_ids.device)
+ )
+ # [batch_size, sequence_length, sequence_length+1]
+ relative_positions = relative_positions - position_ids.unsqueeze(0).repeat(batch_size, sequence_length, 1)
+ main_relative_position_buckets = compute_relative_buckets(
+ self.num_buckets, self.relative_max_distance, relative_positions, False
+ )
+
+ # [batch_size, sequence_length, num_buckets * num_heads]
+ rel_pos_embeddings = self.relative_pos_embeddings(hidden_states)
+ rel_pos_embeddings = rel_pos_embeddings.view(
+ rel_pos_embeddings.shape[:2] + (self.num_buckets, self.num_attn_heads)
+ )
+ rel_pos_embeddings = rel_pos_embeddings.permute(0, 3, 1, 2)
+ # [batch_size, num_heads, sequence_length, num_buckets]
+ rel_pos_embeddings = rel_pos_embeddings.reshape(attn_weights.shape[:3] + (-1,))
+
+ main_relative_position_buckets = main_relative_position_buckets.repeat(1, self.num_attn_heads, 1)
+ # [batch_size * num_heads * sequence_length, sequence_length]
+ main_relative_position_buckets = main_relative_position_buckets.view(
+ -1, main_relative_position_buckets.shape[-1]
+ )
+ main_relative_position_buckets = main_relative_position_buckets.long()
+ # [batch_size * num_heads * sequence_length, sequence_length]
+ rel_pos_embeddings = rel_pos_embeddings.reshape(-1, rel_pos_embeddings.size(-1))
+
+ main_relative_pos_embeddings = torch.gather(rel_pos_embeddings, dim=1, index=main_relative_position_buckets)
+ main_relative_pos_embeddings = main_relative_pos_embeddings.view(batch_size, num_attn_heads, tgt_len, -1)
+ return main_relative_pos_embeddings
+
+ def get_predict_relative_pos_embeddings(
+ self, hidden_states, attn_weights, position_ids, predict_relative_position_buckets
+ ):
+ # input hidden_states [batch_size, sequence_length, ngram, hidden_size]
+ # input attn_weights [batch_size, ngram, num_heads, sequence_length, 2*sequence_length]
+ # input position_ids [batch_size, sequence_length] or [1,1]
+ # input predict_relative_position_buckets [batch_size, sequence_length, 2*sequence_length] or None
+ batch_size, sequence_length = hidden_states.shape[0:2]
+
+ if predict_relative_position_buckets is None:
+ key_sequence_length = attn_weights.shape[-1]
+ assert position_ids[0][0] == key_sequence_length - 1, (
+ "`position_ids` are incorrect. They should be of the format 1 2 3 4 5 ... (key_sequence_length - 1)"
+ )
+ relative_positions = (
+ torch.arange(0, key_sequence_length)
+ .unsqueeze(0)
+ .unsqueeze(0)
+ .repeat(batch_size, sequence_length, 1)
+ .to(position_ids.device)
+ )
+
+ relative_positions = relative_positions - position_ids.unsqueeze(0).repeat(batch_size, sequence_length, 1)
+ predict_relative_position_buckets = compute_relative_buckets(
+ self.num_buckets, self.relative_max_distance, relative_positions, False
+ )
+
+ # [batch_size, ngram, sequence_length, hidden_size]
+ hidden_states = hidden_states.transpose(1, 2)
+ rel_pos_embeddings = self.relative_pos_embeddings(hidden_states)
+
+ # [batch_size, ngram, sequence_length, num_buckets, num_heads]
+ rel_pos_embeddings = rel_pos_embeddings.view(
+ hidden_states.shape[:-1] + (self.num_buckets, self.num_attn_heads)
+ )
+ rel_pos_embeddings = rel_pos_embeddings.permute(0, 2, 1, 4, 3)
+ # [batch_size * ngram * sequence_length * num_heads, num_buckets]
+ rel_pos_embeddings = rel_pos_embeddings.reshape(-1, self.num_buckets)
+ # [ngram, batch_size, num_heads * sequence_length, -1]
+ predict_relative_position_buckets = predict_relative_position_buckets.unsqueeze(0)
+ predict_relative_position_buckets = predict_relative_position_buckets.repeat(
+ self.ngram, 1, self.num_attn_heads, 1
+ )
+ # [ngram * batch_size * num_heads * sequence_length, -1]
+ predict_relative_position_buckets = predict_relative_position_buckets.view(
+ -1, predict_relative_position_buckets.size(-1)
+ ).long()
+
+ predict_relative_pos_embeddings = torch.gather(
+ rel_pos_embeddings, dim=1, index=predict_relative_position_buckets
+ )
+
+ # [batch_size, gram, num_heads, sequence_length, -1]
+ predict_relative_pos_embeddings = predict_relative_pos_embeddings.view(
+ batch_size, self.ngram, self.num_attn_heads, sequence_length, -1
+ )
+
+ return predict_relative_pos_embeddings
+
+
+class XLMProphetNetEncoderLayer(nn.Module):
+ """
+ Encoder block for XLMProphetnet
+ """
+
+ def __init__(self, config: XLMProphetNetConfig):
+ super().__init__()
+ # 1st residual block
+ self.self_attn = XLMProphetNetAttention(config, config.num_encoder_attention_heads)
+ self.self_attn_layer_norm = LayerNorm(config.hidden_size)
+
+ # 2nd residual block
+ self.feed_forward = XLMProphetNetFeedForward(config, config.encoder_ffn_dim)
+ self.feed_forward_layer_norm = LayerNorm(config.hidden_size)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ output_attentions: bool = False,
+ ):
+ # 1st residual block
+ attention_output, attn_weights, _ = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ layer_head_mask=layer_head_mask,
+ output_attentions=output_attentions,
+ )
+ hidden_states = self.self_attn_layer_norm(attention_output + hidden_states)
+
+ # 2nd residual block
+ feed_forward_output = self.feed_forward(hidden_states)
+ hidden_states = self.feed_forward_layer_norm(feed_forward_output + hidden_states)
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ return outputs
+
+
+class XLMProphetNetDecoderLayer(nn.Module):
+ """
+ Decoder block for XLMProphetnet
+ """
+
+ def __init__(self, config: XLMProphetNetConfig):
+ super().__init__()
+ # 1st residual block
+ self.self_attn = XLMProphetNetNgramSelfAttention(config)
+ self.self_attn_layer_norm = LayerNorm(config.hidden_size)
+
+ # 2nd residual block
+ if config.add_cross_attention:
+ self.cross_attn = XLMProphetNetAttention(config, config.num_decoder_attention_heads)
+ self.cross_attn_layer_norm = LayerNorm(config.hidden_size)
+
+ # 3rd residual block
+ self.feed_forward = XLMProphetNetFeedForward(config, config.decoder_ffn_dim)
+ self.feed_forward_layer_norm = LayerNorm(config.hidden_size)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ encoder_hidden_states=None,
+ encoder_attn_mask=None,
+ layer_head_mask=None,
+ cross_attn_layer_head_mask=None,
+ extended_predict_attention_mask=None,
+ main_relative_position_buckets=None,
+ predict_relative_position_buckets=None,
+ position_ids=None,
+ past_key_value=None,
+ use_cache: bool = True,
+ output_attentions: bool = False,
+ ):
+ # 1st residual block
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
+ ngram_attention_output, self_attn_weights, self_attn_weights_ngram, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ past_key_value=self_attn_past_key_value,
+ attention_mask=attention_mask,
+ layer_head_mask=layer_head_mask,
+ extended_predict_attention_mask=extended_predict_attention_mask,
+ main_relative_position_buckets=main_relative_position_buckets,
+ predict_relative_position_buckets=predict_relative_position_buckets,
+ position_ids=position_ids,
+ )
+ hidden_states = self.self_attn_layer_norm(hidden_states + ngram_attention_output)
+
+ # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
+ cross_attn_weights = None
+ if encoder_hidden_states is not None:
+ # 2nd residual block
+ attention_output, cross_attn_weights, cross_attn_present_key_value = self.cross_attn(
+ hidden_states=hidden_states,
+ key_value_states=encoder_hidden_states,
+ attention_mask=encoder_attn_mask,
+ layer_head_mask=cross_attn_layer_head_mask,
+ past_key_value=cross_attn_past_key_value,
+ output_attentions=output_attentions,
+ )
+ hidden_states = self.cross_attn_layer_norm(attention_output + hidden_states)
+
+ # add cross-attn to positions 3,4 of present_key_value tuple
+ present_key_value = present_key_value + cross_attn_present_key_value
+
+ # 3rd residual block
+ feed_forward_output = self.feed_forward(hidden_states)
+ hidden_states = self.feed_forward_layer_norm(feed_forward_output + hidden_states)
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights, self_attn_weights_ngram, cross_attn_weights)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+
+@add_start_docstrings(
+ "The standalone encoder part of the XLMProphetNetModel.",
+ XLM_PROPHETNET_START_DOCSTRING,
+)
+class XLMProphetNetEncoder(XLMProphetNetPreTrainedModel):
+ r"""
+ word_embeddings (`torch.nn.Embeddings` of shape `(config.vocab_size, config.hidden_size)`, *optional*):
+ The word embedding parameters. This can be used to initialize [`XLMProphetNetEncoder`] with pre-defined word
+ embeddings instead of randomly initialized word embeddings.
+ """
+
+ def __init__(self, config: XLMProphetNetConfig, word_embeddings: nn.Embedding = None):
+ super().__init__(config)
+
+ self.word_embeddings = (
+ word_embeddings
+ if word_embeddings is not None
+ else nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
+ )
+ self.position_embeddings = XLMProphetNetPositionalEmbeddings(config)
+ self.embeddings_layer_norm = LayerNorm(config.hidden_size)
+
+ self.layers = nn.ModuleList([XLMProphetNetEncoderLayer(config) for _ in range(config.num_encoder_layers)])
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.word_embeddings
+
+ def set_input_embeddings(self, value):
+ self.word_embeddings = value
+
+ @add_start_docstrings_to_model_forward(XLM_PROPHETNET_STANDALONE_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutput]:
+ r"""
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, XLMProphetNetEncoder
+ >>> import torch
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("patrickvonplaten/xprophetnet-large-uncased-standalone")
+ >>> model = XLMProphetNetEncoder.from_pretrained("patrickvonplaten/prophetnet-large-uncased-standalone")
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
+ >>> outputs = model(**inputs)
+
+ >>> last_hidden_states = outputs.last_hidden_state
+ ```"""
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is None and inputs_embeds is None:
+ raise ValueError("Either input_ids or inputs_embeds has to be passed.")
+ elif input_ids is not None and inputs_embeds is not None:
+ raise ValueError("Make sure to only pass input_ids or inputs_embeds.")
+ elif input_ids is not None and inputs_embeds is None:
+ inputs_embeds = self.word_embeddings(input_ids)
+
+ # prepare attention mask
+ if attention_mask is not None:
+ extended_attention_mask = (
+ 1.0 - attention_mask[:, None, None, :].repeat(1, self.config.num_encoder_attention_heads, 1, 1)
+ ) * torch.finfo(self.dtype).min
+ extended_attention_mask = extended_attention_mask.to(inputs_embeds.dtype)
+ else:
+ extended_attention_mask = None
+
+ position_embeddings, position_ids = self.position_embeddings(inputs_embeds.shape[:2], inputs_embeds.device)
+
+ hidden_states = inputs_embeds + position_embeddings
+ hidden_states = self.embeddings_layer_norm(hidden_states)
+ hidden_states = nn.functional.dropout(hidden_states, p=self.config.dropout, training=self.training)
+
+ encoder_hidden_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ # check if head_mask has a correct number of layers specified if desired
+ if head_mask is not None:
+ assert head_mask.size()[0] == (len(self.layers)), (
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
+ )
+ for idx, encoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ encoder_hidden_states = encoder_hidden_states + (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ encoder_layer.__call__,
+ hidden_states,
+ extended_attention_mask,
+ (head_mask[idx] if head_mask is not None else None),
+ output_attentions,
+ )
+ else:
+ layer_outputs = encoder_layer(
+ hidden_states,
+ attention_mask=extended_attention_mask,
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ encoder_hidden_states = encoder_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, encoder_hidden_states, all_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states, hidden_states=encoder_hidden_states, attentions=all_attentions
+ )
+
+
+@add_start_docstrings(
+ "The standalone decoder part of the XLMProphetNetModel.",
+ XLM_PROPHETNET_START_DOCSTRING,
+)
+class XLMProphetNetDecoder(XLMProphetNetPreTrainedModel):
+ r"""
+ word_embeddings (`torch.nn.Embeddings` of shape `(config.vocab_size, config.hidden_size)`, *optional*):
+ The word embedding parameters. This can be used to initialize [`XLMProphetNetEncoder`] with pre-defined word
+ embeddings instead of randomly initialized word embeddings.
+ """
+
+ def __init__(self, config: XLMProphetNetConfig, word_embeddings: Optional[nn.Embedding] = None):
+ super().__init__(config)
+
+ self.ngram = config.ngram
+ self.num_buckets = config.num_buckets
+ self.relative_max_distance = config.relative_max_distance
+ self.dropout = config.dropout
+ self.max_target_positions = config.max_position_embeddings
+
+ self.word_embeddings = (
+ word_embeddings
+ if word_embeddings is not None
+ else nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
+ )
+ self.position_embeddings = XLMProphetNetPositionalEmbeddings(config)
+
+ self.ngram_embeddings = nn.Embedding(self.ngram, config.hidden_size, None)
+ self.layers = nn.ModuleList([XLMProphetNetDecoderLayer(config) for _ in range(config.num_decoder_layers)])
+ self.embeddings_layer_norm = LayerNorm(config.hidden_size)
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.word_embeddings
+
+ def set_input_embeddings(self, value):
+ self.word_embeddings = value
+
+ @add_start_docstrings_to_model_forward(XLM_PROPHETNET_STANDALONE_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=XLMProphetNetDecoderModelOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, XLMProphetNetDecoderModelOutput]:
+ r"""
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+ the model is configured as a decoder.
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
+ cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, XLMProphetNetDecoder
+ >>> import torch
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("patrickvonplaten/xprophetnet-large-uncased-standalone")
+ >>> model = XLMProphetNetDecoder.from_pretrained("patrickvonplaten/xprophetnet-large-uncased-standalone", add_cross_attention=False)
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
+ >>> outputs = model(**inputs)
+
+ >>> last_hidden_states = outputs.last_hidden_state
+ ```"""
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is None and inputs_embeds is None:
+ raise ValueError("Either `decoder_input_ids` or `decoder_inputs_embeds` has to be passed.")
+ elif input_ids is not None and inputs_embeds is not None:
+ raise ValueError("Make sure to only pass `decoder_input_ids` or `decoder_inputs_embeds`.")
+ elif input_ids is not None and inputs_embeds is None:
+ inputs_embeds = self.word_embeddings(input_ids)
+
+ batch_size, sequence_length = inputs_embeds.shape[:2]
+
+ main_stream_pos_embed, position_ids = self.position_embeddings(
+ (batch_size, sequence_length),
+ device=inputs_embeds.device,
+ past_key_values=past_key_values,
+ )
+
+ if past_key_values is not None:
+ main_relative_position_buckets, predict_relative_position_buckets = None, None
+ else:
+ (
+ main_relative_position_buckets,
+ predict_relative_position_buckets,
+ ) = self.compute_buffered_relative_buckets(position_ids)
+ predicting_stream_pos_embed = self.position_embeddings._forward(position_ids + 1)
+
+ # add position embeddings
+ hidden_states = inputs_embeds + main_stream_pos_embed
+
+ ngram_embeddings = self.ngram_embeddings.weight
+
+ # prepare attention mask
+ if past_key_values is not None:
+ assert hidden_states.size(1) == 1, (
+ "At the moment `use_cache` is only supported for `decoder_input_ids` of length 1"
+ )
+
+ ngram_hidden_states = [
+ (ngram_embeddings[ngram - 1] + predicting_stream_pos_embed).repeat(batch_size, 1, 1)
+ for ngram in range(self.ngram)
+ ]
+ extended_attention_mask = None
+ extended_predict_attention_mask = None
+ else:
+ ngram_hidden_states = [
+ (ngram_embeddings[ngram - 1] + predicting_stream_pos_embed) for ngram in range(self.ngram)
+ ]
+ extended_attention_mask = self.prepare_attention_mask(hidden_states, attention_mask)
+ extended_predict_attention_mask = self.prepare_predict_attention_mask(hidden_states, attention_mask)
+
+ # prepare encoder attention mask
+ if encoder_attention_mask is not None:
+ extended_encoder_attention_mask = (
+ 1.0 - encoder_attention_mask[:, None, None, :].repeat(1, self.config.num_decoder_attention_heads, 1, 1)
+ ) * torch.finfo(self.dtype).min
+ extended_encoder_attention_mask = extended_encoder_attention_mask.to(inputs_embeds.dtype)
+ else:
+ extended_encoder_attention_mask = None
+
+ hidden_states = torch.cat([hidden_states] + ngram_hidden_states, 1)
+
+ if self.embeddings_layer_norm:
+ hidden_states = self.embeddings_layer_norm(hidden_states)
+
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+
+ # init attentions, hidden_states and cache with empty tuples
+ all_main_stream_hidden_states = () if output_hidden_states else None
+ all_ngram_stream_hidden_states = () if output_hidden_states and self.config.ngram > 0 else None
+
+ all_main_stream_attns = () if output_attentions else None
+ all_ngram_stream_attns = () if output_attentions else None
+ all_cross_attns = () if output_attentions and self.config.add_cross_attention else None
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ present_key_values = () if use_cache else None
+
+ # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
+ for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
+ if attn_mask is not None:
+ assert attn_mask.size()[0] == (len(self.layers)), (
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
+ f" {head_mask.size()[0]}."
+ )
+ for idx, decoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ # grad cannot be kept because tensor is sliced
+ all_main_stream_hidden_states += (hidden_states[:, :sequence_length],)
+ if self.config.ngram > 0:
+ all_ngram_stream_hidden_states += (hidden_states[:, sequence_length:],)
+
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ decoder_layer.__call__,
+ hidden_states,
+ extended_attention_mask,
+ encoder_hidden_states,
+ extended_encoder_attention_mask,
+ (head_mask[idx] if head_mask is not None else None),
+ (cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
+ extended_predict_attention_mask,
+ main_relative_position_buckets,
+ predict_relative_position_buckets,
+ position_ids,
+ None,
+ use_cache,
+ output_attentions,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=extended_attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attn_mask=extended_encoder_attention_mask,
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
+ cross_attn_layer_head_mask=(
+ cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
+ ),
+ extended_predict_attention_mask=extended_predict_attention_mask,
+ main_relative_position_buckets=main_relative_position_buckets,
+ predict_relative_position_buckets=predict_relative_position_buckets,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ present_key_values += (layer_outputs[4 if output_attentions else 1],)
+
+ if output_attentions:
+ all_main_stream_attns += (layer_outputs[1],)
+ all_ngram_stream_attns += (layer_outputs[2],)
+
+ if self.config.add_cross_attention:
+ all_cross_attns += (layer_outputs[3],)
+
+ if output_hidden_states:
+ all_main_stream_hidden_states += (hidden_states[:, :sequence_length],)
+ if self.config.ngram > 0:
+ all_ngram_stream_hidden_states += (hidden_states[:, sequence_length:],)
+
+ # split last_hidden_state for return
+ last_hidden_state = hidden_states[:, :sequence_length]
+ last_hidden_state_ngram = hidden_states[:, sequence_length:] if self.config.ngram > 0 else None
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [
+ last_hidden_state,
+ last_hidden_state_ngram,
+ present_key_values,
+ all_main_stream_hidden_states,
+ all_ngram_stream_hidden_states,
+ all_main_stream_attns,
+ all_ngram_stream_attns,
+ all_cross_attns,
+ ]
+ if v is not None
+ )
+ return XLMProphetNetDecoderModelOutput(
+ last_hidden_state=last_hidden_state,
+ last_hidden_state_ngram=last_hidden_state_ngram,
+ past_key_values=present_key_values,
+ hidden_states=all_main_stream_hidden_states,
+ hidden_states_ngram=all_ngram_stream_hidden_states,
+ attentions=all_main_stream_attns,
+ ngram_attentions=all_ngram_stream_attns,
+ cross_attentions=all_cross_attns,
+ )
+
+ def compute_buffered_relative_buckets(self, position_ids):
+ batch_size, sequence_length = position_ids.shape
+
+ position_ids = torch.arange(1, self.max_target_positions).to(position_ids.device).repeat(1, 1)
+ main_relative_buckets, predict_relative_buckets = compute_all_stream_relative_buckets(
+ self.num_buckets, self.relative_max_distance, position_ids
+ )
+
+ # buffer relative buckets
+ main_relative_buckets = main_relative_buckets[:, :sequence_length, :sequence_length].repeat(batch_size, 1, 1)
+ predict_relative_buckets = torch.cat(
+ [
+ predict_relative_buckets[:, :sequence_length, :sequence_length],
+ predict_relative_buckets[
+ :, :sequence_length, self.max_target_positions : self.max_target_positions + sequence_length
+ ],
+ ],
+ 2,
+ ).repeat(batch_size, 1, 1)
+
+ return main_relative_buckets, predict_relative_buckets
+
+ def prepare_attention_mask(self, hidden_states, attention_mask):
+ batch_size, seq_length = hidden_states.shape[:2]
+
+ # get causal mask
+ causal_mask = torch.full(
+ (seq_length, seq_length),
+ torch.finfo(hidden_states.dtype).min,
+ dtype=hidden_states.dtype,
+ device=hidden_states.device,
+ )
+ causal_mask = torch.triu(causal_mask, 1)
+
+ extended_causal_mask = causal_mask[:seq_length, :seq_length][None, None, :, :].expand(
+ (batch_size, self.config.num_decoder_attention_heads) + causal_mask.shape
+ )
+
+ # add usual attention mask
+ if attention_mask is not None:
+ extended_attention_mask = (1.0 - attention_mask[:, None, None, :]) * torch.finfo(self.dtype).min
+ extended_attention_mask = extended_causal_mask + extended_attention_mask
+ else:
+ extended_attention_mask = extended_causal_mask
+ return extended_attention_mask.to(hidden_states.dtype)
+
+ def prepare_predict_attention_mask(self, hidden_states, attention_mask):
+ batch_size, seq_length = hidden_states.shape[:2]
+
+ # get causal mask
+ predict_causal_mask = ngram_attention_bias(
+ self.max_target_positions, self.ngram, hidden_states.device, hidden_states.dtype
+ )
+ predict_causal_mask = torch.cat(
+ [
+ predict_causal_mask[:, :seq_length, :seq_length],
+ predict_causal_mask[
+ :, :seq_length, self.max_target_positions : self.max_target_positions + seq_length
+ ],
+ ],
+ dim=-1,
+ )
+ extended_predict_causal_mask = predict_causal_mask[None, None, :, :, :].expand(
+ (batch_size, self.config.num_decoder_attention_heads) + predict_causal_mask.shape
+ )
+
+ # add usual attention mask
+ if attention_mask is not None:
+ extended_attention_mask = (1.0 - attention_mask[:, None, None, None, :]) * torch.finfo(self.dtype).min
+ extended_attention_mask = extended_attention_mask.expand(
+ (batch_size, self.config.num_decoder_attention_heads, self.ngram, seq_length, seq_length)
+ )
+ # predicted stream attention_mask should always be 0
+ extended_attention_mask = torch.cat(
+ [extended_attention_mask, torch.zeros_like(extended_attention_mask)], dim=-1
+ )
+ extended_predict_attention_mask = extended_predict_causal_mask + extended_attention_mask
+ else:
+ extended_predict_attention_mask = extended_predict_causal_mask
+ return extended_predict_attention_mask.to(hidden_states.dtype)
+
+
+@add_start_docstrings(
+ "The bare XLMProphetNet Model outputting raw hidden-states without any specific head on top.",
+ XLM_PROPHETNET_START_DOCSTRING,
+)
+class XLMProphetNetModel(XLMProphetNetPreTrainedModel):
+ _tied_weights_keys = ["encoder.word_embeddings.weight", "decoder.word_embeddings.weight"]
+
+ def __init__(self, config: XLMProphetNetConfig):
+ super().__init__(config)
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
+
+ encoder_config = copy.deepcopy(config)
+ encoder_config.is_encoder_decoder = False
+ encoder_config.use_cache = False
+ self.encoder = XLMProphetNetEncoder(encoder_config, self.word_embeddings)
+
+ decoder_config = copy.deepcopy(config)
+ decoder_config.is_decoder = True
+ decoder_config.is_encoder_decoder = False
+ self.decoder = XLMProphetNetDecoder(decoder_config, self.word_embeddings)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.word_embeddings
+
+ def set_input_embeddings(self, value):
+ self.word_embeddings = value
+ self.encoder.word_embeddings = self.word_embeddings
+ self.decoder.word_embeddings = self.word_embeddings
+
+ def _tie_weights(self):
+ if self.config.tie_word_embeddings:
+ self._tie_or_clone_weights(self.encoder.word_embeddings, self.word_embeddings)
+ self._tie_or_clone_weights(self.decoder.word_embeddings, self.word_embeddings)
+
+ def get_encoder(self):
+ return self.encoder
+
+ def get_decoder(self):
+ return self.decoder
+
+ @add_start_docstrings_to_model_forward(XLM_PROPHETNET_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=XLMProphetNetSeq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ decoder_input_ids: Optional[torch.Tensor] = None,
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ decoder_head_mask: Optional[torch.Tensor] = None,
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
+ encoder_outputs: Optional[Tuple] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ decoder_inputs_embeds: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, XLMProphetNetSeq2SeqModelOutput]:
+ r"""
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, XLMProphetNetModel
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("patrickvonplaten/xprophetnet-large-uncased-standalone")
+ >>> model = XLMProphetNetModel.from_pretrained("patrickvonplaten/xprophetnet-large-uncased-standalone")
+
+ >>> input_ids = tokenizer(
+ ... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
+ ... ).input_ids # Batch size 1
+ >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1
+ >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
+
+ >>> last_hidden_states = outputs.last_hidden_state # main stream hidden states
+ >>> last_hidden_states_ngram = outputs.last_hidden_state_ngram # predict hidden states
+ ```"""
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if encoder_outputs is None:
+ encoder_outputs = self.encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn)
+ decoder_outputs = self.decoder(
+ input_ids=decoder_input_ids,
+ attention_mask=decoder_attention_mask,
+ encoder_hidden_states=encoder_outputs[0],
+ encoder_attention_mask=attention_mask,
+ head_mask=decoder_head_mask,
+ cross_attn_head_mask=cross_attn_head_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=decoder_inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ use_cache=use_cache,
+ return_dict=return_dict,
+ )
+
+ if not return_dict:
+ return decoder_outputs + encoder_outputs
+ return XLMProphetNetSeq2SeqModelOutput(
+ last_hidden_state=decoder_outputs.last_hidden_state,
+ last_hidden_state_ngram=decoder_outputs.last_hidden_state_ngram,
+ past_key_values=decoder_outputs.past_key_values,
+ decoder_hidden_states=decoder_outputs.hidden_states,
+ decoder_ngram_hidden_states=decoder_outputs.hidden_states_ngram,
+ decoder_attentions=decoder_outputs.attentions,
+ decoder_ngram_attentions=decoder_outputs.ngram_attentions,
+ cross_attentions=decoder_outputs.cross_attentions,
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
+ encoder_hidden_states=encoder_outputs.hidden_states,
+ encoder_attentions=encoder_outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ "The XLMProphetNet Model with a language modeling head. Can be used for sequence generation tasks.",
+ XLM_PROPHETNET_START_DOCSTRING,
+)
+class XLMProphetNetForConditionalGeneration(XLMProphetNetPreTrainedModel):
+ _tied_weights_keys = ["encoder.word_embeddings.weight", "decoder.word_embeddings.weight", "lm_head.weight"]
+
+ def __init__(self, config: XLMProphetNetConfig):
+ super().__init__(config)
+ self.prophetnet = XLMProphetNetModel(config)
+ self.padding_idx = config.pad_token_id
+ self.disable_ngram_loss = config.disable_ngram_loss
+
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def _tie_weights(self):
+ if self.config.tie_word_embeddings:
+ self._tie_or_clone_weights(self.prophetnet.word_embeddings, self.lm_head)
+
+ def get_input_embeddings(self):
+ return self.prophetnet.word_embeddings
+
+ @add_start_docstrings_to_model_forward(XLM_PROPHETNET_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=XLMProphetNetSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ decoder_input_ids: Optional[torch.Tensor] = None,
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ decoder_head_mask: Optional[torch.Tensor] = None,
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
+ encoder_outputs: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ decoder_inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, XLMProphetNetSeq2SeqLMOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
+ config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
+ labels in `[0, ..., config.vocab_size]`
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, XLMProphetNetForConditionalGeneration
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("patrickvonplaten/xprophetnet-large-uncased-standalone")
+ >>> model = XLMProphetNetForConditionalGeneration.from_pretrained("patrickvonplaten/xprophetnet-large-uncased-standalone")
+
+ >>> input_ids = tokenizer(
+ ... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
+ ... ).input_ids # Batch size 1
+ >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1
+ >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
+
+ >>> logits_next_token = outputs.logits # logits to predict next token as usual
+ >>> logits_ngram_next_tokens = outputs.logits_ngram # logits to predict 2nd, 3rd, ... next tokens
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
+ # get decoder inputs from shifting lm labels to the right
+ decoder_input_ids = self._shift_right(labels)
+
+ outputs = self.prophetnet(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ decoder_input_ids=decoder_input_ids,
+ decoder_attention_mask=decoder_attention_mask,
+ head_mask=head_mask,
+ decoder_head_mask=decoder_head_mask,
+ cross_attn_head_mask=cross_attn_head_mask,
+ encoder_outputs=encoder_outputs,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ decoder_inputs_embeds=decoder_inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ batch_size, sequence_length = (
+ decoder_input_ids.shape if decoder_input_ids is not None else decoder_inputs_embeds.shape[:2]
+ )
+
+ predicting_streams = outputs[1].view(batch_size, self.config.ngram, sequence_length, -1)
+ predict_logits = self.lm_head(predicting_streams)
+
+ logits = predict_logits[:, 0]
+ logits_ngram = predict_logits[:, 1:] if self.config.ngram > 1 else None
+
+ # To use .view in loss computation, make sure that logits is contiguous.
+ if not logits.is_contiguous():
+ logits = logits.contiguous()
+
+ loss = None
+ if labels is not None:
+ loss = self._compute_loss(predict_logits, labels)
+
+ if not return_dict:
+ all_logits = tuple(v for v in [logits, logits_ngram] if v is not None)
+ return (loss,) + all_logits + outputs[2:] if loss is not None else all_logits + outputs[2:]
+ else:
+ return XLMProphetNetSeq2SeqLMOutput(
+ loss=loss,
+ logits=logits,
+ logits_ngram=logits_ngram,
+ past_key_values=outputs.past_key_values,
+ decoder_hidden_states=outputs.decoder_hidden_states,
+ decoder_ngram_hidden_states=outputs.decoder_ngram_hidden_states,
+ decoder_attentions=outputs.decoder_attentions,
+ decoder_ngram_attentions=outputs.decoder_ngram_attentions,
+ cross_attentions=outputs.cross_attentions,
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
+ encoder_hidden_states=outputs.encoder_hidden_states,
+ encoder_attentions=outputs.encoder_attentions,
+ )
+
+ def _compute_loss(self, logits, labels, ignore_index=-100):
+ expend_targets = labels.new_zeros(self.config.ngram, labels.size(0), labels.size(1)).fill_(ignore_index)
+
+ for i in range(self.config.ngram):
+ if i > 0 and self.disable_ngram_loss:
+ break
+ expend_targets[i, :, :] = labels
+
+ logits = logits.transpose(0, 1).contiguous()
+ lprobs = nn.functional.log_softmax(
+ logits.view(-1, logits.size(-1)),
+ dim=-1,
+ dtype=torch.float32,
+ )
+
+ loss = nn.functional.nll_loss(lprobs, expend_targets.view(-1), reduction="mean")
+
+ if self.config.eps > 0.0:
+ smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
+ non_masked_tokens = expend_targets.ne(ignore_index).view(-1)
+ smooth_loss = smooth_loss[non_masked_tokens]
+ smooth_loss = smooth_loss.mean()
+
+ eps_i = self.config.eps / lprobs.size(-1)
+ loss = (1.0 - self.config.eps) * loss + eps_i * smooth_loss
+
+ return loss
+
+ def prepare_inputs_for_generation(
+ self,
+ decoder_input_ids,
+ past_key_values=None,
+ attention_mask=None,
+ head_mask=None,
+ decoder_head_mask=None,
+ cross_attn_head_mask=None,
+ use_cache=None,
+ encoder_outputs=None,
+ **kwargs,
+ ):
+ assert encoder_outputs is not None, "`encoder_outputs` have to be passed for generation."
+
+ if past_key_values:
+ decoder_input_ids = decoder_input_ids[:, -1:]
+ # first step, decoder_cached_states are empty
+ return {
+ "input_ids": None, # encoder_outputs is defined. input_ids not needed
+ "encoder_outputs": encoder_outputs,
+ "past_key_values": past_key_values,
+ "decoder_input_ids": decoder_input_ids,
+ "attention_mask": attention_mask,
+ "head_mask": head_mask,
+ "decoder_head_mask": decoder_head_mask,
+ "cross_attn_head_mask": cross_attn_head_mask,
+ "use_cache": use_cache,
+ }
+
+ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
+ return self._shift_right(labels)
+
+ @staticmethod
+ def _reorder_cache(past_key_values, beam_idx):
+ reordered_past = ()
+ for layer_past in past_key_values:
+ # cached cross_attention states don't have to be reordered -> they are always the same
+ reordered_past += (
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
+ + layer_past[2:],
+ )
+ return reordered_past
+
+ def get_encoder(self):
+ return self.prophetnet.encoder
+
+ def get_decoder(self):
+ return self.prophetnet.decoder
+
+
+@add_start_docstrings(
+ "The standalone decoder part of the XLMProphetNetModel with a lm head on top. The model can be used for causal"
+ " language modeling.",
+ XLM_PROPHETNET_START_DOCSTRING,
+)
+class XLMProphetNetForCausalLM(XLMProphetNetPreTrainedModel):
+ _tied_weights_keys = [
+ "prophetnet.word_embeddings.weight",
+ "prophetnet.decoder.word_embeddings.weight",
+ "lm_head.weight",
+ ]
+
+ def __init__(self, config: XLMProphetNetConfig):
+ # set config for CLM
+ config = copy.deepcopy(config)
+ config.is_decoder = True
+ config.is_encoder_decoder = False
+ super().__init__(config)
+ self.prophetnet = XLMProphetNetDecoderWrapper(config)
+
+ self.padding_idx = config.pad_token_id
+ self.disable_ngram_loss = config.disable_ngram_loss
+
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.prophetnet.decoder.word_embeddings
+
+ def set_input_embeddings(self, value):
+ self.prophetnet.decoder.word_embeddings = value
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def _tie_weights(self):
+ if self.config.tie_word_embeddings:
+ self._tie_or_clone_weights(self.prophetnet.decoder.word_embeddings, self.lm_head)
+
+ def set_decoder(self, decoder):
+ self.prophetnet.decoder = decoder
+
+ def get_decoder(self):
+ return self.prophetnet.decoder
+
+ @add_start_docstrings_to_model_forward(XLM_PROPHETNET_STANDALONE_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=XLMProphetNetDecoderLMOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, XLMProphetNetDecoderLMOutput]:
+ r"""
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+ the model is configured as a decoder.
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
+ cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
+ `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
+ ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, XLMProphetNetForCausalLM
+ >>> import torch
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("patrickvonplaten/xprophetnet-large-uncased-standalone")
+ >>> model = XLMProphetNetForCausalLM.from_pretrained("patrickvonplaten/xprophetnet-large-uncased-standalone")
+ >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
+ >>> outputs = model(**inputs)
+
+ >>> logits = outputs.logits
+
+ >>> # Model can also be used with EncoderDecoder framework
+ >>> from transformers import BertTokenizer, EncoderDecoderModel, AutoTokenizer
+ >>> import torch
+
+ >>> tokenizer_enc = BertTokenizer.from_pretrained("google-bert/bert-large-uncased")
+ >>> tokenizer_dec = AutoTokenizer.from_pretrained("patrickvonplaten/xprophetnet-large-uncased-standalone")
+ >>> model = EncoderDecoderModel.from_encoder_decoder_pretrained(
+ ... "google-bert/bert-large-uncased", "patrickvonplaten/xprophetnet-large-uncased-standalone"
+ ... )
+
+ >>> ARTICLE = (
+ ... "the us state department said wednesday it had received no "
+ ... "formal word from bolivia that it was expelling the us ambassador there "
+ ... "but said the charges made against him are `` baseless ."
+ ... )
+ >>> input_ids = tokenizer_enc(ARTICLE, return_tensors="pt").input_ids
+ >>> labels = tokenizer_dec(
+ ... "us rejects charges against its ambassador in bolivia", return_tensors="pt"
+ ... ).input_ids
+ >>> outputs = model(input_ids=input_ids, decoder_input_ids=labels[:, :-1], labels=labels[:, 1:])
+
+ >>> loss = outputs.loss
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn)
+ outputs = self.prophetnet.decoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ head_mask=head_mask,
+ cross_attn_head_mask=cross_attn_head_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ batch_size, sequence_length = input_ids.shape if input_ids is not None else inputs_embeds.shape[:2]
+
+ predicting_streams = outputs[1].view(batch_size, self.config.ngram, sequence_length, -1)
+ predict_logits = self.lm_head(predicting_streams)
+
+ logits = predict_logits[:, 0]
+ logits_ngram = predict_logits[:, 1:] if self.config.ngram > 1 else None
+
+ loss = None
+ if labels is not None:
+ loss = self._compute_loss(predict_logits, labels)
+
+ if not return_dict:
+ all_logits = tuple(v for v in [logits, logits_ngram] if v is not None)
+ return (loss,) + all_logits + outputs[2:] if loss is not None else all_logits + outputs[2:]
+ else:
+ return XLMProphetNetDecoderLMOutput(
+ loss=loss,
+ logits=logits,
+ logits_ngram=logits_ngram,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ hidden_states_ngram=outputs.hidden_states_ngram,
+ attentions=outputs.attentions,
+ ngram_attentions=outputs.ngram_attentions,
+ cross_attentions=outputs.cross_attentions,
+ )
+
+ def _compute_loss(self, logits, labels, ignore_index=-100):
+ expend_targets = labels.new_zeros(self.config.ngram, labels.size(0), labels.size(1)).fill_(ignore_index)
+
+ for i in range(self.config.ngram):
+ if i > 0 and self.disable_ngram_loss:
+ break
+ expend_targets[i, :, :] = labels
+
+ logits = logits.transpose(0, 1).contiguous()
+ lprobs = nn.functional.log_softmax(
+ logits.view(-1, logits.size(-1)),
+ dim=-1,
+ dtype=torch.float32,
+ )
+
+ loss = nn.functional.nll_loss(lprobs, expend_targets.view(-1), reduction="mean")
+
+ if self.config.eps > 0.0:
+ smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
+ non_masked_tokens = expend_targets.ne(ignore_index).view(-1)
+ smooth_loss = smooth_loss[non_masked_tokens]
+ smooth_loss = smooth_loss.mean()
+
+ eps_i = self.config.eps / lprobs.size(-1)
+ loss = (1.0 - self.config.eps) * loss + eps_i * smooth_loss
+
+ return loss
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ attention_mask=None,
+ head_mask=None,
+ use_cache=None,
+ **kwargs,
+ ):
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
+ if attention_mask is None:
+ attention_mask = input_ids.new_ones(input_ids.shape)
+
+ if past_key_values:
+ input_ids = input_ids[:, -1:]
+ # first step, decoder_cached_states are empty
+ return {
+ "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
+ "attention_mask": attention_mask,
+ "head_mask": head_mask,
+ "past_key_values": past_key_values,
+ "use_cache": use_cache,
+ }
+
+ @staticmethod
+ def _reorder_cache(past_key_values, beam_idx):
+ reordered_past = ()
+ for layer_past in past_key_values:
+ reordered_past += (
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
+ )
+ return reordered_past
+
+
+class XLMProphetNetDecoderWrapper(XLMProphetNetPreTrainedModel):
+ """
+ This is a wrapper class, so that [`XLMProphetNetForCausalLM`] can correctly be loaded from pretrained XLMProphetNet
+ classes.
+ """
+
+ def __init__(self, config: XLMProphetNetConfig):
+ super().__init__(config)
+
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
+ self.decoder = XLMProphetNetDecoder(config, word_embeddings=self.word_embeddings)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def _tie_weights(self):
+ self._tie_or_clone_weights(self.word_embeddings, self.decoder.get_input_embeddings())
+
+ def forward(self, *args, **kwargs):
+ return self.decoder(*args, **kwargs)
+
+
+__all__ = [
+ "XLMProphetNetDecoder",
+ "XLMProphetNetEncoder",
+ "XLMProphetNetForCausalLM",
+ "XLMProphetNetForConditionalGeneration",
+ "XLMProphetNetModel",
+ "XLMProphetNetPreTrainedModel",
+]
diff --git a/docs/transformers/build/lib/transformers/models/deprecated/xlm_prophetnet/tokenization_xlm_prophetnet.py b/docs/transformers/build/lib/transformers/models/deprecated/xlm_prophetnet/tokenization_xlm_prophetnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a5da12859f29e9e3cabe354c63d25c584edbd49
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/deprecated/xlm_prophetnet/tokenization_xlm_prophetnet.py
@@ -0,0 +1,326 @@
+# coding=utf-8
+# Copyright 2020 The Microsoft Authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import collections
+import os
+from shutil import copyfile
+from typing import Any, Dict, List, Optional, Tuple
+
+from ....tokenization_utils import PreTrainedTokenizer
+from ....utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+SPIECE_UNDERLINE = "▁"
+
+VOCAB_FILES_NAMES = {"vocab_file": "prophetnet.tokenizer"}
+
+
+def load_vocab(vocab_file):
+ """Loads a vocabulary file into a dictionary."""
+ vocab = collections.OrderedDict()
+ with open(vocab_file, "r", encoding="utf-8") as reader:
+ tokens = reader.readlines()
+ for index, token in enumerate(tokens):
+ token = token.rstrip("\n")
+ vocab[token] = index
+ return vocab
+
+
+class XLMProphetNetTokenizer(PreTrainedTokenizer):
+ """
+ Adapted from [`RobertaTokenizer`] and [`XLNetTokenizer`]. Based on
+ [SentencePiece](https://github.com/google/sentencepiece).
+
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
+ this superclass for more information regarding those methods.
+
+ Args:
+ vocab_file (`str`):
+ Path to the vocabulary file.
+ bos_token (`str`, *optional*, defaults to `"[SEP]"`):
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
+
+
+
+ When building a sequence using special tokens, this is not the token that is used for the beginning of
+ sequence. The token used is the `cls_token`.
+
+
+
+ eos_token (`str`, *optional*, defaults to `"[SEP]"`):
+ The end of sequence token.
+
+
+
+ When building a sequence using special tokens, this is not the token that is used for the end of sequence.
+ The token used is the `sep_token`.
+
+
+
+ sep_token (`str`, *optional*, defaults to `"[SEP]"`):
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
+ sequence classification or for a text and a question for question answering. It is also used as the last
+ token of a sequence built with special tokens.
+ unk_token (`str`, *optional*, defaults to `"[UNK]"`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ pad_token (`str`, *optional*, defaults to `"[PAD]"`):
+ The token used for padding, for example when batching sequences of different lengths.
+ cls_token (`str`, *optional*, defaults to `"[CLS]"`):
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
+ mask_token (`str`, *optional*, defaults to `"[MASK]"`):
+ The token used for masking values. This is the token used when training this model with masked language
+ modeling. This is the token which the model will try to predict.
+ sp_model_kwargs (`dict`, *optional*):
+ Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
+ SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
+ to set:
+
+ - `enable_sampling`: Enable subword regularization.
+ - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
+
+ - `nbest_size = {0,1}`: No sampling is performed.
+ - `nbest_size > 1`: samples from the nbest_size results.
+ - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
+ using forward-filtering-and-backward-sampling algorithm.
+
+ - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
+ BPE-dropout.
+
+ Attributes:
+ sp_model (`SentencePieceProcessor`):
+ The *SentencePiece* processor that is used for every conversion (string, tokens and IDs).
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ model_input_names = ["input_ids", "attention_mask"]
+
+ def __init__(
+ self,
+ vocab_file,
+ bos_token="[SEP]",
+ eos_token="[SEP]",
+ sep_token="[SEP]",
+ unk_token="[UNK]",
+ pad_token="[PAD]",
+ cls_token="[CLS]",
+ mask_token="[MASK]",
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
+ **kwargs,
+ ) -> None:
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
+
+ try:
+ import sentencepiece as spm
+ except ImportError:
+ logger.warning(
+ "You need to install SentencePiece to use XLMRobertaTokenizer: https://github.com/google/sentencepiece"
+ " pip install sentencepiece"
+ )
+ raise
+
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
+ self.sp_model.Load(str(vocab_file))
+ self.vocab_file = vocab_file
+
+ # Original fairseq vocab and spm vocab must be "aligned":
+ # Vocab | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9
+ # -------- | ------- | ------- | ------ | ------- | --- | --- | --- | ----- | ----- | ----
+ # fairseq | '' | '' | '' | '' | ',' | '.' | '▁' | 's' | '▁de' | '-'
+ # spm | '' | '' | '' | ',' | '.' | '▁' | 's' | '▁de' | '-' | '▁a'
+
+ # put special tokens and [unused] tokens into the vocab
+ self.fairseq_tokens_to_ids = {"[PAD]": 0, "[CLS]": 1, "[SEP]": 2, "[UNK]": 3, "[MASK]": 4}
+
+ for i in range(10):
+ tok = f"[unused{i}]"
+ self.fairseq_tokens_to_ids[tok] = 5 + i
+
+ # The first "real" token "," has position 15 in the embedding vocab and position 3 in the spm vocab
+ self.fairseq_offset = 12
+ self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}
+
+ # TODO ArthurZ fairseq_ids_to_tokens should be removed
+
+ super().__init__(
+ bos_token=bos_token,
+ eos_token=eos_token,
+ sep_token=sep_token,
+ unk_token=unk_token,
+ pad_token=pad_token,
+ cls_token=cls_token,
+ mask_token=mask_token,
+ sp_model_kwargs=self.sp_model_kwargs,
+ **kwargs,
+ )
+
+ @property
+ def can_save_slow_tokenizer(self) -> bool:
+ return os.path.isfile(self.vocab_file) if self.vocab_file else False
+
+ def __getstate__(self):
+ state = self.__dict__.copy()
+ state["sp_model"] = None
+ return state
+
+ def __setstate__(self, d):
+ self.__dict__ = d
+ try:
+ import sentencepiece as spm
+ except ImportError:
+ logger.warning(
+ "You need to install SentencePiece to use XLMRobertaTokenizer: https://github.com/google/sentencepiece"
+ " pip install sentencepiece"
+ )
+ raise
+
+ # for backward compatibility
+ if not hasattr(self, "sp_model_kwargs"):
+ self.sp_model_kwargs = {}
+
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
+ self.sp_model.Load(self.vocab_file)
+
+ def get_special_tokens_mask(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
+ ) -> List[int]:
+ """
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
+ special tokens using the tokenizer `prepare_for_model` method.
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
+ Whether or not the token list is already formatted with special tokens for the model.
+
+ Returns:
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
+ """
+
+ if already_has_special_tokens:
+ return super().get_special_tokens_mask(
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
+ )
+
+ if token_ids_1 is None:
+ return ([0] * len(token_ids_0)) + [1]
+ return ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
+
+ def create_token_type_ids_from_sequences(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. XLMProphetNet
+ does not make use of token type ids, therefore a list of zeros is returned.
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `List[int]`: List of zeros.
+
+ """
+
+ sep = [self.sep_token_id]
+
+ if token_ids_1 is None:
+ return len(token_ids_0 + sep) * [0]
+ return len(token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
+
+ @property
+ def vocab_size(self):
+ return len(self.sp_model) + self.fairseq_offset
+
+ def get_vocab(self):
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
+ vocab.update(self.added_tokens_encoder)
+ return vocab
+
+ def _tokenize(self, text: str) -> str:
+ return self.sp_model.encode(text, out_type=str)
+
+ def _convert_token_to_id(self, token):
+ """Converts a token (str) in an id using the vocab."""
+ if token in self.fairseq_tokens_to_ids:
+ return self.fairseq_tokens_to_ids[token]
+ spm_id = self.sp_model.PieceToId(token)
+
+ # Need to return unknown token if the SP model returned 0
+ return spm_id + self.fairseq_offset if spm_id else self.unk_token_id
+
+ def _convert_id_to_token(self, index):
+ """Converts an index (integer) in a token (str) using the vocab."""
+ if index in self.fairseq_ids_to_tokens:
+ return self.fairseq_ids_to_tokens[index]
+ return self.sp_model.IdToPiece(index - self.fairseq_offset)
+
+ def convert_tokens_to_string(self, tokens):
+ """Converts a sequence of tokens (strings for sub-words) in a single string."""
+ out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
+ return out_string
+
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
+ if not os.path.isdir(save_directory):
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
+ return
+ out_vocab_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
+ )
+
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
+ copyfile(self.vocab_file, out_vocab_file)
+ elif not os.path.isfile(self.vocab_file):
+ with open(out_vocab_file, "wb") as fi:
+ content_spiece_model = self.sp_model.serialized_model_proto()
+ fi.write(content_spiece_model)
+
+ return (out_vocab_file,)
+
+ def build_inputs_with_special_tokens(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
+ adding special tokens. A XLMProphetNet sequence has the following format:
+
+ - single sequence: `X [SEP]`
+ - pair of sequences: `A [SEP] B [SEP]`
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs to which the special tokens will be added
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `List[int]`: list of [input IDs](../glossary#input-ids) with the appropriate special tokens.
+ """
+
+ if token_ids_1 is None:
+ return token_ids_0 + [self.sep_token_id]
+ sep = [self.sep_token_id]
+ return token_ids_0 + sep + token_ids_1 + sep
+
+
+__all__ = ["XLMProphetNetTokenizer"]
diff --git a/docs/transformers/build/lib/transformers/models/depth_anything/__init__.py b/docs/transformers/build/lib/transformers/models/depth_anything/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7425e37e0399c792155a88488045176fb3b5e7a5
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/depth_anything/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_depth_anything import *
+ from .modeling_depth_anything import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/docs/transformers/build/lib/transformers/models/depth_anything/configuration_depth_anything.py b/docs/transformers/build/lib/transformers/models/depth_anything/configuration_depth_anything.py
new file mode 100644
index 0000000000000000000000000000000000000000..3bbe621a44310c2b217fe00fd2fba653de3489d9
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/depth_anything/configuration_depth_anything.py
@@ -0,0 +1,168 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""DepthAnything model configuration"""
+
+import copy
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+from ...utils.backbone_utils import verify_backbone_config_arguments
+from ..auto.configuration_auto import CONFIG_MAPPING
+
+
+logger = logging.get_logger(__name__)
+
+
+class DepthAnythingConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`DepthAnythingModel`]. It is used to instantiate a DepthAnything
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the DepthAnything
+ [LiheYoung/depth-anything-small-hf](https://huggingface.co/LiheYoung/depth-anything-small-hf) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ backbone_config (`Union[Dict[str, Any], PretrainedConfig]`, *optional*):
+ The configuration of the backbone model. Only used in case `is_hybrid` is `True` or in case you want to
+ leverage the [`AutoBackbone`] API.
+ backbone (`str`, *optional*):
+ Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
+ will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
+ is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
+ use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
+ Whether to use pretrained weights for the backbone.
+ use_timm_backbone (`bool`, *optional*, defaults to `False`):
+ Whether or not to use the `timm` library for the backbone. If set to `False`, will use the [`AutoBackbone`]
+ API.
+ backbone_kwargs (`dict`, *optional*):
+ Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
+ e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
+ patch_size (`int`, *optional*, defaults to 14):
+ The size of the patches to extract from the backbone features.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ reassemble_hidden_size (`int`, *optional*, defaults to 384):
+ The number of input channels of the reassemble layers.
+ reassemble_factors (`List[int]`, *optional*, defaults to `[4, 2, 1, 0.5]`):
+ The up/downsampling factors of the reassemble layers.
+ neck_hidden_sizes (`List[str]`, *optional*, defaults to `[48, 96, 192, 384]`):
+ The hidden sizes to project to for the feature maps of the backbone.
+ fusion_hidden_size (`int`, *optional*, defaults to 64):
+ The number of channels before fusion.
+ head_in_index (`int`, *optional*, defaults to -1):
+ The index of the features to use in the depth estimation head.
+ head_hidden_size (`int`, *optional*, defaults to 32):
+ The number of output channels in the second convolution of the depth estimation head.
+ depth_estimation_type (`str`, *optional*, defaults to `"relative"`):
+ The type of depth estimation to use. Can be one of `["relative", "metric"]`.
+ max_depth (`float`, *optional*):
+ The maximum depth to use for the "metric" depth estimation head. 20 should be used for indoor models
+ and 80 for outdoor models. For "relative" depth estimation, this value is ignored.
+
+ Example:
+
+ ```python
+ >>> from transformers import DepthAnythingConfig, DepthAnythingForDepthEstimation
+
+ >>> # Initializing a DepthAnything small style configuration
+ >>> configuration = DepthAnythingConfig()
+
+ >>> # Initializing a model from the DepthAnything small style configuration
+ >>> model = DepthAnythingForDepthEstimation(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "depth_anything"
+
+ def __init__(
+ self,
+ backbone_config=None,
+ backbone=None,
+ use_pretrained_backbone=False,
+ use_timm_backbone=False,
+ backbone_kwargs=None,
+ patch_size=14,
+ initializer_range=0.02,
+ reassemble_hidden_size=384,
+ reassemble_factors=[4, 2, 1, 0.5],
+ neck_hidden_sizes=[48, 96, 192, 384],
+ fusion_hidden_size=64,
+ head_in_index=-1,
+ head_hidden_size=32,
+ depth_estimation_type="relative",
+ max_depth=None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ if backbone_config is None and backbone is None:
+ logger.info("`backbone_config` is `None`. Initializing the config with the default `Dinov2` backbone.")
+ backbone_config = CONFIG_MAPPING["dinov2"](
+ image_size=518,
+ hidden_size=384,
+ num_attention_heads=6,
+ out_indices=[9, 10, 11, 12],
+ apply_layernorm=True,
+ reshape_hidden_states=False,
+ )
+ elif isinstance(backbone_config, dict):
+ backbone_model_type = backbone_config.get("model_type")
+ config_class = CONFIG_MAPPING[backbone_model_type]
+ backbone_config = config_class.from_dict(backbone_config)
+
+ verify_backbone_config_arguments(
+ use_timm_backbone=use_timm_backbone,
+ use_pretrained_backbone=use_pretrained_backbone,
+ backbone=backbone,
+ backbone_config=backbone_config,
+ backbone_kwargs=backbone_kwargs,
+ )
+
+ self.backbone_config = backbone_config
+ self.backbone = backbone
+ self.use_pretrained_backbone = use_pretrained_backbone
+ self.use_timm_backbone = use_timm_backbone
+ self.backbone_kwargs = backbone_kwargs
+ self.reassemble_hidden_size = reassemble_hidden_size
+ self.patch_size = patch_size
+ self.initializer_range = initializer_range
+ self.reassemble_factors = reassemble_factors
+ self.neck_hidden_sizes = neck_hidden_sizes
+ self.fusion_hidden_size = fusion_hidden_size
+ self.head_in_index = head_in_index
+ self.head_hidden_size = head_hidden_size
+ if depth_estimation_type not in ["relative", "metric"]:
+ raise ValueError("depth_estimation_type must be one of ['relative', 'metric']")
+ self.depth_estimation_type = depth_estimation_type
+ self.max_depth = max_depth if max_depth else 1
+
+ def to_dict(self):
+ """
+ Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. Returns:
+ `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
+ """
+ output = copy.deepcopy(self.__dict__)
+
+ if output["backbone_config"] is not None:
+ output["backbone_config"] = self.backbone_config.to_dict()
+
+ output["model_type"] = self.__class__.model_type
+ return output
+
+
+__all__ = ["DepthAnythingConfig"]
diff --git a/docs/transformers/build/lib/transformers/models/detr/image_processing_detr.py b/docs/transformers/build/lib/transformers/models/detr/image_processing_detr.py
new file mode 100644
index 0000000000000000000000000000000000000000..75d7e74adde07d32393fe51673c5577d221bf148
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/detr/image_processing_detr.py
@@ -0,0 +1,2048 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Image processor class for DETR."""
+
+import io
+import pathlib
+from collections import defaultdict
+from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union
+
+import numpy as np
+
+from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
+from ...image_transforms import (
+ PaddingMode,
+ center_to_corners_format,
+ corners_to_center_format,
+ id_to_rgb,
+ pad,
+ rescale,
+ resize,
+ rgb_to_id,
+ to_channel_dimension_format,
+)
+from ...image_utils import (
+ IMAGENET_DEFAULT_MEAN,
+ IMAGENET_DEFAULT_STD,
+ AnnotationFormat,
+ AnnotationType,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ get_image_size,
+ infer_channel_dimension_format,
+ is_scaled_image,
+ make_list_of_images,
+ to_numpy_array,
+ valid_images,
+ validate_annotations,
+ validate_kwargs,
+ validate_preprocess_arguments,
+)
+from ...utils import (
+ TensorType,
+ is_flax_available,
+ is_jax_tensor,
+ is_scipy_available,
+ is_tf_available,
+ is_tf_tensor,
+ is_torch_available,
+ is_torch_tensor,
+ is_vision_available,
+ logging,
+)
+from ...utils.import_utils import requires
+
+
+if is_torch_available():
+ import torch
+ from torch import nn
+
+
+if is_vision_available():
+ import PIL
+
+
+if is_scipy_available():
+ import scipy.special
+ import scipy.stats
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+SUPPORTED_ANNOTATION_FORMATS = (AnnotationFormat.COCO_DETECTION, AnnotationFormat.COCO_PANOPTIC)
+
+
+# From the original repo: https://github.com/facebookresearch/detr/blob/3af9fa878e73b6894ce3596450a8d9b89d918ca9/datasets/transforms.py#L76
+def get_size_with_aspect_ratio(image_size, size, max_size=None) -> Tuple[int, int]:
+ """
+ Computes the output image size given the input image size and the desired output size.
+
+ Args:
+ image_size (`Tuple[int, int]`):
+ The input image size.
+ size (`int`):
+ The desired output size.
+ max_size (`int`, *optional*):
+ The maximum allowed output size.
+ """
+ height, width = image_size
+ raw_size = None
+ if max_size is not None:
+ min_original_size = float(min((height, width)))
+ max_original_size = float(max((height, width)))
+ if max_original_size / min_original_size * size > max_size:
+ raw_size = max_size * min_original_size / max_original_size
+ size = int(round(raw_size))
+
+ if (height <= width and height == size) or (width <= height and width == size):
+ oh, ow = height, width
+ elif width < height:
+ ow = size
+ if max_size is not None and raw_size is not None:
+ oh = int(raw_size * height / width)
+ else:
+ oh = int(size * height / width)
+ else:
+ oh = size
+ if max_size is not None and raw_size is not None:
+ ow = int(raw_size * width / height)
+ else:
+ ow = int(size * width / height)
+
+ return (oh, ow)
+
+
+def get_image_size_for_max_height_width(
+ input_image: np.ndarray,
+ max_height: int,
+ max_width: int,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+) -> Tuple[int, int]:
+ """
+ Computes the output image size given the input image and the maximum allowed height and width. Keep aspect ratio.
+ Important, even if image_height < max_height and image_width < max_width, the image will be resized
+ to at least one of the edges be equal to max_height or max_width.
+
+ For example:
+ - input_size: (100, 200), max_height: 50, max_width: 50 -> output_size: (25, 50)
+ - input_size: (100, 200), max_height: 200, max_width: 500 -> output_size: (200, 400)
+
+ Args:
+ input_image (`np.ndarray`):
+ The image to resize.
+ max_height (`int`):
+ The maximum allowed height.
+ max_width (`int`):
+ The maximum allowed width.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred from the input image.
+ """
+ image_size = get_image_size(input_image, input_data_format)
+ height, width = image_size
+ height_scale = max_height / height
+ width_scale = max_width / width
+ min_scale = min(height_scale, width_scale)
+ new_height = int(height * min_scale)
+ new_width = int(width * min_scale)
+ return new_height, new_width
+
+
+def get_resize_output_image_size(
+ input_image: np.ndarray,
+ size: Union[int, Tuple[int, int], List[int]],
+ max_size: Optional[int] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+) -> Tuple[int, int]:
+ """
+ Computes the output image size given the input image size and the desired output size. If the desired output size
+ is a tuple or list, the output image size is returned as is. If the desired output size is an integer, the output
+ image size is computed by keeping the aspect ratio of the input image size.
+
+ Args:
+ input_image (`np.ndarray`):
+ The image to resize.
+ size (`int` or `Tuple[int, int]` or `List[int]`):
+ The desired output size.
+ max_size (`int`, *optional*):
+ The maximum allowed output size.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred from the input image.
+ """
+ image_size = get_image_size(input_image, input_data_format)
+ if isinstance(size, (list, tuple)):
+ return size
+
+ return get_size_with_aspect_ratio(image_size, size, max_size)
+
+
+def get_numpy_to_framework_fn(arr) -> Callable:
+ """
+ Returns a function that converts a numpy array to the framework of the input array.
+
+ Args:
+ arr (`np.ndarray`): The array to convert.
+ """
+ if isinstance(arr, np.ndarray):
+ return np.array
+ if is_tf_available() and is_tf_tensor(arr):
+ import tensorflow as tf
+
+ return tf.convert_to_tensor
+ if is_torch_available() and is_torch_tensor(arr):
+ import torch
+
+ return torch.tensor
+ if is_flax_available() and is_jax_tensor(arr):
+ import jax.numpy as jnp
+
+ return jnp.array
+ raise ValueError(f"Cannot convert arrays of type {type(arr)}")
+
+
+def safe_squeeze(arr: np.ndarray, axis: Optional[int] = None) -> np.ndarray:
+ """
+ Squeezes an array, but only if the axis specified has dim 1.
+ """
+ if axis is None:
+ return arr.squeeze()
+
+ try:
+ return arr.squeeze(axis=axis)
+ except ValueError:
+ return arr
+
+
+def normalize_annotation(annotation: Dict, image_size: Tuple[int, int]) -> Dict:
+ image_height, image_width = image_size
+ norm_annotation = {}
+ for key, value in annotation.items():
+ if key == "boxes":
+ boxes = value
+ boxes = corners_to_center_format(boxes)
+ boxes /= np.asarray([image_width, image_height, image_width, image_height], dtype=np.float32)
+ norm_annotation[key] = boxes
+ else:
+ norm_annotation[key] = value
+ return norm_annotation
+
+
+# Copied from transformers.models.vilt.image_processing_vilt.max_across_indices
+def max_across_indices(values: Iterable[Any]) -> List[Any]:
+ """
+ Return the maximum value across all indices of an iterable of values.
+ """
+ return [max(values_i) for values_i in zip(*values)]
+
+
+# Copied from transformers.models.vilt.image_processing_vilt.get_max_height_width
+def get_max_height_width(
+ images: List[np.ndarray], input_data_format: Optional[Union[str, ChannelDimension]] = None
+) -> List[int]:
+ """
+ Get the maximum height and width across all images in a batch.
+ """
+ if input_data_format is None:
+ input_data_format = infer_channel_dimension_format(images[0])
+
+ if input_data_format == ChannelDimension.FIRST:
+ _, max_height, max_width = max_across_indices([img.shape for img in images])
+ elif input_data_format == ChannelDimension.LAST:
+ max_height, max_width, _ = max_across_indices([img.shape for img in images])
+ else:
+ raise ValueError(f"Invalid channel dimension format: {input_data_format}")
+ return (max_height, max_width)
+
+
+# Copied from transformers.models.vilt.image_processing_vilt.make_pixel_mask
+def make_pixel_mask(
+ image: np.ndarray, output_size: Tuple[int, int], input_data_format: Optional[Union[str, ChannelDimension]] = None
+) -> np.ndarray:
+ """
+ Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
+
+ Args:
+ image (`np.ndarray`):
+ Image to make the pixel mask for.
+ output_size (`Tuple[int, int]`):
+ Output size of the mask.
+ """
+ input_height, input_width = get_image_size(image, channel_dim=input_data_format)
+ mask = np.zeros(output_size, dtype=np.int64)
+ mask[:input_height, :input_width] = 1
+ return mask
+
+
+# inspired by https://github.com/facebookresearch/detr/blob/master/datasets/coco.py#L33
+def convert_coco_poly_to_mask(segmentations, height: int, width: int) -> np.ndarray:
+ """
+ Convert a COCO polygon annotation to a mask.
+
+ Args:
+ segmentations (`List[List[float]]`):
+ List of polygons, each polygon represented by a list of x-y coordinates.
+ height (`int`):
+ Height of the mask.
+ width (`int`):
+ Width of the mask.
+ """
+ try:
+ from pycocotools import mask as coco_mask
+ except ImportError:
+ raise ImportError("Pycocotools is not installed in your environment.")
+
+ masks = []
+ for polygons in segmentations:
+ rles = coco_mask.frPyObjects(polygons, height, width)
+ mask = coco_mask.decode(rles)
+ if len(mask.shape) < 3:
+ mask = mask[..., None]
+ mask = np.asarray(mask, dtype=np.uint8)
+ mask = np.any(mask, axis=2)
+ masks.append(mask)
+ if masks:
+ masks = np.stack(masks, axis=0)
+ else:
+ masks = np.zeros((0, height, width), dtype=np.uint8)
+
+ return masks
+
+
+# inspired by https://github.com/facebookresearch/detr/blob/master/datasets/coco.py#L50
+def prepare_coco_detection_annotation(
+ image,
+ target,
+ return_segmentation_masks: bool = False,
+ input_data_format: Optional[Union[ChannelDimension, str]] = None,
+):
+ """
+ Convert the target in COCO format into the format expected by DETR.
+ """
+ image_height, image_width = get_image_size(image, channel_dim=input_data_format)
+
+ image_id = target["image_id"]
+ image_id = np.asarray([image_id], dtype=np.int64)
+
+ # Get all COCO annotations for the given image.
+ annotations = target["annotations"]
+ annotations = [obj for obj in annotations if "iscrowd" not in obj or obj["iscrowd"] == 0]
+
+ classes = [obj["category_id"] for obj in annotations]
+ classes = np.asarray(classes, dtype=np.int64)
+
+ # for conversion to coco api
+ area = np.asarray([obj["area"] for obj in annotations], dtype=np.float32)
+ iscrowd = np.asarray([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in annotations], dtype=np.int64)
+
+ boxes = [obj["bbox"] for obj in annotations]
+ # guard against no boxes via resizing
+ boxes = np.asarray(boxes, dtype=np.float32).reshape(-1, 4)
+ boxes[:, 2:] += boxes[:, :2]
+ boxes[:, 0::2] = boxes[:, 0::2].clip(min=0, max=image_width)
+ boxes[:, 1::2] = boxes[:, 1::2].clip(min=0, max=image_height)
+
+ keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
+
+ new_target = {}
+ new_target["image_id"] = image_id
+ new_target["class_labels"] = classes[keep]
+ new_target["boxes"] = boxes[keep]
+ new_target["area"] = area[keep]
+ new_target["iscrowd"] = iscrowd[keep]
+ new_target["orig_size"] = np.asarray([int(image_height), int(image_width)], dtype=np.int64)
+
+ if annotations and "keypoints" in annotations[0]:
+ keypoints = [obj["keypoints"] for obj in annotations]
+ # Converting the filtered keypoints list to a numpy array
+ keypoints = np.asarray(keypoints, dtype=np.float32)
+ # Apply the keep mask here to filter the relevant annotations
+ keypoints = keypoints[keep]
+ num_keypoints = keypoints.shape[0]
+ keypoints = keypoints.reshape((-1, 3)) if num_keypoints else keypoints
+ new_target["keypoints"] = keypoints
+
+ if return_segmentation_masks:
+ segmentation_masks = [obj["segmentation"] for obj in annotations]
+ masks = convert_coco_poly_to_mask(segmentation_masks, image_height, image_width)
+ new_target["masks"] = masks[keep]
+
+ return new_target
+
+
+def masks_to_boxes(masks: np.ndarray) -> np.ndarray:
+ """
+ Compute the bounding boxes around the provided panoptic segmentation masks.
+
+ Args:
+ masks: masks in format `[number_masks, height, width]` where N is the number of masks
+
+ Returns:
+ boxes: bounding boxes in format `[number_masks, 4]` in xyxy format
+ """
+ if masks.size == 0:
+ return np.zeros((0, 4))
+
+ h, w = masks.shape[-2:]
+ y = np.arange(0, h, dtype=np.float32)
+ x = np.arange(0, w, dtype=np.float32)
+ # see https://github.com/pytorch/pytorch/issues/50276
+ y, x = np.meshgrid(y, x, indexing="ij")
+
+ x_mask = masks * np.expand_dims(x, axis=0)
+ x_max = x_mask.reshape(x_mask.shape[0], -1).max(-1)
+ x = np.ma.array(x_mask, mask=~(np.array(masks, dtype=bool)))
+ x_min = x.filled(fill_value=1e8)
+ x_min = x_min.reshape(x_min.shape[0], -1).min(-1)
+
+ y_mask = masks * np.expand_dims(y, axis=0)
+ y_max = y_mask.reshape(x_mask.shape[0], -1).max(-1)
+ y = np.ma.array(y_mask, mask=~(np.array(masks, dtype=bool)))
+ y_min = y.filled(fill_value=1e8)
+ y_min = y_min.reshape(y_min.shape[0], -1).min(-1)
+
+ return np.stack([x_min, y_min, x_max, y_max], 1)
+
+
+def prepare_coco_panoptic_annotation(
+ image: np.ndarray,
+ target: Dict,
+ masks_path: Union[str, pathlib.Path],
+ return_masks: bool = True,
+ input_data_format: Union[ChannelDimension, str] = None,
+) -> Dict:
+ """
+ Prepare a coco panoptic annotation for DETR.
+ """
+ image_height, image_width = get_image_size(image, channel_dim=input_data_format)
+ annotation_path = pathlib.Path(masks_path) / target["file_name"]
+
+ new_target = {}
+ new_target["image_id"] = np.asarray([target["image_id"] if "image_id" in target else target["id"]], dtype=np.int64)
+ new_target["size"] = np.asarray([image_height, image_width], dtype=np.int64)
+ new_target["orig_size"] = np.asarray([image_height, image_width], dtype=np.int64)
+
+ if "segments_info" in target:
+ masks = np.asarray(PIL.Image.open(annotation_path), dtype=np.uint32)
+ masks = rgb_to_id(masks)
+
+ ids = np.array([segment_info["id"] for segment_info in target["segments_info"]])
+ masks = masks == ids[:, None, None]
+ masks = masks.astype(np.uint8)
+ if return_masks:
+ new_target["masks"] = masks
+ new_target["boxes"] = masks_to_boxes(masks)
+ new_target["class_labels"] = np.array(
+ [segment_info["category_id"] for segment_info in target["segments_info"]], dtype=np.int64
+ )
+ new_target["iscrowd"] = np.asarray(
+ [segment_info["iscrowd"] for segment_info in target["segments_info"]], dtype=np.int64
+ )
+ new_target["area"] = np.asarray(
+ [segment_info["area"] for segment_info in target["segments_info"]], dtype=np.float32
+ )
+
+ return new_target
+
+
+def get_segmentation_image(
+ masks: np.ndarray, input_size: Tuple, target_size: Tuple, stuff_equiv_classes, deduplicate=False
+):
+ h, w = input_size
+ final_h, final_w = target_size
+
+ m_id = scipy.special.softmax(masks.transpose(0, 1), -1)
+
+ if m_id.shape[-1] == 0:
+ # We didn't detect any mask :(
+ m_id = np.zeros((h, w), dtype=np.int64)
+ else:
+ m_id = m_id.argmax(-1).reshape(h, w)
+
+ if deduplicate:
+ # Merge the masks corresponding to the same stuff class
+ for equiv in stuff_equiv_classes.values():
+ for eq_id in equiv:
+ m_id[m_id == eq_id] = equiv[0]
+
+ seg_img = id_to_rgb(m_id)
+ seg_img = resize(seg_img, (final_w, final_h), resample=PILImageResampling.NEAREST)
+ return seg_img
+
+
+def get_mask_area(seg_img: np.ndarray, target_size: Tuple[int, int], n_classes: int) -> np.ndarray:
+ final_h, final_w = target_size
+ np_seg_img = seg_img.astype(np.uint8)
+ np_seg_img = np_seg_img.reshape(final_h, final_w, 3)
+ m_id = rgb_to_id(np_seg_img)
+ area = [(m_id == i).sum() for i in range(n_classes)]
+ return area
+
+
+def score_labels_from_class_probabilities(logits: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
+ probs = scipy.special.softmax(logits, axis=-1)
+ labels = probs.argmax(-1, keepdims=True)
+ scores = np.take_along_axis(probs, labels, axis=-1)
+ scores, labels = scores.squeeze(-1), labels.squeeze(-1)
+ return scores, labels
+
+
+def post_process_panoptic_sample(
+ out_logits: np.ndarray,
+ masks: np.ndarray,
+ boxes: np.ndarray,
+ processed_size: Tuple[int, int],
+ target_size: Tuple[int, int],
+ is_thing_map: Dict,
+ threshold=0.85,
+) -> Dict:
+ """
+ Converts the output of [`DetrForSegmentation`] into panoptic segmentation predictions for a single sample.
+
+ Args:
+ out_logits (`torch.Tensor`):
+ The logits for this sample.
+ masks (`torch.Tensor`):
+ The predicted segmentation masks for this sample.
+ boxes (`torch.Tensor`):
+ The prediced bounding boxes for this sample. The boxes are in the normalized format `(center_x, center_y,
+ width, height)` and values between `[0, 1]`, relative to the size the image (disregarding padding).
+ processed_size (`Tuple[int, int]`):
+ The processed size of the image `(height, width)`, as returned by the preprocessing step i.e. the size
+ after data augmentation but before batching.
+ target_size (`Tuple[int, int]`):
+ The target size of the image, `(height, width)` corresponding to the requested final size of the
+ prediction.
+ is_thing_map (`Dict`):
+ A dictionary mapping class indices to a boolean value indicating whether the class is a thing or not.
+ threshold (`float`, *optional*, defaults to 0.85):
+ The threshold used to binarize the segmentation masks.
+ """
+ # we filter empty queries and detection below threshold
+ scores, labels = score_labels_from_class_probabilities(out_logits)
+ keep = (labels != out_logits.shape[-1] - 1) & (scores > threshold)
+
+ cur_scores = scores[keep]
+ cur_classes = labels[keep]
+ cur_boxes = center_to_corners_format(boxes[keep])
+
+ if len(cur_boxes) != len(cur_classes):
+ raise ValueError("Not as many boxes as there are classes")
+
+ cur_masks = masks[keep]
+ cur_masks = resize(cur_masks[:, None], processed_size, resample=PILImageResampling.BILINEAR)
+ cur_masks = safe_squeeze(cur_masks, 1)
+ b, h, w = cur_masks.shape
+
+ # It may be that we have several predicted masks for the same stuff class.
+ # In the following, we track the list of masks ids for each stuff class (they are merged later on)
+ cur_masks = cur_masks.reshape(b, -1)
+ stuff_equiv_classes = defaultdict(list)
+ for k, label in enumerate(cur_classes):
+ if not is_thing_map[label]:
+ stuff_equiv_classes[label].append(k)
+
+ seg_img = get_segmentation_image(cur_masks, processed_size, target_size, stuff_equiv_classes, deduplicate=True)
+ area = get_mask_area(cur_masks, processed_size, n_classes=len(cur_scores))
+
+ # We filter out any mask that is too small
+ if cur_classes.size() > 0:
+ # We know filter empty masks as long as we find some
+ filtered_small = np.array([a <= 4 for a in area], dtype=bool)
+ while filtered_small.any():
+ cur_masks = cur_masks[~filtered_small]
+ cur_scores = cur_scores[~filtered_small]
+ cur_classes = cur_classes[~filtered_small]
+ seg_img = get_segmentation_image(cur_masks, (h, w), target_size, stuff_equiv_classes, deduplicate=True)
+ area = get_mask_area(seg_img, target_size, n_classes=len(cur_scores))
+ filtered_small = np.array([a <= 4 for a in area], dtype=bool)
+ else:
+ cur_classes = np.ones((1, 1), dtype=np.int64)
+
+ segments_info = [
+ {"id": i, "isthing": is_thing_map[cat], "category_id": int(cat), "area": a}
+ for i, (cat, a) in enumerate(zip(cur_classes, area))
+ ]
+ del cur_classes
+
+ with io.BytesIO() as out:
+ PIL.Image.fromarray(seg_img).save(out, format="PNG")
+ predictions = {"png_string": out.getvalue(), "segments_info": segments_info}
+
+ return predictions
+
+
+def resize_annotation(
+ annotation: Dict[str, Any],
+ orig_size: Tuple[int, int],
+ target_size: Tuple[int, int],
+ threshold: float = 0.5,
+ resample: PILImageResampling = PILImageResampling.NEAREST,
+):
+ """
+ Resizes an annotation to a target size.
+
+ Args:
+ annotation (`Dict[str, Any]`):
+ The annotation dictionary.
+ orig_size (`Tuple[int, int]`):
+ The original size of the input image.
+ target_size (`Tuple[int, int]`):
+ The target size of the image, as returned by the preprocessing `resize` step.
+ threshold (`float`, *optional*, defaults to 0.5):
+ The threshold used to binarize the segmentation masks.
+ resample (`PILImageResampling`, defaults to `PILImageResampling.NEAREST`):
+ The resampling filter to use when resizing the masks.
+ """
+ ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(target_size, orig_size))
+ ratio_height, ratio_width = ratios
+
+ new_annotation = {}
+ new_annotation["size"] = target_size
+
+ for key, value in annotation.items():
+ if key == "boxes":
+ boxes = value
+ scaled_boxes = boxes * np.asarray([ratio_width, ratio_height, ratio_width, ratio_height], dtype=np.float32)
+ new_annotation["boxes"] = scaled_boxes
+ elif key == "area":
+ area = value
+ scaled_area = area * (ratio_width * ratio_height)
+ new_annotation["area"] = scaled_area
+ elif key == "masks":
+ masks = value[:, None]
+ masks = np.array([resize(mask, target_size, resample=resample) for mask in masks])
+ masks = masks.astype(np.float32)
+ masks = masks[:, 0] > threshold
+ new_annotation["masks"] = masks
+ elif key == "size":
+ new_annotation["size"] = target_size
+ else:
+ new_annotation[key] = value
+
+ return new_annotation
+
+
+# TODO - (Amy) make compatible with other frameworks
+def binary_mask_to_rle(mask):
+ """
+ Converts given binary mask of shape `(height, width)` to the run-length encoding (RLE) format.
+
+ Args:
+ mask (`torch.Tensor` or `numpy.array`):
+ A binary mask tensor of shape `(height, width)` where 0 denotes background and 1 denotes the target
+ segment_id or class_id.
+ Returns:
+ `List`: Run-length encoded list of the binary mask. Refer to COCO API for more information about the RLE
+ format.
+ """
+ if is_torch_tensor(mask):
+ mask = mask.numpy()
+
+ pixels = mask.flatten()
+ pixels = np.concatenate([[0], pixels, [0]])
+ runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
+ runs[1::2] -= runs[::2]
+ return list(runs)
+
+
+# TODO - (Amy) make compatible with other frameworks
+def convert_segmentation_to_rle(segmentation):
+ """
+ Converts given segmentation map of shape `(height, width)` to the run-length encoding (RLE) format.
+
+ Args:
+ segmentation (`torch.Tensor` or `numpy.array`):
+ A segmentation map of shape `(height, width)` where each value denotes a segment or class id.
+ Returns:
+ `List[List]`: A list of lists, where each list is the run-length encoding of a segment / class id.
+ """
+ segment_ids = torch.unique(segmentation)
+
+ run_length_encodings = []
+ for idx in segment_ids:
+ mask = torch.where(segmentation == idx, 1, 0)
+ rle = binary_mask_to_rle(mask)
+ run_length_encodings.append(rle)
+
+ return run_length_encodings
+
+
+def remove_low_and_no_objects(masks, scores, labels, object_mask_threshold, num_labels):
+ """
+ Binarize the given masks using `object_mask_threshold`, it returns the associated values of `masks`, `scores` and
+ `labels`.
+
+ Args:
+ masks (`torch.Tensor`):
+ A tensor of shape `(num_queries, height, width)`.
+ scores (`torch.Tensor`):
+ A tensor of shape `(num_queries)`.
+ labels (`torch.Tensor`):
+ A tensor of shape `(num_queries)`.
+ object_mask_threshold (`float`):
+ A number between 0 and 1 used to binarize the masks.
+ Raises:
+ `ValueError`: Raised when the first dimension doesn't match in all input tensors.
+ Returns:
+ `Tuple[`torch.Tensor`, `torch.Tensor`, `torch.Tensor`]`: The `masks`, `scores` and `labels` without the region
+ < `object_mask_threshold`.
+ """
+ if not (masks.shape[0] == scores.shape[0] == labels.shape[0]):
+ raise ValueError("mask, scores and labels must have the same shape!")
+
+ to_keep = labels.ne(num_labels) & (scores > object_mask_threshold)
+
+ return masks[to_keep], scores[to_keep], labels[to_keep]
+
+
+def check_segment_validity(mask_labels, mask_probs, k, mask_threshold=0.5, overlap_mask_area_threshold=0.8):
+ # Get the mask associated with the k class
+ mask_k = mask_labels == k
+ mask_k_area = mask_k.sum()
+
+ # Compute the area of all the stuff in query k
+ original_area = (mask_probs[k] >= mask_threshold).sum()
+ mask_exists = mask_k_area > 0 and original_area > 0
+
+ # Eliminate disconnected tiny segments
+ if mask_exists:
+ area_ratio = mask_k_area / original_area
+ if not area_ratio.item() > overlap_mask_area_threshold:
+ mask_exists = False
+
+ return mask_exists, mask_k
+
+
+def compute_segments(
+ mask_probs,
+ pred_scores,
+ pred_labels,
+ mask_threshold: float = 0.5,
+ overlap_mask_area_threshold: float = 0.8,
+ label_ids_to_fuse: Optional[Set[int]] = None,
+ target_size: Tuple[int, int] = None,
+):
+ height = mask_probs.shape[1] if target_size is None else target_size[0]
+ width = mask_probs.shape[2] if target_size is None else target_size[1]
+
+ segmentation = torch.zeros((height, width), dtype=torch.int32, device=mask_probs.device)
+ segments: List[Dict] = []
+
+ if target_size is not None:
+ mask_probs = nn.functional.interpolate(
+ mask_probs.unsqueeze(0), size=target_size, mode="bilinear", align_corners=False
+ )[0]
+
+ current_segment_id = 0
+
+ # Weigh each mask by its prediction score
+ mask_probs *= pred_scores.view(-1, 1, 1)
+ mask_labels = mask_probs.argmax(0) # [height, width]
+
+ # Keep track of instances of each class
+ stuff_memory_list: Dict[str, int] = {}
+ for k in range(pred_labels.shape[0]):
+ pred_class = pred_labels[k].item()
+ should_fuse = pred_class in label_ids_to_fuse
+
+ # Check if mask exists and large enough to be a segment
+ mask_exists, mask_k = check_segment_validity(
+ mask_labels, mask_probs, k, mask_threshold, overlap_mask_area_threshold
+ )
+
+ if mask_exists:
+ if pred_class in stuff_memory_list:
+ current_segment_id = stuff_memory_list[pred_class]
+ else:
+ current_segment_id += 1
+
+ # Add current object segment to final segmentation map
+ segmentation[mask_k] = current_segment_id
+ segment_score = round(pred_scores[k].item(), 6)
+ segments.append(
+ {
+ "id": current_segment_id,
+ "label_id": pred_class,
+ "was_fused": should_fuse,
+ "score": segment_score,
+ }
+ )
+ if should_fuse:
+ stuff_memory_list[pred_class] = current_segment_id
+
+ return segmentation, segments
+
+
+@requires(backends=("vision",))
+class DetrImageProcessor(BaseImageProcessor):
+ r"""
+ Constructs a Detr image processor.
+
+ Args:
+ format (`str`, *optional*, defaults to `"coco_detection"`):
+ Data format of the annotations. One of "coco_detection" or "coco_panoptic".
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Controls whether to resize the image's `(height, width)` dimensions to the specified `size`. Can be
+ overridden by the `do_resize` parameter in the `preprocess` method.
+ size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 800, "longest_edge": 1333}`):
+ Size of the image's `(height, width)` dimensions after resizing. Can be overridden by the `size` parameter
+ in the `preprocess` method. Available options are:
+ - `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`.
+ Do NOT keep the aspect ratio.
+ - `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting
+ the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge
+ less or equal to `longest_edge`.
+ - `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the
+ aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to
+ `max_width`.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
+ Resampling filter to use if resizing the image.
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Controls whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the
+ `do_rescale` parameter in the `preprocess` method.
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
+ Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
+ `preprocess` method.
+ do_normalize (`bool`, *optional*, defaults to True):
+ Controls whether to normalize the image. Can be overridden by the `do_normalize` parameter in the
+ `preprocess` method.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_DEFAULT_MEAN`):
+ Mean values to use when normalizing the image. Can be a single value or a list of values, one for each
+ channel. Can be overridden by the `image_mean` parameter in the `preprocess` method.
+ image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_DEFAULT_STD`):
+ Standard deviation values to use when normalizing the image. Can be a single value or a list of values, one
+ for each channel. Can be overridden by the `image_std` parameter in the `preprocess` method.
+ do_convert_annotations (`bool`, *optional*, defaults to `True`):
+ Controls whether to convert the annotations to the format expected by the DETR model. Converts the
+ bounding boxes to the format `(center_x, center_y, width, height)` and in the range `[0, 1]`.
+ Can be overridden by the `do_convert_annotations` parameter in the `preprocess` method.
+ do_pad (`bool`, *optional*, defaults to `True`):
+ Controls whether to pad the image. Can be overridden by the `do_pad` parameter in the `preprocess`
+ method. If `True`, padding will be applied to the bottom and right of the image with zeros.
+ If `pad_size` is provided, the image will be padded to the specified dimensions.
+ Otherwise, the image will be padded to the maximum height and width of the batch.
+ pad_size (`Dict[str, int]`, *optional*):
+ The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size
+ provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest
+ height and width in the batch.
+ """
+
+ model_input_names = ["pixel_values", "pixel_mask"]
+
+ def __init__(
+ self,
+ format: Union[str, AnnotationFormat] = AnnotationFormat.COCO_DETECTION,
+ do_resize: bool = True,
+ size: Dict[str, int] = None,
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
+ do_rescale: bool = True,
+ rescale_factor: Union[int, float] = 1 / 255,
+ do_normalize: bool = True,
+ image_mean: Union[float, List[float]] = None,
+ image_std: Union[float, List[float]] = None,
+ do_convert_annotations: Optional[bool] = None,
+ do_pad: bool = True,
+ pad_size: Optional[Dict[str, int]] = None,
+ **kwargs,
+ ) -> None:
+ if "pad_and_return_pixel_mask" in kwargs:
+ do_pad = kwargs.pop("pad_and_return_pixel_mask")
+
+ if "max_size" in kwargs:
+ logger.warning_once(
+ "The `max_size` parameter is deprecated and will be removed in v4.26. "
+ "Please specify in `size['longest_edge'] instead`.",
+ )
+ max_size = kwargs.pop("max_size")
+ else:
+ max_size = None if size is None else 1333
+
+ size = size if size is not None else {"shortest_edge": 800, "longest_edge": 1333}
+ size = get_size_dict(size, max_size=max_size, default_to_square=False)
+
+ # Backwards compatibility
+ if do_convert_annotations is None:
+ do_convert_annotations = do_normalize
+
+ super().__init__(**kwargs)
+ self.format = format
+ self.do_resize = do_resize
+ self.size = size
+ self.resample = resample
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
+ self.do_normalize = do_normalize
+ self.do_convert_annotations = do_convert_annotations
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
+ self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
+ self.do_pad = do_pad
+ self.pad_size = pad_size
+ self._valid_processor_keys = [
+ "images",
+ "annotations",
+ "return_segmentation_masks",
+ "masks_path",
+ "do_resize",
+ "size",
+ "resample",
+ "do_rescale",
+ "rescale_factor",
+ "do_normalize",
+ "do_convert_annotations",
+ "image_mean",
+ "image_std",
+ "do_pad",
+ "pad_size",
+ "format",
+ "return_tensors",
+ "data_format",
+ "input_data_format",
+ ]
+
+ @classmethod
+ def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
+ """
+ Overrides the `from_dict` method from the base class to make sure parameters are updated if image processor is
+ created using from_dict and kwargs e.g. `DetrImageProcessor.from_pretrained(checkpoint, size=600,
+ max_size=800)`
+ """
+ image_processor_dict = image_processor_dict.copy()
+ if "max_size" in kwargs:
+ image_processor_dict["max_size"] = kwargs.pop("max_size")
+ if "pad_and_return_pixel_mask" in kwargs:
+ image_processor_dict["pad_and_return_pixel_mask"] = kwargs.pop("pad_and_return_pixel_mask")
+ return super().from_dict(image_processor_dict, **kwargs)
+
+ def prepare_annotation(
+ self,
+ image: np.ndarray,
+ target: Dict,
+ format: Optional[AnnotationFormat] = None,
+ return_segmentation_masks: Optional[bool] = None,
+ masks_path: Optional[Union[str, pathlib.Path]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> Dict:
+ """
+ Prepare an annotation for feeding into DETR model.
+ """
+ format = format if format is not None else self.format
+
+ if format == AnnotationFormat.COCO_DETECTION:
+ return_segmentation_masks = False if return_segmentation_masks is None else return_segmentation_masks
+ target = prepare_coco_detection_annotation(
+ image, target, return_segmentation_masks, input_data_format=input_data_format
+ )
+ elif format == AnnotationFormat.COCO_PANOPTIC:
+ return_segmentation_masks = True if return_segmentation_masks is None else return_segmentation_masks
+ target = prepare_coco_panoptic_annotation(
+ image,
+ target,
+ masks_path=masks_path,
+ return_masks=return_segmentation_masks,
+ input_data_format=input_data_format,
+ )
+ else:
+ raise ValueError(f"Format {format} is not supported.")
+ return target
+
+ def resize(
+ self,
+ image: np.ndarray,
+ size: Dict[str, int],
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
+ data_format: Optional[ChannelDimension] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ) -> np.ndarray:
+ """
+ Resize the image to the given size. Size can be `min_size` (scalar) or `(height, width)` tuple. If size is an
+ int, smaller edge of the image will be matched to this number.
+
+ Args:
+ image (`np.ndarray`):
+ Image to resize.
+ size (`Dict[str, int]`):
+ Size of the image's `(height, width)` dimensions after resizing. Available options are:
+ - `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`.
+ Do NOT keep the aspect ratio.
+ - `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting
+ the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge
+ less or equal to `longest_edge`.
+ - `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the
+ aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to
+ `max_width`.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
+ Resampling filter to use if resizing the image.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
+ image is used.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred.
+ """
+ if "max_size" in kwargs:
+ logger.warning_once(
+ "The `max_size` parameter is deprecated and will be removed in v4.26. "
+ "Please specify in `size['longest_edge'] instead`.",
+ )
+ max_size = kwargs.pop("max_size")
+ else:
+ max_size = None
+ size = get_size_dict(size, max_size=max_size, default_to_square=False)
+ if "shortest_edge" in size and "longest_edge" in size:
+ new_size = get_resize_output_image_size(
+ image, size["shortest_edge"], size["longest_edge"], input_data_format=input_data_format
+ )
+ elif "max_height" in size and "max_width" in size:
+ new_size = get_image_size_for_max_height_width(
+ image, size["max_height"], size["max_width"], input_data_format=input_data_format
+ )
+ elif "height" in size and "width" in size:
+ new_size = (size["height"], size["width"])
+ else:
+ raise ValueError(
+ "Size must contain 'height' and 'width' keys or 'shortest_edge' and 'longest_edge' keys. Got"
+ f" {size.keys()}."
+ )
+ image = resize(
+ image,
+ size=new_size,
+ resample=resample,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ **kwargs,
+ )
+ return image
+
+ def resize_annotation(
+ self,
+ annotation,
+ orig_size,
+ size,
+ resample: PILImageResampling = PILImageResampling.NEAREST,
+ ) -> Dict:
+ """
+ Resize the annotation to match the resized image. If size is an int, smaller edge of the mask will be matched
+ to this number.
+ """
+ return resize_annotation(annotation, orig_size=orig_size, target_size=size, resample=resample)
+
+ # TODO (Amy) - update to use `rescale_factor` instead of `scale`
+ def rescale(
+ self,
+ image: np.ndarray,
+ rescale_factor: float,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> np.ndarray:
+ """
+ Rescale the image by the given factor. image = image * rescale_factor.
+
+ Args:
+ image (`np.ndarray`):
+ Image to rescale.
+ rescale_factor (`float`):
+ The value to use for rescaling.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
+ image is used. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ input_data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format for the input image. If unset, is inferred from the input image. Can be
+ one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ """
+ return rescale(image, rescale_factor, data_format=data_format, input_data_format=input_data_format)
+
+ def normalize_annotation(self, annotation: Dict, image_size: Tuple[int, int]) -> Dict:
+ """
+ Normalize the boxes in the annotation from `[top_left_x, top_left_y, bottom_right_x, bottom_right_y]` to
+ `[center_x, center_y, width, height]` format and from absolute to relative pixel values.
+ """
+ return normalize_annotation(annotation, image_size=image_size)
+
+ def _update_annotation_for_padded_image(
+ self,
+ annotation: Dict,
+ input_image_size: Tuple[int, int],
+ output_image_size: Tuple[int, int],
+ padding,
+ update_bboxes,
+ ) -> Dict:
+ """
+ Update the annotation for a padded image.
+ """
+ new_annotation = {}
+ new_annotation["size"] = output_image_size
+
+ for key, value in annotation.items():
+ if key == "masks":
+ masks = value
+ masks = pad(
+ masks,
+ padding,
+ mode=PaddingMode.CONSTANT,
+ constant_values=0,
+ input_data_format=ChannelDimension.FIRST,
+ )
+ masks = safe_squeeze(masks, 1)
+ new_annotation["masks"] = masks
+ elif key == "boxes" and update_bboxes:
+ boxes = value
+ boxes *= np.asarray(
+ [
+ input_image_size[1] / output_image_size[1],
+ input_image_size[0] / output_image_size[0],
+ input_image_size[1] / output_image_size[1],
+ input_image_size[0] / output_image_size[0],
+ ]
+ )
+ new_annotation["boxes"] = boxes
+ elif key == "size":
+ new_annotation["size"] = output_image_size
+ else:
+ new_annotation[key] = value
+ return new_annotation
+
+ def _pad_image(
+ self,
+ image: np.ndarray,
+ output_size: Tuple[int, int],
+ annotation: Optional[Dict[str, Any]] = None,
+ constant_values: Union[float, Iterable[float]] = 0,
+ data_format: Optional[ChannelDimension] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ update_bboxes: bool = True,
+ ) -> np.ndarray:
+ """
+ Pad an image with zeros to the given size.
+ """
+ input_height, input_width = get_image_size(image, channel_dim=input_data_format)
+ output_height, output_width = output_size
+
+ pad_bottom = output_height - input_height
+ pad_right = output_width - input_width
+ padding = ((0, pad_bottom), (0, pad_right))
+ padded_image = pad(
+ image,
+ padding,
+ mode=PaddingMode.CONSTANT,
+ constant_values=constant_values,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ )
+ if annotation is not None:
+ annotation = self._update_annotation_for_padded_image(
+ annotation, (input_height, input_width), (output_height, output_width), padding, update_bboxes
+ )
+ return padded_image, annotation
+
+ def pad(
+ self,
+ images: List[np.ndarray],
+ annotations: Optional[Union[AnnotationType, List[AnnotationType]]] = None,
+ constant_values: Union[float, Iterable[float]] = 0,
+ return_pixel_mask: bool = True,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: Optional[ChannelDimension] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ update_bboxes: bool = True,
+ pad_size: Optional[Dict[str, int]] = None,
+ ) -> BatchFeature:
+ """
+ Pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width
+ in the batch and optionally returns their corresponding pixel mask.
+
+ Args:
+ images (List[`np.ndarray`]):
+ Images to pad.
+ annotations (`AnnotationType` or `List[AnnotationType]`, *optional*):
+ Annotations to transform according to the padding that is applied to the images.
+ constant_values (`float` or `Iterable[float]`, *optional*):
+ The value to use for the padding if `mode` is `"constant"`.
+ return_pixel_mask (`bool`, *optional*, defaults to `True`):
+ Whether to return a pixel mask.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred.
+ update_bboxes (`bool`, *optional*, defaults to `True`):
+ Whether to update the bounding boxes in the annotations to match the padded images. If the
+ bounding boxes have not been converted to relative coordinates and `(centre_x, centre_y, width, height)`
+ format, the bounding boxes will not be updated.
+ pad_size (`Dict[str, int]`, *optional*):
+ The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size
+ provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest
+ height and width in the batch.
+ """
+ pad_size = pad_size if pad_size is not None else self.pad_size
+ if pad_size is not None:
+ padded_size = (pad_size["height"], pad_size["width"])
+ else:
+ padded_size = get_max_height_width(images, input_data_format=input_data_format)
+
+ annotation_list = annotations if annotations is not None else [None] * len(images)
+ padded_images = []
+ padded_annotations = []
+ for image, annotation in zip(images, annotation_list):
+ padded_image, padded_annotation = self._pad_image(
+ image,
+ padded_size,
+ annotation,
+ constant_values=constant_values,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ update_bboxes=update_bboxes,
+ )
+ padded_images.append(padded_image)
+ padded_annotations.append(padded_annotation)
+
+ data = {"pixel_values": padded_images}
+
+ if return_pixel_mask:
+ masks = [
+ make_pixel_mask(image=image, output_size=padded_size, input_data_format=input_data_format)
+ for image in images
+ ]
+ data["pixel_mask"] = masks
+
+ encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
+
+ if annotations is not None:
+ encoded_inputs["labels"] = [
+ BatchFeature(annotation, tensor_type=return_tensors) for annotation in padded_annotations
+ ]
+
+ return encoded_inputs
+
+ def preprocess(
+ self,
+ images: ImageInput,
+ annotations: Optional[Union[AnnotationType, List[AnnotationType]]] = None,
+ return_segmentation_masks: Optional[bool] = None,
+ masks_path: Optional[Union[str, pathlib.Path]] = None,
+ do_resize: Optional[bool] = None,
+ size: Optional[Dict[str, int]] = None,
+ resample=None, # PILImageResampling
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[Union[int, float]] = None,
+ do_normalize: Optional[bool] = None,
+ do_convert_annotations: Optional[bool] = None,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ do_pad: Optional[bool] = None,
+ format: Optional[Union[str, AnnotationFormat]] = None,
+ return_tensors: Optional[Union[TensorType, str]] = None,
+ data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ pad_size: Optional[Dict[str, int]] = None,
+ **kwargs,
+ ) -> BatchFeature:
+ """
+ Preprocess an image or a batch of images so that it can be used by the model.
+
+ Args:
+ images (`ImageInput`):
+ Image or batch of images to preprocess. Expects a single or batch of images with pixel values ranging
+ from 0 to 255. If passing in images with pixel values between 0 and 1, set `do_rescale=False`.
+ annotations (`AnnotationType` or `List[AnnotationType]`, *optional*):
+ List of annotations associated with the image or batch of images. If annotation is for object
+ detection, the annotations should be a dictionary with the following keys:
+ - "image_id" (`int`): The image id.
+ - "annotations" (`List[Dict]`): List of annotations for an image. Each annotation should be a
+ dictionary. An image can have no annotations, in which case the list should be empty.
+ If annotation is for segmentation, the annotations should be a dictionary with the following keys:
+ - "image_id" (`int`): The image id.
+ - "segments_info" (`List[Dict]`): List of segments for an image. Each segment should be a dictionary.
+ An image can have no segments, in which case the list should be empty.
+ - "file_name" (`str`): The file name of the image.
+ return_segmentation_masks (`bool`, *optional*, defaults to self.return_segmentation_masks):
+ Whether to return segmentation masks.
+ masks_path (`str` or `pathlib.Path`, *optional*):
+ Path to the directory containing the segmentation masks.
+ do_resize (`bool`, *optional*, defaults to self.do_resize):
+ Whether to resize the image.
+ size (`Dict[str, int]`, *optional*, defaults to self.size):
+ Size of the image's `(height, width)` dimensions after resizing. Available options are:
+ - `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`.
+ Do NOT keep the aspect ratio.
+ - `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting
+ the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge
+ less or equal to `longest_edge`.
+ - `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the
+ aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to
+ `max_width`.
+ resample (`PILImageResampling`, *optional*, defaults to self.resample):
+ Resampling filter to use when resizing the image.
+ do_rescale (`bool`, *optional*, defaults to self.do_rescale):
+ Whether to rescale the image.
+ rescale_factor (`float`, *optional*, defaults to self.rescale_factor):
+ Rescale factor to use when rescaling the image.
+ do_normalize (`bool`, *optional*, defaults to self.do_normalize):
+ Whether to normalize the image.
+ do_convert_annotations (`bool`, *optional*, defaults to self.do_convert_annotations):
+ Whether to convert the annotations to the format expected by the model. Converts the bounding
+ boxes from the format `(top_left_x, top_left_y, width, height)` to `(center_x, center_y, width, height)`
+ and in relative coordinates.
+ image_mean (`float` or `List[float]`, *optional*, defaults to self.image_mean):
+ Mean to use when normalizing the image.
+ image_std (`float` or `List[float]`, *optional*, defaults to self.image_std):
+ Standard deviation to use when normalizing the image.
+ do_pad (`bool`, *optional*, defaults to self.do_pad):
+ Whether to pad the image. If `True`, padding will be applied to the bottom and right of
+ the image with zeros. If `pad_size` is provided, the image will be padded to the specified
+ dimensions. Otherwise, the image will be padded to the maximum height and width of the batch.
+ format (`str` or `AnnotationFormat`, *optional*, defaults to self.format):
+ Format of the annotations.
+ return_tensors (`str` or `TensorType`, *optional*, defaults to self.return_tensors):
+ Type of tensors to return. If `None`, will return the list of images.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - Unset: Use the channel dimension format of the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ pad_size (`Dict[str, int]`, *optional*):
+ The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size
+ provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest
+ height and width in the batch.
+ """
+ if "pad_and_return_pixel_mask" in kwargs:
+ logger.warning_once(
+ "The `pad_and_return_pixel_mask` argument is deprecated and will be removed in a future version, "
+ "use `do_pad` instead."
+ )
+ do_pad = kwargs.pop("pad_and_return_pixel_mask")
+
+ if "max_size" in kwargs:
+ logger.warning_once(
+ "The `max_size` argument is deprecated and will be removed in a future version, use"
+ " `size['longest_edge']` instead."
+ )
+ size = kwargs.pop("max_size")
+
+ do_resize = self.do_resize if do_resize is None else do_resize
+ size = self.size if size is None else size
+ size = get_size_dict(size=size, default_to_square=False)
+ resample = self.resample if resample is None else resample
+ do_rescale = self.do_rescale if do_rescale is None else do_rescale
+ rescale_factor = self.rescale_factor if rescale_factor is None else rescale_factor
+ do_normalize = self.do_normalize if do_normalize is None else do_normalize
+ image_mean = self.image_mean if image_mean is None else image_mean
+ image_std = self.image_std if image_std is None else image_std
+ do_convert_annotations = (
+ self.do_convert_annotations if do_convert_annotations is None else do_convert_annotations
+ )
+ do_pad = self.do_pad if do_pad is None else do_pad
+ pad_size = self.pad_size if pad_size is None else pad_size
+ format = self.format if format is None else format
+
+ images = make_list_of_images(images)
+
+ if not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+ validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys)
+
+ # Here, the pad() method pads to the maximum of (width, height). It does not need to be validated.
+ validate_preprocess_arguments(
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ )
+
+ if annotations is not None and isinstance(annotations, dict):
+ annotations = [annotations]
+
+ if annotations is not None and len(images) != len(annotations):
+ raise ValueError(
+ f"The number of images ({len(images)}) and annotations ({len(annotations)}) do not match."
+ )
+
+ format = AnnotationFormat(format)
+ if annotations is not None:
+ validate_annotations(format, SUPPORTED_ANNOTATION_FORMATS, annotations)
+
+ if (
+ masks_path is not None
+ and format == AnnotationFormat.COCO_PANOPTIC
+ and not isinstance(masks_path, (pathlib.Path, str))
+ ):
+ raise ValueError(
+ "The path to the directory containing the mask PNG files should be provided as a"
+ f" `pathlib.Path` or string object, but is {type(masks_path)} instead."
+ )
+
+ # All transformations expect numpy arrays
+ images = [to_numpy_array(image) for image in images]
+
+ if do_rescale and is_scaled_image(images[0]):
+ logger.warning_once(
+ "It looks like you are trying to rescale already rescaled images. If the input"
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
+ )
+
+ if input_data_format is None:
+ # We assume that all images have the same channel dimension format.
+ input_data_format = infer_channel_dimension_format(images[0])
+
+ # prepare (COCO annotations as a list of Dict -> DETR target as a single Dict per image)
+ if annotations is not None:
+ prepared_images = []
+ prepared_annotations = []
+ for image, target in zip(images, annotations):
+ target = self.prepare_annotation(
+ image,
+ target,
+ format,
+ return_segmentation_masks=return_segmentation_masks,
+ masks_path=masks_path,
+ input_data_format=input_data_format,
+ )
+ prepared_images.append(image)
+ prepared_annotations.append(target)
+ images = prepared_images
+ annotations = prepared_annotations
+ del prepared_images, prepared_annotations
+
+ # transformations
+ if do_resize:
+ if annotations is not None:
+ resized_images, resized_annotations = [], []
+ for image, target in zip(images, annotations):
+ orig_size = get_image_size(image, input_data_format)
+ resized_image = self.resize(
+ image, size=size, resample=resample, input_data_format=input_data_format
+ )
+ resized_annotation = self.resize_annotation(
+ target, orig_size, get_image_size(resized_image, input_data_format)
+ )
+ resized_images.append(resized_image)
+ resized_annotations.append(resized_annotation)
+ images = resized_images
+ annotations = resized_annotations
+ del resized_images, resized_annotations
+ else:
+ images = [
+ self.resize(image, size=size, resample=resample, input_data_format=input_data_format)
+ for image in images
+ ]
+
+ if do_rescale:
+ images = [self.rescale(image, rescale_factor, input_data_format=input_data_format) for image in images]
+
+ if do_normalize:
+ images = [
+ self.normalize(image, image_mean, image_std, input_data_format=input_data_format) for image in images
+ ]
+
+ if do_convert_annotations and annotations is not None:
+ annotations = [
+ self.normalize_annotation(annotation, get_image_size(image, input_data_format))
+ for annotation, image in zip(annotations, images)
+ ]
+
+ if do_pad:
+ # Pads images and returns their mask: {'pixel_values': ..., 'pixel_mask': ...}
+ encoded_inputs = self.pad(
+ images,
+ annotations=annotations,
+ return_pixel_mask=True,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ update_bboxes=do_convert_annotations,
+ return_tensors=return_tensors,
+ pad_size=pad_size,
+ )
+ else:
+ images = [
+ to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
+ for image in images
+ ]
+ encoded_inputs = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)
+ if annotations is not None:
+ encoded_inputs["labels"] = [
+ BatchFeature(annotation, tensor_type=return_tensors) for annotation in annotations
+ ]
+
+ return encoded_inputs
+
+ # POSTPROCESSING METHODS - TODO: add support for other frameworks
+ # inspired by https://github.com/facebookresearch/detr/blob/master/models/detr.py#L258
+ def post_process(self, outputs, target_sizes):
+ """
+ Converts the raw output of [`DetrForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y,
+ bottom_right_x, bottom_right_y) format. Only supports PyTorch.
+
+ Args:
+ outputs ([`DetrObjectDetectionOutput`]):
+ Raw outputs of the model.
+ target_sizes (`torch.Tensor` of shape `(batch_size, 2)`):
+ Tensor containing the size (height, width) of each image of the batch. For evaluation, this must be the
+ original image size (before any data augmentation). For visualization, this should be the image size
+ after data augment, but before padding.
+ Returns:
+ `List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
+ in the batch as predicted by the model.
+ """
+ logger.warning_once(
+ "`post_process` is deprecated and will be removed in v5 of Transformers, please use"
+ " `post_process_object_detection` instead, with `threshold=0.` for equivalent results.",
+ )
+
+ out_logits, out_bbox = outputs.logits, outputs.pred_boxes
+
+ if len(out_logits) != len(target_sizes):
+ raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits")
+ if target_sizes.shape[1] != 2:
+ raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch")
+
+ prob = nn.functional.softmax(out_logits, -1)
+ scores, labels = prob[..., :-1].max(-1)
+
+ # convert to [x0, y0, x1, y1] format
+ boxes = center_to_corners_format(out_bbox)
+ # and from relative [0, 1] to absolute [0, height] coordinates
+ img_h, img_w = target_sizes.unbind(1)
+ scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device)
+ boxes = boxes * scale_fct[:, None, :]
+
+ results = [{"scores": s, "labels": l, "boxes": b} for s, l, b in zip(scores, labels, boxes)]
+ return results
+
+ def post_process_segmentation(self, outputs, target_sizes, threshold=0.9, mask_threshold=0.5):
+ """
+ Converts the output of [`DetrForSegmentation`] into image segmentation predictions. Only supports PyTorch.
+
+ Args:
+ outputs ([`DetrSegmentationOutput`]):
+ Raw outputs of the model.
+ target_sizes (`torch.Tensor` of shape `(batch_size, 2)` or `List[Tuple]` of length `batch_size`):
+ Torch Tensor (or list) corresponding to the requested final size (h, w) of each prediction.
+ threshold (`float`, *optional*, defaults to 0.9):
+ Threshold to use to filter out queries.
+ mask_threshold (`float`, *optional*, defaults to 0.5):
+ Threshold to use when turning the predicted masks into binary values.
+ Returns:
+ `List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels, and masks for an image
+ in the batch as predicted by the model.
+ """
+ logger.warning_once(
+ "`post_process_segmentation` is deprecated and will be removed in v5 of Transformers, please use"
+ " `post_process_semantic_segmentation`.",
+ )
+ out_logits, raw_masks = outputs.logits, outputs.pred_masks
+ empty_label = out_logits.shape[-1] - 1
+ preds = []
+
+ def to_tuple(tup):
+ if isinstance(tup, tuple):
+ return tup
+ return tuple(tup.tolist())
+
+ for cur_logits, cur_masks, size in zip(out_logits, raw_masks, target_sizes):
+ # we filter empty queries and detection below threshold
+ cur_scores, cur_labels = cur_logits.softmax(-1).max(-1)
+ keep = cur_labels.ne(empty_label) & (cur_scores > threshold)
+ cur_scores = cur_scores[keep]
+ cur_labels = cur_labels[keep]
+ cur_masks = cur_masks[keep]
+ cur_masks = nn.functional.interpolate(cur_masks[:, None], to_tuple(size), mode="bilinear").squeeze(1)
+ cur_masks = (cur_masks.sigmoid() > mask_threshold) * 1
+
+ predictions = {"scores": cur_scores, "labels": cur_labels, "masks": cur_masks}
+ preds.append(predictions)
+ return preds
+
+ # inspired by https://github.com/facebookresearch/detr/blob/master/models/segmentation.py#L218
+ def post_process_instance(self, results, outputs, orig_target_sizes, max_target_sizes, threshold=0.5):
+ """
+ Converts the output of [`DetrForSegmentation`] into actual instance segmentation predictions. Only supports
+ PyTorch.
+
+ Args:
+ results (`List[Dict]`):
+ Results list obtained by [`~DetrImageProcessor.post_process`], to which "masks" results will be added.
+ outputs ([`DetrSegmentationOutput`]):
+ Raw outputs of the model.
+ orig_target_sizes (`torch.Tensor` of shape `(batch_size, 2)`):
+ Tensor containing the size (h, w) of each image of the batch. For evaluation, this must be the original
+ image size (before any data augmentation).
+ max_target_sizes (`torch.Tensor` of shape `(batch_size, 2)`):
+ Tensor containing the maximum size (h, w) of each image of the batch. For evaluation, this must be the
+ original image size (before any data augmentation).
+ threshold (`float`, *optional*, defaults to 0.5):
+ Threshold to use when turning the predicted masks into binary values.
+ Returns:
+ `List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels, boxes and masks for an
+ image in the batch as predicted by the model.
+ """
+ logger.warning_once(
+ "`post_process_instance` is deprecated and will be removed in v5 of Transformers, please use"
+ " `post_process_instance_segmentation`.",
+ )
+
+ if len(orig_target_sizes) != len(max_target_sizes):
+ raise ValueError("Make sure to pass in as many orig_target_sizes as max_target_sizes")
+ max_h, max_w = max_target_sizes.max(0)[0].tolist()
+ outputs_masks = outputs.pred_masks.squeeze(2)
+ outputs_masks = nn.functional.interpolate(
+ outputs_masks, size=(max_h, max_w), mode="bilinear", align_corners=False
+ )
+ outputs_masks = (outputs_masks.sigmoid() > threshold).cpu()
+
+ for i, (cur_mask, t, tt) in enumerate(zip(outputs_masks, max_target_sizes, orig_target_sizes)):
+ img_h, img_w = t[0], t[1]
+ results[i]["masks"] = cur_mask[:, :img_h, :img_w].unsqueeze(1)
+ results[i]["masks"] = nn.functional.interpolate(
+ results[i]["masks"].float(), size=tuple(tt.tolist()), mode="nearest"
+ ).byte()
+
+ return results
+
+ # inspired by https://github.com/facebookresearch/detr/blob/master/models/segmentation.py#L241
+ def post_process_panoptic(self, outputs, processed_sizes, target_sizes=None, is_thing_map=None, threshold=0.85):
+ """
+ Converts the output of [`DetrForSegmentation`] into actual panoptic predictions. Only supports PyTorch.
+
+ Args:
+ outputs ([`DetrSegmentationOutput`]):
+ Raw outputs of the model.
+ processed_sizes (`torch.Tensor` of shape `(batch_size, 2)` or `List[Tuple]` of length `batch_size`):
+ Torch Tensor (or list) containing the size (h, w) of each image of the batch, i.e. the size after data
+ augmentation but before batching.
+ target_sizes (`torch.Tensor` of shape `(batch_size, 2)` or `List[Tuple]` of length `batch_size`, *optional*):
+ Torch Tensor (or list) corresponding to the requested final size `(height, width)` of each prediction.
+ If left to None, it will default to the `processed_sizes`.
+ is_thing_map (`torch.Tensor` of shape `(batch_size, 2)`, *optional*):
+ Dictionary mapping class indices to either True or False, depending on whether or not they are a thing.
+ If not set, defaults to the `is_thing_map` of COCO panoptic.
+ threshold (`float`, *optional*, defaults to 0.85):
+ Threshold to use to filter out queries.
+ Returns:
+ `List[Dict]`: A list of dictionaries, each dictionary containing a PNG string and segments_info values for
+ an image in the batch as predicted by the model.
+ """
+ logger.warning_once(
+ "`post_process_panoptic is deprecated and will be removed in v5 of Transformers, please use"
+ " `post_process_panoptic_segmentation`.",
+ )
+ if target_sizes is None:
+ target_sizes = processed_sizes
+ if len(processed_sizes) != len(target_sizes):
+ raise ValueError("Make sure to pass in as many processed_sizes as target_sizes")
+
+ if is_thing_map is None:
+ # default to is_thing_map of COCO panoptic
+ is_thing_map = {i: i <= 90 for i in range(201)}
+
+ out_logits, raw_masks, raw_boxes = outputs.logits, outputs.pred_masks, outputs.pred_boxes
+ if not len(out_logits) == len(raw_masks) == len(target_sizes):
+ raise ValueError(
+ "Make sure that you pass in as many target sizes as the batch dimension of the logits and masks"
+ )
+ empty_label = out_logits.shape[-1] - 1
+ preds = []
+
+ def to_tuple(tup):
+ if isinstance(tup, tuple):
+ return tup
+ return tuple(tup.tolist())
+
+ for cur_logits, cur_masks, cur_boxes, size, target_size in zip(
+ out_logits, raw_masks, raw_boxes, processed_sizes, target_sizes
+ ):
+ # we filter empty queries and detection below threshold
+ cur_scores, cur_labels = cur_logits.softmax(-1).max(-1)
+ keep = cur_labels.ne(empty_label) & (cur_scores > threshold)
+ cur_scores = cur_scores[keep]
+ cur_labels = cur_labels[keep]
+ cur_masks = cur_masks[keep]
+ cur_masks = nn.functional.interpolate(cur_masks[:, None], to_tuple(size), mode="bilinear").squeeze(1)
+ cur_boxes = center_to_corners_format(cur_boxes[keep])
+
+ h, w = cur_masks.shape[-2:]
+ if len(cur_boxes) != len(cur_labels):
+ raise ValueError("Not as many boxes as there are classes")
+
+ # It may be that we have several predicted masks for the same stuff class.
+ # In the following, we track the list of masks ids for each stuff class (they are merged later on)
+ cur_masks = cur_masks.flatten(1)
+ stuff_equiv_classes = defaultdict(lambda: [])
+ for k, label in enumerate(cur_labels):
+ if not is_thing_map[label.item()]:
+ stuff_equiv_classes[label.item()].append(k)
+
+ def get_ids_area(masks, scores, dedup=False):
+ # This helper function creates the final panoptic segmentation image
+ # It also returns the area of the masks that appears on the image
+
+ m_id = masks.transpose(0, 1).softmax(-1)
+
+ if m_id.shape[-1] == 0:
+ # We didn't detect any mask :(
+ m_id = torch.zeros((h, w), dtype=torch.long, device=m_id.device)
+ else:
+ m_id = m_id.argmax(-1).view(h, w)
+
+ if dedup:
+ # Merge the masks corresponding to the same stuff class
+ for equiv in stuff_equiv_classes.values():
+ if len(equiv) > 1:
+ for eq_id in equiv:
+ m_id.masked_fill_(m_id.eq(eq_id), equiv[0])
+
+ final_h, final_w = to_tuple(target_size)
+
+ seg_img = PIL.Image.fromarray(id_to_rgb(m_id.view(h, w).cpu().numpy()))
+ seg_img = seg_img.resize(size=(final_w, final_h), resample=PILImageResampling.NEAREST)
+
+ np_seg_img = torch.ByteTensor(torch.ByteStorage.from_buffer(seg_img.tobytes()))
+ np_seg_img = np_seg_img.view(final_h, final_w, 3)
+ np_seg_img = np_seg_img.numpy()
+
+ m_id = torch.from_numpy(rgb_to_id(np_seg_img))
+
+ area = []
+ for i in range(len(scores)):
+ area.append(m_id.eq(i).sum().item())
+ return area, seg_img
+
+ area, seg_img = get_ids_area(cur_masks, cur_scores, dedup=True)
+ if cur_labels.numel() > 0:
+ # We know filter empty masks as long as we find some
+ while True:
+ filtered_small = torch.as_tensor(
+ [area[i] <= 4 for i, c in enumerate(cur_labels)], dtype=torch.bool, device=keep.device
+ )
+ if filtered_small.any().item():
+ cur_scores = cur_scores[~filtered_small]
+ cur_labels = cur_labels[~filtered_small]
+ cur_masks = cur_masks[~filtered_small]
+ area, seg_img = get_ids_area(cur_masks, cur_scores)
+ else:
+ break
+
+ else:
+ cur_labels = torch.ones(1, dtype=torch.long, device=cur_labels.device)
+
+ segments_info = []
+ for i, a in enumerate(area):
+ cat = cur_labels[i].item()
+ segments_info.append({"id": i, "isthing": is_thing_map[cat], "category_id": cat, "area": a})
+ del cur_labels
+
+ with io.BytesIO() as out:
+ seg_img.save(out, format="PNG")
+ predictions = {"png_string": out.getvalue(), "segments_info": segments_info}
+ preds.append(predictions)
+ return preds
+
+ # inspired by https://github.com/facebookresearch/detr/blob/master/models/detr.py#L258
+ def post_process_object_detection(
+ self, outputs, threshold: float = 0.5, target_sizes: Union[TensorType, List[Tuple]] = None
+ ):
+ """
+ Converts the raw output of [`DetrForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y,
+ bottom_right_x, bottom_right_y) format. Only supports PyTorch.
+
+ Args:
+ outputs ([`DetrObjectDetectionOutput`]):
+ Raw outputs of the model.
+ threshold (`float`, *optional*):
+ Score threshold to keep object detection predictions.
+ target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*):
+ Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
+ `(height, width)` of each image in the batch. If unset, predictions will not be resized.
+ Returns:
+ `List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image
+ in the batch as predicted by the model.
+ """
+ out_logits, out_bbox = outputs.logits, outputs.pred_boxes
+
+ if target_sizes is not None:
+ if len(out_logits) != len(target_sizes):
+ raise ValueError(
+ "Make sure that you pass in as many target sizes as the batch dimension of the logits"
+ )
+
+ prob = nn.functional.softmax(out_logits, -1)
+ scores, labels = prob[..., :-1].max(-1)
+
+ # Convert to [x0, y0, x1, y1] format
+ boxes = center_to_corners_format(out_bbox)
+
+ # Convert from relative [0, 1] to absolute [0, height] coordinates
+ if target_sizes is not None:
+ if isinstance(target_sizes, List):
+ img_h = torch.Tensor([i[0] for i in target_sizes])
+ img_w = torch.Tensor([i[1] for i in target_sizes])
+ else:
+ img_h, img_w = target_sizes.unbind(1)
+
+ scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device)
+ boxes = boxes * scale_fct[:, None, :]
+
+ results = []
+ for s, l, b in zip(scores, labels, boxes):
+ score = s[s > threshold]
+ label = l[s > threshold]
+ box = b[s > threshold]
+ results.append({"scores": score, "labels": label, "boxes": box})
+
+ return results
+
+ def post_process_semantic_segmentation(self, outputs, target_sizes: List[Tuple[int, int]] = None):
+ """
+ Converts the output of [`DetrForSegmentation`] into semantic segmentation maps. Only supports PyTorch.
+
+ Args:
+ outputs ([`DetrForSegmentation`]):
+ Raw outputs of the model.
+ target_sizes (`List[Tuple[int, int]]`, *optional*):
+ A list of tuples (`Tuple[int, int]`) containing the target size (height, width) of each image in the
+ batch. If unset, predictions will not be resized.
+ Returns:
+ `List[torch.Tensor]`:
+ A list of length `batch_size`, where each item is a semantic segmentation map of shape (height, width)
+ corresponding to the target_sizes entry (if `target_sizes` is specified). Each entry of each
+ `torch.Tensor` correspond to a semantic class id.
+ """
+ class_queries_logits = outputs.logits # [batch_size, num_queries, num_classes+1]
+ masks_queries_logits = outputs.pred_masks # [batch_size, num_queries, height, width]
+
+ # Remove the null class `[..., :-1]`
+ masks_classes = class_queries_logits.softmax(dim=-1)[..., :-1]
+ masks_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width]
+
+ # Semantic segmentation logits of shape (batch_size, num_classes, height, width)
+ segmentation = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs)
+ batch_size = class_queries_logits.shape[0]
+
+ # Resize logits and compute semantic segmentation maps
+ if target_sizes is not None:
+ if batch_size != len(target_sizes):
+ raise ValueError(
+ "Make sure that you pass in as many target sizes as the batch dimension of the logits"
+ )
+
+ semantic_segmentation = []
+ for idx in range(batch_size):
+ resized_logits = nn.functional.interpolate(
+ segmentation[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
+ )
+ semantic_map = resized_logits[0].argmax(dim=0)
+ semantic_segmentation.append(semantic_map)
+ else:
+ semantic_segmentation = segmentation.argmax(dim=1)
+ semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
+
+ return semantic_segmentation
+
+ # inspired by https://github.com/facebookresearch/detr/blob/master/models/segmentation.py#L218
+ def post_process_instance_segmentation(
+ self,
+ outputs,
+ threshold: float = 0.5,
+ mask_threshold: float = 0.5,
+ overlap_mask_area_threshold: float = 0.8,
+ target_sizes: Optional[List[Tuple[int, int]]] = None,
+ return_coco_annotation: Optional[bool] = False,
+ ) -> List[Dict]:
+ """
+ Converts the output of [`DetrForSegmentation`] into instance segmentation predictions. Only supports PyTorch.
+
+ Args:
+ outputs ([`DetrForSegmentation`]):
+ Raw outputs of the model.
+ threshold (`float`, *optional*, defaults to 0.5):
+ The probability score threshold to keep predicted instance masks.
+ mask_threshold (`float`, *optional*, defaults to 0.5):
+ Threshold to use when turning the predicted masks into binary values.
+ overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8):
+ The overlap mask area threshold to merge or discard small disconnected parts within each binary
+ instance mask.
+ target_sizes (`List[Tuple]`, *optional*):
+ List of length (batch_size), where each list item (`Tuple[int, int]]`) corresponds to the requested
+ final size (height, width) of each prediction. If unset, predictions will not be resized.
+ return_coco_annotation (`bool`, *optional*):
+ Defaults to `False`. If set to `True`, segmentation maps are returned in COCO run-length encoding (RLE)
+ format.
+ Returns:
+ `List[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys:
+ - **segmentation** -- A tensor of shape `(height, width)` where each pixel represents a `segment_id` or
+ `List[List]` run-length encoding (RLE) of the segmentation map if return_coco_annotation is set to
+ `True`. Set to `None` if no mask if found above `threshold`.
+ - **segments_info** -- A dictionary that contains additional information on each segment.
+ - **id** -- An integer representing the `segment_id`.
+ - **label_id** -- An integer representing the label / semantic class id corresponding to `segment_id`.
+ - **score** -- Prediction score of segment with `segment_id`.
+ """
+ class_queries_logits = outputs.logits # [batch_size, num_queries, num_classes+1]
+ masks_queries_logits = outputs.pred_masks # [batch_size, num_queries, height, width]
+
+ batch_size = class_queries_logits.shape[0]
+ num_labels = class_queries_logits.shape[-1] - 1
+
+ mask_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width]
+
+ # Predicted label and score of each query (batch_size, num_queries)
+ pred_scores, pred_labels = nn.functional.softmax(class_queries_logits, dim=-1).max(-1)
+
+ # Loop over items in batch size
+ results: List[Dict[str, TensorType]] = []
+
+ for i in range(batch_size):
+ mask_probs_item, pred_scores_item, pred_labels_item = remove_low_and_no_objects(
+ mask_probs[i], pred_scores[i], pred_labels[i], threshold, num_labels
+ )
+
+ # No mask found
+ if mask_probs_item.shape[0] <= 0:
+ height, width = target_sizes[i] if target_sizes is not None else mask_probs_item.shape[1:]
+ segmentation = torch.zeros((height, width)) - 1
+ results.append({"segmentation": segmentation, "segments_info": []})
+ continue
+
+ # Get segmentation map and segment information of batch item
+ target_size = target_sizes[i] if target_sizes is not None else None
+ segmentation, segments = compute_segments(
+ mask_probs=mask_probs_item,
+ pred_scores=pred_scores_item,
+ pred_labels=pred_labels_item,
+ mask_threshold=mask_threshold,
+ overlap_mask_area_threshold=overlap_mask_area_threshold,
+ label_ids_to_fuse=[],
+ target_size=target_size,
+ )
+
+ # Return segmentation map in run-length encoding (RLE) format
+ if return_coco_annotation:
+ segmentation = convert_segmentation_to_rle(segmentation)
+
+ results.append({"segmentation": segmentation, "segments_info": segments})
+ return results
+
+ # inspired by https://github.com/facebookresearch/detr/blob/master/models/segmentation.py#L241
+ def post_process_panoptic_segmentation(
+ self,
+ outputs,
+ threshold: float = 0.5,
+ mask_threshold: float = 0.5,
+ overlap_mask_area_threshold: float = 0.8,
+ label_ids_to_fuse: Optional[Set[int]] = None,
+ target_sizes: Optional[List[Tuple[int, int]]] = None,
+ ) -> List[Dict]:
+ """
+ Converts the output of [`DetrForSegmentation`] into image panoptic segmentation predictions. Only supports
+ PyTorch.
+
+ Args:
+ outputs ([`DetrForSegmentation`]):
+ The outputs from [`DetrForSegmentation`].
+ threshold (`float`, *optional*, defaults to 0.5):
+ The probability score threshold to keep predicted instance masks.
+ mask_threshold (`float`, *optional*, defaults to 0.5):
+ Threshold to use when turning the predicted masks into binary values.
+ overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8):
+ The overlap mask area threshold to merge or discard small disconnected parts within each binary
+ instance mask.
+ label_ids_to_fuse (`Set[int]`, *optional*):
+ The labels in this state will have all their instances be fused together. For instance we could say
+ there can only be one sky in an image, but several persons, so the label ID for sky would be in that
+ set, but not the one for person.
+ target_sizes (`List[Tuple]`, *optional*):
+ List of length (batch_size), where each list item (`Tuple[int, int]]`) corresponds to the requested
+ final size (height, width) of each prediction in batch. If unset, predictions will not be resized.
+ Returns:
+ `List[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys:
+ - **segmentation** -- a tensor of shape `(height, width)` where each pixel represents a `segment_id` or
+ `None` if no mask if found above `threshold`. If `target_sizes` is specified, segmentation is resized to
+ the corresponding `target_sizes` entry.
+ - **segments_info** -- A dictionary that contains additional information on each segment.
+ - **id** -- an integer representing the `segment_id`.
+ - **label_id** -- An integer representing the label / semantic class id corresponding to `segment_id`.
+ - **was_fused** -- a boolean, `True` if `label_id` was in `label_ids_to_fuse`, `False` otherwise.
+ Multiple instances of the same class / label were fused and assigned a single `segment_id`.
+ - **score** -- Prediction score of segment with `segment_id`.
+ """
+
+ if label_ids_to_fuse is None:
+ logger.warning_once("`label_ids_to_fuse` unset. No instance will be fused.")
+ label_ids_to_fuse = set()
+
+ class_queries_logits = outputs.logits # [batch_size, num_queries, num_classes+1]
+ masks_queries_logits = outputs.pred_masks # [batch_size, num_queries, height, width]
+
+ batch_size = class_queries_logits.shape[0]
+ num_labels = class_queries_logits.shape[-1] - 1
+
+ mask_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width]
+
+ # Predicted label and score of each query (batch_size, num_queries)
+ pred_scores, pred_labels = nn.functional.softmax(class_queries_logits, dim=-1).max(-1)
+
+ # Loop over items in batch size
+ results: List[Dict[str, TensorType]] = []
+
+ for i in range(batch_size):
+ mask_probs_item, pred_scores_item, pred_labels_item = remove_low_and_no_objects(
+ mask_probs[i], pred_scores[i], pred_labels[i], threshold, num_labels
+ )
+
+ # No mask found
+ if mask_probs_item.shape[0] <= 0:
+ height, width = target_sizes[i] if target_sizes is not None else mask_probs_item.shape[1:]
+ segmentation = torch.zeros((height, width)) - 1
+ results.append({"segmentation": segmentation, "segments_info": []})
+ continue
+
+ # Get segmentation map and segment information of batch item
+ target_size = target_sizes[i] if target_sizes is not None else None
+ segmentation, segments = compute_segments(
+ mask_probs=mask_probs_item,
+ pred_scores=pred_scores_item,
+ pred_labels=pred_labels_item,
+ mask_threshold=mask_threshold,
+ overlap_mask_area_threshold=overlap_mask_area_threshold,
+ label_ids_to_fuse=label_ids_to_fuse,
+ target_size=target_size,
+ )
+
+ results.append({"segmentation": segmentation, "segments_info": segments})
+ return results
+
+
+__all__ = ["DetrImageProcessor"]
diff --git a/docs/transformers/build/lib/transformers/models/diffllama/modeling_diffllama.py b/docs/transformers/build/lib/transformers/models/diffllama/modeling_diffllama.py
new file mode 100644
index 0000000000000000000000000000000000000000..b2d736cb4f99d912ea8493e35f886be6f4702a21
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/diffllama/modeling_diffllama.py
@@ -0,0 +1,1389 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/diffllama/modular_diffllama.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_diffllama.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2024 weak-kajuma and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on Llama implementations in this library and Microsoft's
+# Differential Transformer implementations.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import math
+from typing import Optional, Tuple, Union
+
+import torch
+from torch import nn
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache, StaticCache
+from ...generation import GenerationMixin
+from ...integrations import use_kernel_forward_from_hub
+from ...modeling_attn_mask_utils import AttentionMaskConverter
+from ...modeling_flash_attention_utils import (
+ FlashAttentionKwargs,
+ _flash_attention_forward,
+ flash_attn_supports_top_left_mask,
+)
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import (
+ BaseModelOutputWithPast,
+ CausalLMOutputWithPast,
+ QuestionAnsweringModelOutput,
+ SequenceClassifierOutputWithPast,
+ TokenClassifierOutput,
+)
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+from ...modeling_utils import PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import (
+ LossKwargs,
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ can_return_tuple,
+ is_torch_flex_attn_available,
+ logging,
+ replace_return_docstrings,
+)
+from .configuration_diffllama import DiffLlamaConfig
+
+
+if is_torch_flex_attn_available():
+ from torch.nn.attention.flex_attention import BlockMask
+
+ from ...integrations.flex_attention import make_flex_block_causal_mask
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "kajuma/DiffLlama-0.3B-handcut"
+_CONFIG_FOR_DOC = "DiffLlamaConfig"
+
+
+class DiffLlamaMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, x):
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+ return down_proj
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def lambda_init_fn(layer_idx):
+ return 0.8 - 0.6 * math.exp(-0.3 * layer_idx)
+
+
+class DiffLlamaAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: DiffLlamaConfig, layer_idx: Optional[int] = None):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ if layer_idx is None:
+ logger.warning_once(
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
+ "when creating this class."
+ )
+
+ self.attention_dropout = config.attention_dropout
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
+ self.num_key_value_heads = config.num_key_value_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ # under this are not used
+ self.max_position_embeddings = config.max_position_embeddings
+ self.rope_theta = config.rope_theta
+ self.is_causal = True
+
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
+
+ self.lambda_init = lambda_init_fn(layer_idx)
+ self.lambda_q1 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,)))
+ self.lambda_k1 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,)))
+ self.lambda_q2 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,)))
+ self.lambda_k2 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,)))
+ self.groupnorm = nn.RMSNorm(2 * self.head_dim, eps=config.rms_norm_eps, elementwise_affine=False)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ bsz, target_len, _ = hidden_states.size()
+ q_len = target_len
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+ value_states = torch.cat(torch.chunk(value_states, 2, dim=1), dim=-1)
+ value_states = value_states.repeat(1, 2, 1, 1)
+
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
+ lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1, dtype=torch.float32)).to(
+ query_states.dtype
+ )
+ lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1, dtype=torch.float32)).to(
+ query_states.dtype
+ )
+ lambda_full = lambda_1 - lambda_2 + self.lambda_init
+
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output1, attn_output2 = torch.chunk(attn_output, 2, dim=1)
+
+ attn_output = attn_output1 - lambda_full * attn_output2
+ attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.reshape(bsz, q_len, -1)
+
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights
+
+
+class DiffLlamaFlashAttention2(DiffLlamaAttention):
+ """
+ DiffLlama flash attention module. This module inherits from `DiffLlamaAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+ self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if isinstance(past_key_value, StaticCache):
+ raise ValueError(
+ "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
+ "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
+ )
+
+ output_attentions = False
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ # therefore we just need to keep the original shape
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
+ # to be able to avoid many of these transpose/reshape/view.
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ dropout_rate = self.attention_dropout if self.training else 0.0
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (DiffLlamaRMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ value_states1, value_states2 = torch.chunk(value_states, 2, dim=2)
+ value_states1 = value_states1.repeat(1, 1, 2, 1)
+ value_states2 = value_states2.repeat(1, 1, 2, 1)
+
+ attn_output1 = _flash_attention_forward(
+ query_states,
+ key_states,
+ value_states1,
+ attention_mask,
+ q_len,
+ position_ids=position_ids,
+ dropout=dropout_rate,
+ sliding_window=getattr(self, "sliding_window", None),
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ is_causal=self.is_causal,
+ )
+
+ attn_output2 = _flash_attention_forward(
+ query_states,
+ key_states,
+ value_states2,
+ attention_mask,
+ q_len,
+ position_ids=position_ids,
+ dropout=dropout_rate,
+ sliding_window=getattr(self, "sliding_window", None),
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ is_causal=self.is_causal,
+ )
+
+ attn_output = torch.cat([attn_output1, attn_output2], dim=-1)
+ attn_output1, attn_output2 = torch.chunk(attn_output, 2, dim=2)
+
+ lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1, dtype=torch.float32)).to(
+ query_states.dtype
+ )
+ lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1, dtype=torch.float32)).to(
+ query_states.dtype
+ )
+ lambda_full = lambda_1 - lambda_2 + self.lambda_init
+
+ attn_output = attn_output1 - lambda_full * attn_output2
+ attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output)
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights
+
+
+class DiffLlamaSdpaAttention(DiffLlamaAttention):
+ """
+ DiffLlama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
+ `DiffLlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
+ SDPA API.
+ """
+
+ # Adapted from DiffLlamaAttention.forward
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if output_attentions:
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
+ logger.warning_once(
+ "DiffLlamaModel is using DiffLlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ return super().forward(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ )
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+ value_states = torch.cat(torch.chunk(value_states, 2, dim=1), dim=-1)
+ value_states = value_states.repeat(1, 2, 1, 1)
+
+ causal_mask = attention_mask
+ if attention_mask is not None:
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
+
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
+ if query_states.device.type == "cuda" and causal_mask is not None:
+ query_states = query_states.contiguous()
+ key_states = key_states.contiguous()
+ value_states = value_states.contiguous()
+
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
+ is_causal = True if causal_mask is None and q_len > 1 else False
+
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_mask=causal_mask,
+ dropout_p=self.attention_dropout if self.training else 0.0,
+ is_causal=is_causal,
+ )
+
+ attn_output1, attn_output2 = torch.chunk(attn_output, 2, dim=1)
+
+ lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1, dtype=torch.float32)).to(
+ query_states.dtype
+ )
+ lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1, dtype=torch.float32)).to(
+ query_states.dtype
+ )
+ lambda_full = lambda_1 - lambda_2 + self.lambda_init
+
+ attn_output = attn_output1 - lambda_full * attn_output2
+ attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.view(bsz, q_len, -1)
+ attn_output = self.o_proj(attn_output)
+
+ return attn_output, None
+
+
+@use_kernel_forward_from_hub("RMSNorm")
+class DiffLlamaRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ DiffLlamaRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+DIFFLLAMA_ATTENTION_CLASSES = {
+ "eager": DiffLlamaAttention,
+ "flash_attention_2": DiffLlamaFlashAttention2,
+ "sdpa": DiffLlamaSdpaAttention,
+}
+
+
+class DiffLlamaDecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: DiffLlamaConfig, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+
+ self.self_attn = DIFFLLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
+
+ self.mlp = DiffLlamaMLP(config)
+ self.input_layernorm = DiffLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = DiffLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ return outputs
+
+
+DIFFLLAMA_START_DOCSTRING = r"""
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`DiffLlamaConfig`]):
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
+ load the weights associated with the model, only the configuration. Check out the
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+@add_start_docstrings(
+ "The bare DiffLlama Model outputting raw hidden-states without any specific head on top.",
+ DIFFLLAMA_START_DOCSTRING,
+)
+class DiffLlamaPreTrainedModel(PreTrainedModel):
+ config_class = DiffLlamaConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["DiffLlamaDecoderLayer"]
+ _skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn_2 = True
+ _supports_sdpa = True
+ _supports_flex_attn = False
+ _supports_cache_class = True
+ _supports_quantized_cache = True
+ _supports_static_cache = True
+ _supports_attention_backend = False
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, DiffLlamaRMSNorm): # noqa: F821
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, DiffLlamaAttention):
+ module.lambda_q1.data.normal_(0, self.config.lambda_std_dev)
+ module.lambda_k1.data.normal_(0, self.config.lambda_std_dev)
+ module.lambda_q2.data.normal_(0, self.config.lambda_std_dev)
+ module.lambda_k2.data.normal_(0, self.config.lambda_std_dev)
+
+
+class DiffLlamaRotaryEmbedding(nn.Module):
+ def __init__(self, config: DiffLlamaConfig, device=None):
+ super().__init__()
+ # BC: "rope_type" was originally "type"
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+DIFFLLAMA_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length) or `BlockMask`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ If the model is configured to use flex_attention, it will attempt to convert the mask Tensor into a BlockMask,
+ but you can also pass a `BlockMask` object directly here.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
+ `past_key_values`).
+
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
+ information on the default strategy.
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.n_positions - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ past_key_values (`Cache`, *optional*):
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
+
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
+ of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
+ the complete sequence length.
+"""
+
+
+@add_start_docstrings(
+ "The bare DiffLlama Model outputting raw hidden-states without any specific head on top.",
+ DIFFLLAMA_START_DOCSTRING,
+)
+class DiffLlamaModel(DiffLlamaPreTrainedModel):
+ """
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DiffLlamaDecoderLayer`]
+
+ Args:
+ config: DiffLlamaConfig
+ """
+
+ def __init__(self, config: DiffLlamaConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList(
+ [DiffLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = DiffLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = DiffLlamaRotaryEmbedding(config=config)
+ self.gradient_checkpointing = False
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ @can_return_tuple
+ @add_start_docstrings_to_model_forward(DIFFLLAMA_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
+ ) -> BaseModelOutputWithPast:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if self.gradient_checkpointing and self.training and use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
+ )
+ use_cache = False
+
+ # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
+ if not isinstance(past_key_values, (type(None), Cache)):
+ raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if use_cache and past_key_values is None:
+ past_key_values = DynamicCache()
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = self._update_causal_mask(
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
+ )
+
+ hidden_states = inputs_embeds
+
+ # create position embeddings to be shared across the decoder layers
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **flash_attn_kwargs,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values if use_cache else None,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+ def _update_causal_mask(
+ self,
+ attention_mask: Union[torch.Tensor, "BlockMask"],
+ input_tensor: torch.Tensor,
+ cache_position: torch.Tensor,
+ past_key_values: Cache,
+ output_attentions: bool = False,
+ ):
+ if self.config._attn_implementation == "flash_attention_2":
+ if attention_mask is not None and (attention_mask == 0.0).any():
+ return attention_mask
+ return None
+ if self.config._attn_implementation == "flex_attention":
+ if isinstance(attention_mask, torch.Tensor):
+ attention_mask = make_flex_block_causal_mask(attention_mask)
+ return attention_mask
+
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
+ # to infer the attention mask.
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ using_static_cache = isinstance(past_key_values, StaticCache)
+
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
+ if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
+ attention_mask,
+ inputs_embeds=input_tensor,
+ past_key_values_length=past_seen_tokens,
+ is_training=self.training,
+ ):
+ return None
+
+ dtype, device = input_tensor.dtype, input_tensor.device
+ sequence_length = input_tensor.shape[1]
+ if using_static_cache:
+ target_length = past_key_values.get_max_cache_shape()
+ else:
+ target_length = (
+ attention_mask.shape[-1]
+ if isinstance(attention_mask, torch.Tensor)
+ else past_seen_tokens + sequence_length + 1
+ )
+
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask,
+ sequence_length=sequence_length,
+ target_length=target_length,
+ dtype=dtype,
+ device=device,
+ cache_position=cache_position,
+ batch_size=input_tensor.shape[0],
+ )
+
+ if (
+ self.config._attn_implementation == "sdpa"
+ and attention_mask is not None
+ and attention_mask.device.type in ["cuda", "xpu", "npu"]
+ and not output_attentions
+ ):
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
+ # Details: https://github.com/pytorch/pytorch/issues/110213
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
+
+ return causal_mask
+
+ @staticmethod
+ def _prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask: torch.Tensor,
+ sequence_length: int,
+ target_length: int,
+ dtype: torch.dtype,
+ device: torch.device,
+ cache_position: torch.Tensor,
+ batch_size: int,
+ **kwargs,
+ ):
+ """
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
+
+ Args:
+ attention_mask (`torch.Tensor`):
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
+ `(batch_size, 1, query_length, key_value_length)`.
+ sequence_length (`int`):
+ The sequence length being processed.
+ target_length (`int`):
+ The target length: when generating with static cache, the mask should be as long as the static cache,
+ to account for the 0 padding, the part of the cache that is not filled yet.
+ dtype (`torch.dtype`):
+ The dtype to use for the 4D attention mask.
+ device (`torch.device`):
+ The device to place the 4D attention mask on.
+ cache_position (`torch.Tensor`):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ batch_size (`torch.Tensor`):
+ Batch size.
+ """
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+ causal_mask = attention_mask
+ else:
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = torch.full(
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
+ )
+ if sequence_length != 1:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
+ causal_mask.device
+ )
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+ padding_mask, min_dtype
+ )
+
+ return causal_mask
+
+
+class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
+
+
+class DiffLlamaForCausalLM(DiffLlamaPreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["lm_head.weight"]
+ _tp_plan = {"lm_head": "colwise_rep"}
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = DiffLlamaModel(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model = decoder
+
+ def get_decoder(self):
+ return self.model
+
+ @can_return_tuple
+ @add_start_docstrings_to_model_forward(DIFFLLAMA_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs: Unpack[KwargsForCausalLM],
+ ) -> CausalLMOutputWithPast:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, DiffLlamaForCausalLM
+
+ >>> model = DiffLlamaForCausalLM.from_pretrained("google/diffllama-7b")
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/diffllama-7b")
+
+ >>> prompt = "What is your favorite condiment?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "What is your favorite condiment?"
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs: BaseModelOutputWithPast = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs.last_hidden_state
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ The DiffLlama Model transformer with a sequence classification head on top (linear layer).
+
+ [`DiffLlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
+ (e.g. GPT-2) do.
+
+ Since it does classification on the last token, it requires to know the position of the last token. If a
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
+ each row of the batch).
+ """,
+ DIFFLLAMA_START_DOCSTRING,
+)
+class DiffLlamaForSequenceClassification(DiffLlamaPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.model = DiffLlamaModel(config)
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ @can_return_tuple
+ @add_start_docstrings_to_model_forward(DIFFLLAMA_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ ) -> SequenceClassifierOutputWithPast:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+
+ transformer_outputs: BaseModelOutputWithPast = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+ hidden_states = transformer_outputs.last_hidden_state
+ logits = self.score(hidden_states)
+
+ if input_ids is not None:
+ batch_size = input_ids.shape[0]
+ else:
+ batch_size = inputs_embeds.shape[0]
+
+ if self.config.pad_token_id is None and batch_size != 1:
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
+ if self.config.pad_token_id is None:
+ last_non_pad_token = -1
+ elif input_ids is not None:
+ # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
+ non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
+ token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
+ last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
+ else:
+ last_non_pad_token = -1
+ logger.warning_once(
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
+ )
+
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
+
+ return SequenceClassifierOutputWithPast(
+ loss=loss,
+ logits=pooled_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+The DiffLlama Model transformer with a span classification head on top for extractive question-answering tasks like
+SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
+ """,
+ DIFFLLAMA_START_DOCSTRING,
+)
+class DiffLlamaForQuestionAnswering(DiffLlamaPreTrainedModel):
+ base_model_prefix = "transformer"
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.transformer = DiffLlamaModel(config)
+ self.qa_outputs = nn.Linear(config.hidden_size, 2)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.transformer.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.transformer.embed_tokens = value
+
+ @can_return_tuple
+ @add_start_docstrings_to_model_forward(DIFFLLAMA_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ start_positions: Optional[torch.LongTensor] = None,
+ end_positions: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ **kwargs,
+ ) -> QuestionAnsweringModelOutput:
+ r"""
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
+ are not taken into account for computing the loss.
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
+ are not taken into account for computing the loss.
+ """
+
+ outputs: BaseModelOutputWithPast = self.transformer(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+
+ sequence_output = outputs.last_hidden_state
+
+ logits = self.qa_outputs(sequence_output)
+ start_logits, end_logits = logits.split(1, dim=-1)
+ start_logits = start_logits.squeeze(-1).contiguous()
+ end_logits = end_logits.squeeze(-1).contiguous()
+
+ loss = None
+ if start_positions is not None and end_positions is not None:
+ loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)
+
+ return QuestionAnsweringModelOutput(
+ loss=loss,
+ start_logits=start_logits,
+ end_logits=end_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ The DiffLlama Model transformer with a token classification head on top (a linear layer on top of the hidden-states
+ output) e.g. for Named-Entity-Recognition (NER) tasks.
+ """,
+ DIFFLLAMA_START_DOCSTRING,
+)
+class DiffLlamaForTokenClassification(DiffLlamaPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.model = DiffLlamaModel(config)
+ if getattr(config, "classifier_dropout", None) is not None:
+ classifier_dropout = config.classifier_dropout
+ elif getattr(config, "hidden_dropout", None) is not None:
+ classifier_dropout = config.hidden_dropout
+ else:
+ classifier_dropout = 0.1
+ self.dropout = nn.Dropout(classifier_dropout)
+ self.score = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ @can_return_tuple
+ @add_start_docstrings_to_model_forward(DIFFLLAMA_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TokenClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ ) -> TokenClassifierOutput:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+
+ outputs: BaseModelOutputWithPast = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+ sequence_output = outputs.last_hidden_state
+ sequence_output = self.dropout(sequence_output)
+ logits = self.score(sequence_output)
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits, labels, self.config)
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+__all__ = [
+ "DiffLlamaPreTrainedModel",
+ "DiffLlamaModel",
+ "DiffLlamaForCausalLM",
+ "DiffLlamaForSequenceClassification",
+ "DiffLlamaForQuestionAnswering",
+ "DiffLlamaForTokenClassification",
+]
diff --git a/docs/transformers/build/lib/transformers/models/diffllama/modular_diffllama.py b/docs/transformers/build/lib/transformers/models/diffllama/modular_diffllama.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7bc2d2c5ac1fbf98e7ffe9607d2dd8c4e5ae265
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/diffllama/modular_diffllama.py
@@ -0,0 +1,480 @@
+# coding=utf-8
+# Copyright 2024 weak-kajuma and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on Llama implementations in this library and Microsoft's
+# Differential Transformer implementations.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import math
+from typing import Optional, Tuple
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+
+from ...cache_utils import Cache, StaticCache
+from ...modeling_flash_attention_utils import _flash_attention_forward, flash_attn_supports_top_left_mask
+from ...utils import logging
+from ..gemma.modeling_gemma import GemmaForCausalLM
+from ..llama.modeling_llama import (
+ LlamaDecoderLayer,
+ LlamaForQuestionAnswering,
+ LlamaForSequenceClassification,
+ LlamaForTokenClassification,
+ LlamaModel,
+ LlamaPreTrainedModel,
+ apply_rotary_pos_emb,
+ repeat_kv,
+)
+from ..mistral.modeling_mistral import MistralMLP
+from .configuration_diffllama import DiffLlamaConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "kajuma/DiffLlama-0.3B-handcut"
+_CONFIG_FOR_DOC = "DiffLlamaConfig"
+
+
+class DiffLlamaMLP(MistralMLP):
+ pass
+
+
+def lambda_init_fn(layer_idx):
+ return 0.8 - 0.6 * math.exp(-0.3 * layer_idx)
+
+
+class DiffLlamaAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: DiffLlamaConfig, layer_idx: Optional[int] = None):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ if layer_idx is None:
+ logger.warning_once(
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
+ "when creating this class."
+ )
+
+ self.attention_dropout = config.attention_dropout
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
+ self.num_key_value_heads = config.num_key_value_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ # under this are not used
+ self.max_position_embeddings = config.max_position_embeddings
+ self.rope_theta = config.rope_theta
+ self.is_causal = True
+
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
+
+ self.lambda_init = lambda_init_fn(layer_idx)
+ self.lambda_q1 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,)))
+ self.lambda_k1 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,)))
+ self.lambda_q2 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,)))
+ self.lambda_k2 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,)))
+ self.groupnorm = nn.RMSNorm(2 * self.head_dim, eps=config.rms_norm_eps, elementwise_affine=False)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ bsz, target_len, _ = hidden_states.size()
+ q_len = target_len
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+ value_states = torch.cat(torch.chunk(value_states, 2, dim=1), dim=-1)
+ value_states = value_states.repeat(1, 2, 1, 1)
+
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
+ lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1, dtype=torch.float32)).to(
+ query_states.dtype
+ )
+ lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1, dtype=torch.float32)).to(
+ query_states.dtype
+ )
+ lambda_full = lambda_1 - lambda_2 + self.lambda_init
+
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output1, attn_output2 = torch.chunk(attn_output, 2, dim=1)
+
+ attn_output = attn_output1 - lambda_full * attn_output2
+ attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.reshape(bsz, q_len, -1)
+
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights
+
+
+class DiffLlamaFlashAttention2(DiffLlamaAttention):
+ """
+ DiffLlama flash attention module. This module inherits from `DiffLlamaAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+ self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if isinstance(past_key_value, StaticCache):
+ raise ValueError(
+ "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
+ "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
+ )
+
+ output_attentions = False
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ # therefore we just need to keep the original shape
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
+ # to be able to avoid many of these transpose/reshape/view.
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ dropout_rate = self.attention_dropout if self.training else 0.0
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (DiffLlamaRMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ value_states1, value_states2 = torch.chunk(value_states, 2, dim=2)
+ value_states1 = value_states1.repeat(1, 1, 2, 1)
+ value_states2 = value_states2.repeat(1, 1, 2, 1)
+
+ attn_output1 = _flash_attention_forward(
+ query_states,
+ key_states,
+ value_states1,
+ attention_mask,
+ q_len,
+ position_ids=position_ids,
+ dropout=dropout_rate,
+ sliding_window=getattr(self, "sliding_window", None),
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ is_causal=self.is_causal,
+ )
+
+ attn_output2 = _flash_attention_forward(
+ query_states,
+ key_states,
+ value_states2,
+ attention_mask,
+ q_len,
+ position_ids=position_ids,
+ dropout=dropout_rate,
+ sliding_window=getattr(self, "sliding_window", None),
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ is_causal=self.is_causal,
+ )
+
+ attn_output = torch.cat([attn_output1, attn_output2], dim=-1)
+ attn_output1, attn_output2 = torch.chunk(attn_output, 2, dim=2)
+
+ lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1, dtype=torch.float32)).to(
+ query_states.dtype
+ )
+ lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1, dtype=torch.float32)).to(
+ query_states.dtype
+ )
+ lambda_full = lambda_1 - lambda_2 + self.lambda_init
+
+ attn_output = attn_output1 - lambda_full * attn_output2
+ attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output)
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights
+
+
+class DiffLlamaSdpaAttention(DiffLlamaAttention):
+ """
+ DiffLlama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
+ `DiffLlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
+ SDPA API.
+ """
+
+ # Adapted from DiffLlamaAttention.forward
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if output_attentions:
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
+ logger.warning_once(
+ "DiffLlamaModel is using DiffLlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ return super().forward(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ )
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+ value_states = torch.cat(torch.chunk(value_states, 2, dim=1), dim=-1)
+ value_states = value_states.repeat(1, 2, 1, 1)
+
+ causal_mask = attention_mask
+ if attention_mask is not None:
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
+
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
+ if query_states.device.type == "cuda" and causal_mask is not None:
+ query_states = query_states.contiguous()
+ key_states = key_states.contiguous()
+ value_states = value_states.contiguous()
+
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
+ is_causal = True if causal_mask is None and q_len > 1 else False
+
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_mask=causal_mask,
+ dropout_p=self.attention_dropout if self.training else 0.0,
+ is_causal=is_causal,
+ )
+
+ attn_output1, attn_output2 = torch.chunk(attn_output, 2, dim=1)
+
+ lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1, dtype=torch.float32)).to(
+ query_states.dtype
+ )
+ lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1, dtype=torch.float32)).to(
+ query_states.dtype
+ )
+ lambda_full = lambda_1 - lambda_2 + self.lambda_init
+
+ attn_output = attn_output1 - lambda_full * attn_output2
+ attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.view(bsz, q_len, -1)
+ attn_output = self.o_proj(attn_output)
+
+ return attn_output, None
+
+
+DIFFLLAMA_ATTENTION_CLASSES = {
+ "eager": DiffLlamaAttention,
+ "flash_attention_2": DiffLlamaFlashAttention2,
+ "sdpa": DiffLlamaSdpaAttention,
+}
+
+
+class DiffLlamaDecoderLayer(LlamaDecoderLayer):
+ def __init__(self, config: DiffLlamaConfig, layer_idx: int):
+ super().__init__(config, layer_idx)
+
+ self.self_attn = DIFFLLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
+
+
+class DiffLlamaPreTrainedModel(LlamaPreTrainedModel):
+ _supports_flex_attn = False
+ _supports_attention_backend = False
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, DiffLlamaRMSNorm): # noqa: F821
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, DiffLlamaAttention):
+ module.lambda_q1.data.normal_(0, self.config.lambda_std_dev)
+ module.lambda_k1.data.normal_(0, self.config.lambda_std_dev)
+ module.lambda_q2.data.normal_(0, self.config.lambda_std_dev)
+ module.lambda_k2.data.normal_(0, self.config.lambda_std_dev)
+
+
+class DiffLlamaModel(LlamaModel):
+ pass
+
+
+class DiffLlamaForCausalLM(GemmaForCausalLM):
+ pass
+
+
+class DiffLlamaForSequenceClassification(LlamaForSequenceClassification):
+ pass
+
+
+class DiffLlamaForQuestionAnswering(LlamaForQuestionAnswering):
+ pass
+
+
+class DiffLlamaForTokenClassification(LlamaForTokenClassification):
+ pass
+
+
+__all__ = [
+ "DiffLlamaPreTrainedModel",
+ "DiffLlamaModel", # noqa: F822
+ "DiffLlamaForCausalLM",
+ "DiffLlamaForSequenceClassification",
+ "DiffLlamaForQuestionAnswering",
+ "DiffLlamaForTokenClassification",
+]
diff --git a/docs/transformers/build/lib/transformers/models/dinat/__init__.py b/docs/transformers/build/lib/transformers/models/dinat/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b64cdbb3c7eb0467f6112225b8c0d9e1f65f9e99
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/dinat/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_dinat import *
+ from .modeling_dinat import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/docs/transformers/build/lib/transformers/models/dinat/configuration_dinat.py b/docs/transformers/build/lib/transformers/models/dinat/configuration_dinat.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b432e37c851395e98ff9c0dff859294b3e016f4
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/dinat/configuration_dinat.py
@@ -0,0 +1,152 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Dilated Neighborhood Attention Transformer model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
+
+
+logger = logging.get_logger(__name__)
+
+
+class DinatConfig(BackboneConfigMixin, PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`DinatModel`]. It is used to instantiate a Dinat
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the Dinat
+ [shi-labs/dinat-mini-in1k-224](https://huggingface.co/shi-labs/dinat-mini-in1k-224) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ patch_size (`int`, *optional*, defaults to 4):
+ The size (resolution) of each patch. NOTE: Only patch size of 4 is supported at the moment.
+ num_channels (`int`, *optional*, defaults to 3):
+ The number of input channels.
+ embed_dim (`int`, *optional*, defaults to 64):
+ Dimensionality of patch embedding.
+ depths (`List[int]`, *optional*, defaults to `[3, 4, 6, 5]`):
+ Number of layers in each level of the encoder.
+ num_heads (`List[int]`, *optional*, defaults to `[2, 4, 8, 16]`):
+ Number of attention heads in each layer of the Transformer encoder.
+ kernel_size (`int`, *optional*, defaults to 7):
+ Neighborhood Attention kernel size.
+ dilations (`List[List[int]]`, *optional*, defaults to `[[1, 8, 1], [1, 4, 1, 4], [1, 2, 1, 2, 1, 2], [1, 1, 1, 1, 1]]`):
+ Dilation value of each NA layer in the Transformer encoder.
+ mlp_ratio (`float`, *optional*, defaults to 3.0):
+ Ratio of MLP hidden dimensionality to embedding dimensionality.
+ qkv_bias (`bool`, *optional*, defaults to `True`):
+ Whether or not a learnable bias should be added to the queries, keys and values.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
+ The dropout probability for all fully connected layers in the embeddings and encoder.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ drop_path_rate (`float`, *optional*, defaults to 0.1):
+ Stochastic depth rate.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder. If string, `"gelu"`, `"relu"`,
+ `"selu"` and `"gelu_new"` are supported.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the layer normalization layers.
+ layer_scale_init_value (`float`, *optional*, defaults to 0.0):
+ The initial value for the layer scale. Disabled if <=0.
+ out_features (`List[str]`, *optional*):
+ If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
+ (depending on how many stages the model has). If unset and `out_indices` is set, will default to the
+ corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the
+ same order as defined in the `stage_names` attribute.
+ out_indices (`List[int]`, *optional*):
+ If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
+ many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
+ If unset and `out_features` is unset, will default to the last stage. Must be in the
+ same order as defined in the `stage_names` attribute.
+
+ Example:
+
+ ```python
+ >>> from transformers import DinatConfig, DinatModel
+
+ >>> # Initializing a Dinat shi-labs/dinat-mini-in1k-224 style configuration
+ >>> configuration = DinatConfig()
+
+ >>> # Initializing a model (with random weights) from the shi-labs/dinat-mini-in1k-224 style configuration
+ >>> model = DinatModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "dinat"
+
+ attribute_map = {
+ "num_attention_heads": "num_heads",
+ "num_hidden_layers": "num_layers",
+ }
+
+ def __init__(
+ self,
+ patch_size=4,
+ num_channels=3,
+ embed_dim=64,
+ depths=[3, 4, 6, 5],
+ num_heads=[2, 4, 8, 16],
+ kernel_size=7,
+ dilations=[[1, 8, 1], [1, 4, 1, 4], [1, 2, 1, 2, 1, 2], [1, 1, 1, 1, 1]],
+ mlp_ratio=3.0,
+ qkv_bias=True,
+ hidden_dropout_prob=0.0,
+ attention_probs_dropout_prob=0.0,
+ drop_path_rate=0.1,
+ hidden_act="gelu",
+ initializer_range=0.02,
+ layer_norm_eps=1e-5,
+ layer_scale_init_value=0.0,
+ out_features=None,
+ out_indices=None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.embed_dim = embed_dim
+ self.depths = depths
+ self.num_layers = len(depths)
+ self.num_heads = num_heads
+ self.kernel_size = kernel_size
+ self.dilations = dilations
+ self.mlp_ratio = mlp_ratio
+ self.qkv_bias = qkv_bias
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.drop_path_rate = drop_path_rate
+ self.hidden_act = hidden_act
+ self.layer_norm_eps = layer_norm_eps
+ self.initializer_range = initializer_range
+ # we set the hidden_size attribute in order to make Dinat work with VisionEncoderDecoderModel
+ # this indicates the channel dimension after the last stage of the model
+ self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
+ self.layer_scale_init_value = layer_scale_init_value
+ self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)]
+ self._out_features, self._out_indices = get_aligned_output_features_output_indices(
+ out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
+ )
+
+
+__all__ = ["DinatConfig"]
diff --git a/docs/transformers/build/lib/transformers/models/dinat/modeling_dinat.py b/docs/transformers/build/lib/transformers/models/dinat/modeling_dinat.py
new file mode 100644
index 0000000000000000000000000000000000000000..8837372a84c463af5a92ad23627c9336ec8be24c
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/dinat/modeling_dinat.py
@@ -0,0 +1,960 @@
+# coding=utf-8
+# Copyright 2022 SHI Labs and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Dilated Neighborhood Attention Transformer model."""
+
+import math
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from ...activations import ACT2FN
+from ...modeling_outputs import BackboneOutput
+from ...modeling_utils import PreTrainedModel
+from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
+from ...utils import (
+ ModelOutput,
+ OptionalDependencyNotAvailable,
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ is_natten_available,
+ logging,
+ replace_return_docstrings,
+ requires_backends,
+)
+from ...utils.backbone_utils import BackboneMixin
+from .configuration_dinat import DinatConfig
+
+
+if is_natten_available():
+ from natten.functional import natten2dav, natten2dqkrpb
+else:
+
+ def natten2dqkrpb(*args, **kwargs):
+ raise OptionalDependencyNotAvailable()
+
+ def natten2dav(*args, **kwargs):
+ raise OptionalDependencyNotAvailable()
+
+
+logger = logging.get_logger(__name__)
+
+# General docstring
+_CONFIG_FOR_DOC = "DinatConfig"
+
+# Base docstring
+_CHECKPOINT_FOR_DOC = "shi-labs/dinat-mini-in1k-224"
+_EXPECTED_OUTPUT_SHAPE = [1, 7, 7, 512]
+
+# Image classification docstring
+_IMAGE_CLASS_CHECKPOINT = "shi-labs/dinat-mini-in1k-224"
+_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
+
+
+# drop_path and DinatDropPath are from the timm library.
+
+
+@dataclass
+class DinatEncoderOutput(ModelOutput):
+ """
+ Dinat encoder's outputs, with potential hidden states and attentions.
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
+ shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
+ shape `(batch_size, hidden_size, height, width)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
+ include the spatial dimensions.
+ """
+
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+ reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class DinatModelOutput(ModelOutput):
+ """
+ Dinat model's outputs that also contains a pooling of the last hidden states.
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*, returned when `add_pooling_layer=True` is passed):
+ Average pooling of the last layer hidden-state.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
+ shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
+ shape `(batch_size, hidden_size, height, width)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
+ include the spatial dimensions.
+ """
+
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ pooler_output: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+ reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class DinatImageClassifierOutput(ModelOutput):
+ """
+ Dinat outputs for image classification.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Classification (or regression if config.num_labels==1) loss.
+ logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
+ Classification (or regression if config.num_labels==1) scores (before SoftMax).
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
+ shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
+ shape `(batch_size, hidden_size, height, width)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
+ include the spatial dimensions.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+ reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+
+
+class DinatEmbeddings(nn.Module):
+ """
+ Construct the patch and position embeddings.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+
+ self.patch_embeddings = DinatPatchEmbeddings(config)
+
+ self.norm = nn.LayerNorm(config.embed_dim)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor]:
+ embeddings = self.patch_embeddings(pixel_values)
+ embeddings = self.norm(embeddings)
+
+ embeddings = self.dropout(embeddings)
+
+ return embeddings
+
+
+class DinatPatchEmbeddings(nn.Module):
+ """
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
+ `hidden_states` (patch embeddings) of shape `(batch_size, height, width, hidden_size)` to be consumed by a
+ Transformer.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ patch_size = config.patch_size
+ num_channels, hidden_size = config.num_channels, config.embed_dim
+ self.num_channels = num_channels
+
+ if patch_size == 4:
+ pass
+ else:
+ # TODO: Support arbitrary patch sizes.
+ raise ValueError("Dinat only supports patch size of 4 at the moment.")
+
+ self.projection = nn.Sequential(
+ nn.Conv2d(self.num_channels, hidden_size // 2, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
+ nn.Conv2d(hidden_size // 2, hidden_size, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
+ )
+
+ def forward(self, pixel_values: Optional[torch.FloatTensor]) -> torch.Tensor:
+ _, num_channels, height, width = pixel_values.shape
+ if num_channels != self.num_channels:
+ raise ValueError(
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
+ )
+ embeddings = self.projection(pixel_values)
+ embeddings = embeddings.permute(0, 2, 3, 1)
+
+ return embeddings
+
+
+class DinatDownsampler(nn.Module):
+ """
+ Convolutional Downsampling Layer.
+
+ Args:
+ dim (`int`):
+ Number of input channels.
+ norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`):
+ Normalization layer class.
+ """
+
+ def __init__(self, dim: int, norm_layer: nn.Module = nn.LayerNorm) -> None:
+ super().__init__()
+ self.dim = dim
+ self.reduction = nn.Conv2d(dim, 2 * dim, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
+ self.norm = norm_layer(2 * dim)
+
+ def forward(self, input_feature: torch.Tensor) -> torch.Tensor:
+ input_feature = self.reduction(input_feature.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
+ input_feature = self.norm(input_feature)
+ return input_feature
+
+
+# Copied from transformers.models.beit.modeling_beit.drop_path
+def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
+ """
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+ Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
+ however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
+ layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
+ argument.
+ """
+ if drop_prob == 0.0 or not training:
+ return input
+ keep_prob = 1 - drop_prob
+ shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
+ random_tensor.floor_() # binarize
+ output = input.div(keep_prob) * random_tensor
+ return output
+
+
+# Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->Dinat
+class DinatDropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob: Optional[float] = None) -> None:
+ super().__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ return drop_path(hidden_states, self.drop_prob, self.training)
+
+ def extra_repr(self) -> str:
+ return "p={}".format(self.drop_prob)
+
+
+class NeighborhoodAttention(nn.Module):
+ def __init__(self, config, dim, num_heads, kernel_size, dilation):
+ super().__init__()
+ if dim % num_heads != 0:
+ raise ValueError(
+ f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})"
+ )
+
+ self.num_attention_heads = num_heads
+ self.attention_head_size = int(dim / num_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+ self.kernel_size = kernel_size
+ self.dilation = dilation
+
+ # rpb is learnable relative positional biases; same concept is used Swin.
+ self.rpb = nn.Parameter(torch.zeros(num_heads, (2 * self.kernel_size - 1), (2 * self.kernel_size - 1)))
+
+ self.query = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
+ self.key = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
+ self.value = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+
+ def transpose_for_scores(self, x):
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
+ x = x.view(new_x_shape)
+ return x.permute(0, 3, 1, 2, 4)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor]:
+ query_layer = self.transpose_for_scores(self.query(hidden_states))
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+
+ # Apply the scale factor before computing attention weights. It's usually more efficient because
+ # attention weights are typically a bigger tensor compared to query.
+ # It gives identical results because scalars are commutable in matrix multiplication.
+ query_layer = query_layer / math.sqrt(self.attention_head_size)
+
+ # Compute NA between "query" and "key" to get the raw attention scores, and add relative positional biases.
+ attention_scores = natten2dqkrpb(query_layer, key_layer, self.rpb, self.kernel_size, self.dilation)
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs)
+
+ context_layer = natten2dav(attention_probs, value_layer, self.kernel_size, self.dilation)
+ context_layer = context_layer.permute(0, 2, 3, 1, 4).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(new_context_layer_shape)
+
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+ return outputs
+
+
+class NeighborhoodAttentionOutput(nn.Module):
+ def __init__(self, config, dim):
+ super().__init__()
+ self.dense = nn.Linear(dim, dim)
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+
+ return hidden_states
+
+
+class NeighborhoodAttentionModule(nn.Module):
+ def __init__(self, config, dim, num_heads, kernel_size, dilation):
+ super().__init__()
+ self.self = NeighborhoodAttention(config, dim, num_heads, kernel_size, dilation)
+ self.output = NeighborhoodAttentionOutput(config, dim)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
+ )
+
+ # Prune linear layers
+ self.self.query = prune_linear_layer(self.self.query, index)
+ self.self.key = prune_linear_layer(self.self.key, index)
+ self.self.value = prune_linear_layer(self.self.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor]:
+ self_outputs = self.self(hidden_states, output_attentions)
+ attention_output = self.output(self_outputs[0], hidden_states)
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+ return outputs
+
+
+class DinatIntermediate(nn.Module):
+ def __init__(self, config, dim):
+ super().__init__()
+ self.dense = nn.Linear(dim, int(config.mlp_ratio * dim))
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ return hidden_states
+
+
+class DinatOutput(nn.Module):
+ def __init__(self, config, dim):
+ super().__init__()
+ self.dense = nn.Linear(int(config.mlp_ratio * dim), dim)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ return hidden_states
+
+
+class DinatLayer(nn.Module):
+ def __init__(self, config, dim, num_heads, dilation, drop_path_rate=0.0):
+ super().__init__()
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.kernel_size = config.kernel_size
+ self.dilation = dilation
+ self.window_size = self.kernel_size * self.dilation
+ self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps)
+ self.attention = NeighborhoodAttentionModule(
+ config, dim, num_heads, kernel_size=self.kernel_size, dilation=self.dilation
+ )
+ self.drop_path = DinatDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
+ self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps)
+ self.intermediate = DinatIntermediate(config, dim)
+ self.output = DinatOutput(config, dim)
+ self.layer_scale_parameters = (
+ nn.Parameter(config.layer_scale_init_value * torch.ones((2, dim)), requires_grad=True)
+ if config.layer_scale_init_value > 0
+ else None
+ )
+
+ def maybe_pad(self, hidden_states, height, width):
+ window_size = self.window_size
+ pad_values = (0, 0, 0, 0, 0, 0)
+ if height < window_size or width < window_size:
+ pad_l = pad_t = 0
+ pad_r = max(0, window_size - width)
+ pad_b = max(0, window_size - height)
+ pad_values = (0, 0, pad_l, pad_r, pad_t, pad_b)
+ hidden_states = nn.functional.pad(hidden_states, pad_values)
+ return hidden_states, pad_values
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ batch_size, height, width, channels = hidden_states.size()
+ shortcut = hidden_states
+
+ hidden_states = self.layernorm_before(hidden_states)
+ # pad hidden_states if they are smaller than kernel size x dilation
+ hidden_states, pad_values = self.maybe_pad(hidden_states, height, width)
+
+ _, height_pad, width_pad, _ = hidden_states.shape
+
+ attention_outputs = self.attention(hidden_states, output_attentions=output_attentions)
+
+ attention_output = attention_outputs[0]
+
+ was_padded = pad_values[3] > 0 or pad_values[5] > 0
+ if was_padded:
+ attention_output = attention_output[:, :height, :width, :].contiguous()
+
+ if self.layer_scale_parameters is not None:
+ attention_output = self.layer_scale_parameters[0] * attention_output
+
+ hidden_states = shortcut + self.drop_path(attention_output)
+
+ layer_output = self.layernorm_after(hidden_states)
+ layer_output = self.output(self.intermediate(layer_output))
+
+ if self.layer_scale_parameters is not None:
+ layer_output = self.layer_scale_parameters[1] * layer_output
+
+ layer_output = hidden_states + self.drop_path(layer_output)
+
+ layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,)
+ return layer_outputs
+
+
+class DinatStage(nn.Module):
+ def __init__(self, config, dim, depth, num_heads, dilations, drop_path_rate, downsample):
+ super().__init__()
+ self.config = config
+ self.dim = dim
+ self.layers = nn.ModuleList(
+ [
+ DinatLayer(
+ config=config,
+ dim=dim,
+ num_heads=num_heads,
+ dilation=dilations[i],
+ drop_path_rate=drop_path_rate[i],
+ )
+ for i in range(depth)
+ ]
+ )
+
+ # patch merging layer
+ if downsample is not None:
+ self.downsample = downsample(dim=dim, norm_layer=nn.LayerNorm)
+ else:
+ self.downsample = None
+
+ self.pointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor]:
+ _, height, width, _ = hidden_states.size()
+ for i, layer_module in enumerate(self.layers):
+ layer_outputs = layer_module(hidden_states, output_attentions)
+ hidden_states = layer_outputs[0]
+
+ hidden_states_before_downsampling = hidden_states
+ if self.downsample is not None:
+ hidden_states = self.downsample(hidden_states_before_downsampling)
+
+ stage_outputs = (hidden_states, hidden_states_before_downsampling)
+
+ if output_attentions:
+ stage_outputs += layer_outputs[1:]
+ return stage_outputs
+
+
+class DinatEncoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.num_levels = len(config.depths)
+ self.config = config
+ dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")]
+ self.levels = nn.ModuleList(
+ [
+ DinatStage(
+ config=config,
+ dim=int(config.embed_dim * 2**i_layer),
+ depth=config.depths[i_layer],
+ num_heads=config.num_heads[i_layer],
+ dilations=config.dilations[i_layer],
+ drop_path_rate=dpr[sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])],
+ downsample=DinatDownsampler if (i_layer < self.num_levels - 1) else None,
+ )
+ for i_layer in range(self.num_levels)
+ ]
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ output_attentions: Optional[bool] = False,
+ output_hidden_states: Optional[bool] = False,
+ output_hidden_states_before_downsampling: Optional[bool] = False,
+ return_dict: Optional[bool] = True,
+ ) -> Union[Tuple, DinatEncoderOutput]:
+ all_hidden_states = () if output_hidden_states else None
+ all_reshaped_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+
+ if output_hidden_states:
+ # rearrange b h w c -> b c h w
+ reshaped_hidden_state = hidden_states.permute(0, 3, 1, 2)
+ all_hidden_states += (hidden_states,)
+ all_reshaped_hidden_states += (reshaped_hidden_state,)
+
+ for i, layer_module in enumerate(self.levels):
+ layer_outputs = layer_module(hidden_states, output_attentions)
+
+ hidden_states = layer_outputs[0]
+ hidden_states_before_downsampling = layer_outputs[1]
+
+ if output_hidden_states and output_hidden_states_before_downsampling:
+ # rearrange b h w c -> b c h w
+ reshaped_hidden_state = hidden_states_before_downsampling.permute(0, 3, 1, 2)
+ all_hidden_states += (hidden_states_before_downsampling,)
+ all_reshaped_hidden_states += (reshaped_hidden_state,)
+ elif output_hidden_states and not output_hidden_states_before_downsampling:
+ # rearrange b h w c -> b c h w
+ reshaped_hidden_state = hidden_states.permute(0, 3, 1, 2)
+ all_hidden_states += (hidden_states,)
+ all_reshaped_hidden_states += (reshaped_hidden_state,)
+
+ if output_attentions:
+ all_self_attentions += layer_outputs[2:]
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
+
+ return DinatEncoderOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ reshaped_hidden_states=all_reshaped_hidden_states,
+ )
+
+
+class DinatPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = DinatConfig
+ base_model_prefix = "dinat"
+ main_input_name = "pixel_values"
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+
+DINAT_START_DOCSTRING = r"""
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
+ it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+ behavior.
+
+ Parameters:
+ config ([`DinatConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+DINAT_INPUTS_DOCSTRING = r"""
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`]
+ for details.
+
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare Dinat Model transformer outputting raw hidden-states without any specific head on top.",
+ DINAT_START_DOCSTRING,
+)
+class DinatModel(DinatPreTrainedModel):
+ def __init__(self, config, add_pooling_layer=True):
+ super().__init__(config)
+
+ requires_backends(self, ["natten"])
+
+ self.config = config
+ self.num_levels = len(config.depths)
+ self.num_features = int(config.embed_dim * 2 ** (self.num_levels - 1))
+
+ self.embeddings = DinatEmbeddings(config)
+ self.encoder = DinatEncoder(config)
+
+ self.layernorm = nn.LayerNorm(self.num_features, eps=config.layer_norm_eps)
+ self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embeddings.patch_embeddings
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ @add_start_docstrings_to_model_forward(DINAT_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=DinatModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ modality="vision",
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
+ )
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, DinatModelOutput]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ embedding_output = self.embeddings(pixel_values)
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = encoder_outputs[0]
+ sequence_output = self.layernorm(sequence_output)
+
+ pooled_output = None
+ if self.pooler is not None:
+ pooled_output = self.pooler(sequence_output.flatten(1, 2).transpose(1, 2))
+ pooled_output = torch.flatten(pooled_output, 1)
+
+ if not return_dict:
+ output = (sequence_output, pooled_output) + encoder_outputs[1:]
+
+ return output
+
+ return DinatModelOutput(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ reshaped_hidden_states=encoder_outputs.reshaped_hidden_states,
+ )
+
+
+@add_start_docstrings(
+ """
+ Dinat Model transformer with an image classification head on top (a linear layer on top of the final hidden state
+ of the [CLS] token) e.g. for ImageNet.
+ """,
+ DINAT_START_DOCSTRING,
+)
+class DinatForImageClassification(DinatPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ requires_backends(self, ["natten"])
+
+ self.num_labels = config.num_labels
+ self.dinat = DinatModel(config)
+
+ # Classifier head
+ self.classifier = (
+ nn.Linear(self.dinat.num_features, config.num_labels) if config.num_labels > 0 else nn.Identity()
+ )
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(DINAT_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_IMAGE_CLASS_CHECKPOINT,
+ output_type=DinatImageClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
+ )
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, DinatImageClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.dinat(
+ pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ pooled_output = outputs[1]
+
+ logits = self.classifier(pooled_output)
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return DinatImageClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ reshaped_hidden_states=outputs.reshaped_hidden_states,
+ )
+
+
+@add_start_docstrings(
+ "NAT backbone, to be used with frameworks like DETR and MaskFormer.",
+ DINAT_START_DOCSTRING,
+)
+class DinatBackbone(DinatPreTrainedModel, BackboneMixin):
+ def __init__(self, config):
+ super().__init__(config)
+ super()._init_backbone(config)
+
+ requires_backends(self, ["natten"])
+
+ self.embeddings = DinatEmbeddings(config)
+ self.encoder = DinatEncoder(config)
+ self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
+
+ # Add layer norms to hidden states of out_features
+ hidden_states_norms = {}
+ for stage, num_channels in zip(self._out_features, self.channels):
+ hidden_states_norms[stage] = nn.LayerNorm(num_channels)
+ self.hidden_states_norms = nn.ModuleDict(hidden_states_norms)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embeddings.patch_embeddings
+
+ @add_start_docstrings_to_model_forward(DINAT_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ output_hidden_states: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> BackboneOutput:
+ """
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoImageProcessor, AutoBackbone
+ >>> import torch
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> processor = AutoImageProcessor.from_pretrained("shi-labs/nat-mini-in1k-224")
+ >>> model = AutoBackbone.from_pretrained(
+ ... "shi-labs/nat-mini-in1k-224", out_features=["stage1", "stage2", "stage3", "stage4"]
+ ... )
+
+ >>> inputs = processor(image, return_tensors="pt")
+
+ >>> outputs = model(**inputs)
+
+ >>> feature_maps = outputs.feature_maps
+ >>> list(feature_maps[-1].shape)
+ [1, 512, 7, 7]
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+
+ embedding_output = self.embeddings(pixel_values)
+
+ outputs = self.encoder(
+ embedding_output,
+ output_attentions=output_attentions,
+ output_hidden_states=True,
+ output_hidden_states_before_downsampling=True,
+ return_dict=True,
+ )
+
+ hidden_states = outputs.reshaped_hidden_states
+
+ feature_maps = ()
+ for stage, hidden_state in zip(self.stage_names, hidden_states):
+ if stage in self.out_features:
+ batch_size, num_channels, height, width = hidden_state.shape
+ hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous()
+ hidden_state = hidden_state.view(batch_size, height * width, num_channels)
+ hidden_state = self.hidden_states_norms[stage](hidden_state)
+ hidden_state = hidden_state.view(batch_size, height, width, num_channels)
+ hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
+ feature_maps += (hidden_state,)
+
+ if not return_dict:
+ output = (feature_maps,)
+ if output_hidden_states:
+ output += (outputs.hidden_states,)
+ return output
+
+ return BackboneOutput(
+ feature_maps=feature_maps,
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
+ attentions=outputs.attentions,
+ )
+
+
+__all__ = ["DinatForImageClassification", "DinatModel", "DinatPreTrainedModel", "DinatBackbone"]
diff --git a/docs/transformers/build/lib/transformers/models/dinov2/__init__.py b/docs/transformers/build/lib/transformers/models/dinov2/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3cc316957eac509573bf44785209d0729ea13bb6
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/dinov2/__init__.py
@@ -0,0 +1,28 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_dinov2 import *
+ from .modeling_dinov2 import *
+ from .modeling_flax_dinov2 import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/docs/transformers/build/lib/transformers/models/dinov2/configuration_dinov2.py b/docs/transformers/build/lib/transformers/models/dinov2/configuration_dinov2.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4b29273a509d3e25d7a11eb77de8808e0893f96
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/dinov2/configuration_dinov2.py
@@ -0,0 +1,179 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""DINOv2 model configuration"""
+
+from collections import OrderedDict
+from typing import Mapping
+
+from packaging import version
+
+from ...configuration_utils import PretrainedConfig
+from ...onnx import OnnxConfig
+from ...utils import logging
+from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
+
+
+logger = logging.get_logger(__name__)
+
+
+class Dinov2Config(BackboneConfigMixin, PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Dinov2Model`]. It is used to instantiate an
+ Dinov2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the Dinov2
+ [google/dinov2-base-patch16-224](https://huggingface.co/google/dinov2-base-patch16-224) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ mlp_ratio (`int`, *optional*, defaults to 4):
+ Ratio of the hidden size of the MLPs relative to the `hidden_size`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the layer normalization layers.
+ image_size (`int`, *optional*, defaults to 224):
+ The size (resolution) of each image.
+ patch_size (`int`, *optional*, defaults to 14):
+ The size (resolution) of each patch.
+ num_channels (`int`, *optional*, defaults to 3):
+ The number of input channels.
+ qkv_bias (`bool`, *optional*, defaults to `True`):
+ Whether to add a bias to the queries, keys and values.
+ layerscale_value (`float`, *optional*, defaults to 1.0):
+ Initial value to use for layer scale.
+ drop_path_rate (`float`, *optional*, defaults to 0.0):
+ Stochastic depth rate per sample (when applied in the main path of residual layers).
+ use_swiglu_ffn (`bool`, *optional*, defaults to `False`):
+ Whether to use the SwiGLU feedforward neural network.
+ out_features (`List[str]`, *optional*):
+ If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
+ (depending on how many stages the model has). If unset and `out_indices` is set, will default to the
+ corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the
+ same order as defined in the `stage_names` attribute.
+ out_indices (`List[int]`, *optional*):
+ If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
+ many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
+ If unset and `out_features` is unset, will default to the last stage. Must be in the
+ same order as defined in the `stage_names` attribute.
+ apply_layernorm (`bool`, *optional*, defaults to `True`):
+ Whether to apply layer normalization to the feature maps in case the model is used as backbone.
+ reshape_hidden_states (`bool`, *optional*, defaults to `True`):
+ Whether to reshape the feature maps to 4D tensors of shape `(batch_size, hidden_size, height, width)` in
+ case the model is used as backbone. If `False`, the feature maps will be 3D tensors of shape `(batch_size,
+ seq_len, hidden_size)`.
+ use_mask_token (`bool`, *optional*, defaults to `True`):
+ Whether to use mask_token in embeddings.
+
+ Example:
+
+ ```python
+ >>> from transformers import Dinov2Config, Dinov2Model
+
+ >>> # Initializing a Dinov2 dinov2-base-patch16-224 style configuration
+ >>> configuration = Dinov2Config()
+
+ >>> # Initializing a model (with random weights) from the dinov2-base-patch16-224 style configuration
+ >>> model = Dinov2Model(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "dinov2"
+
+ def __init__(
+ self,
+ hidden_size=768,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ mlp_ratio=4,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.0,
+ attention_probs_dropout_prob=0.0,
+ initializer_range=0.02,
+ layer_norm_eps=1e-6,
+ image_size=224,
+ patch_size=14,
+ num_channels=3,
+ qkv_bias=True,
+ layerscale_value=1.0,
+ drop_path_rate=0.0,
+ use_swiglu_ffn=False,
+ out_features=None,
+ out_indices=None,
+ apply_layernorm=True,
+ reshape_hidden_states=True,
+ use_mask_token=True,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.mlp_ratio = mlp_ratio
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.qkv_bias = qkv_bias
+ self.layerscale_value = layerscale_value
+ self.drop_path_rate = drop_path_rate
+ self.use_swiglu_ffn = use_swiglu_ffn
+ self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, num_hidden_layers + 1)]
+ self._out_features, self._out_indices = get_aligned_output_features_output_indices(
+ out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
+ )
+ self.apply_layernorm = apply_layernorm
+ self.reshape_hidden_states = reshape_hidden_states
+ self.use_mask_token = use_mask_token
+
+
+class Dinov2OnnxConfig(OnnxConfig):
+ torch_onnx_minimum_version = version.parse("1.11")
+
+ @property
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
+ return OrderedDict(
+ [
+ ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
+ ]
+ )
+
+ @property
+ def atol_for_validation(self) -> float:
+ return 1e-4
+
+
+__all__ = ["Dinov2Config", "Dinov2OnnxConfig"]
diff --git a/docs/transformers/build/lib/transformers/models/dinov2/convert_dinov2_to_hf.py b/docs/transformers/build/lib/transformers/models/dinov2/convert_dinov2_to_hf.py
new file mode 100644
index 0000000000000000000000000000000000000000..d716191b2fcbd4775bd2349ef98a7ad0d781a90c
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/dinov2/convert_dinov2_to_hf.py
@@ -0,0 +1,285 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Convert DINOv2 checkpoints from the original repository.
+
+URL: https://github.com/facebookresearch/dinov2/tree/main
+"""
+
+import argparse
+import json
+from pathlib import Path
+
+import requests
+import torch
+import torch.nn as nn
+from huggingface_hub import hf_hub_download
+from PIL import Image
+from torchvision import transforms
+
+from transformers import BitImageProcessor, Dinov2Config, Dinov2ForImageClassification, Dinov2Model
+from transformers.image_utils import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, PILImageResampling
+from transformers.utils import logging
+
+
+logging.set_verbosity_info()
+logger = logging.get_logger(__name__)
+
+
+def get_dinov2_config(model_name, image_classifier=False):
+ config = Dinov2Config(image_size=518, patch_size=14)
+
+ # size of the architecture
+ if "vits" in model_name:
+ config.hidden_size = 384
+ config.num_attention_heads = 6
+ elif "vitb" in model_name:
+ pass
+ elif "vitl" in model_name:
+ config.hidden_size = 1024
+ config.num_hidden_layers = 24
+ config.num_attention_heads = 16
+ elif "vitg" in model_name:
+ config.use_swiglu_ffn = True
+ config.hidden_size = 1536
+ config.num_hidden_layers = 40
+ config.num_attention_heads = 24
+ else:
+ raise ValueError("Model not supported")
+
+ if image_classifier:
+ repo_id = "huggingface/label-files"
+ filename = "imagenet-1k-id2label.json"
+ config.num_labels = 1000
+ config.id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
+ config.id2label = {int(k): v for k, v in config.id2label.items()}
+
+ return config
+
+
+def create_rename_keys(config):
+ rename_keys = []
+ # fmt: off
+
+ # patch embedding layer
+ rename_keys.append(("cls_token", "embeddings.cls_token"))
+ rename_keys.append(("mask_token", "embeddings.mask_token"))
+ rename_keys.append(("pos_embed", "embeddings.position_embeddings"))
+ rename_keys.append(("patch_embed.proj.weight", "embeddings.patch_embeddings.projection.weight"))
+ rename_keys.append(("patch_embed.proj.bias", "embeddings.patch_embeddings.projection.bias"))
+
+ for i in range(config.num_hidden_layers):
+ # layernorms
+ rename_keys.append((f"blocks.{i}.norm1.weight", f"encoder.layer.{i}.norm1.weight"))
+ rename_keys.append((f"blocks.{i}.norm1.bias", f"encoder.layer.{i}.norm1.bias"))
+ rename_keys.append((f"blocks.{i}.norm2.weight", f"encoder.layer.{i}.norm2.weight"))
+ rename_keys.append((f"blocks.{i}.norm2.bias", f"encoder.layer.{i}.norm2.bias"))
+ # MLP
+ if config.use_swiglu_ffn:
+ rename_keys.append((f"blocks.{i}.mlp.w12.weight", f"encoder.layer.{i}.mlp.w12.weight"))
+ rename_keys.append((f"blocks.{i}.mlp.w12.bias", f"encoder.layer.{i}.mlp.w12.bias"))
+ rename_keys.append((f"blocks.{i}.mlp.w3.weight", f"encoder.layer.{i}.mlp.w3.weight"))
+ rename_keys.append((f"blocks.{i}.mlp.w3.bias", f"encoder.layer.{i}.mlp.w3.bias"))
+ else:
+ rename_keys.append((f"blocks.{i}.mlp.fc1.weight", f"encoder.layer.{i}.mlp.fc1.weight"))
+ rename_keys.append((f"blocks.{i}.mlp.fc1.bias", f"encoder.layer.{i}.mlp.fc1.bias"))
+ rename_keys.append((f"blocks.{i}.mlp.fc2.weight", f"encoder.layer.{i}.mlp.fc2.weight"))
+ rename_keys.append((f"blocks.{i}.mlp.fc2.bias", f"encoder.layer.{i}.mlp.fc2.bias"))
+ # layerscale
+ rename_keys.append((f"blocks.{i}.ls1.gamma", f"encoder.layer.{i}.layer_scale1.lambda1"))
+ rename_keys.append((f"blocks.{i}.ls2.gamma", f"encoder.layer.{i}.layer_scale2.lambda1"))
+ # attention projection layer
+ rename_keys.append((f"blocks.{i}.attn.proj.weight", f"encoder.layer.{i}.attention.output.dense.weight"))
+ rename_keys.append((f"blocks.{i}.attn.proj.bias", f"encoder.layer.{i}.attention.output.dense.bias"))
+
+ # final layernorm
+ rename_keys.append(("norm.weight", "layernorm.weight"))
+ rename_keys.append(("norm.bias", "layernorm.bias"))
+
+ # fmt: on
+ return rename_keys
+
+
+def rename_key(dct, old, new):
+ val = dct.pop(old)
+ dct[new] = val
+
+
+# we split up the matrix of each encoder layer into queries, keys and values
+def read_in_q_k_v(state_dict, config):
+ for i in range(config.num_hidden_layers):
+ # read in weights + bias of input projection layer (in timm, this is a single matrix + bias)
+ in_proj_weight = state_dict.pop(f"blocks.{i}.attn.qkv.weight")
+ in_proj_bias = state_dict.pop(f"blocks.{i}.attn.qkv.bias")
+ # next, add query, keys and values (in that order) to the state dict
+ state_dict[f"encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[: config.hidden_size, :]
+ state_dict[f"encoder.layer.{i}.attention.attention.query.bias"] = in_proj_bias[: config.hidden_size]
+ state_dict[f"encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[
+ config.hidden_size : config.hidden_size * 2, :
+ ]
+ state_dict[f"encoder.layer.{i}.attention.attention.key.bias"] = in_proj_bias[
+ config.hidden_size : config.hidden_size * 2
+ ]
+ state_dict[f"encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[-config.hidden_size :, :]
+ state_dict[f"encoder.layer.{i}.attention.attention.value.bias"] = in_proj_bias[-config.hidden_size :]
+
+
+# We will verify our results on an image of cute cats
+def prepare_img():
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
+ return image
+
+
+@torch.no_grad()
+def convert_dinov2_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=False):
+ """
+ Copy/paste/tweak model's weights to our DINOv2 structure.
+ """
+
+ # define default Dinov2 configuration
+ image_classifier = "1layer" in model_name
+ config = get_dinov2_config(model_name, image_classifier=image_classifier)
+
+ # load original model from torch hub
+ original_model = torch.hub.load("facebookresearch/dinov2", model_name.replace("_1layer", ""))
+ original_model.eval()
+
+ # load state_dict of original model, remove and rename some keys
+ state_dict = original_model.state_dict()
+ rename_keys = create_rename_keys(config)
+ for src, dest in rename_keys:
+ rename_key(state_dict, src, dest)
+ read_in_q_k_v(state_dict, config)
+
+ for key, val in state_dict.copy().items():
+ val = state_dict.pop(key)
+ if "w12" in key:
+ key = key.replace("w12", "weights_in")
+ if "w3" in key:
+ key = key.replace("w3", "weights_out")
+ state_dict[key] = val
+
+ # load HuggingFace model
+ if image_classifier:
+ model = Dinov2ForImageClassification(config).eval()
+ model.dinov2.load_state_dict(state_dict)
+ model_name_to_classifier_dict_url = {
+ "dinov2_vits14_1layer": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_linear_head.pth",
+ "dinov2_vitb14_1layer": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_linear_head.pth",
+ "dinov2_vitl14_1layer": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_linear_head.pth",
+ "dinov2_vitg14_1layer": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_linear_head.pth",
+ }
+ url = model_name_to_classifier_dict_url[model_name]
+ classifier_state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
+ model.classifier.weight = nn.Parameter(classifier_state_dict["weight"])
+ model.classifier.bias = nn.Parameter(classifier_state_dict["bias"])
+ else:
+ model = Dinov2Model(config).eval()
+ model.load_state_dict(state_dict)
+
+ # load image
+ image = prepare_img()
+
+ # preprocess image
+ transformations = transforms.Compose(
+ [
+ transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
+ transforms.CenterCrop(224),
+ transforms.ToTensor(),
+ transforms.Normalize(
+ mean=IMAGENET_DEFAULT_MEAN, # these are RGB mean+std values
+ std=IMAGENET_DEFAULT_STD, # across a large photo dataset.
+ ),
+ ]
+ )
+
+ original_pixel_values = transformations(image).unsqueeze(0) # insert batch dimension
+
+ processor = BitImageProcessor(
+ size={"shortest_edge": 256},
+ resample=PILImageResampling.BICUBIC,
+ image_mean=IMAGENET_DEFAULT_MEAN,
+ image_std=IMAGENET_DEFAULT_STD,
+ )
+ pixel_values = processor(image, return_tensors="pt").pixel_values
+
+ assert torch.allclose(original_pixel_values, pixel_values)
+
+ with torch.no_grad():
+ outputs = model(pixel_values, output_hidden_states=True)
+ original_outputs = original_model(pixel_values)
+
+ # assert values
+ if image_classifier:
+ print("Predicted class:")
+ class_idx = outputs.logits.argmax(-1).item()
+ print(model.config.id2label[class_idx])
+ else:
+ assert outputs.last_hidden_state[:, 0].shape == original_outputs.shape
+ assert torch.allclose(outputs.last_hidden_state[:, 0], original_outputs, atol=1e-3)
+ print("Looks ok!")
+
+ if pytorch_dump_folder_path is not None:
+ Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
+ print(f"Saving model {model_name} to {pytorch_dump_folder_path}")
+ model.save_pretrained(pytorch_dump_folder_path)
+ print(f"Saving image processor to {pytorch_dump_folder_path}")
+ processor.save_pretrained(pytorch_dump_folder_path)
+
+ if push_to_hub:
+ model_name_to_hf_name = {
+ "dinov2_vits14": "dinov2-small",
+ "dinov2_vitb14": "dinov2-base",
+ "dinov2_vitl14": "dinov2-large",
+ "dinov2_vitg14": "dinov2-giant",
+ "dinov2_vits14_1layer": "dinov2-small-imagenet1k-1-layer",
+ "dinov2_vitb14_1layer": "dinov2-base-imagenet1k-1-layer",
+ "dinov2_vitl14_1layer": "dinov2-large-imagenet1k-1-layer",
+ "dinov2_vitg14_1layer": "dinov2-giant-imagenet1k-1-layer",
+ }
+
+ name = model_name_to_hf_name[model_name]
+ model.push_to_hub(f"facebook/{name}")
+ processor.push_to_hub(f"facebook/{name}")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ # Required parameters
+ parser.add_argument(
+ "--model_name",
+ default="dinov2_vitb14",
+ type=str,
+ choices=[
+ "dinov2_vits14",
+ "dinov2_vitb14",
+ "dinov2_vitl14",
+ "dinov2_vitg14",
+ "dinov2_vits14_1layer",
+ "dinov2_vitb14_1layer",
+ "dinov2_vitl14_1layer",
+ "dinov2_vitg14_1layer",
+ ],
+ help="Name of the model you'd like to convert.",
+ )
+ parser.add_argument(
+ "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
+ )
+ parser.add_argument(
+ "--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
+ )
+
+ args = parser.parse_args()
+ convert_dinov2_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub)
diff --git a/docs/transformers/build/lib/transformers/models/dinov2/modeling_dinov2.py b/docs/transformers/build/lib/transformers/models/dinov2/modeling_dinov2.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ed5a4ec6cb7b180e23c84a8beb1b3a2dc0ed392
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/dinov2/modeling_dinov2.py
@@ -0,0 +1,904 @@
+# coding=utf-8
+# Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch DINOv2 model."""
+
+import collections.abc
+from typing import Callable, Dict, List, Optional, Set, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from ...activations import ACT2FN
+from ...modeling_outputs import (
+ BackboneOutput,
+ BaseModelOutput,
+ BaseModelOutputWithPooling,
+ ImageClassifierOutput,
+)
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
+from ...utils import (
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+ torch_int,
+)
+from ...utils.backbone_utils import BackboneMixin
+from .configuration_dinov2 import Dinov2Config
+
+
+logger = logging.get_logger(__name__)
+
+# General docstring
+_CONFIG_FOR_DOC = "Dinov2Config"
+
+# Base docstring
+_CHECKPOINT_FOR_DOC = "facebook/dinov2-base"
+_EXPECTED_OUTPUT_SHAPE = [1, 257, 768]
+
+# Image classification docstring
+_IMAGE_CLASS_CHECKPOINT = "facebook/dinov2-small-imagenet1k-1-layer"
+_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
+
+
+class Dinov2Embeddings(nn.Module):
+ """
+ Construct the CLS token, mask token, position and patch embeddings.
+ """
+
+ def __init__(self, config: Dinov2Config) -> None:
+ super().__init__()
+
+ self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
+ if config.use_mask_token:
+ self.mask_token = nn.Parameter(torch.zeros(1, config.hidden_size))
+ self.patch_embeddings = Dinov2PatchEmbeddings(config)
+ num_patches = self.patch_embeddings.num_patches
+ self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size))
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ self.patch_size = config.patch_size
+ self.use_mask_token = config.use_mask_token
+ self.config = config
+
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
+ """
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
+ images. This method is also adapted to support torch.jit tracing and interpolation at torch.float32 precision.
+
+ Adapted from:
+ - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
+ - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
+ """
+
+ num_patches = embeddings.shape[1] - 1
+ num_positions = self.position_embeddings.shape[1] - 1
+
+ # always interpolate when tracing to ensure the exported model works for dynamic input shapes
+ if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
+ return self.position_embeddings
+
+ class_pos_embed = self.position_embeddings[:, :1]
+ patch_pos_embed = self.position_embeddings[:, 1:]
+
+ dim = embeddings.shape[-1]
+
+ new_height = height // self.patch_size
+ new_width = width // self.patch_size
+
+ sqrt_num_positions = torch_int(num_positions**0.5)
+ patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
+ target_dtype = patch_pos_embed.dtype
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed.to(torch.float32),
+ size=(new_height, new_width),
+ mode="bicubic",
+ align_corners=False,
+ ).to(dtype=target_dtype)
+
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+
+ return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
+
+ def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Tensor] = None) -> torch.Tensor:
+ batch_size, _, height, width = pixel_values.shape
+ target_dtype = self.patch_embeddings.projection.weight.dtype
+ embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype))
+
+ if bool_masked_pos is not None and self.use_mask_token:
+ embeddings = torch.where(
+ bool_masked_pos.unsqueeze(-1), self.mask_token.to(embeddings.dtype).unsqueeze(0), embeddings
+ )
+
+ # add the [CLS] token to the embedded patch tokens
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
+ embeddings = torch.cat((cls_tokens, embeddings), dim=1)
+
+ # add positional encoding to each token
+ embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
+
+ embeddings = self.dropout(embeddings)
+
+ return embeddings
+
+
+class Dinov2PatchEmbeddings(nn.Module):
+ """
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
+ Transformer.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ image_size, patch_size = config.image_size, config.patch_size
+ num_channels, hidden_size = config.num_channels, config.hidden_size
+
+ image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.num_patches = num_patches
+
+ self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
+
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
+ num_channels = pixel_values.shape[1]
+ if num_channels != self.num_channels:
+ raise ValueError(
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
+ f" Expected {self.num_channels} but got {num_channels}."
+ )
+ embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
+ return embeddings
+
+
+# Copied from transformers.models.vit.modeling_vit.eager_attention_forward
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs,
+):
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
+
+ # Normalize the attention scores to probabilities.
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+
+ # Mask heads if we want to
+ if attention_mask is not None:
+ attn_weights = attn_weights * attention_mask
+
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->Dinov2
+class Dinov2SelfAttention(nn.Module):
+ def __init__(self, config: Dinov2Config) -> None:
+ super().__init__()
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+ raise ValueError(
+ f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
+ f"heads {config.num_attention_heads}."
+ )
+
+ self.config = config
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+ self.dropout_prob = config.attention_probs_dropout_prob
+ self.scaling = self.attention_head_size**-0.5
+ self.is_causal = False
+
+ self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+ self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
+ x = x.view(new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward(
+ self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+ query_layer = self.transpose_for_scores(self.query(hidden_states))
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ if self.config._attn_implementation == "sdpa" and output_attentions:
+ logger.warning_once(
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ else:
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ context_layer, attention_probs = attention_interface(
+ self,
+ query_layer,
+ key_layer,
+ value_layer,
+ head_mask,
+ is_causal=self.is_causal,
+ scaling=self.scaling,
+ dropout=0.0 if not self.training else self.dropout_prob,
+ )
+
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.reshape(new_context_layer_shape)
+
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+ return outputs
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Dinov2
+class Dinov2SelfOutput(nn.Module):
+ """
+ The residual connection is defined in Dinov2Layer instead of here (as is the case with other models), due to the
+ layernorm applied before each block.
+ """
+
+ def __init__(self, config: Dinov2Config) -> None:
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+
+ return hidden_states
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->Dinov2
+class Dinov2Attention(nn.Module):
+ def __init__(self, config: Dinov2Config) -> None:
+ super().__init__()
+ self.attention = Dinov2SelfAttention(config)
+ self.output = Dinov2SelfOutput(config)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads: Set[int]) -> None:
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
+ )
+
+ # Prune linear layers
+ self.attention.query = prune_linear_layer(self.attention.query, index)
+ self.attention.key = prune_linear_layer(self.attention.key, index)
+ self.attention.value = prune_linear_layer(self.attention.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
+ self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
+ self_outputs = self.attention(hidden_states, head_mask, output_attentions)
+
+ attention_output = self.output(self_outputs[0], hidden_states)
+
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+ return outputs
+
+
+class Dinov2LayerScale(nn.Module):
+ def __init__(self, config) -> None:
+ super().__init__()
+ self.lambda1 = nn.Parameter(config.layerscale_value * torch.ones(config.hidden_size))
+
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
+ return hidden_state * self.lambda1
+
+
+# Copied from transformers.models.beit.modeling_beit.drop_path
+def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
+ """
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+ Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
+ however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
+ layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
+ argument.
+ """
+ if drop_prob == 0.0 or not training:
+ return input
+ keep_prob = 1 - drop_prob
+ shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
+ random_tensor.floor_() # binarize
+ output = input.div(keep_prob) * random_tensor
+ return output
+
+
+# Copied from transformers.models.beit.modeling_beit.BeitDropPath
+class Dinov2DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob: Optional[float] = None) -> None:
+ super().__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ return drop_path(hidden_states, self.drop_prob, self.training)
+
+ def extra_repr(self) -> str:
+ return "p={}".format(self.drop_prob)
+
+
+class Dinov2MLP(nn.Module):
+ def __init__(self, config) -> None:
+ super().__init__()
+ in_features = out_features = config.hidden_size
+ hidden_features = int(config.hidden_size * config.mlp_ratio)
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=True)
+ if isinstance(config.hidden_act, str):
+ self.activation = ACT2FN[config.hidden_act]
+ else:
+ self.activation = config.hidden_act
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=True)
+
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
+ hidden_state = self.fc1(hidden_state)
+ hidden_state = self.activation(hidden_state)
+ hidden_state = self.fc2(hidden_state)
+ return hidden_state
+
+
+class Dinov2SwiGLUFFN(nn.Module):
+ def __init__(self, config) -> None:
+ super().__init__()
+ in_features = out_features = config.hidden_size
+ hidden_features = int(config.hidden_size * config.mlp_ratio)
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
+
+ self.weights_in = nn.Linear(in_features, 2 * hidden_features, bias=True)
+ self.weights_out = nn.Linear(hidden_features, out_features, bias=True)
+
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
+ hidden_state = self.weights_in(hidden_state)
+ x1, x2 = hidden_state.chunk(2, dim=-1)
+ hidden = nn.functional.silu(x1) * x2
+ return self.weights_out(hidden)
+
+
+class Dinov2Layer(nn.Module):
+ """This corresponds to the Block class in the original implementation."""
+
+ def __init__(self, config: Dinov2Config) -> None:
+ super().__init__()
+
+ self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.attention = Dinov2Attention(config)
+ self.layer_scale1 = Dinov2LayerScale(config)
+ self.drop_path = Dinov2DropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
+
+ self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ if config.use_swiglu_ffn:
+ self.mlp = Dinov2SwiGLUFFN(config)
+ else:
+ self.mlp = Dinov2MLP(config)
+ self.layer_scale2 = Dinov2LayerScale(config)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
+ self_attention_outputs = self.attention(
+ self.norm1(hidden_states), # in Dinov2, layernorm is applied before self-attention
+ head_mask,
+ output_attentions=output_attentions,
+ )
+ attention_output = self_attention_outputs[0]
+
+ attention_output = self.layer_scale1(attention_output)
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
+
+ # first residual connection
+ hidden_states = self.drop_path(attention_output) + hidden_states
+
+ # in Dinov2, layernorm is also applied after self-attention
+ layer_output = self.norm2(hidden_states)
+ layer_output = self.mlp(layer_output)
+ layer_output = self.layer_scale2(layer_output)
+
+ # second residual connection
+ layer_output = self.drop_path(layer_output) + hidden_states
+
+ outputs = (layer_output,) + outputs
+
+ return outputs
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->Dinov2
+class Dinov2Encoder(nn.Module):
+ def __init__(self, config: Dinov2Config) -> None:
+ super().__init__()
+ self.config = config
+ self.layer = nn.ModuleList([Dinov2Layer(config) for _ in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ) -> Union[tuple, BaseModelOutput]:
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+
+ for i, layer_module in enumerate(self.layer):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ layer_module.__call__,
+ hidden_states,
+ layer_head_mask,
+ output_attentions,
+ )
+ else:
+ layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+
+class Dinov2PreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = Dinov2Config
+ base_model_prefix = "dinov2"
+ main_input_name = "pixel_values"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["Dinov2SwiGLUFFN"]
+ _supports_sdpa = True
+ _supports_flash_attn_2 = True
+
+ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
+ """Initialize the weights"""
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
+ # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
+ # `trunc_normal_cpu` not implemented in `half` issues
+ module.weight.data = nn.init.trunc_normal_(
+ module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
+ ).to(module.weight.dtype)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, Dinov2Embeddings):
+ module.position_embeddings.data = nn.init.trunc_normal_(
+ module.position_embeddings.data.to(torch.float32),
+ mean=0.0,
+ std=self.config.initializer_range,
+ ).to(module.position_embeddings.dtype)
+
+ module.cls_token.data = nn.init.trunc_normal_(
+ module.cls_token.data.to(torch.float32),
+ mean=0.0,
+ std=self.config.initializer_range,
+ ).to(module.cls_token.dtype)
+
+ if self.config.use_mask_token:
+ module.mask_token.data.zero_()
+ elif isinstance(module, Dinov2LayerScale):
+ module.lambda1.data.fill_(self.config.layerscale_value)
+
+
+DINOV2_START_DOCSTRING = r"""
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+ behavior.
+
+ Parameters:
+ config ([`Dinov2Config`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+DINOV2_BASE_INPUTS_DOCSTRING = r"""
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
+ [`BitImageProcessor.preprocess`] for details.
+
+ bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, sequence_length)`):
+ Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). Only relevant for
+ pre-training.
+
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+DINOV2_INPUTS_DOCSTRING = r"""
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
+ [`BitImageProcessor.preprocess`] for details.
+
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare DINOv2 Model transformer outputting raw hidden-states without any specific head on top.",
+ DINOV2_START_DOCSTRING,
+)
+class Dinov2Model(Dinov2PreTrainedModel):
+ def __init__(self, config: Dinov2Config):
+ super().__init__(config)
+ self.config = config
+
+ self.embeddings = Dinov2Embeddings(config)
+ self.encoder = Dinov2Encoder(config)
+
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self) -> Dinov2PatchEmbeddings:
+ return self.embeddings.patch_embeddings
+
+ def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ @add_start_docstrings_to_model_forward(DINOV2_BASE_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=BaseModelOutputWithPooling,
+ config_class=_CONFIG_FOR_DOC,
+ modality="vision",
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
+ )
+ def forward(
+ self,
+ pixel_values: Optional[torch.Tensor] = None,
+ bool_masked_pos: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ embedding_output = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = encoder_outputs[0]
+ sequence_output = self.layernorm(sequence_output)
+ pooled_output = sequence_output[:, 0, :]
+
+ if not return_dict:
+ head_outputs = (sequence_output, pooled_output)
+ return head_outputs + encoder_outputs[1:]
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ Dinov2 Model transformer with an image classification head on top (a linear layer on top of the final hidden state
+ of the [CLS] token) e.g. for ImageNet.
+ """,
+ DINOV2_START_DOCSTRING,
+)
+class Dinov2ForImageClassification(Dinov2PreTrainedModel):
+ def __init__(self, config: Dinov2Config) -> None:
+ super().__init__(config)
+
+ self.num_labels = config.num_labels
+ self.dinov2 = Dinov2Model(config)
+
+ # Classifier head
+ self.classifier = (
+ nn.Linear(config.hidden_size * 2, config.num_labels) if config.num_labels > 0 else nn.Identity()
+ )
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(DINOV2_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_IMAGE_CLASS_CHECKPOINT,
+ output_type=ImageClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
+ )
+ def forward(
+ self,
+ pixel_values: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, ImageClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.dinov2(
+ pixel_values,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0] # batch_size, sequence_length, hidden_size
+
+ cls_token = sequence_output[:, 0]
+ patch_tokens = sequence_output[:, 1:]
+
+ linear_input = torch.cat([cls_token, patch_tokens.mean(dim=1)], dim=1)
+
+ logits = self.classifier(linear_input)
+
+ loss = None
+ if labels is not None:
+ # move labels to correct device to enable model parallelism
+ labels = labels.to(logits.device)
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return ImageClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ Dinov2 backbone, to be used with frameworks like DETR and MaskFormer.
+ """,
+ DINOV2_START_DOCSTRING,
+)
+class Dinov2Backbone(Dinov2PreTrainedModel, BackboneMixin):
+ def __init__(self, config):
+ super().__init__(config)
+ super()._init_backbone(config)
+
+ self.num_features = [config.hidden_size for _ in range(config.num_hidden_layers + 1)]
+ self.embeddings = Dinov2Embeddings(config)
+ self.encoder = Dinov2Encoder(config)
+
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self) -> Dinov2PatchEmbeddings:
+ return self.embeddings.patch_embeddings
+
+ @add_start_docstrings_to_model_forward(DINOV2_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ output_hidden_states: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> BackboneOutput:
+ """
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoImageProcessor, AutoBackbone
+ >>> import torch
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
+ >>> model = AutoBackbone.from_pretrained(
+ ... "facebook/dinov2-base", out_features=["stage2", "stage5", "stage8", "stage11"]
+ ... )
+
+ >>> inputs = processor(image, return_tensors="pt")
+
+ >>> outputs = model(**inputs)
+ >>> feature_maps = outputs.feature_maps
+ >>> list(feature_maps[-1].shape)
+ [1, 768, 16, 16]
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+
+ embedding_output = self.embeddings(pixel_values)
+
+ outputs = self.encoder(
+ embedding_output, output_hidden_states=True, output_attentions=output_attentions, return_dict=return_dict
+ )
+
+ hidden_states = outputs.hidden_states if return_dict else outputs[1]
+
+ feature_maps = ()
+ for stage, hidden_state in zip(self.stage_names, hidden_states):
+ if stage in self.out_features:
+ if self.config.apply_layernorm:
+ hidden_state = self.layernorm(hidden_state)
+ if self.config.reshape_hidden_states:
+ hidden_state = hidden_state[:, 1:]
+ # this was actually a bug in the original implementation that we copied here,
+ # cause normally the order is height, width
+ batch_size, _, height, width = pixel_values.shape
+ patch_size = self.config.patch_size
+ hidden_state = hidden_state.reshape(batch_size, height // patch_size, width // patch_size, -1)
+ hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
+ feature_maps += (hidden_state,)
+
+ if not return_dict:
+ if output_hidden_states:
+ output = (feature_maps,) + outputs[1:]
+ else:
+ output = (feature_maps,) + outputs[2:]
+ return output
+
+ return BackboneOutput(
+ feature_maps=feature_maps,
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
+ attentions=outputs.attentions if output_attentions else None,
+ )
+
+
+__all__ = ["Dinov2ForImageClassification", "Dinov2Model", "Dinov2PreTrainedModel", "Dinov2Backbone"]
diff --git a/docs/transformers/build/lib/transformers/models/dinov2/modeling_flax_dinov2.py b/docs/transformers/build/lib/transformers/models/dinov2/modeling_flax_dinov2.py
new file mode 100644
index 0000000000000000000000000000000000000000..2766850e921ea805d525c330858babab23614204
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/dinov2/modeling_flax_dinov2.py
@@ -0,0 +1,801 @@
+# coding=utf-8
+# Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Flax DINOv2 model."""
+
+import collections.abc
+import math
+from typing import Optional, Tuple
+
+import flax.linen as nn
+import jax
+import jax.numpy as jnp
+from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
+from flax.linen.attention import dot_product_attention_weights
+from flax.traverse_util import flatten_dict, unflatten_dict
+
+from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxBaseModelOutputWithPooling, FlaxSequenceClassifierOutput
+from ...modeling_flax_utils import (
+ ACT2FN,
+ FlaxPreTrainedModel,
+ append_replace_return_docstrings,
+ overwrite_call_docstring,
+)
+from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward
+from .configuration_dinov2 import Dinov2Config
+
+
+DINOV2_START_DOCSTRING = r"""
+
+ This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading, saving and converting weights from PyTorch models)
+
+ This model is also a
+ [flax.linen.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) subclass. Use it as
+ a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and
+ behavior.
+
+ Finally, this model supports inherent JAX features such as:
+
+ - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
+ - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
+ - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
+ - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
+
+ Parameters:
+ config ([`Dinov2Config`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
+ dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
+ The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
+ `jax.numpy.bfloat16` (on TPUs).
+
+ This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
+ specified all the computation will be performed with the given `dtype`.
+
+ **Note that this only specifies the dtype of the computation and does not influence the dtype of model
+ parameters.**
+
+ If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
+ [`~FlaxPreTrainedModel.to_bf16`].
+"""
+
+DINOV2_INPUTS_DOCSTRING = r"""
+ Args:
+ pixel_values (`numpy.ndarray` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`Dinov2ImageProcessor.__call__`]
+ for details.
+
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+class FlaxDinov2PatchEmbeddings(nn.Module):
+ config: Dinov2Config
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ image_size = self.config.image_size
+ patch_size = self.config.patch_size
+ image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+
+ self.num_patches = num_patches
+ self.num_channels = self.config.num_channels
+ self.projection = nn.Conv(
+ self.config.hidden_size,
+ kernel_size=patch_size,
+ strides=patch_size,
+ padding="VALID",
+ dtype=self.dtype,
+ kernel_init=jax.nn.initializers.variance_scaling(
+ self.config.initializer_range**2, "fan_in", "truncated_normal"
+ ),
+ )
+
+ # Copied from transformers.models.vit.modeling_flax_vit.FlaxViTPatchEmbeddings.__call__
+ def __call__(self, pixel_values):
+ num_channels = pixel_values.shape[-1]
+ if num_channels != self.num_channels:
+ raise ValueError(
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
+ )
+ embeddings = self.projection(pixel_values)
+ batch_size, _, _, channels = embeddings.shape
+ return jnp.reshape(embeddings, (batch_size, -1, channels))
+
+
+class FlaxDinov2Embeddings(nn.Module):
+ """Construct the CLS token, position and patch embeddings."""
+
+ config: Dinov2Config
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.cls_token = self.param(
+ "cls_token",
+ jax.nn.initializers.variance_scaling(self.config.initializer_range**2, "fan_in", "truncated_normal"),
+ (1, 1, self.config.hidden_size),
+ )
+ if self.config.use_mask_token:
+ self.mask_token = self.param(
+ "mask_token",
+ jax.nn.initializers.variance_scaling(self.config.initializer_range**2, "fan_in", "truncated_normal"),
+ (1, self.config.hidden_size),
+ )
+ self.patch_embeddings = FlaxDinov2PatchEmbeddings(self.config, dtype=self.dtype)
+ num_patches = self.patch_embeddings.num_patches
+ self.position_embeddings = self.param(
+ "position_embeddings",
+ jax.nn.initializers.variance_scaling(self.config.initializer_range**2, "fan_in", "truncated_normal"),
+ (1, num_patches + 1, self.config.hidden_size),
+ )
+ self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
+
+ def interpolate_pos_encoding(self, config, hidden_states, height, width, position_embeddings):
+ num_patches = hidden_states.shape[1] - 1
+ num_positions = position_embeddings.shape[1] - 1
+ if num_patches == num_positions and height == width:
+ return position_embeddings
+ class_pos_embed = position_embeddings[:, 0]
+ patch_pos_embed = position_embeddings[:, 1:]
+ dim = hidden_states.shape[-1]
+
+ h = height // config.patch_size
+ w = width // config.patch_size
+ height, width = h + 0.1, w + 0.1
+
+ patch_pos_embed = patch_pos_embed.reshape(
+ (1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
+ )
+ patch_pos_embed = jnp.transpose(patch_pos_embed, (0, 3, 1, 2))
+ target_dtype = patch_pos_embed.dtype
+ new_height_ratio = jnp.float32(height / math.sqrt(num_positions))
+ new_width_ratio = jnp.float32(width / math.sqrt(num_positions))
+
+ scale = jnp.array([new_height_ratio, new_width_ratio], dtype=jnp.float32)
+ translation = jnp.array([0.0, 0.0], dtype=jnp.float32)
+
+ patch_pos_embed = jax.image.scale_and_translate(
+ patch_pos_embed.astype(jnp.float32),
+ shape=(patch_pos_embed.shape[0], patch_pos_embed.shape[1], h, w),
+ spatial_dims=(2, 3),
+ scale=scale,
+ translation=translation,
+ method="bicubic",
+ antialias=False,
+ )
+ patch_pos_embed = patch_pos_embed.astype(target_dtype)
+ patch_pos_embed = jnp.transpose(patch_pos_embed, (0, 2, 3, 1)).reshape((position_embeddings.shape[0], -1, dim))
+ patch_pos_embed_expanded = jnp.tile(patch_pos_embed, (hidden_states.shape[0], 1, 1))
+ class_pos_embed_expanded = jnp.tile(class_pos_embed, (hidden_states.shape[0], 1, 1))
+
+ return jnp.concatenate((class_pos_embed_expanded, patch_pos_embed_expanded), axis=1)
+
+ def __call__(self, pixel_values, deterministic=True):
+ batch_size = pixel_values.shape[0]
+ target_dtype = self.patch_embeddings.projection.dtype
+ height, width = pixel_values.shape[1], pixel_values.shape[2]
+
+ embeddings = self.patch_embeddings(pixel_values.astype(target_dtype))
+
+ cls_tokens = jnp.broadcast_to(self.cls_token, (batch_size, 1, self.config.hidden_size))
+ embeddings = jnp.concatenate((cls_tokens, embeddings), axis=1)
+
+ embeddings = embeddings + self.interpolate_pos_encoding(
+ self.config, embeddings, height, width, self.position_embeddings
+ )
+
+ embeddings = self.dropout(embeddings, deterministic=deterministic)
+ return embeddings
+
+
+# Copied from transformers.models.vit.modeling_flax_vit.FlaxViTSelfAttention with ViT->Dinov2
+class FlaxDinov2SelfAttention(nn.Module):
+ config: Dinov2Config
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ if self.config.hidden_size % self.config.num_attention_heads != 0:
+ raise ValueError(
+ "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`:"
+ " {self.config.num_attention_heads}"
+ )
+
+ self.query = nn.Dense(
+ self.config.hidden_size,
+ dtype=self.dtype,
+ kernel_init=jax.nn.initializers.variance_scaling(
+ self.config.initializer_range**2, mode="fan_in", distribution="truncated_normal"
+ ),
+ use_bias=self.config.qkv_bias,
+ )
+ self.key = nn.Dense(
+ self.config.hidden_size,
+ dtype=self.dtype,
+ kernel_init=jax.nn.initializers.variance_scaling(
+ self.config.initializer_range**2, mode="fan_in", distribution="truncated_normal"
+ ),
+ use_bias=self.config.qkv_bias,
+ )
+ self.value = nn.Dense(
+ self.config.hidden_size,
+ dtype=self.dtype,
+ kernel_init=jax.nn.initializers.variance_scaling(
+ self.config.initializer_range**2, mode="fan_in", distribution="truncated_normal"
+ ),
+ use_bias=self.config.qkv_bias,
+ )
+
+ def __call__(self, hidden_states, deterministic: bool = True, output_attentions: bool = False):
+ head_dim = self.config.hidden_size // self.config.num_attention_heads
+
+ query_states = self.query(hidden_states).reshape(
+ hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
+ )
+ value_states = self.value(hidden_states).reshape(
+ hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
+ )
+ key_states = self.key(hidden_states).reshape(
+ hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
+ )
+
+ dropout_rng = None
+ if not deterministic and self.config.attention_probs_dropout_prob > 0.0:
+ dropout_rng = self.make_rng("dropout")
+
+ attn_weights = dot_product_attention_weights(
+ query_states,
+ key_states,
+ dropout_rng=dropout_rng,
+ dropout_rate=self.config.attention_probs_dropout_prob,
+ broadcast_dropout=True,
+ deterministic=deterministic,
+ dtype=self.dtype,
+ precision=None,
+ )
+
+ attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
+ attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,))
+
+ outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
+ return outputs
+
+
+# Copied from transformers.models.vit.modeling_flax_vit.FlaxViTSelfOutput with ViT->Dinov2
+class FlaxDinov2SelfOutput(nn.Module):
+ config: Dinov2Config
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.dense = nn.Dense(
+ self.config.hidden_size,
+ kernel_init=jax.nn.initializers.variance_scaling(
+ self.config.initializer_range**2, "fan_in", "truncated_normal"
+ ),
+ dtype=self.dtype,
+ )
+ self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
+
+ def __call__(self, hidden_states, input_tensor, deterministic: bool = True):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
+ return hidden_states
+
+
+# Copied from transformers.models.vit.modeling_flax_vit.FlaxViTAttention with ViT->Dinov2
+class FlaxDinov2Attention(nn.Module):
+ config: Dinov2Config
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ self.attention = FlaxDinov2SelfAttention(self.config, dtype=self.dtype)
+ self.output = FlaxDinov2SelfOutput(self.config, dtype=self.dtype)
+
+ def __call__(self, hidden_states, deterministic=True, output_attentions: bool = False):
+ attn_outputs = self.attention(hidden_states, deterministic=deterministic, output_attentions=output_attentions)
+ attn_output = attn_outputs[0]
+ hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic)
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (attn_outputs[1],)
+
+ return outputs
+
+
+def ones_with_scale(key, shape, scale, dtype=jnp.float32):
+ return jnp.ones(shape, dtype) * scale
+
+
+class FlaxDinov2LayerScale(nn.Module):
+ config: Dinov2Config
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.lambda1 = self.config.layerscale_value * self.param(
+ "lambda1",
+ jax.nn.initializers.ones,
+ (self.config.hidden_size,),
+ )
+ self.lambda1 = self.lambda1 * self.config.layerscale_value
+
+ def __call__(self, hidden_states):
+ return self.lambda1 * hidden_states
+
+
+# Copied from transformers.models.beit.modeling_flax_beit.FlaxBeitDropPath with Beit -> Dinov2
+class FlaxDinov2DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ rate: float
+
+ @nn.module.compact
+ def __call__(self, inputs, deterministic: Optional[bool] = True):
+ if self.rate == 0.0:
+ return inputs
+ keep_prob = 1.0 - self.rate
+ if deterministic:
+ return inputs
+ else:
+ shape = (inputs.shape[0],) + (1,) * (inputs.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ rng = self.make_rng("droppath")
+ random_tensor = keep_prob + jax.random.uniform(rng, shape=shape, dtype=inputs.dtype)
+ binary_tensor = jnp.floor(random_tensor)
+ output = inputs / keep_prob * binary_tensor
+ return output
+
+
+class FlaxDinov2MLP(nn.Module):
+ config: Dinov2Config
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.fc1 = nn.Dense(
+ self.config.hidden_size * self.config.mlp_ratio,
+ kernel_init=jax.nn.initializers.variance_scaling(
+ self.config.initializer_range**2, "fan_in", "truncated_normal"
+ ),
+ dtype=self.dtype,
+ )
+ self.fc2 = nn.Dense(
+ self.config.hidden_size,
+ kernel_init=jax.nn.initializers.variance_scaling(
+ self.config.initializer_range**2, "fan_in", "truncated_normal"
+ ),
+ dtype=self.dtype,
+ )
+ if isinstance(self.config.hidden_act, str):
+ self.act = ACT2FN[self.config.hidden_act]
+ else:
+ self.act = self.config.hidden_act
+
+ def __call__(self, hidden_states):
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+class FlaxDinov2SwiGLUFFN(nn.Module):
+ config: Dinov2Config
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ hidden_features = int(self.config.hidden_size * self.config.mlp_ratio)
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
+
+ self.weights_in = nn.Dense(
+ 2 * hidden_features,
+ kernel_init=jax.nn.initializers.variance_scaling(
+ self.config.initializer_range**2, "fan_in", "truncated_normal"
+ ),
+ dtype=self.dtype,
+ )
+ self.weights_out = nn.Dense(
+ self.config.hidden_size,
+ kernel_init=jax.nn.initializers.variance_scaling(
+ self.config.initializer_range**2, "fan_in", "truncated_normal"
+ ),
+ dtype=self.dtype,
+ )
+
+ def __call__(self, hidden_states):
+ hidden_states = self.weights_in(hidden_states)
+ x1, x2 = jnp.split(hidden_states, 2, axis=-1)
+ hidden = nn.silu(x1) * x2
+ return self.weights_out(hidden)
+
+
+class FlaxDinov2Layer(nn.Module):
+ config: Dinov2Config
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.norm1 = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
+ self.attention = FlaxDinov2Attention(self.config, dtype=self.dtype)
+ self.layer_scale1 = FlaxDinov2LayerScale(self.config, dtype=self.dtype)
+ self.drop_path = FlaxDinov2DropPath(self.config.drop_path_rate)
+ self.norm2 = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
+
+ if self.config.use_swiglu_ffn:
+ self.mlp = FlaxDinov2SwiGLUFFN(self.config, dtype=self.dtype)
+ else:
+ self.mlp = FlaxDinov2MLP(self.config, dtype=self.dtype)
+
+ self.layer_scale2 = FlaxDinov2LayerScale(self.config, dtype=self.dtype)
+
+ def __call__(self, hidden_states, deterministic: bool = True, output_attentions: bool = False):
+ self_attention_outputs = self.attention(
+ self.norm1(hidden_states), # in Dinov2, layernorm is applied before self-attention
+ deterministic=deterministic,
+ output_attentions=output_attentions,
+ )
+
+ attention_output = self_attention_outputs[0]
+
+ attention_output = self.layer_scale1(attention_output)
+
+ outputs = self_attention_outputs[1:]
+
+ # first residual connection
+ hidden_states = self.drop_path(attention_output) + hidden_states
+
+ # in Dinov2, layernorm is also applied after self-attention
+ layer_output = self.norm2(hidden_states)
+ layer_output = self.mlp(layer_output)
+ layer_output = self.layer_scale2(layer_output)
+
+ # second residual connection
+ layer_output = self.drop_path(layer_output) + hidden_states
+
+ outputs = (layer_output,) + outputs
+
+ return outputs
+
+
+# Copied from transformers.models.vit.modeling_flax_vit.FlaxViTLayerCollection with ViT->Dinov2
+class FlaxDinov2LayerCollection(nn.Module):
+ config: Dinov2Config
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.layers = [
+ FlaxDinov2Layer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers)
+ ]
+
+ def __call__(
+ self,
+ hidden_states,
+ deterministic: bool = True,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ):
+ all_attentions = () if output_attentions else None
+ all_hidden_states = () if output_hidden_states else None
+
+ for i, layer in enumerate(self.layers):
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ layer_outputs = layer(hidden_states, deterministic=deterministic, output_attentions=output_attentions)
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_attentions += (layer_outputs[1],)
+
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ outputs = (hidden_states,)
+ if not return_dict:
+ return tuple(v for v in outputs if v is not None)
+
+ return FlaxBaseModelOutput(
+ last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
+ )
+
+
+# Copied from transformers.models.vit.modeling_flax_vit.FlaxViTEncoder with ViT->Dinov2
+class FlaxDinov2Encoder(nn.Module):
+ config: Dinov2Config
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.layer = FlaxDinov2LayerCollection(self.config, dtype=self.dtype)
+
+ def __call__(
+ self,
+ hidden_states,
+ deterministic: bool = True,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ):
+ return self.layer(
+ hidden_states,
+ deterministic=deterministic,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+
+class FlaxDinov2PreTrainedModel(FlaxPreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = Dinov2Config
+ base_model_prefix = "dinov2"
+ main_input_name = "pixel_values"
+ module_class: nn.Module = None
+
+ def __init__(
+ self,
+ config: Dinov2Config,
+ input_shape=None,
+ seed: int = 0,
+ dtype: jnp.dtype = jnp.float32,
+ _do_init: bool = True,
+ **kwargs,
+ ):
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
+ if input_shape is None:
+ input_shape = (1, config.image_size, config.image_size, config.num_channels)
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
+
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
+ # init input tensors
+ pixel_values = jnp.zeros(input_shape, dtype=self.dtype)
+
+ params_rng, dropout_rng = jax.random.split(rng)
+ dropout_rng, droppath_rng = jax.random.split(dropout_rng)
+ rngs = {"params": params_rng, "dropout": dropout_rng, "droppath": droppath_rng}
+
+ random_params = self.module.init(rngs, pixel_values, return_dict=False)["params"]
+
+ if params is not None:
+ random_params = flatten_dict(unfreeze(random_params))
+ params = flatten_dict(unfreeze(params))
+ for missing_key in self._missing_keys:
+ params[missing_key] = random_params[missing_key]
+ self._missing_keys = set()
+ return freeze(unflatten_dict(params))
+ else:
+ return random_params
+
+ @add_start_docstrings_to_model_forward(DINOV2_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ def __call__(
+ self,
+ pixel_values,
+ params: dict = None,
+ dropout_rng: jax.random.PRNGKey = None,
+ train: bool = False,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ):
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
+
+ pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1))
+ # Handle any PRNG if needed
+ rngs = {}
+ if dropout_rng is not None:
+ dropout_rng, droppath_rng = jax.random.split(dropout_rng)
+ rngs["dropout"] = dropout_rng
+ rngs["droppath"] = droppath_rng
+
+ return self.module.apply(
+ {"params": params or self.params},
+ jnp.array(pixel_values, dtype=jnp.float32),
+ not train,
+ output_attentions,
+ output_hidden_states,
+ return_dict,
+ rngs=rngs,
+ )
+
+
+class FlaxDinov2Module(nn.Module):
+ config: Dinov2Config
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.embeddings = FlaxDinov2Embeddings(self.config, dtype=self.dtype)
+ self.encoder = FlaxDinov2Encoder(self.config, dtype=self.dtype)
+ self.layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
+
+ def __call__(
+ self,
+ pixel_values,
+ deterministic: bool = True,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ):
+ hidden_states = self.embeddings(pixel_values, deterministic=deterministic)
+
+ encoder_outputs = self.encoder(
+ hidden_states,
+ deterministic=deterministic,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = encoder_outputs[0]
+ sequence_output = self.layernorm(sequence_output)
+ pooled_output = sequence_output[:, 0, :]
+
+ if not return_dict:
+ head_outputs = (sequence_output, pooled_output)
+ return head_outputs + encoder_outputs[1:]
+
+ return FlaxBaseModelOutputWithPooling(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ "The bare Dinov2 Model transformer outputting raw hidden-states without any specific head on top.",
+ DINOV2_START_DOCSTRING,
+)
+class FlaxDinov2Model(FlaxDinov2PreTrainedModel):
+ module_class = FlaxDinov2Module
+
+
+FLAX_VISION_MODEL_DOCSTRING = """
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoImageProcessor, FlaxDinov2Model
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> image_processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
+ >>> model = FlaxDinov2Model.from_pretrained("facebook/dinov2-base")
+
+ >>> inputs = image_processor(images=image, return_tensors="np")
+ >>> outputs = model(**inputs)
+ >>> last_hidden_states = outputs.last_hidden_state
+ ```
+"""
+
+overwrite_call_docstring(FlaxDinov2Model, FLAX_VISION_MODEL_DOCSTRING)
+append_replace_return_docstrings(
+ FlaxDinov2Model, output_type=FlaxBaseModelOutputWithPooling, config_class=Dinov2Config
+)
+
+
+class FlaxDinov2ForImageClassificationModule(nn.Module):
+ config: Dinov2Config
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ self.dinov2 = FlaxDinov2Module(config=self.config, dtype=self.dtype)
+ self.classifier = nn.Dense(
+ self.config.num_labels,
+ dtype=self.dtype,
+ kernel_init=jax.nn.initializers.variance_scaling(
+ self.config.initializer_range**2, "fan_in", "truncated_normal"
+ ),
+ )
+
+ def __call__(
+ self,
+ pixel_values=None,
+ deterministic: bool = True,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.dinov2(
+ pixel_values,
+ deterministic=deterministic,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[0]
+
+ cls_token = hidden_states[:, 0]
+ patch_tokens = hidden_states[:, 1:]
+ linear_input = jnp.concatenate([cls_token, patch_tokens.mean(axis=1)], axis=-1)
+
+ logits = self.classifier(linear_input)
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return output
+
+ return FlaxSequenceClassifierOutput(
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ Dinov2 Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
+ the [CLS] token) e.g. for ImageNet.
+ """,
+ DINOV2_START_DOCSTRING,
+)
+class FlaxDinov2ForImageClassification(FlaxDinov2PreTrainedModel):
+ module_class = FlaxDinov2ForImageClassificationModule
+
+
+FLAX_VISION_CLASSIFICATION_DOCSTRING = """
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoImageProcessor, FlaxDinov2ForImageClassification
+ >>> from PIL import Image
+ >>> import jax
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> image_processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base-imagenet1k-1-layer")
+ >>> model = FlaxDinov2ForImageClassification.from_pretrained("facebook/dinov2-base-imagenet1k-1-layer", from_pt=True)
+
+ >>> inputs = image_processor(images=image, return_tensors="np")
+ >>> outputs = model(**inputs)
+ >>> logits = outputs.logits
+
+ >>> # model predicts one of the 1000 ImageNet classes
+ >>> predicted_class_idx = jax.numpy.argmax(logits, axis=-1)
+ >>> print("Predicted class:", model.config.id2label[predicted_class_idx.item()])
+ ```
+"""
+
+overwrite_call_docstring(FlaxDinov2ForImageClassification, FLAX_VISION_CLASSIFICATION_DOCSTRING)
+append_replace_return_docstrings(
+ FlaxDinov2ForImageClassification, output_type=FlaxSequenceClassifierOutput, config_class=Dinov2Config
+)
+
+
+__all__ = ["FlaxDinov2ForImageClassification", "FlaxDinov2Model", "FlaxDinov2PreTrainedModel"]
diff --git a/docs/transformers/build/lib/transformers/models/dinov2_with_registers/__init__.py b/docs/transformers/build/lib/transformers/models/dinov2_with_registers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d10027b6a3b6375235a6785df044e8f0ce5fb33
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/dinov2_with_registers/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_dinov2_with_registers import *
+ from .modeling_dinov2_with_registers import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/docs/transformers/build/lib/transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py b/docs/transformers/build/lib/transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0295899fcd72f6481aa4e12cbef8630652f8149
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py
@@ -0,0 +1,159 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_dinov2_with_registers.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2024 Meta Inc. and the HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ...configuration_utils import PretrainedConfig
+from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
+
+
+class Dinov2WithRegistersConfig(BackboneConfigMixin, PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Dinov2WithRegistersModel`]. It is used to instantiate an
+ Dinov2WithRegisters model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the DINOv2 with Registers
+ [facebook/dinov2-with-registers-base](https://huggingface.co/facebook/dinov2-with-registers-base) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ mlp_ratio (`int`, *optional*, defaults to 4):
+ Ratio of the hidden size of the MLPs relative to the `hidden_size`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the layer normalization layers.
+ image_size (`int`, *optional*, defaults to 224):
+ The size (resolution) of each image.
+ patch_size (`int`, *optional*, defaults to 16):
+ The size (resolution) of each patch.
+ num_channels (`int`, *optional*, defaults to 3):
+ The number of input channels.
+ qkv_bias (`bool`, *optional*, defaults to `True`):
+ Whether to add a bias to the queries, keys and values.
+ layerscale_value (`float`, *optional*, defaults to 1.0):
+ Initial value to use for layer scale.
+ drop_path_rate (`float`, *optional*, defaults to 0.0):
+ Stochastic depth rate per sample (when applied in the main path of residual layers).
+ use_swiglu_ffn (`bool`, *optional*, defaults to `False`):
+ Whether to use the SwiGLU feedforward neural network.
+ num_register_tokens (`int`, *optional*, defaults to 4):
+ Number of register tokens to use.
+ out_features (`List[str]`, *optional*):
+ If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
+ (depending on how many stages the model has). If unset and `out_indices` is set, will default to the
+ corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the
+ same order as defined in the `stage_names` attribute.
+ out_indices (`List[int]`, *optional*):
+ If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
+ many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
+ If unset and `out_features` is unset, will default to the last stage. Must be in the
+ same order as defined in the `stage_names` attribute.
+ apply_layernorm (`bool`, *optional*, defaults to `True`):
+ Whether to apply layer normalization to the feature maps in case the model is used as backbone.
+ reshape_hidden_states (`bool`, *optional*, defaults to `True`):
+ Whether to reshape the feature maps to 4D tensors of shape `(batch_size, hidden_size, height, width)` in
+ case the model is used as backbone. If `False`, the feature maps will be 3D tensors of shape `(batch_size,
+ seq_len, hidden_size)`.
+
+ Example:
+
+ ```python
+ >>> from transformers import Dinov2WithRegistersConfig, Dinov2WithRegistersModel
+
+ >>> # Initializing a Dinov2WithRegisters base style configuration
+ >>> configuration = Dinov2WithRegistersConfig()
+
+ >>> # Initializing a model (with random weights) from the base style configuration
+ >>> model = Dinov2WithRegistersModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "dinov2_with_registers"
+
+ def __init__(
+ self,
+ hidden_size=768,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ mlp_ratio=4,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.0,
+ attention_probs_dropout_prob=0.0,
+ initializer_range=0.02,
+ layer_norm_eps=1e-6,
+ image_size=224,
+ patch_size=16,
+ num_channels=3,
+ qkv_bias=True,
+ layerscale_value=1.0,
+ drop_path_rate=0.0,
+ use_swiglu_ffn=False,
+ num_register_tokens=4,
+ out_features=None,
+ out_indices=None,
+ apply_layernorm=True,
+ reshape_hidden_states=True,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.mlp_ratio = mlp_ratio
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.qkv_bias = qkv_bias
+ self.layerscale_value = layerscale_value
+ self.drop_path_rate = drop_path_rate
+ self.use_swiglu_ffn = use_swiglu_ffn
+ self.num_register_tokens = num_register_tokens
+ self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, num_hidden_layers + 1)]
+ self._out_features, self._out_indices = get_aligned_output_features_output_indices(
+ out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
+ )
+ self.apply_layernorm = apply_layernorm
+ self.reshape_hidden_states = reshape_hidden_states
+
+
+__all__ = ["Dinov2WithRegistersConfig"]
diff --git a/docs/transformers/build/lib/transformers/models/dinov2_with_registers/convert_dinov2_with_registers_to_hf.py b/docs/transformers/build/lib/transformers/models/dinov2_with_registers/convert_dinov2_with_registers_to_hf.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ff2697f74667e1b941ec65adb1a39cfd0a87460
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/dinov2_with_registers/convert_dinov2_with_registers_to_hf.py
@@ -0,0 +1,291 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Convert DINOv2 with Registers checkpoints from the original repository.
+
+URL: https://github.com/facebookresearch/dinov2/tree/main
+"""
+
+import argparse
+import json
+from pathlib import Path
+
+import requests
+import torch
+import torch.nn as nn
+from huggingface_hub import hf_hub_download
+from PIL import Image
+from torchvision import transforms
+
+from transformers import (
+ BitImageProcessor,
+ Dinov2WithRegistersConfig,
+ Dinov2WithRegistersForImageClassification,
+ Dinov2WithRegistersModel,
+)
+from transformers.image_utils import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, PILImageResampling
+from transformers.utils import logging
+
+
+logging.set_verbosity_info()
+logger = logging.get_logger(__name__)
+
+
+def get_dinov2_with_registers_config(model_name, image_classifier=False):
+ config = Dinov2WithRegistersConfig(image_size=518, patch_size=14)
+
+ # size of the architecture
+ if "vits" in model_name:
+ config.hidden_size = 384
+ config.num_attention_heads = 6
+ elif "vitb" in model_name:
+ pass
+ elif "vitl" in model_name:
+ config.hidden_size = 1024
+ config.num_hidden_layers = 24
+ config.num_attention_heads = 16
+ elif "vitg" in model_name:
+ config.use_swiglu_ffn = True
+ config.hidden_size = 1536
+ config.num_hidden_layers = 40
+ config.num_attention_heads = 24
+ else:
+ raise ValueError("Model not supported")
+
+ if image_classifier:
+ repo_id = "huggingface/label-files"
+ filename = "imagenet-1k-id2label.json"
+ config.num_labels = 1000
+ config.id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
+ config.id2label = {int(k): v for k, v in config.id2label.items()}
+
+ return config
+
+
+def create_rename_keys(config):
+ rename_keys = []
+ # fmt: off
+
+ # patch embedding layer
+ rename_keys.append(("cls_token", "embeddings.cls_token"))
+ rename_keys.append(("mask_token", "embeddings.mask_token"))
+ rename_keys.append(("pos_embed", "embeddings.position_embeddings"))
+ rename_keys.append(("register_tokens", "embeddings.register_tokens"))
+ rename_keys.append(("patch_embed.proj.weight", "embeddings.patch_embeddings.projection.weight"))
+ rename_keys.append(("patch_embed.proj.bias", "embeddings.patch_embeddings.projection.bias"))
+
+ for i in range(config.num_hidden_layers):
+ # layernorms
+ rename_keys.append((f"blocks.{i}.norm1.weight", f"encoder.layer.{i}.norm1.weight"))
+ rename_keys.append((f"blocks.{i}.norm1.bias", f"encoder.layer.{i}.norm1.bias"))
+ rename_keys.append((f"blocks.{i}.norm2.weight", f"encoder.layer.{i}.norm2.weight"))
+ rename_keys.append((f"blocks.{i}.norm2.bias", f"encoder.layer.{i}.norm2.bias"))
+ # MLP
+ if config.use_swiglu_ffn:
+ rename_keys.append((f"blocks.{i}.mlp.w12.weight", f"encoder.layer.{i}.mlp.w12.weight"))
+ rename_keys.append((f"blocks.{i}.mlp.w12.bias", f"encoder.layer.{i}.mlp.w12.bias"))
+ rename_keys.append((f"blocks.{i}.mlp.w3.weight", f"encoder.layer.{i}.mlp.w3.weight"))
+ rename_keys.append((f"blocks.{i}.mlp.w3.bias", f"encoder.layer.{i}.mlp.w3.bias"))
+ else:
+ rename_keys.append((f"blocks.{i}.mlp.fc1.weight", f"encoder.layer.{i}.mlp.fc1.weight"))
+ rename_keys.append((f"blocks.{i}.mlp.fc1.bias", f"encoder.layer.{i}.mlp.fc1.bias"))
+ rename_keys.append((f"blocks.{i}.mlp.fc2.weight", f"encoder.layer.{i}.mlp.fc2.weight"))
+ rename_keys.append((f"blocks.{i}.mlp.fc2.bias", f"encoder.layer.{i}.mlp.fc2.bias"))
+ # layerscale
+ rename_keys.append((f"blocks.{i}.ls1.gamma", f"encoder.layer.{i}.layer_scale1.lambda1"))
+ rename_keys.append((f"blocks.{i}.ls2.gamma", f"encoder.layer.{i}.layer_scale2.lambda1"))
+ # attention projection layer
+ rename_keys.append((f"blocks.{i}.attn.proj.weight", f"encoder.layer.{i}.attention.output.dense.weight"))
+ rename_keys.append((f"blocks.{i}.attn.proj.bias", f"encoder.layer.{i}.attention.output.dense.bias"))
+
+ # final layernorm
+ rename_keys.append(("norm.weight", "layernorm.weight"))
+ rename_keys.append(("norm.bias", "layernorm.bias"))
+
+ # fmt: on
+ return rename_keys
+
+
+def rename_key(dct, old, new):
+ val = dct.pop(old)
+ dct[new] = val
+
+
+# we split up the matrix of each encoder layer into queries, keys and values
+def read_in_q_k_v(state_dict, config):
+ for i in range(config.num_hidden_layers):
+ # read in weights + bias of input projection layer (in timm, this is a single matrix + bias)
+ in_proj_weight = state_dict.pop(f"blocks.{i}.attn.qkv.weight")
+ in_proj_bias = state_dict.pop(f"blocks.{i}.attn.qkv.bias")
+ # next, add query, keys and values (in that order) to the state dict
+ state_dict[f"encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[: config.hidden_size, :]
+ state_dict[f"encoder.layer.{i}.attention.attention.query.bias"] = in_proj_bias[: config.hidden_size]
+ state_dict[f"encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[
+ config.hidden_size : config.hidden_size * 2, :
+ ]
+ state_dict[f"encoder.layer.{i}.attention.attention.key.bias"] = in_proj_bias[
+ config.hidden_size : config.hidden_size * 2
+ ]
+ state_dict[f"encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[-config.hidden_size :, :]
+ state_dict[f"encoder.layer.{i}.attention.attention.value.bias"] = in_proj_bias[-config.hidden_size :]
+
+
+# We will verify our results on an image of cute cats
+def prepare_img():
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
+ return image
+
+
+@torch.no_grad()
+def convert_dinov2_with_registers_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=False):
+ """
+ Copy/paste/tweak model's weights to our Dinov2WithRegisters structure.
+ """
+
+ # define default Dinov2WithRegisters configuration
+ image_classifier = "1layer" in model_name
+ config = get_dinov2_with_registers_config(model_name, image_classifier=image_classifier)
+
+ # load original model from torch hub
+ original_model = torch.hub.load("facebookresearch/dinov2", model_name.replace("_1layer", ""))
+ original_model.eval()
+
+ # load state_dict of original model, remove and rename some keys
+ state_dict = original_model.state_dict()
+ rename_keys = create_rename_keys(config)
+ for src, dest in rename_keys:
+ rename_key(state_dict, src, dest)
+ read_in_q_k_v(state_dict, config)
+
+ for key, val in state_dict.copy().items():
+ val = state_dict.pop(key)
+ if "w12" in key:
+ key = key.replace("w12", "weights_in")
+ if "w3" in key:
+ key = key.replace("w3", "weights_out")
+ state_dict[key] = val
+
+ # load HuggingFace model
+ if image_classifier:
+ model = Dinov2WithRegistersForImageClassification(config).eval()
+ model.dinov2_with_registers.load_state_dict(state_dict)
+ model_name_to_classifier_dict_url = {
+ "dinov2_vits14_reg_1layer": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_reg4_linear_head.pth",
+ "dinov2_vitb14_reg_1layer": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_reg4_linear_head.pth",
+ "dinov2_vitl14_reg_1layer": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_reg4_linear_head.pth",
+ "dinov2_vitg14_reg_1layer": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_reg4_linear_head.pth",
+ }
+ url = model_name_to_classifier_dict_url[model_name]
+ classifier_state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
+ model.classifier.weight = nn.Parameter(classifier_state_dict["weight"])
+ model.classifier.bias = nn.Parameter(classifier_state_dict["bias"])
+ else:
+ model = Dinov2WithRegistersModel(config).eval()
+ model.load_state_dict(state_dict)
+
+ # load image
+ image = prepare_img()
+
+ # preprocess image
+ transformations = transforms.Compose(
+ [
+ transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
+ transforms.CenterCrop(224),
+ transforms.ToTensor(),
+ transforms.Normalize(
+ mean=IMAGENET_DEFAULT_MEAN, # these are RGB mean+std values
+ std=IMAGENET_DEFAULT_STD, # across a large photo dataset.
+ ),
+ ]
+ )
+
+ original_pixel_values = transformations(image).unsqueeze(0) # insert batch dimension
+
+ processor = BitImageProcessor(
+ size={"shortest_edge": 256},
+ resample=PILImageResampling.BICUBIC,
+ image_mean=IMAGENET_DEFAULT_MEAN,
+ image_std=IMAGENET_DEFAULT_STD,
+ )
+ pixel_values = processor(image, return_tensors="pt").pixel_values
+
+ assert torch.allclose(original_pixel_values, pixel_values)
+
+ with torch.no_grad():
+ outputs = model(pixel_values, output_hidden_states=True)
+ original_outputs = original_model(pixel_values)
+
+ # assert values
+ if image_classifier:
+ print("Predicted class:")
+ class_idx = outputs.logits.argmax(-1).item()
+ print(model.config.id2label[class_idx])
+ else:
+ assert outputs.last_hidden_state[:, 0].shape == original_outputs.shape
+ assert torch.allclose(outputs.last_hidden_state[:, 0], original_outputs, atol=1e-3)
+ print("Looks ok!")
+
+ if pytorch_dump_folder_path is not None:
+ Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
+ print(f"Saving model {model_name} to {pytorch_dump_folder_path}")
+ model.save_pretrained(pytorch_dump_folder_path)
+ print(f"Saving image processor to {pytorch_dump_folder_path}")
+ processor.save_pretrained(pytorch_dump_folder_path)
+
+ if push_to_hub:
+ model_name_to_hf_name = {
+ "dinov2_vits14_reg": "dinov2-with-registers-small",
+ "dinov2_vitb14_reg": "dinov2-with-registers-base",
+ "dinov2_vitl14_reg": "dinov2-with-registers-large",
+ "dinov2_vitg14_reg": "dinov2-with-registers-giant",
+ "dinov2_vits14_reg_1layer": "dinov2-with-registers-small-imagenet1k-1-layer",
+ "dinov2_vitb14_reg_1layer": "dinov2-with-registers-base-imagenet1k-1-layer",
+ "dinov2_vitl14_reg_1layer": "dinov2-with-registers-large-imagenet1k-1-layer",
+ "dinov2_vitg14_reg_1layer": "dinov2-with-registers-giant-imagenet1k-1-layer",
+ }
+
+ name = model_name_to_hf_name[model_name]
+ model.push_to_hub(f"nielsr/{name}")
+ processor.push_to_hub(f"nielsr/{name}")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ # Required parameters
+ parser.add_argument(
+ "--model_name",
+ default="dinov2_vits14_reg",
+ type=str,
+ choices=[
+ "dinov2_vits14_reg",
+ "dinov2_vitb14_reg",
+ "dinov2_vitl14_reg",
+ "dinov2_vitg14_reg",
+ "dinov2_vits14_reg_1layer",
+ "dinov2_vitb14_reg_1layer",
+ "dinov2_vitl14_reg_1layer",
+ "dinov2_vitg14_reg_1layer",
+ ],
+ help="Name of the model you'd like to convert.",
+ )
+ parser.add_argument(
+ "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
+ )
+ parser.add_argument(
+ "--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
+ )
+
+ args = parser.parse_args()
+ convert_dinov2_with_registers_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub)
diff --git a/docs/transformers/build/lib/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py b/docs/transformers/build/lib/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py
new file mode 100644
index 0000000000000000000000000000000000000000..449bfb9b91cdcaf665c6a3e0ac4cad828d5779eb
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py
@@ -0,0 +1,930 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_dinov2_with_registers.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2024 Meta Inc. and the HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import collections.abc
+from typing import Callable, Dict, List, Optional, Set, Tuple, Union
+
+import torch
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from ...activations import ACT2FN
+from ...modeling_outputs import BackboneOutput, BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
+from ...utils import (
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+ torch_int,
+)
+from ...utils.backbone_utils import BackboneMixin
+from .configuration_dinov2_with_registers import Dinov2WithRegistersConfig
+
+
+logger = logging.get_logger(__name__)
+
+# Base docstring
+_CHECKPOINT_FOR_DOC = "facebook/dinov2_with_registers-base"
+
+# General docstring
+_CONFIG_FOR_DOC = "Dinov2WithRegistersConfig"
+
+
+class Dinov2WithRegistersPatchEmbeddings(nn.Module):
+ """
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
+ Transformer.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ image_size, patch_size = config.image_size, config.patch_size
+ num_channels, hidden_size = config.num_channels, config.hidden_size
+
+ image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.num_patches = num_patches
+
+ self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
+
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
+ num_channels = pixel_values.shape[1]
+ if num_channels != self.num_channels:
+ raise ValueError(
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
+ f" Expected {self.num_channels} but got {num_channels}."
+ )
+ embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
+ return embeddings
+
+
+class Dinov2WithRegistersEmbeddings(nn.Module):
+ """
+ Construct the CLS token, mask token, register tokens, position and patch embeddings.
+ """
+
+ def __init__(self, config: Dinov2WithRegistersConfig) -> None:
+ super().__init__()
+
+ self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
+ self.mask_token = nn.Parameter(torch.zeros(1, config.hidden_size))
+ self.register_tokens = nn.Parameter(torch.zeros(1, config.num_register_tokens, config.hidden_size))
+ self.patch_embeddings = Dinov2WithRegistersPatchEmbeddings(config)
+ num_patches = self.patch_embeddings.num_patches
+ self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size))
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ self.patch_size = config.patch_size
+ self.config = config
+
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
+ """
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
+ resolution images. This implementation supports torch.jit tracing while maintaining backwards compatibility
+ with the original implementation.
+
+ Adapted from:
+ - https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
+ - https://github.com/facebookresearch/dinov2/blob/main/dinov2/models/vision_transformer.py
+ """
+ num_patches = embeddings.shape[1] - 1
+ num_positions = self.position_embeddings.shape[1] - 1
+
+ # Skip interpolation for matching dimensions (unless tracing)
+ if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
+ return self.position_embeddings
+
+ # Handle class token and patch embeddings separately
+ class_pos_embed = self.position_embeddings[:, 0]
+ patch_pos_embed = self.position_embeddings[:, 1:]
+ dim = embeddings.shape[-1]
+
+ # Calculate new dimensions
+ height = height // self.config.patch_size
+ width = width // self.config.patch_size
+
+ # Reshape for interpolation
+ sqrt_num_positions = torch_int(num_positions**0.5)
+ patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
+
+ # Store original dtype for restoration after interpolation
+ target_dtype = patch_pos_embed.dtype
+
+ # Interpolate at float32 precision
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed.to(dtype=torch.float32),
+ size=(torch_int(height), torch_int(width)), # Explicit size instead of scale_factor
+ mode="bicubic",
+ align_corners=False,
+ antialias=True,
+ ).to(dtype=target_dtype)
+
+ # Validate output dimensions if not tracing
+ if not torch.jit.is_tracing():
+ if int(height) != patch_pos_embed.shape[-2] or int(width) != patch_pos_embed.shape[-1]:
+ raise ValueError("Width or height does not match with the interpolated position embeddings")
+
+ # Reshape back to original format
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+
+ # Combine class and patch embeddings
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
+
+ def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Tensor] = None) -> torch.Tensor:
+ batch_size, _, height, width = pixel_values.shape
+ target_dtype = self.patch_embeddings.projection.weight.dtype
+ embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype))
+
+ if bool_masked_pos is not None:
+ embeddings = torch.where(
+ bool_masked_pos.unsqueeze(-1), self.mask_token.to(embeddings.dtype).unsqueeze(0), embeddings
+ )
+
+ # add the [CLS] token to the embedded patch tokens
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
+ embeddings = torch.cat((cls_tokens, embeddings), dim=1)
+
+ # add positional encoding to each token
+ embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
+
+ # add register tokens
+ embeddings = torch.cat(
+ (embeddings[:, :1], self.register_tokens.expand(embeddings.shape[0], -1, -1), embeddings[:, 1:]), dim=1
+ )
+
+ embeddings = self.dropout(embeddings)
+
+ return embeddings
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs,
+):
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
+
+ # Normalize the attention scores to probabilities.
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+
+ # Mask heads if we want to
+ if attention_mask is not None:
+ attn_weights = attn_weights * attention_mask
+
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+class Dinov2WithRegistersSelfAttention(nn.Module):
+ def __init__(self, config: Dinov2WithRegistersConfig) -> None:
+ super().__init__()
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+ raise ValueError(
+ f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
+ f"heads {config.num_attention_heads}."
+ )
+
+ self.config = config
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+ self.dropout_prob = config.attention_probs_dropout_prob
+ self.scaling = self.attention_head_size**-0.5
+ self.is_causal = False
+
+ self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+ self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
+ x = x.view(new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward(
+ self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+ query_layer = self.transpose_for_scores(self.query(hidden_states))
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ if self.config._attn_implementation == "sdpa" and output_attentions:
+ logger.warning_once(
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ else:
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ context_layer, attention_probs = attention_interface(
+ self,
+ query_layer,
+ key_layer,
+ value_layer,
+ head_mask,
+ is_causal=self.is_causal,
+ scaling=self.scaling,
+ dropout=0.0 if not self.training else self.dropout_prob,
+ )
+
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.reshape(new_context_layer_shape)
+
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+ return outputs
+
+
+class Dinov2WithRegistersSelfOutput(nn.Module):
+ """
+ The residual connection is defined in Dinov2WithRegistersLayer instead of here (as is the case with other models), due to the
+ layernorm applied before each block.
+ """
+
+ def __init__(self, config: Dinov2WithRegistersConfig) -> None:
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+
+ return hidden_states
+
+
+class Dinov2WithRegistersAttention(nn.Module):
+ def __init__(self, config: Dinov2WithRegistersConfig) -> None:
+ super().__init__()
+ self.attention = Dinov2WithRegistersSelfAttention(config)
+ self.output = Dinov2WithRegistersSelfOutput(config)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads: Set[int]) -> None:
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
+ )
+
+ # Prune linear layers
+ self.attention.query = prune_linear_layer(self.attention.query, index)
+ self.attention.key = prune_linear_layer(self.attention.key, index)
+ self.attention.value = prune_linear_layer(self.attention.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
+ self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
+ self_outputs = self.attention(hidden_states, head_mask, output_attentions)
+
+ attention_output = self.output(self_outputs[0], hidden_states)
+
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+ return outputs
+
+
+class Dinov2WithRegistersLayerScale(nn.Module):
+ def __init__(self, config) -> None:
+ super().__init__()
+ self.lambda1 = nn.Parameter(config.layerscale_value * torch.ones(config.hidden_size))
+
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
+ return hidden_state * self.lambda1
+
+
+def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
+ """
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+ Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
+ however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
+ layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
+ argument.
+ """
+ if drop_prob == 0.0 or not training:
+ return input
+ keep_prob = 1 - drop_prob
+ shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
+ random_tensor.floor_() # binarize
+ output = input.div(keep_prob) * random_tensor
+ return output
+
+
+class Dinov2WithRegistersDropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob: Optional[float] = None) -> None:
+ super().__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ return drop_path(hidden_states, self.drop_prob, self.training)
+
+ def extra_repr(self) -> str:
+ return "p={}".format(self.drop_prob)
+
+
+class Dinov2WithRegistersMLP(nn.Module):
+ def __init__(self, config) -> None:
+ super().__init__()
+ in_features = out_features = config.hidden_size
+ hidden_features = int(config.hidden_size * config.mlp_ratio)
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=True)
+ if isinstance(config.hidden_act, str):
+ self.activation = ACT2FN[config.hidden_act]
+ else:
+ self.activation = config.hidden_act
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=True)
+
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
+ hidden_state = self.fc1(hidden_state)
+ hidden_state = self.activation(hidden_state)
+ hidden_state = self.fc2(hidden_state)
+ return hidden_state
+
+
+class Dinov2WithRegistersSwiGLUFFN(nn.Module):
+ def __init__(self, config) -> None:
+ super().__init__()
+ in_features = out_features = config.hidden_size
+ hidden_features = int(config.hidden_size * config.mlp_ratio)
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
+
+ self.weights_in = nn.Linear(in_features, 2 * hidden_features, bias=True)
+ self.weights_out = nn.Linear(hidden_features, out_features, bias=True)
+
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
+ hidden_state = self.weights_in(hidden_state)
+ x1, x2 = hidden_state.chunk(2, dim=-1)
+ hidden = nn.functional.silu(x1) * x2
+ return self.weights_out(hidden)
+
+
+class Dinov2WithRegistersLayer(nn.Module):
+ """This corresponds to the Block class in the original implementation."""
+
+ def __init__(self, config: Dinov2WithRegistersConfig) -> None:
+ super().__init__()
+
+ self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.attention = Dinov2WithRegistersAttention(config)
+ self.layer_scale1 = Dinov2WithRegistersLayerScale(config)
+ self.drop_path = (
+ Dinov2WithRegistersDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
+ )
+
+ self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ if config.use_swiglu_ffn:
+ self.mlp = Dinov2WithRegistersSwiGLUFFN(config)
+ else:
+ self.mlp = Dinov2WithRegistersMLP(config)
+ self.layer_scale2 = Dinov2WithRegistersLayerScale(config)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
+ self_attention_outputs = self.attention(
+ self.norm1(hidden_states), # in Dinov2WithRegisters, layernorm is applied before self-attention
+ head_mask,
+ output_attentions=output_attentions,
+ )
+ attention_output = self_attention_outputs[0]
+
+ attention_output = self.layer_scale1(attention_output)
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
+
+ # first residual connection
+ hidden_states = self.drop_path(attention_output) + hidden_states
+
+ # in Dinov2WithRegisters, layernorm is also applied after self-attention
+ layer_output = self.norm2(hidden_states)
+ layer_output = self.mlp(layer_output)
+ layer_output = self.layer_scale2(layer_output)
+
+ # second residual connection
+ layer_output = self.drop_path(layer_output) + hidden_states
+
+ outputs = (layer_output,) + outputs
+
+ return outputs
+
+
+class Dinov2WithRegistersEncoder(nn.Module):
+ def __init__(self, config: Dinov2WithRegistersConfig) -> None:
+ super().__init__()
+ self.config = config
+ self.layer = nn.ModuleList([Dinov2WithRegistersLayer(config) for _ in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ) -> Union[tuple, BaseModelOutput]:
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+
+ for i, layer_module in enumerate(self.layer):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ layer_module.__call__,
+ hidden_states,
+ layer_head_mask,
+ output_attentions,
+ )
+ else:
+ layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+
+class Dinov2WithRegistersPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = Dinov2WithRegistersConfig
+ base_model_prefix = "dinov2_with_registers"
+ main_input_name = "pixel_values"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["Dinov2WithRegistersSwiGLUFFN"]
+ _supports_sdpa = True
+ _supports_flash_attn_2 = True
+
+ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
+ """Initialize the weights"""
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
+ # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
+ # `trunc_normal_cpu` not implemented in `half` issues
+ module.weight.data = nn.init.trunc_normal_(
+ module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
+ ).to(module.weight.dtype)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, Dinov2WithRegistersEmbeddings):
+ module.position_embeddings.data = nn.init.trunc_normal_(
+ module.position_embeddings.data.to(torch.float32),
+ mean=0.0,
+ std=self.config.initializer_range,
+ ).to(module.position_embeddings.dtype)
+
+ module.cls_token.data = nn.init.trunc_normal_(
+ module.cls_token.data.to(torch.float32),
+ mean=0.0,
+ std=self.config.initializer_range,
+ ).to(module.cls_token.dtype)
+
+ module.mask_token.data.zero_()
+ module.register_tokens.data.zero_()
+ elif isinstance(module, Dinov2WithRegistersLayerScale): # noqa: F821
+ module.lambda1.data.fill_(self.config.layerscale_value)
+
+
+_EXPECTED_OUTPUT_SHAPE = [1, 257, 768]
+
+
+DINOV2_WITH_REGISTERS_START_DOCSTRING = r"""
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+ behavior.
+
+ Parameters:
+ config ([`Dinov2WithRegistersConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+DINOV2_WITH_REGISTERS_BASE_INPUTS_DOCSTRING = r"""
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
+ [`BitImageProcessor.preprocess`] for details.
+
+ bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, sequence_length)`):
+ Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). Only relevant for
+ pre-training.
+
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare Dinov2WithRegisters Model transformer outputting raw hidden-states without any specific head on top.",
+ DINOV2_WITH_REGISTERS_START_DOCSTRING,
+)
+class Dinov2WithRegistersModel(Dinov2WithRegistersPreTrainedModel):
+ def __init__(self, config: Dinov2WithRegistersConfig):
+ super().__init__(config)
+ self.config = config
+
+ self.embeddings = Dinov2WithRegistersEmbeddings(config)
+ self.encoder = Dinov2WithRegistersEncoder(config)
+
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self) -> Dinov2WithRegistersPatchEmbeddings:
+ return self.embeddings.patch_embeddings
+
+ def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ @add_start_docstrings_to_model_forward(DINOV2_WITH_REGISTERS_BASE_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=BaseModelOutputWithPooling,
+ config_class=_CONFIG_FOR_DOC,
+ modality="vision",
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
+ )
+ def forward(
+ self,
+ pixel_values: Optional[torch.Tensor] = None,
+ bool_masked_pos: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ embedding_output = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = encoder_outputs[0]
+ sequence_output = self.layernorm(sequence_output)
+ pooled_output = sequence_output[:, 0, :]
+
+ if not return_dict:
+ head_outputs = (sequence_output, pooled_output)
+ return head_outputs + encoder_outputs[1:]
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+# Image classification docstring
+_IMAGE_CLASS_CHECKPOINT = "facebook/dinov2_with_registers-small-imagenet1k-1-layer"
+_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
+
+DINOV2_WITH_REGISTERS_INPUTS_DOCSTRING = r"""
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
+ [`BitImageProcessor.preprocess`] for details.
+
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ """
+ Dinov2WithRegisters Model transformer with an image classification head on top (a linear layer on top of the final hidden state
+ of the [CLS] token) e.g. for ImageNet.
+ """,
+ DINOV2_WITH_REGISTERS_START_DOCSTRING,
+)
+class Dinov2WithRegistersForImageClassification(Dinov2WithRegistersPreTrainedModel):
+ def __init__(self, config: Dinov2WithRegistersConfig) -> None:
+ super().__init__(config)
+
+ self.num_labels = config.num_labels
+ self.dinov2_with_registers = Dinov2WithRegistersModel(config)
+
+ # Classifier head
+ self.classifier = (
+ nn.Linear(config.hidden_size * 2, config.num_labels) if config.num_labels > 0 else nn.Identity()
+ )
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(DINOV2_WITH_REGISTERS_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_IMAGE_CLASS_CHECKPOINT,
+ output_type=ImageClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
+ )
+ def forward(
+ self,
+ pixel_values: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, ImageClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.dinov2_with_registers(
+ pixel_values,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0] # batch_size, sequence_length, hidden_size
+
+ cls_token = sequence_output[:, 0]
+ patch_tokens = sequence_output[:, 1:]
+
+ linear_input = torch.cat([cls_token, patch_tokens.mean(dim=1)], dim=1)
+
+ logits = self.classifier(linear_input)
+
+ loss = None
+ if labels is not None:
+ # move labels to correct device to enable model parallelism
+ labels = labels.to(logits.device)
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return ImageClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ Dinov2WithRegisters backbone, to be used with frameworks like DETR and MaskFormer.
+ """,
+ DINOV2_WITH_REGISTERS_START_DOCSTRING,
+)
+class Dinov2WithRegistersBackbone(Dinov2WithRegistersPreTrainedModel, BackboneMixin):
+ def __init__(self, config):
+ super().__init__(config)
+ super()._init_backbone(config)
+ self.num_features = [config.hidden_size for _ in range(config.num_hidden_layers + 1)]
+ self.embeddings = Dinov2WithRegistersEmbeddings(config)
+ self.encoder = Dinov2WithRegistersEncoder(config)
+
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ self.num_register_tokens = config.num_register_tokens
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self) -> Dinov2WithRegistersPatchEmbeddings:
+ return self.embeddings.patch_embeddings
+
+ @add_start_docstrings_to_model_forward(DINOV2_WITH_REGISTERS_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ output_hidden_states: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> BackboneOutput:
+ """
+ Returns:
+
+ Examples:
+ Returns:
+
+ Examples:
+
+
+ ```python
+ >>> from transformers import AutoImageProcessor, AutoBackbone
+ >>> import torch
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> processor = AutoImageProcessor.from_pretrained("facebook/dinov2-with-registers-base")
+ >>> model = AutoBackbone.from_pretrained(
+ ... "facebook/dinov2-with-registers-base", out_features=["stage2", "stage5", "stage8", "stage11"]
+ ... )
+
+ >>> inputs = processor(image, return_tensors="pt")
+
+ >>> outputs = model(**inputs)
+ >>> feature_maps = outputs.feature_maps
+ >>> list(feature_maps[-1].shape)
+ [1, 768, 16, 16]
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+
+ embedding_output = self.embeddings(pixel_values)
+
+ outputs = self.encoder(
+ embedding_output, output_hidden_states=True, output_attentions=output_attentions, return_dict=return_dict
+ )
+
+ hidden_states = outputs.hidden_states if return_dict else outputs[1]
+
+ feature_maps = ()
+ for stage, hidden_state in zip(self.stage_names, hidden_states):
+ if stage in self.out_features:
+ if self.config.apply_layernorm:
+ hidden_state = self.layernorm(hidden_state)
+ if self.config.reshape_hidden_states:
+ hidden_state = hidden_state[:, self.num_register_tokens + 1 :]
+ # this was actually a bug in the original implementation that we copied here,
+ # cause normally the order is height, width
+ batch_size, _, height, width = pixel_values.shape
+ patch_size = self.config.patch_size
+ hidden_state = hidden_state.reshape(batch_size, height // patch_size, width // patch_size, -1)
+ hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
+ feature_maps += (hidden_state,)
+
+ if not return_dict:
+ if output_hidden_states:
+ output = (feature_maps,) + outputs[1:]
+ else:
+ output = (feature_maps,) + outputs[2:]
+ return output
+
+ return BackboneOutput(
+ feature_maps=feature_maps,
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
+ attentions=outputs.attentions if output_attentions else None,
+ )
+
+
+__all__ = [
+ "Dinov2WithRegistersPreTrainedModel",
+ "Dinov2WithRegistersModel",
+ "Dinov2WithRegistersForImageClassification",
+ "Dinov2WithRegistersBackbone",
+]
diff --git a/docs/transformers/build/lib/transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py b/docs/transformers/build/lib/transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py
new file mode 100644
index 0000000000000000000000000000000000000000..59777e2157894e7641c8512c7ff50febb75aed43
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py
@@ -0,0 +1,420 @@
+# coding=utf-8
+# Copyright 2024 Meta Inc. and the HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+
+from ....transformers.models.dinov2.modeling_dinov2 import (
+ Dinov2Backbone,
+ Dinov2Encoder,
+ Dinov2ForImageClassification,
+ Dinov2Model,
+ Dinov2PatchEmbeddings,
+ Dinov2PreTrainedModel,
+)
+from ...configuration_utils import PretrainedConfig
+from ...modeling_outputs import BackboneOutput
+from ...utils import logging, torch_int
+from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
+
+
+logger = logging.get_logger(__name__)
+
+
+class Dinov2WithRegistersConfig(BackboneConfigMixin, PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Dinov2WithRegistersModel`]. It is used to instantiate an
+ Dinov2WithRegisters model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the DINOv2 with Registers
+ [facebook/dinov2-with-registers-base](https://huggingface.co/facebook/dinov2-with-registers-base) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ mlp_ratio (`int`, *optional*, defaults to 4):
+ Ratio of the hidden size of the MLPs relative to the `hidden_size`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the layer normalization layers.
+ image_size (`int`, *optional*, defaults to 224):
+ The size (resolution) of each image.
+ patch_size (`int`, *optional*, defaults to 16):
+ The size (resolution) of each patch.
+ num_channels (`int`, *optional*, defaults to 3):
+ The number of input channels.
+ qkv_bias (`bool`, *optional*, defaults to `True`):
+ Whether to add a bias to the queries, keys and values.
+ layerscale_value (`float`, *optional*, defaults to 1.0):
+ Initial value to use for layer scale.
+ drop_path_rate (`float`, *optional*, defaults to 0.0):
+ Stochastic depth rate per sample (when applied in the main path of residual layers).
+ use_swiglu_ffn (`bool`, *optional*, defaults to `False`):
+ Whether to use the SwiGLU feedforward neural network.
+ num_register_tokens (`int`, *optional*, defaults to 4):
+ Number of register tokens to use.
+ out_features (`List[str]`, *optional*):
+ If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
+ (depending on how many stages the model has). If unset and `out_indices` is set, will default to the
+ corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the
+ same order as defined in the `stage_names` attribute.
+ out_indices (`List[int]`, *optional*):
+ If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
+ many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
+ If unset and `out_features` is unset, will default to the last stage. Must be in the
+ same order as defined in the `stage_names` attribute.
+ apply_layernorm (`bool`, *optional*, defaults to `True`):
+ Whether to apply layer normalization to the feature maps in case the model is used as backbone.
+ reshape_hidden_states (`bool`, *optional*, defaults to `True`):
+ Whether to reshape the feature maps to 4D tensors of shape `(batch_size, hidden_size, height, width)` in
+ case the model is used as backbone. If `False`, the feature maps will be 3D tensors of shape `(batch_size,
+ seq_len, hidden_size)`.
+
+ Example:
+
+ ```python
+ >>> from transformers import Dinov2WithRegistersConfig, Dinov2WithRegistersModel
+
+ >>> # Initializing a Dinov2WithRegisters base style configuration
+ >>> configuration = Dinov2WithRegistersConfig()
+
+ >>> # Initializing a model (with random weights) from the base style configuration
+ >>> model = Dinov2WithRegistersModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "dinov2_with_registers"
+
+ def __init__(
+ self,
+ hidden_size=768,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ mlp_ratio=4,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.0,
+ attention_probs_dropout_prob=0.0,
+ initializer_range=0.02,
+ layer_norm_eps=1e-6,
+ image_size=224,
+ patch_size=16,
+ num_channels=3,
+ qkv_bias=True,
+ layerscale_value=1.0,
+ drop_path_rate=0.0,
+ use_swiglu_ffn=False,
+ num_register_tokens=4,
+ out_features=None,
+ out_indices=None,
+ apply_layernorm=True,
+ reshape_hidden_states=True,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.mlp_ratio = mlp_ratio
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.qkv_bias = qkv_bias
+ self.layerscale_value = layerscale_value
+ self.drop_path_rate = drop_path_rate
+ self.use_swiglu_ffn = use_swiglu_ffn
+ self.num_register_tokens = num_register_tokens
+ self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, num_hidden_layers + 1)]
+ self._out_features, self._out_indices = get_aligned_output_features_output_indices(
+ out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
+ )
+ self.apply_layernorm = apply_layernorm
+ self.reshape_hidden_states = reshape_hidden_states
+
+
+class Dinov2WithRegistersPatchEmbeddings(Dinov2PatchEmbeddings):
+ pass
+
+
+class Dinov2WithRegistersEmbeddings(nn.Module):
+ """
+ Construct the CLS token, mask token, register tokens, position and patch embeddings.
+ """
+
+ def __init__(self, config: Dinov2WithRegistersConfig) -> None:
+ super().__init__()
+
+ self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
+ self.mask_token = nn.Parameter(torch.zeros(1, config.hidden_size))
+ self.register_tokens = nn.Parameter(torch.zeros(1, config.num_register_tokens, config.hidden_size))
+ self.patch_embeddings = Dinov2WithRegistersPatchEmbeddings(config)
+ num_patches = self.patch_embeddings.num_patches
+ self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size))
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ self.patch_size = config.patch_size
+ self.config = config
+
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
+ """
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
+ resolution images. This implementation supports torch.jit tracing while maintaining backwards compatibility
+ with the original implementation.
+
+ Adapted from:
+ - https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
+ - https://github.com/facebookresearch/dinov2/blob/main/dinov2/models/vision_transformer.py
+ """
+ num_patches = embeddings.shape[1] - 1
+ num_positions = self.position_embeddings.shape[1] - 1
+
+ # Skip interpolation for matching dimensions (unless tracing)
+ if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
+ return self.position_embeddings
+
+ # Handle class token and patch embeddings separately
+ class_pos_embed = self.position_embeddings[:, 0]
+ patch_pos_embed = self.position_embeddings[:, 1:]
+ dim = embeddings.shape[-1]
+
+ # Calculate new dimensions
+ height = height // self.config.patch_size
+ width = width // self.config.patch_size
+
+ # Reshape for interpolation
+ sqrt_num_positions = torch_int(num_positions**0.5)
+ patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
+
+ # Store original dtype for restoration after interpolation
+ target_dtype = patch_pos_embed.dtype
+
+ # Interpolate at float32 precision
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed.to(dtype=torch.float32),
+ size=(torch_int(height), torch_int(width)), # Explicit size instead of scale_factor
+ mode="bicubic",
+ align_corners=False,
+ antialias=True,
+ ).to(dtype=target_dtype)
+
+ # Validate output dimensions if not tracing
+ if not torch.jit.is_tracing():
+ if int(height) != patch_pos_embed.shape[-2] or int(width) != patch_pos_embed.shape[-1]:
+ raise ValueError("Width or height does not match with the interpolated position embeddings")
+
+ # Reshape back to original format
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+
+ # Combine class and patch embeddings
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
+
+ def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Tensor] = None) -> torch.Tensor:
+ batch_size, _, height, width = pixel_values.shape
+ target_dtype = self.patch_embeddings.projection.weight.dtype
+ embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype))
+
+ if bool_masked_pos is not None:
+ embeddings = torch.where(
+ bool_masked_pos.unsqueeze(-1), self.mask_token.to(embeddings.dtype).unsqueeze(0), embeddings
+ )
+
+ # add the [CLS] token to the embedded patch tokens
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
+ embeddings = torch.cat((cls_tokens, embeddings), dim=1)
+
+ # add positional encoding to each token
+ embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
+
+ # add register tokens
+ embeddings = torch.cat(
+ (embeddings[:, :1], self.register_tokens.expand(embeddings.shape[0], -1, -1), embeddings[:, 1:]), dim=1
+ )
+
+ embeddings = self.dropout(embeddings)
+
+ return embeddings
+
+
+class Dinov2WithRegistersEncoder(Dinov2Encoder):
+ pass
+
+
+class Dinov2WithRegistersPreTrainedModel(Dinov2PreTrainedModel):
+ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
+ """Initialize the weights"""
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
+ # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
+ # `trunc_normal_cpu` not implemented in `half` issues
+ module.weight.data = nn.init.trunc_normal_(
+ module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
+ ).to(module.weight.dtype)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, Dinov2WithRegistersEmbeddings):
+ module.position_embeddings.data = nn.init.trunc_normal_(
+ module.position_embeddings.data.to(torch.float32),
+ mean=0.0,
+ std=self.config.initializer_range,
+ ).to(module.position_embeddings.dtype)
+
+ module.cls_token.data = nn.init.trunc_normal_(
+ module.cls_token.data.to(torch.float32),
+ mean=0.0,
+ std=self.config.initializer_range,
+ ).to(module.cls_token.dtype)
+
+ module.mask_token.data.zero_()
+ module.register_tokens.data.zero_()
+ elif isinstance(module, Dinov2WithRegistersLayerScale): # noqa: F821
+ module.lambda1.data.fill_(self.config.layerscale_value)
+
+
+class Dinov2WithRegistersModel(Dinov2Model):
+ pass
+
+
+class Dinov2WithRegistersForImageClassification(Dinov2ForImageClassification):
+ pass
+
+
+class Dinov2WithRegistersBackbone(Dinov2Backbone):
+ def __init__(self, config):
+ super().__init__(config)
+ super()._init_backbone(config)
+
+ self.num_register_tokens = config.num_register_tokens
+ self.num_features = [config.hidden_size for _ in range(config.num_hidden_layers + 1)]
+ self.embeddings = Dinov2WithRegistersEmbeddings(config)
+ self.encoder = Dinov2WithRegistersEncoder(config)
+
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self) -> Dinov2WithRegistersPatchEmbeddings:
+ return self.embeddings.patch_embeddings
+
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ output_hidden_states: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> BackboneOutput:
+ """
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoImageProcessor, AutoBackbone
+ >>> import torch
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> processor = AutoImageProcessor.from_pretrained("facebook/dinov2-with-registers-base")
+ >>> model = AutoBackbone.from_pretrained(
+ ... "facebook/dinov2-with-registers-base", out_features=["stage2", "stage5", "stage8", "stage11"]
+ ... )
+
+ >>> inputs = processor(image, return_tensors="pt")
+
+ >>> outputs = model(**inputs)
+ >>> feature_maps = outputs.feature_maps
+ >>> list(feature_maps[-1].shape)
+ [1, 768, 16, 16]
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+
+ embedding_output = self.embeddings(pixel_values)
+
+ outputs = self.encoder(
+ embedding_output, output_hidden_states=True, output_attentions=output_attentions, return_dict=return_dict
+ )
+
+ hidden_states = outputs.hidden_states if return_dict else outputs[1]
+
+ feature_maps = ()
+ for stage, hidden_state in zip(self.stage_names, hidden_states):
+ if stage in self.out_features:
+ if self.config.apply_layernorm:
+ hidden_state = self.layernorm(hidden_state)
+ if self.config.reshape_hidden_states:
+ hidden_state = hidden_state[:, self.num_register_tokens + 1 :]
+ # this was actually a bug in the original implementation that we copied here,
+ # cause normally the order is height, width
+ batch_size, _, height, width = pixel_values.shape
+ patch_size = self.config.patch_size
+ hidden_state = hidden_state.reshape(batch_size, height // patch_size, width // patch_size, -1)
+ hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
+ feature_maps += (hidden_state,)
+
+ if not return_dict:
+ if output_hidden_states:
+ output = (feature_maps,) + outputs[1:]
+ else:
+ output = (feature_maps,) + outputs[2:]
+ return output
+
+ return BackboneOutput(
+ feature_maps=feature_maps,
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
+ attentions=outputs.attentions if output_attentions else None,
+ )
+
+
+__all__ = [
+ "Dinov2WithRegistersConfig",
+ "Dinov2WithRegistersPreTrainedModel",
+ "Dinov2WithRegistersModel",
+ "Dinov2WithRegistersForImageClassification",
+ "Dinov2WithRegistersBackbone",
+]
diff --git a/docs/transformers/build/lib/transformers/models/distilbert/__init__.py b/docs/transformers/build/lib/transformers/models/distilbert/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d6fae2e0236e7619988f0cfa3502ed49d0f90b0
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/distilbert/__init__.py
@@ -0,0 +1,31 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_distilbert import *
+ from .modeling_distilbert import *
+ from .modeling_flax_distilbert import *
+ from .modeling_tf_distilbert import *
+ from .tokenization_distilbert import *
+ from .tokenization_distilbert_fast import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/docs/transformers/build/lib/transformers/models/distilbert/configuration_distilbert.py b/docs/transformers/build/lib/transformers/models/distilbert/configuration_distilbert.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a28c8e5d03d029d344c3f9f3294c9298a9fd808
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/distilbert/configuration_distilbert.py
@@ -0,0 +1,141 @@
+# coding=utf-8
+# Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""DistilBERT model configuration"""
+
+from collections import OrderedDict
+from typing import Mapping
+
+from ...configuration_utils import PretrainedConfig
+from ...onnx import OnnxConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class DistilBertConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`DistilBertModel`] or a [`TFDistilBertModel`]. It
+ is used to instantiate a DistilBERT model according to the specified arguments, defining the model architecture.
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the DistilBERT
+ [distilbert-base-uncased](https://huggingface.co/distilbert-base-uncased) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 30522):
+ Vocabulary size of the DistilBERT model. Defines the number of different tokens that can be represented by
+ the `inputs_ids` passed when calling [`DistilBertModel`] or [`TFDistilBertModel`].
+ max_position_embeddings (`int`, *optional*, defaults to 512):
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
+ just in case (e.g., 512 or 1024 or 2048).
+ sinusoidal_pos_embds (`boolean`, *optional*, defaults to `False`):
+ Whether to use sinusoidal positional embeddings.
+ n_layers (`int`, *optional*, defaults to 6):
+ Number of hidden layers in the Transformer encoder.
+ n_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ dim (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ hidden_dim (`int`, *optional*, defaults to 3072):
+ The size of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
+ dropout (`float`, *optional*, defaults to 0.1):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_dropout (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for the attention probabilities.
+ activation (`str` or `Callable`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ qa_dropout (`float`, *optional*, defaults to 0.1):
+ The dropout probabilities used in the question answering model [`DistilBertForQuestionAnswering`].
+ seq_classif_dropout (`float`, *optional*, defaults to 0.2):
+ The dropout probabilities used in the sequence classification and the multiple choice model
+ [`DistilBertForSequenceClassification`].
+
+ Examples:
+
+ ```python
+ >>> from transformers import DistilBertConfig, DistilBertModel
+
+ >>> # Initializing a DistilBERT configuration
+ >>> configuration = DistilBertConfig()
+
+ >>> # Initializing a model (with random weights) from the configuration
+ >>> model = DistilBertModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "distilbert"
+ attribute_map = {
+ "hidden_size": "dim",
+ "num_attention_heads": "n_heads",
+ "num_hidden_layers": "n_layers",
+ }
+
+ def __init__(
+ self,
+ vocab_size=30522,
+ max_position_embeddings=512,
+ sinusoidal_pos_embds=False,
+ n_layers=6,
+ n_heads=12,
+ dim=768,
+ hidden_dim=4 * 768,
+ dropout=0.1,
+ attention_dropout=0.1,
+ activation="gelu",
+ initializer_range=0.02,
+ qa_dropout=0.1,
+ seq_classif_dropout=0.2,
+ pad_token_id=0,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.sinusoidal_pos_embds = sinusoidal_pos_embds
+ self.n_layers = n_layers
+ self.n_heads = n_heads
+ self.dim = dim
+ self.hidden_dim = hidden_dim
+ self.dropout = dropout
+ self.attention_dropout = attention_dropout
+ self.activation = activation
+ self.initializer_range = initializer_range
+ self.qa_dropout = qa_dropout
+ self.seq_classif_dropout = seq_classif_dropout
+ super().__init__(**kwargs, pad_token_id=pad_token_id)
+
+
+class DistilBertOnnxConfig(OnnxConfig):
+ @property
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
+ if self.task == "multiple-choice":
+ dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
+ else:
+ dynamic_axis = {0: "batch", 1: "sequence"}
+ return OrderedDict(
+ [
+ ("input_ids", dynamic_axis),
+ ("attention_mask", dynamic_axis),
+ ]
+ )
+
+
+__all__ = ["DistilBertConfig", "DistilBertOnnxConfig"]
diff --git a/docs/transformers/build/lib/transformers/models/distilbert/modeling_distilbert.py b/docs/transformers/build/lib/transformers/models/distilbert/modeling_distilbert.py
new file mode 100644
index 0000000000000000000000000000000000000000..b78050a01aebb24f84ad3abb7b2a0925653f0541
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/distilbert/modeling_distilbert.py
@@ -0,0 +1,1377 @@
+# coding=utf-8
+# Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+PyTorch DistilBERT model adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM) and in
+part from HuggingFace PyTorch version of Google AI Bert model (https://github.com/google-research/bert)
+"""
+
+import math
+from typing import Dict, List, Optional, Set, Tuple, Union
+
+import numpy as np
+import torch
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from ...activations import get_activation
+from ...configuration_utils import PretrainedConfig
+from ...integrations.deepspeed import is_deepspeed_zero3_enabled
+from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa
+from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
+from ...modeling_outputs import (
+ BaseModelOutput,
+ MaskedLMOutput,
+ MultipleChoiceModelOutput,
+ QuestionAnsweringModelOutput,
+ SequenceClassifierOutput,
+ TokenClassifierOutput,
+)
+from ...modeling_utils import PreTrainedModel
+from ...pytorch_utils import (
+ apply_chunking_to_forward,
+ find_pruneable_heads_and_indices,
+ is_torch_greater_or_equal_than_2_2,
+ prune_linear_layer,
+)
+from ...utils import (
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from .configuration_distilbert import DistilBertConfig
+
+
+if is_flash_attn_available():
+ from ...modeling_flash_attention_utils import _flash_attention_forward
+
+
+logger = logging.get_logger(__name__)
+_CHECKPOINT_FOR_DOC = "distilbert-base-uncased"
+_CONFIG_FOR_DOC = "DistilBertConfig"
+
+
+# UTILS AND BUILDING BLOCKS OF THE ARCHITECTURE #
+
+
+def create_sinusoidal_embeddings(n_pos: int, dim: int, out: torch.Tensor):
+ if is_deepspeed_zero3_enabled():
+ import deepspeed
+
+ with deepspeed.zero.GatheredParameters(out, modifier_rank=0):
+ if torch.distributed.get_rank() == 0:
+ _create_sinusoidal_embeddings(n_pos=n_pos, dim=dim, out=out)
+ else:
+ _create_sinusoidal_embeddings(n_pos=n_pos, dim=dim, out=out)
+
+
+def _create_sinusoidal_embeddings(n_pos: int, dim: int, out: torch.Tensor):
+ position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)])
+ out.requires_grad = False
+ out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
+ out[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
+ out.detach_()
+
+
+class Embeddings(nn.Module):
+ def __init__(self, config: PretrainedConfig):
+ super().__init__()
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.dim, padding_idx=config.pad_token_id)
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.dim)
+
+ self.LayerNorm = nn.LayerNorm(config.dim, eps=1e-12)
+ self.dropout = nn.Dropout(config.dropout)
+ self.register_buffer(
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
+ )
+
+ def forward(self, input_ids: torch.Tensor, input_embeds: Optional[torch.Tensor] = None) -> torch.Tensor:
+ """
+ Parameters:
+ input_ids (torch.Tensor):
+ torch.tensor(bs, max_seq_length) The token ids to embed.
+ input_embeds (*optional*, torch.Tensor):
+ The pre-computed word embeddings. Can only be passed if the input ids are `None`.
+
+
+ Returns: torch.tensor(bs, max_seq_length, dim) The embedded tokens (plus position embeddings, no token_type
+ embeddings)
+ """
+ if input_ids is not None:
+ input_embeds = self.word_embeddings(input_ids) # (bs, max_seq_length, dim)
+
+ seq_length = input_embeds.size(1)
+
+ # Setting the position-ids to the registered buffer in constructor, it helps
+ # when tracing the model without passing position-ids, solves
+ # isues similar to issue #5664
+ if hasattr(self, "position_ids"):
+ position_ids = self.position_ids[:, :seq_length]
+ else:
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) # (max_seq_length)
+ position_ids = position_ids.unsqueeze(0).expand_as(input_ids) # (bs, max_seq_length)
+
+ position_embeddings = self.position_embeddings(position_ids) # (bs, max_seq_length, dim)
+
+ embeddings = input_embeds + position_embeddings # (bs, max_seq_length, dim)
+ embeddings = self.LayerNorm(embeddings) # (bs, max_seq_length, dim)
+ embeddings = self.dropout(embeddings) # (bs, max_seq_length, dim)
+ return embeddings
+
+
+class MultiHeadSelfAttention(nn.Module):
+ def __init__(self, config: PretrainedConfig):
+ super().__init__()
+ self.config = config
+
+ self.n_heads = config.n_heads
+ self.dim = config.dim
+ self.dropout = nn.Dropout(p=config.attention_dropout)
+ self.is_causal = False
+
+ # Have an even number of multi heads that divide the dimensions
+ if self.dim % self.n_heads != 0:
+ # Raise value errors for even multi-head attention nodes
+ raise ValueError(f"self.n_heads: {self.n_heads} must divide self.dim: {self.dim} evenly")
+
+ self.q_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
+ self.k_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
+ self.v_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
+ self.out_lin = nn.Linear(in_features=config.dim, out_features=config.dim)
+
+ self.pruned_heads: Set[int] = set()
+ self.attention_head_size = self.dim // self.n_heads
+
+ def prune_heads(self, heads: List[int]):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.n_heads, self.attention_head_size, self.pruned_heads
+ )
+ # Prune linear layers
+ self.q_lin = prune_linear_layer(self.q_lin, index)
+ self.k_lin = prune_linear_layer(self.k_lin, index)
+ self.v_lin = prune_linear_layer(self.v_lin, index)
+ self.out_lin = prune_linear_layer(self.out_lin, index, dim=1)
+ # Update hyper params
+ self.n_heads = self.n_heads - len(heads)
+ self.dim = self.attention_head_size * self.n_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ mask: torch.Tensor,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Tuple[torch.Tensor, ...]:
+ """
+ Parameters:
+ query: torch.tensor(bs, seq_length, dim)
+ key: torch.tensor(bs, seq_length, dim)
+ value: torch.tensor(bs, seq_length, dim)
+ mask: torch.tensor(bs, seq_length)
+
+ Returns:
+ weights: torch.tensor(bs, n_heads, seq_length, seq_length) Attention weights context: torch.tensor(bs,
+ seq_length, dim) Contextualized layer. Optional: only if `output_attentions=True`
+ """
+ bs, q_length, dim = query.size()
+ k_length = key.size(1)
+ # assert dim == self.dim, f'Dimensions do not match: {dim} input vs {self.dim} configured'
+ # assert key.size() == value.size()
+
+ dim_per_head = self.dim // self.n_heads
+
+ mask_reshp = (bs, 1, 1, k_length)
+
+ def shape(x: torch.Tensor) -> torch.Tensor:
+ """separate heads"""
+ return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2)
+
+ def unshape(x: torch.Tensor) -> torch.Tensor:
+ """group heads"""
+ return x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head)
+
+ q = shape(self.q_lin(query)) # (bs, n_heads, q_length, dim_per_head)
+ k = shape(self.k_lin(key)) # (bs, n_heads, k_length, dim_per_head)
+ v = shape(self.v_lin(value)) # (bs, n_heads, k_length, dim_per_head)
+
+ q = q / math.sqrt(dim_per_head) # (bs, n_heads, q_length, dim_per_head)
+ scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length)
+ mask = (mask == 0).view(mask_reshp).expand_as(scores) # (bs, n_heads, q_length, k_length)
+ scores = scores.masked_fill(
+ mask, torch.tensor(torch.finfo(scores.dtype).min)
+ ) # (bs, n_heads, q_length, k_length)
+
+ weights = nn.functional.softmax(scores, dim=-1) # (bs, n_heads, q_length, k_length)
+ weights = self.dropout(weights) # (bs, n_heads, q_length, k_length)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ weights = weights * head_mask
+
+ context = torch.matmul(weights, v) # (bs, n_heads, q_length, dim_per_head)
+ context = unshape(context) # (bs, q_length, dim)
+ context = self.out_lin(context) # (bs, q_length, dim)
+
+ if output_attentions:
+ return (context, weights)
+ else:
+ return (context,)
+
+
+class DistilBertFlashAttention2(MultiHeadSelfAttention):
+ """
+ DistilBert flash attention module. This module inherits from `MultiHeadSelfAttention` as the weights of the module
+ stays untouched. The only required change would be on the forward pass where it needs to correctly call the public
+ API of flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+ self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
+
+ def forward(
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ mask: torch.Tensor,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Tuple[torch.Tensor, ...]:
+ """
+ Parameters:
+ query: torch.tensor(bs, seq_length, dim)
+ key: torch.tensor(bs, seq_length, dim)
+ value: torch.tensor(bs, seq_length, dim)
+ mask: torch.tensor(bs, seq_length)
+
+ Returns:
+ weights: torch.tensor(bs, n_heads, seq_length, seq_length) Attention weights context: torch.tensor(bs,
+ seq_length, dim) Contextualized layer. Optional: only if `output_attentions=True`
+ """
+ batch_size, q_length, dim = query.size()
+
+ dim_per_head = self.dim // self.n_heads
+
+ def reshape(x: torch.Tensor) -> torch.Tensor:
+ """separate heads"""
+ return x.view(batch_size, -1, self.n_heads, dim_per_head)
+
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ query_states = reshape(self.q_lin(query))
+ key_states = reshape(self.k_lin(key))
+ value_states = reshape(self.v_lin(value))
+
+ attn_dropout = self.config.attention_dropout if self.training else 0.0
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (LlamaRMSNorm handles it correctly)
+
+ if query_states.dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_lin.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ attn_weights = _flash_attention_forward(
+ query_states,
+ key_states,
+ value_states,
+ mask,
+ q_length,
+ dropout=attn_dropout,
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ is_causal=self.is_causal,
+ )
+
+ attn_weights_reshaped = attn_weights.reshape(batch_size, q_length, self.n_heads * dim_per_head)
+ attn_output = self.out_lin(attn_weights_reshaped)
+
+ if output_attentions:
+ return (attn_output, attn_weights)
+ else:
+ return (attn_output,)
+
+
+class DistilBertSdpaAttention(MultiHeadSelfAttention):
+ def __init__(self, config: PretrainedConfig):
+ super().__init__(config=config)
+ self.dropout_prob = config.attention_dropout
+ self.require_contiguous_qkv = not is_torch_greater_or_equal_than_2_2
+
+ def forward(
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ mask: torch.Tensor,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Tuple[torch.Tensor, ...]:
+ """
+ Parameters:
+ query: torch.tensor(bs, seq_length, dim)
+ key: torch.tensor(bs, seq_length, dim)
+ value: torch.tensor(bs, seq_length, dim)
+ mask: torch.tensor(bs, seq_length)
+
+ Returns:
+ weights: torch.tensor(bs, n_heads, seq_length, seq_length) Attention weights context: torch.tensor(bs,
+ seq_length, dim) Contextualized layer. Optional: only if `output_attentions=True`
+ """
+ if output_attentions or head_mask is not None:
+ logger.warning_once(
+ "DistilBertSdpaAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support"
+ " `output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but specifying"
+ " the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be"
+ ' removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ return super().forward(
+ query,
+ key,
+ value,
+ mask,
+ head_mask,
+ output_attentions,
+ )
+
+ batch_size, _, _ = query.size()
+ dim_per_head = self.dim // self.n_heads
+
+ def shape(x: torch.Tensor) -> torch.Tensor:
+ """separate heads"""
+ return x.view(batch_size, -1, self.n_heads, dim_per_head).transpose(1, 2)
+
+ def unshape(x: torch.Tensor) -> torch.Tensor:
+ """group heads"""
+ return x.transpose(1, 2).contiguous().view(batch_size, -1, self.n_heads * dim_per_head)
+
+ q = shape(self.q_lin(query)) # (bs, n_heads, q_length, dim_per_head)
+ k = shape(self.k_lin(key)) # (bs, n_heads, k_length, dim_per_head)
+ v = shape(self.v_lin(value)) # (bs, n_heads, k_length, dim_per_head)
+
+ # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
+ # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0.
+ # Reference: https://github.com/pytorch/pytorch/issues/112577
+ if self.require_contiguous_qkv and q.device.type == "cuda" and mask is not None:
+ q = q.contiguous()
+ k = k.contiguous()
+ v = v.contiguous()
+
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ q,
+ k,
+ v,
+ attn_mask=mask,
+ dropout_p=self.dropout_prob if self.training else 0.0,
+ is_causal=False,
+ )
+
+ attn_output = unshape(attn_output)
+ attn_output = self.out_lin(attn_output)
+
+ return (attn_output,)
+
+
+class FFN(nn.Module):
+ def __init__(self, config: PretrainedConfig):
+ super().__init__()
+ self.dropout = nn.Dropout(p=config.dropout)
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.seq_len_dim = 1
+ self.lin1 = nn.Linear(in_features=config.dim, out_features=config.hidden_dim)
+ self.lin2 = nn.Linear(in_features=config.hidden_dim, out_features=config.dim)
+ self.activation = get_activation(config.activation)
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ return apply_chunking_to_forward(self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, input)
+
+ def ff_chunk(self, input: torch.Tensor) -> torch.Tensor:
+ x = self.lin1(input)
+ x = self.activation(x)
+ x = self.lin2(x)
+ x = self.dropout(x)
+ return x
+
+
+DISTILBERT_ATTENTION_CLASSES = {
+ "eager": MultiHeadSelfAttention,
+ "flash_attention_2": DistilBertFlashAttention2,
+ "sdpa": DistilBertSdpaAttention,
+}
+
+
+class TransformerBlock(nn.Module):
+ def __init__(self, config: PretrainedConfig):
+ super().__init__()
+
+ # Have an even number of Configure multi-heads
+ if config.dim % config.n_heads != 0:
+ raise ValueError(f"config.n_heads {config.n_heads} must divide config.dim {config.dim} evenly")
+
+ self.attention = DISTILBERT_ATTENTION_CLASSES[config._attn_implementation](config)
+ self.sa_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12)
+
+ self.ffn = FFN(config)
+ self.output_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Tuple[torch.Tensor, ...]:
+ """
+ Parameters:
+ x: torch.tensor(bs, seq_length, dim)
+ attn_mask: torch.tensor(bs, seq_length)
+
+ Returns:
+ sa_weights: torch.tensor(bs, n_heads, seq_length, seq_length) The attention weights ffn_output:
+ torch.tensor(bs, seq_length, dim) The output of the transformer block contextualization.
+ """
+ # Self-Attention
+ sa_output = self.attention(
+ query=x,
+ key=x,
+ value=x,
+ mask=attn_mask,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ )
+ if output_attentions:
+ sa_output, sa_weights = sa_output # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length)
+ else: # To handle these `output_attentions` or `output_hidden_states` cases returning tuples
+ if type(sa_output) is not tuple:
+ raise TypeError(f"sa_output must be a tuple but it is {type(sa_output)} type")
+
+ sa_output = sa_output[0]
+ sa_output = self.sa_layer_norm(sa_output + x) # (bs, seq_length, dim)
+
+ # Feed Forward Network
+ ffn_output = self.ffn(sa_output) # (bs, seq_length, dim)
+ ffn_output: torch.Tensor = self.output_layer_norm(ffn_output + sa_output) # (bs, seq_length, dim)
+
+ output = (ffn_output,)
+ if output_attentions:
+ output = (sa_weights,) + output
+ return output
+
+
+class Transformer(nn.Module):
+ def __init__(self, config: PretrainedConfig):
+ super().__init__()
+ self.n_layers = config.n_layers
+ self.layer = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)])
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: Optional[bool] = None,
+ ) -> Union[BaseModelOutput, Tuple[torch.Tensor, ...]]: # docstyle-ignore
+ """
+ Parameters:
+ x: torch.tensor(bs, seq_length, dim) Input sequence embedded.
+ attn_mask: torch.tensor(bs, seq_length) Attention mask on the sequence.
+
+ Returns:
+ hidden_state: torch.tensor(bs, seq_length, dim) Sequence of hidden states in the last (top)
+ layer all_hidden_states: Tuple[torch.tensor(bs, seq_length, dim)]
+ Tuple of length n_layers with the hidden states from each layer.
+ Optional: only if output_hidden_states=True
+ all_attentions: Tuple[torch.tensor(bs, n_heads, seq_length, seq_length)]
+ Tuple of length n_layers with the attention weights from each layer
+ Optional: only if output_attentions=True
+ """
+ all_hidden_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ hidden_state = x
+ for i, layer_module in enumerate(self.layer):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_state,)
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ layer_module.__call__,
+ hidden_state,
+ attn_mask,
+ head_mask[i],
+ output_attentions,
+ )
+ else:
+ layer_outputs = layer_module(
+ hidden_state,
+ attn_mask,
+ head_mask[i],
+ output_attentions,
+ )
+
+ hidden_state = layer_outputs[-1]
+
+ if output_attentions:
+ if len(layer_outputs) != 2:
+ raise ValueError(f"The length of the layer_outputs should be 2, but it is {len(layer_outputs)}")
+
+ attentions = layer_outputs[0]
+ all_attentions = all_attentions + (attentions,)
+ else:
+ if len(layer_outputs) != 1:
+ raise ValueError(f"The length of the layer_outputs should be 1, but it is {len(layer_outputs)}")
+
+ # Add last layer
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_state,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_state, all_hidden_states, all_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_state, hidden_states=all_hidden_states, attentions=all_attentions
+ )
+
+
+# INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL #
+class DistilBertPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = DistilBertConfig
+ load_tf_weights = None
+ base_model_prefix = "distilbert"
+ supports_gradient_checkpointing = True
+ _supports_flash_attn_2 = True
+ _supports_sdpa = True
+
+ def _init_weights(self, module: nn.Module):
+ """Initialize the weights."""
+ if isinstance(module, nn.Linear):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, Embeddings) and self.config.sinusoidal_pos_embds:
+ create_sinusoidal_embeddings(
+ self.config.max_position_embeddings, self.config.dim, module.position_embeddings.weight
+ )
+
+
+DISTILBERT_START_DOCSTRING = r"""
+
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`DistilBertConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+DISTILBERT_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `({0})`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare DistilBERT encoder/transformer outputting raw hidden-states without any specific head on top.",
+ DISTILBERT_START_DOCSTRING,
+)
+class DistilBertModel(DistilBertPreTrainedModel):
+ def __init__(self, config: PretrainedConfig):
+ super().__init__(config)
+
+ self.embeddings = Embeddings(config) # Embeddings
+ self.transformer = Transformer(config) # Encoder
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_sdpa = config._attn_implementation == "sdpa"
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_position_embeddings(self) -> nn.Embedding:
+ """
+ Returns the position embeddings
+ """
+ return self.embeddings.position_embeddings
+
+ def resize_position_embeddings(self, new_num_position_embeddings: int):
+ """
+ Resizes position embeddings of the model if `new_num_position_embeddings != config.max_position_embeddings`.
+
+ Arguments:
+ new_num_position_embeddings (`int`):
+ The number of new position embedding matrix. If position embeddings are learned, increasing the size
+ will add newly initialized vectors at the end, whereas reducing the size will remove vectors from the
+ end. If position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the
+ size will add correct vectors at the end following the position encoding algorithm, whereas reducing
+ the size will remove vectors from the end.
+ """
+ num_position_embeds_diff = new_num_position_embeddings - self.config.max_position_embeddings
+
+ # no resizing needs to be done if the length stays the same
+ if num_position_embeds_diff == 0:
+ return
+
+ logger.info(f"Setting `config.max_position_embeddings={new_num_position_embeddings}`...")
+ self.config.max_position_embeddings = new_num_position_embeddings
+
+ old_position_embeddings_weight = self.embeddings.position_embeddings.weight.clone()
+
+ self.embeddings.position_embeddings = nn.Embedding(self.config.max_position_embeddings, self.config.dim)
+
+ if self.config.sinusoidal_pos_embds:
+ create_sinusoidal_embeddings(
+ n_pos=self.config.max_position_embeddings, dim=self.config.dim, out=self.position_embeddings.weight
+ )
+ else:
+ with torch.no_grad():
+ if num_position_embeds_diff > 0:
+ self.embeddings.position_embeddings.weight[:-num_position_embeds_diff] = nn.Parameter(
+ old_position_embeddings_weight
+ )
+ else:
+ self.embeddings.position_embeddings.weight = nn.Parameter(
+ old_position_embeddings_weight[:num_position_embeds_diff]
+ )
+ # move position_embeddings to correct device
+ self.embeddings.position_embeddings.to(self.device)
+
+ def get_input_embeddings(self) -> nn.Embedding:
+ return self.embeddings.word_embeddings
+
+ def set_input_embeddings(self, new_embeddings: nn.Embedding):
+ self.embeddings.word_embeddings = new_embeddings
+
+ def _prune_heads(self, heads_to_prune: Dict[int, List[List[int]]]):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.transformer.layer[layer].attention.prune_heads(heads)
+
+ @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, num_choices"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=BaseModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[BaseModelOutput, Tuple[torch.Tensor, ...]]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
+ input_shape = input_ids.size()
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+
+ head_mask_is_none = head_mask is None
+ # Prepare head mask if needed
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ embeddings = self.embeddings(input_ids, inputs_embeds) # (bs, seq_length, dim)
+
+ if self._use_flash_attention_2:
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
+ else:
+ if attention_mask is None:
+ attention_mask = torch.ones(input_shape, device=device) # (bs, seq_length)
+
+ if self._use_sdpa and head_mask_is_none and not output_attentions:
+ attention_mask = _prepare_4d_attention_mask_for_sdpa(
+ attention_mask, embeddings.dtype, tgt_len=input_shape[1]
+ )
+
+ return self.transformer(
+ x=embeddings,
+ attn_mask=attention_mask,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+
+@add_start_docstrings(
+ """DistilBert Model with a `masked language modeling` head on top.""",
+ DISTILBERT_START_DOCSTRING,
+)
+class DistilBertForMaskedLM(DistilBertPreTrainedModel):
+ _tied_weights_keys = ["vocab_projector.weight"]
+
+ def __init__(self, config: PretrainedConfig):
+ super().__init__(config)
+
+ self.activation = get_activation(config.activation)
+
+ self.distilbert = DistilBertModel(config)
+ self.vocab_transform = nn.Linear(config.dim, config.dim)
+ self.vocab_layer_norm = nn.LayerNorm(config.dim, eps=1e-12)
+ self.vocab_projector = nn.Linear(config.dim, config.vocab_size)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ self.mlm_loss_fct = nn.CrossEntropyLoss()
+
+ def get_position_embeddings(self) -> nn.Embedding:
+ """
+ Returns the position embeddings
+ """
+ return self.distilbert.get_position_embeddings()
+
+ def resize_position_embeddings(self, new_num_position_embeddings: int):
+ """
+ Resizes position embeddings of the model if `new_num_position_embeddings != config.max_position_embeddings`.
+
+ Arguments:
+ new_num_position_embeddings (`int`):
+ The number of new position embedding matrix. If position embeddings are learned, increasing the size
+ will add newly initialized vectors at the end, whereas reducing the size will remove vectors from the
+ end. If position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the
+ size will add correct vectors at the end following the position encoding algorithm, whereas reducing
+ the size will remove vectors from the end.
+ """
+ self.distilbert.resize_position_embeddings(new_num_position_embeddings)
+
+ def get_output_embeddings(self) -> nn.Module:
+ return self.vocab_projector
+
+ def set_output_embeddings(self, new_embeddings: nn.Module):
+ self.vocab_projector = new_embeddings
+
+ @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, num_choices"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=MaskedLMOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[MaskedLMOutput, Tuple[torch.Tensor, ...]]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ dlbrt_output = self.distilbert(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ hidden_states = dlbrt_output[0] # (bs, seq_length, dim)
+ prediction_logits = self.vocab_transform(hidden_states) # (bs, seq_length, dim)
+ prediction_logits = self.activation(prediction_logits) # (bs, seq_length, dim)
+ prediction_logits = self.vocab_layer_norm(prediction_logits) # (bs, seq_length, dim)
+ prediction_logits = self.vocab_projector(prediction_logits) # (bs, seq_length, vocab_size)
+
+ mlm_loss = None
+ if labels is not None:
+ mlm_loss = self.mlm_loss_fct(prediction_logits.view(-1, prediction_logits.size(-1)), labels.view(-1))
+
+ if not return_dict:
+ output = (prediction_logits,) + dlbrt_output[1:]
+ return ((mlm_loss,) + output) if mlm_loss is not None else output
+
+ return MaskedLMOutput(
+ loss=mlm_loss,
+ logits=prediction_logits,
+ hidden_states=dlbrt_output.hidden_states,
+ attentions=dlbrt_output.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ DistilBert Model transformer with a sequence classification/regression head on top (a linear layer on top of the
+ pooled output) e.g. for GLUE tasks.
+ """,
+ DISTILBERT_START_DOCSTRING,
+)
+class DistilBertForSequenceClassification(DistilBertPreTrainedModel):
+ def __init__(self, config: PretrainedConfig):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.config = config
+
+ self.distilbert = DistilBertModel(config)
+ self.pre_classifier = nn.Linear(config.dim, config.dim)
+ self.classifier = nn.Linear(config.dim, config.num_labels)
+ self.dropout = nn.Dropout(config.seq_classif_dropout)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_position_embeddings(self) -> nn.Embedding:
+ """
+ Returns the position embeddings
+ """
+ return self.distilbert.get_position_embeddings()
+
+ def resize_position_embeddings(self, new_num_position_embeddings: int):
+ """
+ Resizes position embeddings of the model if `new_num_position_embeddings != config.max_position_embeddings`.
+
+ Arguments:
+ new_num_position_embeddings (`int`):
+ The number of new position embedding matrix. If position embeddings are learned, increasing the size
+ will add newly initialized vectors at the end, whereas reducing the size will remove vectors from the
+ end. If position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the
+ size will add correct vectors at the end following the position encoding algorithm, whereas reducing
+ the size will remove vectors from the end.
+ """
+ self.distilbert.resize_position_embeddings(new_num_position_embeddings)
+
+ @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=SequenceClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[SequenceClassifierOutput, Tuple[torch.Tensor, ...]]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ distilbert_output = self.distilbert(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ hidden_state = distilbert_output[0] # (bs, seq_len, dim)
+ pooled_output = hidden_state[:, 0] # (bs, dim)
+ pooled_output = self.pre_classifier(pooled_output) # (bs, dim)
+ pooled_output = nn.ReLU()(pooled_output) # (bs, dim)
+ pooled_output = self.dropout(pooled_output) # (bs, dim)
+ logits = self.classifier(pooled_output) # (bs, num_labels)
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+
+ if not return_dict:
+ output = (logits,) + distilbert_output[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=distilbert_output.hidden_states,
+ attentions=distilbert_output.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ DistilBert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a
+ linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
+ """,
+ DISTILBERT_START_DOCSTRING,
+)
+class DistilBertForQuestionAnswering(DistilBertPreTrainedModel):
+ def __init__(self, config: PretrainedConfig):
+ super().__init__(config)
+
+ self.distilbert = DistilBertModel(config)
+ self.qa_outputs = nn.Linear(config.dim, config.num_labels)
+ if config.num_labels != 2:
+ raise ValueError(f"config.num_labels should be 2, but it is {config.num_labels}")
+
+ self.dropout = nn.Dropout(config.qa_dropout)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_position_embeddings(self) -> nn.Embedding:
+ """
+ Returns the position embeddings
+ """
+ return self.distilbert.get_position_embeddings()
+
+ def resize_position_embeddings(self, new_num_position_embeddings: int):
+ """
+ Resizes position embeddings of the model if `new_num_position_embeddings != config.max_position_embeddings`.
+
+ Arguments:
+ new_num_position_embeddings (`int`):
+ The number of new position embedding matrix. If position embeddings are learned, increasing the size
+ will add newly initialized vectors at the end, whereas reducing the size will remove vectors from the
+ end. If position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the
+ size will add correct vectors at the end following the position encoding algorithm, whereas reducing
+ the size will remove vectors from the end.
+ """
+ self.distilbert.resize_position_embeddings(new_num_position_embeddings)
+
+ @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, num_choices"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=QuestionAnsweringModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ start_positions: Optional[torch.Tensor] = None,
+ end_positions: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[QuestionAnsweringModelOutput, Tuple[torch.Tensor, ...]]:
+ r"""
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
+ are not taken into account for computing the loss.
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
+ are not taken into account for computing the loss.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ distilbert_output = self.distilbert(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ hidden_states = distilbert_output[0] # (bs, max_query_len, dim)
+
+ hidden_states = self.dropout(hidden_states) # (bs, max_query_len, dim)
+ logits = self.qa_outputs(hidden_states) # (bs, max_query_len, 2)
+ start_logits, end_logits = logits.split(1, dim=-1)
+ start_logits = start_logits.squeeze(-1).contiguous() # (bs, max_query_len)
+ end_logits = end_logits.squeeze(-1).contiguous() # (bs, max_query_len)
+
+ total_loss = None
+ if start_positions is not None and end_positions is not None:
+ # If we are on multi-GPU, split add a dimension
+ if len(start_positions.size()) > 1:
+ start_positions = start_positions.squeeze(-1)
+ if len(end_positions.size()) > 1:
+ end_positions = end_positions.squeeze(-1)
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
+ ignored_index = start_logits.size(1)
+ start_positions = start_positions.clamp(0, ignored_index)
+ end_positions = end_positions.clamp(0, ignored_index)
+
+ loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)
+ start_loss = loss_fct(start_logits, start_positions)
+ end_loss = loss_fct(end_logits, end_positions)
+ total_loss = (start_loss + end_loss) / 2
+
+ if not return_dict:
+ output = (start_logits, end_logits) + distilbert_output[1:]
+ return ((total_loss,) + output) if total_loss is not None else output
+
+ return QuestionAnsweringModelOutput(
+ loss=total_loss,
+ start_logits=start_logits,
+ end_logits=end_logits,
+ hidden_states=distilbert_output.hidden_states,
+ attentions=distilbert_output.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ DistilBert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g.
+ for Named-Entity-Recognition (NER) tasks.
+ """,
+ DISTILBERT_START_DOCSTRING,
+)
+class DistilBertForTokenClassification(DistilBertPreTrainedModel):
+ def __init__(self, config: PretrainedConfig):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.distilbert = DistilBertModel(config)
+ self.dropout = nn.Dropout(config.dropout)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_position_embeddings(self) -> nn.Embedding:
+ """
+ Returns the position embeddings
+ """
+ return self.distilbert.get_position_embeddings()
+
+ def resize_position_embeddings(self, new_num_position_embeddings: int):
+ """
+ Resizes position embeddings of the model if `new_num_position_embeddings != config.max_position_embeddings`.
+
+ Arguments:
+ new_num_position_embeddings (`int`):
+ The number of new position embedding matrix. If position embeddings are learned, increasing the size
+ will add newly initialized vectors at the end, whereas reducing the size will remove vectors from the
+ end. If position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the
+ size will add correct vectors at the end following the position encoding algorithm, whereas reducing
+ the size will remove vectors from the end.
+ """
+ self.distilbert.resize_position_embeddings(new_num_position_embeddings)
+
+ @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TokenClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[TokenClassifierOutput, Tuple[torch.Tensor, ...]]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.distilbert(
+ input_ids,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ sequence_output = self.dropout(sequence_output)
+ logits = self.classifier(sequence_output)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ DistilBert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and
+ a softmax) e.g. for RocStories/SWAG tasks.
+ """,
+ DISTILBERT_START_DOCSTRING,
+)
+class DistilBertForMultipleChoice(DistilBertPreTrainedModel):
+ def __init__(self, config: PretrainedConfig):
+ super().__init__(config)
+
+ self.distilbert = DistilBertModel(config)
+ self.pre_classifier = nn.Linear(config.dim, config.dim)
+ self.classifier = nn.Linear(config.dim, 1)
+ self.dropout = nn.Dropout(config.seq_classif_dropout)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_position_embeddings(self) -> nn.Embedding:
+ """
+ Returns the position embeddings
+ """
+ return self.distilbert.get_position_embeddings()
+
+ def resize_position_embeddings(self, new_num_position_embeddings: int):
+ """
+ Resizes position embeddings of the model if `new_num_position_embeddings != config.max_position_embeddings`.
+
+ Arguments:
+ new_num_position_embeddings (`int`)
+ The number of new position embeddings. If position embeddings are learned, increasing the size will add
+ newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If
+ position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will
+ add correct vectors at the end following the position encoding algorithm, whereas reducing the size
+ will remove vectors from the end.
+ """
+ self.distilbert.resize_position_embeddings(new_num_position_embeddings)
+
+ @add_start_docstrings_to_model_forward(
+ DISTILBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
+ )
+ @replace_return_docstrings(output_type=MultipleChoiceModelOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[MultipleChoiceModelOutput, Tuple[torch.Tensor, ...]]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
+ num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
+ `input_ids` above)
+
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoTokenizer, DistilBertForMultipleChoice
+ >>> import torch
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("distilbert-base-cased")
+ >>> model = DistilBertForMultipleChoice.from_pretrained("distilbert-base-cased")
+
+ >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
+ >>> choice0 = "It is eaten with a fork and a knife."
+ >>> choice1 = "It is eaten while held in the hand."
+ >>> labels = torch.tensor(0).unsqueeze(0) # choice0 is correct (according to Wikipedia ;)), batch size 1
+
+ >>> encoding = tokenizer([[prompt, choice0], [prompt, choice1]], return_tensors="pt", padding=True)
+ >>> outputs = model(**{k: v.unsqueeze(0) for k, v in encoding.items()}, labels=labels) # batch size is 1
+
+ >>> # the linear classifier still needs to be trained
+ >>> loss = outputs.loss
+ >>> logits = outputs.logits
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
+
+ input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
+ inputs_embeds = (
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
+ if inputs_embeds is not None
+ else None
+ )
+
+ outputs = self.distilbert(
+ input_ids,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_state = outputs[0] # (bs * num_choices, seq_len, dim)
+ pooled_output = hidden_state[:, 0] # (bs * num_choices, dim)
+ pooled_output = self.pre_classifier(pooled_output) # (bs * num_choices, dim)
+ pooled_output = nn.ReLU()(pooled_output) # (bs * num_choices, dim)
+ pooled_output = self.dropout(pooled_output) # (bs * num_choices, dim)
+ logits = self.classifier(pooled_output) # (bs * num_choices, 1)
+
+ reshaped_logits = logits.view(-1, num_choices) # (bs, num_choices)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(reshaped_logits, labels)
+
+ if not return_dict:
+ output = (reshaped_logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return MultipleChoiceModelOutput(
+ loss=loss,
+ logits=reshaped_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+__all__ = [
+ "DistilBertForMaskedLM",
+ "DistilBertForMultipleChoice",
+ "DistilBertForQuestionAnswering",
+ "DistilBertForSequenceClassification",
+ "DistilBertForTokenClassification",
+ "DistilBertModel",
+ "DistilBertPreTrainedModel",
+]
diff --git a/docs/transformers/build/lib/transformers/models/distilbert/modeling_flax_distilbert.py b/docs/transformers/build/lib/transformers/models/distilbert/modeling_flax_distilbert.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f2b6ac96ab63590dba11b9535042d05edb58400
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/distilbert/modeling_flax_distilbert.py
@@ -0,0 +1,906 @@
+# coding=utf-8
+# Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import Callable, Optional, Tuple
+
+import flax.linen as nn
+import jax
+import jax.numpy as jnp
+import numpy as np
+from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
+from flax.traverse_util import flatten_dict, unflatten_dict
+from jax import lax
+
+from ...modeling_flax_outputs import (
+ FlaxBaseModelOutput,
+ FlaxMaskedLMOutput,
+ FlaxMultipleChoiceModelOutput,
+ FlaxQuestionAnsweringModelOutput,
+ FlaxSequenceClassifierOutput,
+ FlaxTokenClassifierOutput,
+)
+from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring, overwrite_call_docstring
+from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging
+from .configuration_distilbert import DistilBertConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "distilbert-base-uncased"
+_CONFIG_FOR_DOC = "DistilBertConfig"
+
+
+FLAX_DISTILBERT_START_DOCSTRING = r"""
+
+ This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading, saving and converting weights from PyTorch models)
+
+ This model is also a
+ [flax.linen.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) subclass. Use it as
+ a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and
+ behavior.
+
+ Finally, this model supports inherent JAX features such as:
+
+ - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
+ - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
+ - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
+ - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
+
+ Parameters:
+ config ([`DistilBertConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+DISTILBERT_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`numpy.ndarray` of shape `({0})`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`numpy.ndarray` of shape `({0})`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+def get_angles(pos, i, d_model):
+ angle_rates = 1 / np.power(10000, (2 * (i // 2)) / np.float32(d_model))
+ return pos * angle_rates
+
+
+def positional_encoding(position, d_model):
+ # create the sinusoidal pattern for the positional encoding
+ angle_rads = get_angles(np.arange(position)[:, np.newaxis], np.arange(d_model)[np.newaxis, :], d_model)
+
+ # apply sin to even indices in the array; 2i
+ angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])
+
+ # apply cos to odd indices in the array; 2i+1
+ angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])
+
+ pos_encoding = angle_rads[np.newaxis, ...]
+
+ return jnp.array(pos_encoding)
+
+
+class FlaxEmbeddings(nn.Module):
+ """Construct the embeddings from word, position and token_type embeddings."""
+
+ config: DistilBertConfig
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.word_embeddings = nn.Embed(
+ self.config.vocab_size,
+ self.config.dim,
+ embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
+ )
+ if not self.config.sinusoidal_pos_embds:
+ self.position_embeddings = nn.Embed(
+ self.config.max_position_embeddings,
+ self.config.dim,
+ embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
+ )
+ else:
+ self.pos_encoding = positional_encoding(self.config.max_position_embeddings, self.config.dim)
+ self.LayerNorm = nn.LayerNorm(epsilon=1e-12, dtype=self.dtype)
+ self.dropout = nn.Dropout(rate=self.config.dropout)
+
+ def __call__(self, input_ids, deterministic: bool = True):
+ # Embed
+ batch_size, seq_length = input_ids.shape
+ inputs_embeds = self.word_embeddings(input_ids.astype("i4"))
+ if not self.config.sinusoidal_pos_embds:
+ position_ids = jnp.arange(seq_length).astype("i4")
+ position_ids = jnp.broadcast_to(position_ids, shape=(batch_size, seq_length))
+ position_embeds = self.position_embeddings(position_ids.astype("i4"))
+ else:
+ position_embeds = self.pos_encoding[:, :seq_length, :]
+ # explicitly cast the positions here, since self.embed_positions are not registered as parameters
+ position_embeds = position_embeds.astype(inputs_embeds.dtype)
+
+ # Sum all embeddings
+ hidden_states = inputs_embeds + position_embeds
+
+ # Layer Norm
+ hidden_states = self.LayerNorm(hidden_states)
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
+ return hidden_states
+
+
+class FlaxMultiHeadSelfAttention(nn.Module):
+ config: DistilBertConfig
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.n_heads = self.config.n_heads
+ self.dim = self.config.dim
+ self.dropout = nn.Dropout(rate=self.config.attention_dropout)
+
+ if not (self.dim % self.n_heads == 0):
+ raise ValueError(f"Hidden size {self.dim} not dividable by number of heads {self.n_heads}")
+
+ self.q_lin = nn.Dense(
+ self.dim,
+ dtype=self.dtype,
+ kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
+ )
+ self.k_lin = nn.Dense(
+ self.dim,
+ dtype=self.dtype,
+ kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
+ )
+ self.v_lin = nn.Dense(
+ self.dim,
+ dtype=self.dtype,
+ kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
+ )
+ self.out_lin = nn.Dense(
+ self.dim,
+ dtype=self.dtype,
+ kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
+ )
+
+ def __call__(
+ self,
+ query,
+ key,
+ value,
+ mask,
+ deterministic: bool = True,
+ output_attentions: bool = False,
+ ):
+ bs, q_len, dim = query.shape
+ k_len = key.shape[1]
+ # assert dim == self.dim, f'Dimensions do not match: {dim} input vs {self.dim} configured'
+ # assert key.size() == value.size()
+
+ dim_per_head = self.dim // self.n_heads
+
+ mask_reshp = (bs, 1, 1, k_len)
+
+ def shape(x):
+ """separate heads"""
+ return x.reshape(bs, -1, self.n_heads, dim_per_head).transpose(0, 2, 1, 3)
+
+ def unshape(x):
+ """group heads"""
+ return x.transpose(0, 2, 1, 3).reshape(bs, -1, self.n_heads * dim_per_head)
+
+ q = shape(self.q_lin(query)) # (bs, n_heads, q_len, dim_per_head)
+ k = shape(self.k_lin(key)) # (bs, n_heads, k_len, dim_per_head)
+ v = shape(self.v_lin(value)) # (bs, n_heads, k_len, dim_per_head)
+
+ q = q / math.sqrt(dim_per_head) # (bs, n_heads, q_len, dim_per_head)
+ scores = jnp.matmul(q, k.transpose(0, 1, 3, 2)) # (bs, n_heads, q_len, k_len)
+ mask = jnp.reshape(mask, mask_reshp)
+
+ mask = mask.astype(scores.dtype)
+ scores = scores - 1e30 * (1.0 - mask)
+
+ weights = nn.softmax(scores, axis=-1) # (bs, n_heads, q_len, k_len)
+ weights = self.dropout(weights, deterministic=deterministic)
+
+ context = jnp.matmul(weights, v) # (bs, n_heads, q_len, dim_per_head)
+ context = unshape(context) # (bs, q_len, dim)
+ context = self.out_lin(context) # (bs, q_len, dim)
+
+ if output_attentions:
+ return (context, weights)
+ else:
+ return (context,)
+
+
+class FlaxFFN(nn.Module):
+ config: DistilBertConfig
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.dropout = nn.Dropout(rate=self.config.dropout)
+ self.chunk_size_feed_forward = self.config.chunk_size_feed_forward
+ self.seq_len_dim = 1
+ self.lin1 = nn.Dense(
+ self.config.hidden_dim,
+ dtype=self.dtype,
+ kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
+ )
+ self.lin2 = nn.Dense(
+ self.config.dim,
+ dtype=self.dtype,
+ kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
+ )
+
+ self.activation = ACT2FN[self.config.activation]
+
+ def __call__(self, hidden_states, deterministic: bool = True):
+ hidden_states = self.lin1(hidden_states)
+ hidden_states = self.activation(hidden_states)
+ hidden_states = self.lin2(hidden_states)
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
+ return hidden_states
+
+
+class FlaxTransformerBlock(nn.Module):
+ config: DistilBertConfig
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ assert self.config.dim % self.config.n_heads == 0, (
+ f"Hidden size {self.config.dim} not dividable by number of heads {self.config.n_heads}"
+ )
+
+ self.attention = FlaxMultiHeadSelfAttention(self.config, dtype=self.dtype)
+ self.sa_layer_norm = nn.LayerNorm(epsilon=1e-12, dtype=self.dtype)
+
+ self.ffn = FlaxFFN(self.config, dtype=self.dtype)
+ self.output_layer_norm = nn.LayerNorm(epsilon=1e-12, dtype=self.dtype)
+
+ def __call__(
+ self,
+ hidden_states,
+ attn_mask,
+ output_attentions: bool = False,
+ deterministic: bool = True,
+ ):
+ # Self-Attention
+ sa_output = self.attention(
+ query=hidden_states,
+ key=hidden_states,
+ value=hidden_states,
+ mask=attn_mask,
+ output_attentions=output_attentions,
+ deterministic=deterministic,
+ )
+ if output_attentions:
+ sa_output, sa_weights = sa_output
+ else:
+ assert type(sa_output) is tuple
+ sa_output = sa_output[0]
+ sa_output = self.sa_layer_norm(sa_output + hidden_states)
+
+ # Feed Forward Network
+ ffn_output = self.ffn(sa_output, deterministic=deterministic)
+ ffn_output = self.output_layer_norm(ffn_output + sa_output)
+ output = (ffn_output,)
+ if output_attentions:
+ output = (sa_weights,) + output
+ return output
+
+
+class FlaxTransformer(nn.Module):
+ config: DistilBertConfig
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.layers = [
+ FlaxTransformerBlock(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.n_layers)
+ ]
+
+ def __call__(
+ self,
+ hidden_states,
+ attention_mask,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ deterministic: bool = True,
+ return_dict: bool = False,
+ ):
+ all_hidden_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ for layer_module in self.layers:
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_outputs = layer_module(
+ hidden_states=hidden_states,
+ attn_mask=attention_mask,
+ output_attentions=output_attentions,
+ deterministic=deterministic,
+ )
+ hidden_states = layer_outputs[-1]
+
+ if output_attentions:
+ assert len(layer_outputs) == 2
+ attentions = layer_outputs[0]
+ all_attentions = all_attentions + (attentions,)
+ else:
+ assert len(layer_outputs) == 1
+
+ # Add last layer
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_attentions, all_hidden_states] if v is not None)
+ return FlaxBaseModelOutput(
+ last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
+ )
+
+
+class FlaxTransformerEncoder(nn.Module):
+ config: DistilBertConfig
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.layer = FlaxTransformer(self.config, dtype=self.dtype)
+
+ def __call__(
+ self,
+ hidden_states,
+ attention_mask,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ deterministic: bool = True,
+ return_dict: bool = False,
+ ):
+ return self.layer(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ deterministic=deterministic,
+ return_dict=return_dict,
+ )
+
+
+class FlaxDistilBertLMDecoder(nn.Module):
+ config: DistilBertConfig
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+ bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros
+
+ def setup(self):
+ self.bias = self.param("bias", self.bias_init, (self.config.vocab_size,))
+
+ def __call__(self, inputs, kernel):
+ inputs = jnp.asarray(inputs, self.dtype)
+ kernel = jnp.asarray(kernel, self.dtype)
+ y = lax.dot_general(inputs, kernel, (((inputs.ndim - 1,), (0,)), ((), ())))
+ bias = jnp.asarray(self.bias, self.dtype)
+ y = y + bias
+ return y
+
+
+class FlaxDistilBertPreTrainedModel(FlaxPreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = DistilBertConfig
+ base_model_prefix = "distilbert"
+ module_class: nn.Module = None
+
+ def __init__(
+ self,
+ config: DistilBertConfig,
+ input_shape: Tuple = (1, 1),
+ seed: int = 0,
+ dtype: jnp.dtype = jnp.float32,
+ _do_init: bool = True,
+ **kwargs,
+ ):
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
+
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
+ # init input tensors
+ input_ids = jnp.zeros(input_shape, dtype="i4")
+ attention_mask = jnp.ones_like(input_ids)
+
+ params_rng, dropout_rng = jax.random.split(rng)
+ rngs = {"params": params_rng, "dropout": dropout_rng}
+
+ random_params = self.module.init(rngs, input_ids, attention_mask, return_dict=False)["params"]
+
+ if params is not None:
+ random_params = flatten_dict(unfreeze(random_params))
+ params = flatten_dict(unfreeze(params))
+ for missing_key in self._missing_keys:
+ params[missing_key] = random_params[missing_key]
+ self._missing_keys = set()
+ return freeze(unflatten_dict(params))
+ else:
+ return random_params
+
+ @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ def __call__(
+ self,
+ input_ids,
+ attention_mask=None,
+ head_mask=None,
+ params: dict = None,
+ dropout_rng: jax.random.PRNGKey = None,
+ train: bool = False,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ):
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
+
+ if attention_mask is None:
+ attention_mask = jnp.ones_like(input_ids)
+
+ # Handle any PRNG if needed
+ rngs = {}
+ if dropout_rng is not None:
+ rngs["dropout"] = dropout_rng
+
+ return self.module.apply(
+ {"params": params or self.params},
+ jnp.array(input_ids, dtype="i4"),
+ jnp.array(attention_mask, dtype="i4"),
+ not train,
+ output_attentions,
+ output_hidden_states,
+ return_dict,
+ rngs=rngs,
+ )
+
+
+class FlaxDistilBertModule(nn.Module):
+ config: DistilBertConfig
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.embeddings = FlaxEmbeddings(self.config, dtype=self.dtype)
+ self.transformer = FlaxTransformerEncoder(self.config, dtype=self.dtype)
+
+ def __call__(
+ self,
+ input_ids,
+ attention_mask,
+ deterministic: bool = True,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ):
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
+
+ input_embeds = self.embeddings(input_ids, deterministic=deterministic)
+ return self.transformer(
+ hidden_states=input_embeds,
+ attention_mask=attention_mask,
+ deterministic=deterministic,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+
+@add_start_docstrings(
+ "The bare DistilBert Model transformer outputting raw hidden-states without any specific head on top.",
+ FLAX_DISTILBERT_START_DOCSTRING,
+)
+class FlaxDistilBertModel(FlaxDistilBertPreTrainedModel):
+ module_class = FlaxDistilBertModule
+
+
+append_call_sample_docstring(FlaxDistilBertModel, _CHECKPOINT_FOR_DOC, None, _CONFIG_FOR_DOC)
+
+
+class FlaxDistilBertForMaskedLMModule(nn.Module):
+ config: DistilBertConfig
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.distilbert = FlaxDistilBertModule(self.config, dtype=self.dtype)
+ self.vocab_transform = nn.Dense(
+ self.config.dim,
+ dtype=self.dtype,
+ kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
+ )
+ self.vocab_layer_norm = nn.LayerNorm(epsilon=1e-12, dtype=self.dtype)
+ if self.config.tie_word_embeddings:
+ self.vocab_projector = FlaxDistilBertLMDecoder(
+ self.config,
+ dtype=self.dtype,
+ )
+ else:
+ self.vocab_projector = nn.Dense(
+ self.config.vocab_size,
+ dtype=self.dtype,
+ kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
+ )
+
+ def __call__(
+ self,
+ input_ids,
+ attention_mask,
+ deterministic: bool = True,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ):
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ dlbrt_output = self.distilbert(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ deterministic=deterministic,
+ return_dict=return_dict,
+ )
+ hidden_states = dlbrt_output[0]
+ prediction_logits = self.vocab_transform(hidden_states)
+ prediction_logits = ACT2FN[self.config.activation](prediction_logits)
+ prediction_logits = self.vocab_layer_norm(prediction_logits)
+
+ if self.config.tie_word_embeddings:
+ shared_embedding = self.distilbert.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
+ prediction_logits = self.vocab_projector(prediction_logits, shared_embedding.T)
+ else:
+ prediction_logits = self.vocab_projector(prediction_logits)
+
+ if not return_dict:
+ output = (prediction_logits,) + dlbrt_output[1:]
+ return output
+
+ return FlaxMaskedLMOutput(
+ logits=prediction_logits,
+ hidden_states=dlbrt_output.hidden_states,
+ attentions=dlbrt_output.attentions,
+ )
+
+
+@add_start_docstrings("""DistilBert Model with a `language modeling` head on top.""", FLAX_DISTILBERT_START_DOCSTRING)
+class FlaxDistilBertForMaskedLM(FlaxDistilBertPreTrainedModel):
+ module_class = FlaxDistilBertForMaskedLMModule
+
+
+append_call_sample_docstring(FlaxDistilBertForMaskedLM, _CHECKPOINT_FOR_DOC, FlaxMaskedLMOutput, _CONFIG_FOR_DOC)
+
+
+class FlaxDistilBertForSequenceClassificationModule(nn.Module):
+ config: DistilBertConfig
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ self.distilbert = FlaxDistilBertModule(config=self.config, dtype=self.dtype)
+ self.pre_classifier = nn.Dense(
+ self.config.dim,
+ dtype=self.dtype,
+ kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
+ )
+ self.dropout = nn.Dropout(rate=self.config.seq_classif_dropout)
+ self.classifier = nn.Dense(
+ self.config.num_labels,
+ dtype=self.dtype,
+ )
+
+ def __call__(
+ self,
+ input_ids,
+ attention_mask,
+ deterministic: bool = True,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ):
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ # Model
+ distilbert_output = self.distilbert(
+ input_ids,
+ attention_mask,
+ deterministic=deterministic,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ hidden_state = distilbert_output[0] # (bs, seq_len, dim)
+ pooled_output = hidden_state[:, 0] # (bs, dim)
+ pooled_output = self.pre_classifier(pooled_output) # (bs, dim)
+ pooled_output = ACT2FN["relu"](pooled_output)
+ pooled_output = self.dropout(pooled_output, deterministic=deterministic)
+ logits = self.classifier(pooled_output) # (bs, dim)
+
+ if not return_dict:
+ return (logits,) + distilbert_output[1:]
+
+ return FlaxSequenceClassifierOutput(
+ logits=logits,
+ hidden_states=distilbert_output.hidden_states,
+ attentions=distilbert_output.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ DistilBert Model transformer with a sequence classification/regression head on top (a linear layer on top of the
+ pooled output) e.g. for GLUE tasks.
+ """,
+ FLAX_DISTILBERT_START_DOCSTRING,
+)
+class FlaxDistilBertForSequenceClassification(FlaxDistilBertPreTrainedModel):
+ module_class = FlaxDistilBertForSequenceClassificationModule
+
+
+append_call_sample_docstring(
+ FlaxDistilBertForSequenceClassification,
+ _CHECKPOINT_FOR_DOC,
+ FlaxSequenceClassifierOutput,
+ _CONFIG_FOR_DOC,
+)
+
+
+class FlaxDistilBertForMultipleChoiceModule(nn.Module):
+ config: DistilBertConfig
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ self.distilbert = FlaxDistilBertModule(config=self.config, dtype=self.dtype)
+ self.pre_classifier = nn.Dense(
+ self.config.dim,
+ dtype=self.dtype,
+ kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
+ )
+ self.dropout = nn.Dropout(rate=self.config.seq_classif_dropout)
+ self.classifier = nn.Dense(
+ 1,
+ dtype=self.dtype,
+ )
+
+ def __call__(
+ self,
+ input_ids,
+ attention_mask,
+ deterministic: bool = True,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ):
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ num_choices = input_ids.shape[1]
+ input_ids = input_ids.reshape(-1, input_ids.shape[-1]) if input_ids is not None else None
+ attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1]) if attention_mask is not None else None
+
+ # Model
+ outputs = self.distilbert(
+ input_ids,
+ attention_mask,
+ deterministic=deterministic,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_state = outputs[0]
+ pooled_output = hidden_state[:, 0]
+ pooled_output = self.pre_classifier(pooled_output)
+ pooled_output = ACT2FN["relu"](pooled_output)
+ pooled_output = self.dropout(pooled_output, deterministic=deterministic)
+ logits = self.classifier(pooled_output)
+
+ reshaped_logits = logits.reshape(-1, num_choices)
+
+ if not return_dict:
+ return (reshaped_logits,) + outputs[2:]
+
+ return FlaxMultipleChoiceModelOutput(
+ logits=reshaped_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ DistilBert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and
+ a softmax) e.g. for RocStories/SWAG tasks.
+ """,
+ FLAX_DISTILBERT_START_DOCSTRING,
+)
+class FlaxDistilBertForMultipleChoice(FlaxDistilBertPreTrainedModel):
+ module_class = FlaxDistilBertForMultipleChoiceModule
+
+
+overwrite_call_docstring(
+ FlaxDistilBertForMultipleChoice, DISTILBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
+)
+append_call_sample_docstring(
+ FlaxDistilBertForMultipleChoice,
+ _CHECKPOINT_FOR_DOC,
+ FlaxMultipleChoiceModelOutput,
+ _CONFIG_FOR_DOC,
+)
+
+
+class FlaxDistilBertForTokenClassificationModule(nn.Module):
+ config: DistilBertConfig
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ self.distilbert = FlaxDistilBertModule(config=self.config, dtype=self.dtype)
+ self.dropout = nn.Dropout(rate=self.config.dropout)
+ self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype)
+
+ def __call__(
+ self,
+ input_ids,
+ attention_mask,
+ deterministic: bool = True,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ):
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ # Model
+ outputs = self.distilbert(
+ input_ids,
+ attention_mask,
+ deterministic=deterministic,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[0]
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
+ logits = self.classifier(hidden_states)
+
+ if not return_dict:
+ return (logits,) + outputs[1:]
+
+ return FlaxTokenClassifierOutput(
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ DistilBert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g.
+ for Named-Entity-Recognition (NER) tasks.
+ """,
+ FLAX_DISTILBERT_START_DOCSTRING,
+)
+class FlaxDistilBertForTokenClassification(FlaxDistilBertPreTrainedModel):
+ module_class = FlaxDistilBertForTokenClassificationModule
+
+
+append_call_sample_docstring(
+ FlaxDistilBertForTokenClassification,
+ _CHECKPOINT_FOR_DOC,
+ FlaxTokenClassifierOutput,
+ _CONFIG_FOR_DOC,
+)
+
+
+class FlaxDistilBertForQuestionAnsweringModule(nn.Module):
+ config: DistilBertConfig
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ self.distilbert = FlaxDistilBertModule(config=self.config, dtype=self.dtype)
+ self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype)
+ assert self.config.num_labels == 2
+ self.dropout = nn.Dropout(rate=self.config.qa_dropout)
+
+ def __call__(
+ self,
+ input_ids,
+ attention_mask,
+ deterministic: bool = True,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ):
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # Model
+ distilbert_output = self.distilbert(
+ input_ids,
+ attention_mask,
+ deterministic=deterministic,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = distilbert_output[0]
+
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
+ logits = self.qa_outputs(hidden_states)
+ start_logits, end_logits = jnp.split(logits, self.config.num_labels, axis=-1)
+ start_logits = start_logits.squeeze(-1)
+ end_logits = end_logits.squeeze(-1)
+
+ if not return_dict:
+ return (start_logits, end_logits) + distilbert_output[1:]
+
+ return FlaxQuestionAnsweringModelOutput(
+ start_logits=start_logits,
+ end_logits=end_logits,
+ hidden_states=distilbert_output.hidden_states,
+ attentions=distilbert_output.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ DistilBert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a
+ linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
+ """,
+ FLAX_DISTILBERT_START_DOCSTRING,
+)
+class FlaxDistilBertForQuestionAnswering(FlaxDistilBertPreTrainedModel):
+ module_class = FlaxDistilBertForQuestionAnsweringModule
+
+
+append_call_sample_docstring(
+ FlaxDistilBertForQuestionAnswering,
+ _CHECKPOINT_FOR_DOC,
+ FlaxQuestionAnsweringModelOutput,
+ _CONFIG_FOR_DOC,
+)
+
+
+__all__ = [
+ "FlaxDistilBertForMaskedLM",
+ "FlaxDistilBertForMultipleChoice",
+ "FlaxDistilBertForQuestionAnswering",
+ "FlaxDistilBertForSequenceClassification",
+ "FlaxDistilBertForTokenClassification",
+ "FlaxDistilBertModel",
+ "FlaxDistilBertPreTrainedModel",
+]
diff --git a/docs/transformers/build/lib/transformers/models/distilbert/modeling_tf_distilbert.py b/docs/transformers/build/lib/transformers/models/distilbert/modeling_tf_distilbert.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0ee2f84835d41585f6133148a4596a1253e8604
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/distilbert/modeling_tf_distilbert.py
@@ -0,0 +1,1147 @@
+# coding=utf-8
+# Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+TF 2.0 DistilBERT model
+"""
+
+from __future__ import annotations
+
+import warnings
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import tensorflow as tf
+
+from ...activations_tf import get_tf_activation
+from ...modeling_tf_outputs import (
+ TFBaseModelOutput,
+ TFMaskedLMOutput,
+ TFMultipleChoiceModelOutput,
+ TFQuestionAnsweringModelOutput,
+ TFSequenceClassifierOutput,
+ TFTokenClassifierOutput,
+)
+from ...modeling_tf_utils import (
+ TFMaskedLanguageModelingLoss,
+ TFModelInputType,
+ TFMultipleChoiceLoss,
+ TFPreTrainedModel,
+ TFQuestionAnsweringLoss,
+ TFSequenceClassificationLoss,
+ TFTokenClassificationLoss,
+ get_initializer,
+ keras,
+ keras_serializable,
+ unpack_inputs,
+)
+from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
+from ...utils import (
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+)
+from .configuration_distilbert import DistilBertConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "distilbert-base-uncased"
+_CONFIG_FOR_DOC = "DistilBertConfig"
+
+
+class TFEmbeddings(keras.layers.Layer):
+ """Construct the embeddings from word, position and token_type embeddings."""
+
+ def __init__(self, config, **kwargs):
+ super().__init__(**kwargs)
+ self.config = config
+ self.dim = config.dim
+ self.initializer_range = config.initializer_range
+ self.max_position_embeddings = config.max_position_embeddings
+ self.LayerNorm = keras.layers.LayerNormalization(epsilon=1e-12, name="LayerNorm")
+ self.dropout = keras.layers.Dropout(rate=config.dropout)
+
+ def build(self, input_shape=None):
+ with tf.name_scope("word_embeddings"):
+ self.weight = self.add_weight(
+ name="weight",
+ shape=[self.config.vocab_size, self.dim],
+ initializer=get_initializer(initializer_range=self.initializer_range),
+ )
+
+ with tf.name_scope("position_embeddings"):
+ self.position_embeddings = self.add_weight(
+ name="embeddings",
+ shape=[self.max_position_embeddings, self.dim],
+ initializer=get_initializer(initializer_range=self.initializer_range),
+ )
+
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "LayerNorm", None) is not None:
+ with tf.name_scope(self.LayerNorm.name):
+ self.LayerNorm.build([None, None, self.config.dim])
+
+ def call(self, input_ids=None, position_ids=None, inputs_embeds=None, training=False):
+ """
+ Applies embedding based on inputs tensor.
+
+ Returns:
+ final_embeddings (`tf.Tensor`): output embedding tensor.
+ """
+ assert not (input_ids is None and inputs_embeds is None)
+
+ if input_ids is not None:
+ check_embeddings_within_bounds(input_ids, self.config.vocab_size)
+ inputs_embeds = tf.gather(params=self.weight, indices=input_ids)
+
+ input_shape = shape_list(inputs_embeds)[:-1]
+
+ if position_ids is None:
+ position_ids = tf.expand_dims(tf.range(start=0, limit=input_shape[-1]), axis=0)
+
+ position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
+ final_embeddings = inputs_embeds + position_embeds
+ final_embeddings = self.LayerNorm(inputs=final_embeddings)
+ final_embeddings = self.dropout(inputs=final_embeddings, training=training)
+
+ return final_embeddings
+
+
+class TFMultiHeadSelfAttention(keras.layers.Layer):
+ def __init__(self, config, **kwargs):
+ super().__init__(**kwargs)
+
+ self.n_heads = config.n_heads
+ self.dim = config.dim
+ self.dropout = keras.layers.Dropout(config.attention_dropout)
+ self.output_attentions = config.output_attentions
+
+ assert self.dim % self.n_heads == 0, f"Hidden size {self.dim} not dividable by number of heads {self.n_heads}"
+
+ self.q_lin = keras.layers.Dense(
+ config.dim, kernel_initializer=get_initializer(config.initializer_range), name="q_lin"
+ )
+ self.k_lin = keras.layers.Dense(
+ config.dim, kernel_initializer=get_initializer(config.initializer_range), name="k_lin"
+ )
+ self.v_lin = keras.layers.Dense(
+ config.dim, kernel_initializer=get_initializer(config.initializer_range), name="v_lin"
+ )
+ self.out_lin = keras.layers.Dense(
+ config.dim, kernel_initializer=get_initializer(config.initializer_range), name="out_lin"
+ )
+
+ self.pruned_heads = set()
+ self.config = config
+
+ def prune_heads(self, heads):
+ raise NotImplementedError
+
+ def call(self, query, key, value, mask, head_mask, output_attentions, training=False):
+ """
+ Parameters:
+ query: tf.Tensor(bs, seq_length, dim)
+ key: tf.Tensor(bs, seq_length, dim)
+ value: tf.Tensor(bs, seq_length, dim)
+ mask: tf.Tensor(bs, seq_length)
+
+ Returns:
+ weights: tf.Tensor(bs, n_heads, seq_length, seq_length) Attention weights context: tf.Tensor(bs,
+ seq_length, dim) Contextualized layer. Optional: only if `output_attentions=True`
+ """
+ bs, q_length, dim = shape_list(query)
+ k_length = shape_list(key)[1]
+ # assert dim == self.dim, f'Dimensions do not match: {dim} input vs {self.dim} configured'
+ # assert key.size() == value.size()
+ dim_per_head = int(self.dim / self.n_heads)
+ dim_per_head = tf.cast(dim_per_head, dtype=tf.int32)
+ mask_reshape = [bs, 1, 1, k_length]
+
+ def shape(x):
+ """separate heads"""
+ return tf.transpose(tf.reshape(x, (bs, -1, self.n_heads, dim_per_head)), perm=(0, 2, 1, 3))
+
+ def unshape(x):
+ """group heads"""
+ return tf.reshape(tf.transpose(x, perm=(0, 2, 1, 3)), (bs, -1, self.n_heads * dim_per_head))
+
+ q = shape(self.q_lin(query)) # (bs, n_heads, q_length, dim_per_head)
+ k = shape(self.k_lin(key)) # (bs, n_heads, k_length, dim_per_head)
+ v = shape(self.v_lin(value)) # (bs, n_heads, k_length, dim_per_head)
+ q = tf.cast(q, dtype=tf.float32)
+ q = tf.multiply(q, tf.math.rsqrt(tf.cast(dim_per_head, dtype=tf.float32)))
+ k = tf.cast(k, dtype=q.dtype)
+ scores = tf.matmul(q, k, transpose_b=True) # (bs, n_heads, q_length, k_length)
+ mask = tf.reshape(mask, mask_reshape) # (bs, n_heads, qlen, klen)
+ # scores.masked_fill_(mask, -float('inf')) # (bs, n_heads, q_length, k_length)
+
+ mask = tf.cast(mask, dtype=scores.dtype)
+ scores = scores - 1e30 * (1.0 - mask)
+ weights = stable_softmax(scores, axis=-1) # (bs, n_heads, qlen, klen)
+ weights = self.dropout(weights, training=training) # (bs, n_heads, qlen, klen)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ weights = weights * head_mask
+
+ context = tf.matmul(weights, v) # (bs, n_heads, qlen, dim_per_head)
+ context = unshape(context) # (bs, q_length, dim)
+ context = self.out_lin(context) # (bs, q_length, dim)
+
+ if output_attentions:
+ return (context, weights)
+ else:
+ return (context,)
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "q_lin", None) is not None:
+ with tf.name_scope(self.q_lin.name):
+ self.q_lin.build([None, None, self.config.dim])
+ if getattr(self, "k_lin", None) is not None:
+ with tf.name_scope(self.k_lin.name):
+ self.k_lin.build([None, None, self.config.dim])
+ if getattr(self, "v_lin", None) is not None:
+ with tf.name_scope(self.v_lin.name):
+ self.v_lin.build([None, None, self.config.dim])
+ if getattr(self, "out_lin", None) is not None:
+ with tf.name_scope(self.out_lin.name):
+ self.out_lin.build([None, None, self.config.dim])
+
+
+class TFFFN(keras.layers.Layer):
+ def __init__(self, config, **kwargs):
+ super().__init__(**kwargs)
+ self.dropout = keras.layers.Dropout(config.dropout)
+ self.lin1 = keras.layers.Dense(
+ config.hidden_dim, kernel_initializer=get_initializer(config.initializer_range), name="lin1"
+ )
+ self.lin2 = keras.layers.Dense(
+ config.dim, kernel_initializer=get_initializer(config.initializer_range), name="lin2"
+ )
+ self.activation = get_tf_activation(config.activation)
+ self.config = config
+
+ def call(self, input, training=False):
+ x = self.lin1(input)
+ x = self.activation(x)
+ x = self.lin2(x)
+ x = self.dropout(x, training=training)
+ return x
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "lin1", None) is not None:
+ with tf.name_scope(self.lin1.name):
+ self.lin1.build([None, None, self.config.dim])
+ if getattr(self, "lin2", None) is not None:
+ with tf.name_scope(self.lin2.name):
+ self.lin2.build([None, None, self.config.hidden_dim])
+
+
+class TFTransformerBlock(keras.layers.Layer):
+ def __init__(self, config, **kwargs):
+ super().__init__(**kwargs)
+
+ self.n_heads = config.n_heads
+ self.dim = config.dim
+ self.hidden_dim = config.hidden_dim
+ self.dropout = keras.layers.Dropout(config.dropout)
+ self.activation = config.activation
+ self.output_attentions = config.output_attentions
+
+ assert config.dim % config.n_heads == 0, (
+ f"Hidden size {config.dim} not dividable by number of heads {config.n_heads}"
+ )
+
+ self.attention = TFMultiHeadSelfAttention(config, name="attention")
+ self.sa_layer_norm = keras.layers.LayerNormalization(epsilon=1e-12, name="sa_layer_norm")
+
+ self.ffn = TFFFN(config, name="ffn")
+ self.output_layer_norm = keras.layers.LayerNormalization(epsilon=1e-12, name="output_layer_norm")
+ self.config = config
+
+ def call(self, x, attn_mask, head_mask, output_attentions, training=False): # removed: src_enc=None, src_len=None
+ """
+ Parameters:
+ x: tf.Tensor(bs, seq_length, dim)
+ attn_mask: tf.Tensor(bs, seq_length)
+
+ Outputs: sa_weights: tf.Tensor(bs, n_heads, seq_length, seq_length) The attention weights ffn_output:
+ tf.Tensor(bs, seq_length, dim) The output of the transformer block contextualization.
+ """
+ # Self-Attention
+ sa_output = self.attention(x, x, x, attn_mask, head_mask, output_attentions, training=training)
+ if output_attentions:
+ sa_output, sa_weights = sa_output # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length)
+ else: # To handle these `output_attentions` or `output_hidden_states` cases returning tuples
+ # assert type(sa_output) == tuple
+ sa_output = sa_output[0]
+ sa_output = self.sa_layer_norm(sa_output + x) # (bs, seq_length, dim)
+
+ # Feed Forward Network
+ ffn_output = self.ffn(sa_output, training=training) # (bs, seq_length, dim)
+ ffn_output = self.output_layer_norm(ffn_output + sa_output) # (bs, seq_length, dim)
+
+ output = (ffn_output,)
+ if output_attentions:
+ output = (sa_weights,) + output
+ return output
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "attention", None) is not None:
+ with tf.name_scope(self.attention.name):
+ self.attention.build(None)
+ if getattr(self, "sa_layer_norm", None) is not None:
+ with tf.name_scope(self.sa_layer_norm.name):
+ self.sa_layer_norm.build([None, None, self.config.dim])
+ if getattr(self, "ffn", None) is not None:
+ with tf.name_scope(self.ffn.name):
+ self.ffn.build(None)
+ if getattr(self, "output_layer_norm", None) is not None:
+ with tf.name_scope(self.output_layer_norm.name):
+ self.output_layer_norm.build([None, None, self.config.dim])
+
+
+class TFTransformer(keras.layers.Layer):
+ def __init__(self, config, **kwargs):
+ super().__init__(**kwargs)
+ self.n_layers = config.n_layers
+ self.output_hidden_states = config.output_hidden_states
+ self.output_attentions = config.output_attentions
+
+ self.layer = [TFTransformerBlock(config, name=f"layer_._{i}") for i in range(config.n_layers)]
+
+ def call(self, x, attn_mask, head_mask, output_attentions, output_hidden_states, return_dict, training=False):
+ # docstyle-ignore
+ """
+ Parameters:
+ x: tf.Tensor(bs, seq_length, dim) Input sequence embedded.
+ attn_mask: tf.Tensor(bs, seq_length) Attention mask on the sequence.
+
+ Returns:
+ hidden_state: tf.Tensor(bs, seq_length, dim)
+ Sequence of hidden states in the last (top) layer
+ all_hidden_states: Tuple[tf.Tensor(bs, seq_length, dim)]
+ Tuple of length n_layers with the hidden states from each layer.
+ Optional: only if output_hidden_states=True
+ all_attentions: Tuple[tf.Tensor(bs, n_heads, seq_length, seq_length)]
+ Tuple of length n_layers with the attention weights from each layer
+ Optional: only if output_attentions=True
+ """
+ all_hidden_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ hidden_state = x
+ for i, layer_module in enumerate(self.layer):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_state,)
+
+ layer_outputs = layer_module(hidden_state, attn_mask, head_mask[i], output_attentions, training=training)
+ hidden_state = layer_outputs[-1]
+
+ if output_attentions:
+ assert len(layer_outputs) == 2
+ attentions = layer_outputs[0]
+ all_attentions = all_attentions + (attentions,)
+ else:
+ assert len(layer_outputs) == 1, f"Incorrect number of outputs {len(layer_outputs)} instead of 1"
+
+ # Add last layer
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_state,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_state, all_hidden_states, all_attentions] if v is not None)
+ return TFBaseModelOutput(
+ last_hidden_state=hidden_state, hidden_states=all_hidden_states, attentions=all_attentions
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "layer", None) is not None:
+ for layer in self.layer:
+ with tf.name_scope(layer.name):
+ layer.build(None)
+
+
+@keras_serializable
+class TFDistilBertMainLayer(keras.layers.Layer):
+ config_class = DistilBertConfig
+
+ def __init__(self, config, **kwargs):
+ super().__init__(**kwargs)
+
+ self.config = config
+ self.num_hidden_layers = config.num_hidden_layers
+ self.output_attentions = config.output_attentions
+ self.output_hidden_states = config.output_hidden_states
+ self.return_dict = config.use_return_dict
+
+ self.embeddings = TFEmbeddings(config, name="embeddings") # Embeddings
+ self.transformer = TFTransformer(config, name="transformer") # Encoder
+
+ def get_input_embeddings(self):
+ return self.embeddings
+
+ def set_input_embeddings(self, value):
+ self.embeddings.weight = value
+ self.embeddings.vocab_size = value.shape[0]
+
+ def _prune_heads(self, heads_to_prune):
+ raise NotImplementedError
+
+ @unpack_inputs
+ def call(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ head_mask=None,
+ inputs_embeds=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ training=False,
+ ):
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = shape_list(input_ids)
+ elif inputs_embeds is not None:
+ input_shape = shape_list(inputs_embeds)[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ if attention_mask is None:
+ attention_mask = tf.ones(input_shape) # (bs, seq_length)
+
+ attention_mask = tf.cast(attention_mask, dtype=tf.float32)
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ if head_mask is not None:
+ raise NotImplementedError
+ else:
+ head_mask = [None] * self.num_hidden_layers
+
+ embedding_output = self.embeddings(input_ids, inputs_embeds=inputs_embeds) # (bs, seq_length, dim)
+ tfmr_output = self.transformer(
+ embedding_output,
+ attention_mask,
+ head_mask,
+ output_attentions,
+ output_hidden_states,
+ return_dict,
+ training=training,
+ )
+
+ return tfmr_output # last-layer hidden-state, (all hidden_states), (all attentions)
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "embeddings", None) is not None:
+ with tf.name_scope(self.embeddings.name):
+ self.embeddings.build(None)
+ if getattr(self, "transformer", None) is not None:
+ with tf.name_scope(self.transformer.name):
+ self.transformer.build(None)
+
+
+# INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL #
+class TFDistilBertPreTrainedModel(TFPreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = DistilBertConfig
+ base_model_prefix = "distilbert"
+
+
+DISTILBERT_START_DOCSTRING = r"""
+
+ This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
+ as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
+ behavior.
+
+
+
+ TensorFlow models and layers in `transformers` accept two formats as input:
+
+ - having all inputs as keyword arguments (like PyTorch models), or
+ - having all inputs as a list, tuple or dict in the first positional argument.
+
+ The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
+ and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
+ pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
+ format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
+ the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
+ positional argument:
+
+ - a single Tensor with `input_ids` only and nothing else: `model(input_ids)`
+ - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
+ `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
+ - a dictionary with one or several input Tensors associated to the input names given in the docstring:
+ `model({"input_ids": input_ids, "token_type_ids": token_type_ids})`
+
+ Note that when creating models and layers with
+ [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
+ about any of this, as you can just pass inputs like you would to any other Python function!
+
+
+
+ Parameters:
+ config ([`DistilBertConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+DISTILBERT_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`Numpy array` or `tf.Tensor` of shape `({0})`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and
+ [`PreTrainedTokenizer.encode`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ head_mask (`Numpy array` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ inputs_embeds (`tf.Tensor` of shape `({0}, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
+ config will be used instead.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
+ used instead.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
+ eager mode, in graph mode the value will always be set to True.
+ training (`bool`, *optional*, defaults to `False`):
+ Whether or not to use the model in training mode (some modules like dropout modules have different
+ behaviors between training and evaluation).
+"""
+
+
+@add_start_docstrings(
+ "The bare DistilBERT encoder/transformer outputting raw hidden-states without any specific head on top.",
+ DISTILBERT_START_DOCSTRING,
+)
+class TFDistilBertModel(TFDistilBertPreTrainedModel):
+ def __init__(self, config, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+ self.distilbert = TFDistilBertMainLayer(config, name="distilbert") # Embeddings
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TFBaseModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ training: Optional[bool] = False,
+ ) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
+ outputs = self.distilbert(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ return outputs
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "distilbert", None) is not None:
+ with tf.name_scope(self.distilbert.name):
+ self.distilbert.build(None)
+
+
+class TFDistilBertLMHead(keras.layers.Layer):
+ def __init__(self, config, input_embeddings, **kwargs):
+ super().__init__(**kwargs)
+
+ self.config = config
+ self.dim = config.dim
+
+ # The output weights are the same as the input embeddings, but there is
+ # an output-only bias for each token.
+ self.input_embeddings = input_embeddings
+
+ def build(self, input_shape):
+ self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias")
+
+ super().build(input_shape)
+
+ def get_output_embeddings(self):
+ return self.input_embeddings
+
+ def set_output_embeddings(self, value):
+ self.input_embeddings.weight = value
+ self.input_embeddings.vocab_size = shape_list(value)[0]
+
+ def get_bias(self):
+ return {"bias": self.bias}
+
+ def set_bias(self, value):
+ self.bias = value["bias"]
+ self.config.vocab_size = shape_list(value["bias"])[0]
+
+ def call(self, hidden_states):
+ seq_length = shape_list(tensor=hidden_states)[1]
+ hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.dim])
+ hidden_states = tf.matmul(a=hidden_states, b=self.input_embeddings.weight, transpose_b=True)
+ hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.config.vocab_size])
+ hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.bias)
+
+ return hidden_states
+
+
+@add_start_docstrings(
+ """DistilBert Model with a `masked language modeling` head on top.""",
+ DISTILBERT_START_DOCSTRING,
+)
+class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel, TFMaskedLanguageModelingLoss):
+ def __init__(self, config, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+ self.config = config
+
+ self.distilbert = TFDistilBertMainLayer(config, name="distilbert")
+ self.vocab_transform = keras.layers.Dense(
+ config.dim, kernel_initializer=get_initializer(config.initializer_range), name="vocab_transform"
+ )
+ self.act = get_tf_activation(config.activation)
+ self.vocab_layer_norm = keras.layers.LayerNormalization(epsilon=1e-12, name="vocab_layer_norm")
+ self.vocab_projector = TFDistilBertLMHead(config, self.distilbert.embeddings, name="vocab_projector")
+
+ def get_lm_head(self):
+ return self.vocab_projector
+
+ def get_prefix_bias_name(self):
+ warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
+ return self.name + "/" + self.vocab_projector.name
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TFMaskedLMOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ labels: np.ndarray | tf.Tensor | None = None,
+ training: Optional[bool] = False,
+ ) -> Union[TFMaskedLMOutput, Tuple[tf.Tensor]]:
+ r"""
+ labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
+ """
+ distilbert_output = self.distilbert(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ hidden_states = distilbert_output[0] # (bs, seq_length, dim)
+ prediction_logits = self.vocab_transform(hidden_states) # (bs, seq_length, dim)
+ prediction_logits = self.act(prediction_logits) # (bs, seq_length, dim)
+ prediction_logits = self.vocab_layer_norm(prediction_logits) # (bs, seq_length, dim)
+ prediction_logits = self.vocab_projector(prediction_logits)
+
+ loss = None if labels is None else self.hf_compute_loss(labels, prediction_logits)
+
+ if not return_dict:
+ output = (prediction_logits,) + distilbert_output[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TFMaskedLMOutput(
+ loss=loss,
+ logits=prediction_logits,
+ hidden_states=distilbert_output.hidden_states,
+ attentions=distilbert_output.attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "distilbert", None) is not None:
+ with tf.name_scope(self.distilbert.name):
+ self.distilbert.build(None)
+ if getattr(self, "vocab_transform", None) is not None:
+ with tf.name_scope(self.vocab_transform.name):
+ self.vocab_transform.build([None, None, self.config.dim])
+ if getattr(self, "vocab_layer_norm", None) is not None:
+ with tf.name_scope(self.vocab_layer_norm.name):
+ self.vocab_layer_norm.build([None, None, self.config.dim])
+ if getattr(self, "vocab_projector", None) is not None:
+ with tf.name_scope(self.vocab_projector.name):
+ self.vocab_projector.build(None)
+
+
+@add_start_docstrings(
+ """
+ DistilBert Model transformer with a sequence classification/regression head on top (a linear layer on top of the
+ pooled output) e.g. for GLUE tasks.
+ """,
+ DISTILBERT_START_DOCSTRING,
+)
+class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel, TFSequenceClassificationLoss):
+ def __init__(self, config, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+ self.num_labels = config.num_labels
+
+ self.distilbert = TFDistilBertMainLayer(config, name="distilbert")
+ self.pre_classifier = keras.layers.Dense(
+ config.dim,
+ kernel_initializer=get_initializer(config.initializer_range),
+ activation="relu",
+ name="pre_classifier",
+ )
+ self.classifier = keras.layers.Dense(
+ config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
+ )
+ self.dropout = keras.layers.Dropout(config.seq_classif_dropout)
+ self.config = config
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TFSequenceClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ labels: np.ndarray | tf.Tensor | None = None,
+ training: Optional[bool] = False,
+ ) -> Union[TFSequenceClassifierOutput, Tuple[tf.Tensor]]:
+ r"""
+ labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ distilbert_output = self.distilbert(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ hidden_state = distilbert_output[0] # (bs, seq_len, dim)
+ pooled_output = hidden_state[:, 0] # (bs, dim)
+ pooled_output = self.pre_classifier(pooled_output) # (bs, dim)
+ pooled_output = self.dropout(pooled_output, training=training) # (bs, dim)
+ logits = self.classifier(pooled_output) # (bs, dim)
+
+ loss = None if labels is None else self.hf_compute_loss(labels, logits)
+
+ if not return_dict:
+ output = (logits,) + distilbert_output[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TFSequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=distilbert_output.hidden_states,
+ attentions=distilbert_output.attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "distilbert", None) is not None:
+ with tf.name_scope(self.distilbert.name):
+ self.distilbert.build(None)
+ if getattr(self, "pre_classifier", None) is not None:
+ with tf.name_scope(self.pre_classifier.name):
+ self.pre_classifier.build([None, None, self.config.dim])
+ if getattr(self, "classifier", None) is not None:
+ with tf.name_scope(self.classifier.name):
+ self.classifier.build([None, None, self.config.dim])
+
+
+@add_start_docstrings(
+ """
+ DistilBert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g.
+ for Named-Entity-Recognition (NER) tasks.
+ """,
+ DISTILBERT_START_DOCSTRING,
+)
+class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel, TFTokenClassificationLoss):
+ def __init__(self, config, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+ self.num_labels = config.num_labels
+
+ self.distilbert = TFDistilBertMainLayer(config, name="distilbert")
+ self.dropout = keras.layers.Dropout(config.dropout)
+ self.classifier = keras.layers.Dense(
+ config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
+ )
+ self.config = config
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TFTokenClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ labels: np.ndarray | tf.Tensor | None = None,
+ training: Optional[bool] = False,
+ ) -> Union[TFTokenClassifierOutput, Tuple[tf.Tensor]]:
+ r"""
+ labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
+ """
+ outputs = self.distilbert(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ sequence_output = outputs[0]
+ sequence_output = self.dropout(sequence_output, training=training)
+ logits = self.classifier(sequence_output)
+ loss = None if labels is None else self.hf_compute_loss(labels, logits)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TFTokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "distilbert", None) is not None:
+ with tf.name_scope(self.distilbert.name):
+ self.distilbert.build(None)
+ if getattr(self, "classifier", None) is not None:
+ with tf.name_scope(self.classifier.name):
+ self.classifier.build([None, None, self.config.hidden_size])
+
+
+@add_start_docstrings(
+ """
+ DistilBert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and
+ a softmax) e.g. for RocStories/SWAG tasks.
+ """,
+ DISTILBERT_START_DOCSTRING,
+)
+class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoiceLoss):
+ def __init__(self, config, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+
+ self.distilbert = TFDistilBertMainLayer(config, name="distilbert")
+ self.dropout = keras.layers.Dropout(config.seq_classif_dropout)
+ self.pre_classifier = keras.layers.Dense(
+ config.dim,
+ kernel_initializer=get_initializer(config.initializer_range),
+ activation="relu",
+ name="pre_classifier",
+ )
+ self.classifier = keras.layers.Dense(
+ 1, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
+ )
+ self.config = config
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(
+ DISTILBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
+ )
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TFMultipleChoiceModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ labels: np.ndarray | tf.Tensor | None = None,
+ training: Optional[bool] = False,
+ ) -> Union[TFMultipleChoiceModelOutput, Tuple[tf.Tensor]]:
+ r"""
+ labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]`
+ where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above)
+ """
+ if input_ids is not None:
+ num_choices = shape_list(input_ids)[1]
+ seq_length = shape_list(input_ids)[2]
+ else:
+ num_choices = shape_list(inputs_embeds)[1]
+ seq_length = shape_list(inputs_embeds)[2]
+
+ flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
+ flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
+ flat_inputs_embeds = (
+ tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
+ if inputs_embeds is not None
+ else None
+ )
+ distilbert_output = self.distilbert(
+ flat_input_ids,
+ flat_attention_mask,
+ head_mask,
+ flat_inputs_embeds,
+ output_attentions,
+ output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ hidden_state = distilbert_output[0] # (bs, seq_len, dim)
+ pooled_output = hidden_state[:, 0] # (bs, dim)
+ pooled_output = self.pre_classifier(pooled_output) # (bs, dim)
+ pooled_output = self.dropout(pooled_output, training=training) # (bs, dim)
+ logits = self.classifier(pooled_output)
+ reshaped_logits = tf.reshape(logits, (-1, num_choices))
+
+ loss = None if labels is None else self.hf_compute_loss(labels, reshaped_logits)
+
+ if not return_dict:
+ output = (reshaped_logits,) + distilbert_output[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TFMultipleChoiceModelOutput(
+ loss=loss,
+ logits=reshaped_logits,
+ hidden_states=distilbert_output.hidden_states,
+ attentions=distilbert_output.attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "distilbert", None) is not None:
+ with tf.name_scope(self.distilbert.name):
+ self.distilbert.build(None)
+ if getattr(self, "pre_classifier", None) is not None:
+ with tf.name_scope(self.pre_classifier.name):
+ self.pre_classifier.build([None, None, self.config.dim])
+ if getattr(self, "classifier", None) is not None:
+ with tf.name_scope(self.classifier.name):
+ self.classifier.build([None, None, self.config.dim])
+
+
+@add_start_docstrings(
+ """
+ DistilBert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a
+ linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
+ """,
+ DISTILBERT_START_DOCSTRING,
+)
+class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel, TFQuestionAnsweringLoss):
+ def __init__(self, config, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+
+ self.distilbert = TFDistilBertMainLayer(config, name="distilbert")
+ self.qa_outputs = keras.layers.Dense(
+ config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
+ )
+ assert config.num_labels == 2, f"Incorrect number of labels {config.num_labels} instead of 2"
+ self.dropout = keras.layers.Dropout(config.qa_dropout)
+ self.config = config
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TFQuestionAnsweringModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ start_positions: np.ndarray | tf.Tensor | None = None,
+ end_positions: np.ndarray | tf.Tensor | None = None,
+ training: Optional[bool] = False,
+ ) -> Union[TFQuestionAnsweringModelOutput, Tuple[tf.Tensor]]:
+ r"""
+ start_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
+ are not taken into account for computing the loss.
+ end_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
+ are not taken into account for computing the loss.
+ """
+ distilbert_output = self.distilbert(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ hidden_states = distilbert_output[0] # (bs, max_query_len, dim)
+ hidden_states = self.dropout(hidden_states, training=training) # (bs, max_query_len, dim)
+ logits = self.qa_outputs(hidden_states) # (bs, max_query_len, 2)
+ start_logits, end_logits = tf.split(logits, 2, axis=-1)
+ start_logits = tf.squeeze(start_logits, axis=-1)
+ end_logits = tf.squeeze(end_logits, axis=-1)
+
+ loss = None
+ if start_positions is not None and end_positions is not None:
+ labels = {"start_position": start_positions}
+ labels["end_position"] = end_positions
+ loss = self.hf_compute_loss(labels, (start_logits, end_logits))
+
+ if not return_dict:
+ output = (start_logits, end_logits) + distilbert_output[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TFQuestionAnsweringModelOutput(
+ loss=loss,
+ start_logits=start_logits,
+ end_logits=end_logits,
+ hidden_states=distilbert_output.hidden_states,
+ attentions=distilbert_output.attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "distilbert", None) is not None:
+ with tf.name_scope(self.distilbert.name):
+ self.distilbert.build(None)
+ if getattr(self, "qa_outputs", None) is not None:
+ with tf.name_scope(self.qa_outputs.name):
+ self.qa_outputs.build([None, None, self.config.dim])
+
+
+__all__ = [
+ "TFDistilBertForMaskedLM",
+ "TFDistilBertForMultipleChoice",
+ "TFDistilBertForQuestionAnswering",
+ "TFDistilBertForSequenceClassification",
+ "TFDistilBertForTokenClassification",
+ "TFDistilBertMainLayer",
+ "TFDistilBertModel",
+ "TFDistilBertPreTrainedModel",
+]
diff --git a/docs/transformers/build/lib/transformers/models/distilbert/tokenization_distilbert.py b/docs/transformers/build/lib/transformers/models/distilbert/tokenization_distilbert.py
new file mode 100644
index 0000000000000000000000000000000000000000..c894211a2e0acf2bd82858186e88f7e3e99f672e
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/distilbert/tokenization_distilbert.py
@@ -0,0 +1,522 @@
+# coding=utf-8
+# Copyright 2018 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization classes for DistilBERT."""
+
+import collections
+import os
+import unicodedata
+from typing import List, Optional, Tuple
+
+from ...tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
+
+
+# Copied from transformers.models.bert.tokenization_bert.load_vocab
+def load_vocab(vocab_file):
+ """Loads a vocabulary file into a dictionary."""
+ vocab = collections.OrderedDict()
+ with open(vocab_file, "r", encoding="utf-8") as reader:
+ tokens = reader.readlines()
+ for index, token in enumerate(tokens):
+ token = token.rstrip("\n")
+ vocab[token] = index
+ return vocab
+
+
+# Copied from transformers.models.bert.tokenization_bert.whitespace_tokenize
+def whitespace_tokenize(text):
+ """Runs basic whitespace cleaning and splitting on a piece of text."""
+ text = text.strip()
+ if not text:
+ return []
+ tokens = text.split()
+ return tokens
+
+
+class DistilBertTokenizer(PreTrainedTokenizer):
+ r"""
+ Construct a DistilBERT tokenizer. Based on WordPiece.
+
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
+ this superclass for more information regarding those methods.
+
+ Args:
+ vocab_file (`str`):
+ File containing the vocabulary.
+ do_lower_case (`bool`, *optional*, defaults to `True`):
+ Whether or not to lowercase the input when tokenizing.
+ do_basic_tokenize (`bool`, *optional*, defaults to `True`):
+ Whether or not to do basic tokenization before WordPiece.
+ never_split (`Iterable`, *optional*):
+ Collection of tokens which will never be split during tokenization. Only has an effect when
+ `do_basic_tokenize=True`
+ unk_token (`str`, *optional*, defaults to `"[UNK]"`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ sep_token (`str`, *optional*, defaults to `"[SEP]"`):
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
+ sequence classification or for a text and a question for question answering. It is also used as the last
+ token of a sequence built with special tokens.
+ pad_token (`str`, *optional*, defaults to `"[PAD]"`):
+ The token used for padding, for example when batching sequences of different lengths.
+ cls_token (`str`, *optional*, defaults to `"[CLS]"`):
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
+ mask_token (`str`, *optional*, defaults to `"[MASK]"`):
+ The token used for masking values. This is the token used when training this model with masked language
+ modeling. This is the token which the model will try to predict.
+ tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
+ Whether or not to tokenize Chinese characters.
+
+ This should likely be deactivated for Japanese (see this
+ [issue](https://github.com/huggingface/transformers/issues/328)).
+ strip_accents (`bool`, *optional*):
+ Whether or not to strip all accents. If this option is not specified, then it will be determined by the
+ value for `lowercase` (as in the original BERT).
+ clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`):
+ Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like
+ extra spaces.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ model_input_names = ["input_ids", "attention_mask"]
+
+ def __init__(
+ self,
+ vocab_file,
+ do_lower_case=True,
+ do_basic_tokenize=True,
+ never_split=None,
+ unk_token="[UNK]",
+ sep_token="[SEP]",
+ pad_token="[PAD]",
+ cls_token="[CLS]",
+ mask_token="[MASK]",
+ tokenize_chinese_chars=True,
+ strip_accents=None,
+ clean_up_tokenization_spaces=True,
+ **kwargs,
+ ):
+ if not os.path.isfile(vocab_file):
+ raise ValueError(
+ f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained"
+ " model use `tokenizer = DistilBertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
+ )
+ self.vocab = load_vocab(vocab_file)
+ self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
+ self.do_basic_tokenize = do_basic_tokenize
+ if do_basic_tokenize:
+ self.basic_tokenizer = BasicTokenizer(
+ do_lower_case=do_lower_case,
+ never_split=never_split,
+ tokenize_chinese_chars=tokenize_chinese_chars,
+ strip_accents=strip_accents,
+ )
+ self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=str(unk_token))
+
+ super().__init__(
+ do_lower_case=do_lower_case,
+ do_basic_tokenize=do_basic_tokenize,
+ never_split=never_split,
+ unk_token=unk_token,
+ sep_token=sep_token,
+ pad_token=pad_token,
+ cls_token=cls_token,
+ mask_token=mask_token,
+ tokenize_chinese_chars=tokenize_chinese_chars,
+ strip_accents=strip_accents,
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
+ **kwargs,
+ )
+
+ @property
+ # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.do_lower_case
+ def do_lower_case(self):
+ return self.basic_tokenizer.do_lower_case
+
+ @property
+ # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.vocab_size
+ def vocab_size(self):
+ return len(self.vocab)
+
+ # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.get_vocab
+ def get_vocab(self):
+ return dict(self.vocab, **self.added_tokens_encoder)
+
+ # Copied from transformers.models.bert.tokenization_bert.BertTokenizer._tokenize
+ def _tokenize(self, text, split_special_tokens=False):
+ split_tokens = []
+ if self.do_basic_tokenize:
+ for token in self.basic_tokenizer.tokenize(
+ text, never_split=self.all_special_tokens if not split_special_tokens else None
+ ):
+ # If the token is part of the never_split set
+ if token in self.basic_tokenizer.never_split:
+ split_tokens.append(token)
+ else:
+ split_tokens += self.wordpiece_tokenizer.tokenize(token)
+ else:
+ split_tokens = self.wordpiece_tokenizer.tokenize(text)
+ return split_tokens
+
+ # Copied from transformers.models.bert.tokenization_bert.BertTokenizer._convert_token_to_id
+ def _convert_token_to_id(self, token):
+ """Converts a token (str) in an id using the vocab."""
+ return self.vocab.get(token, self.vocab.get(self.unk_token))
+
+ # Copied from transformers.models.bert.tokenization_bert.BertTokenizer._convert_id_to_token
+ def _convert_id_to_token(self, index):
+ """Converts an index (integer) in a token (str) using the vocab."""
+ return self.ids_to_tokens.get(index, self.unk_token)
+
+ # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.convert_tokens_to_string
+ def convert_tokens_to_string(self, tokens):
+ """Converts a sequence of tokens (string) in a single string."""
+ out_string = " ".join(tokens).replace(" ##", "").strip()
+ return out_string
+
+ # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.build_inputs_with_special_tokens
+ def build_inputs_with_special_tokens(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
+ adding special tokens. A BERT sequence has the following format:
+
+ - single sequence: `[CLS] X [SEP]`
+ - pair of sequences: `[CLS] A [SEP] B [SEP]`
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs to which the special tokens will be added.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
+ """
+ if token_ids_1 is None:
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
+ cls = [self.cls_token_id]
+ sep = [self.sep_token_id]
+ return cls + token_ids_0 + sep + token_ids_1 + sep
+
+ # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.get_special_tokens_mask
+ def get_special_tokens_mask(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
+ ) -> List[int]:
+ """
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
+ special tokens using the tokenizer `prepare_for_model` method.
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
+ Whether or not the token list is already formatted with special tokens for the model.
+
+ Returns:
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
+ """
+
+ if already_has_special_tokens:
+ return super().get_special_tokens_mask(
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
+ )
+
+ if token_ids_1 is not None:
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
+ return [1] + ([0] * len(token_ids_0)) + [1]
+
+ # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.create_token_type_ids_from_sequences
+ def create_token_type_ids_from_sequences(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence
+ pair mask has the following format:
+
+ ```
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
+ | first sequence | second sequence |
+ ```
+
+ If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s).
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
+ """
+ sep = [self.sep_token_id]
+ cls = [self.cls_token_id]
+ if token_ids_1 is None:
+ return len(cls + token_ids_0 + sep) * [0]
+ return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
+
+ # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.save_vocabulary
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
+ index = 0
+ if os.path.isdir(save_directory):
+ vocab_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
+ )
+ else:
+ vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory
+ with open(vocab_file, "w", encoding="utf-8") as writer:
+ for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
+ if index != token_index:
+ logger.warning(
+ f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
+ " Please check that the vocabulary is not corrupted!"
+ )
+ index = token_index
+ writer.write(token + "\n")
+ index += 1
+ return (vocab_file,)
+
+
+# Copied from transformers.models.bert.tokenization_bert.BasicTokenizer
+class BasicTokenizer:
+ """
+ Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.).
+
+ Args:
+ do_lower_case (`bool`, *optional*, defaults to `True`):
+ Whether or not to lowercase the input when tokenizing.
+ never_split (`Iterable`, *optional*):
+ Collection of tokens which will never be split during tokenization. Only has an effect when
+ `do_basic_tokenize=True`
+ tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
+ Whether or not to tokenize Chinese characters.
+
+ This should likely be deactivated for Japanese (see this
+ [issue](https://github.com/huggingface/transformers/issues/328)).
+ strip_accents (`bool`, *optional*):
+ Whether or not to strip all accents. If this option is not specified, then it will be determined by the
+ value for `lowercase` (as in the original BERT).
+ do_split_on_punc (`bool`, *optional*, defaults to `True`):
+ In some instances we want to skip the basic punctuation splitting so that later tokenization can capture
+ the full context of the words, such as contractions.
+ """
+
+ def __init__(
+ self,
+ do_lower_case=True,
+ never_split=None,
+ tokenize_chinese_chars=True,
+ strip_accents=None,
+ do_split_on_punc=True,
+ ):
+ if never_split is None:
+ never_split = []
+ self.do_lower_case = do_lower_case
+ self.never_split = set(never_split)
+ self.tokenize_chinese_chars = tokenize_chinese_chars
+ self.strip_accents = strip_accents
+ self.do_split_on_punc = do_split_on_punc
+
+ def tokenize(self, text, never_split=None):
+ """
+ Basic Tokenization of a piece of text. For sub-word tokenization, see WordPieceTokenizer.
+
+ Args:
+ never_split (`List[str]`, *optional*)
+ Kept for backward compatibility purposes. Now implemented directly at the base class level (see
+ [`PreTrainedTokenizer.tokenize`]) List of token not to split.
+ """
+ # union() returns a new set by concatenating the two sets.
+ never_split = self.never_split.union(set(never_split)) if never_split else self.never_split
+ text = self._clean_text(text)
+
+ # This was added on November 1st, 2018 for the multilingual and Chinese
+ # models. This is also applied to the English models now, but it doesn't
+ # matter since the English models were not trained on any Chinese data
+ # and generally don't have any Chinese data in them (there are Chinese
+ # characters in the vocabulary because Wikipedia does have some Chinese
+ # words in the English Wikipedia.).
+ if self.tokenize_chinese_chars:
+ text = self._tokenize_chinese_chars(text)
+ # prevents treating the same character with different unicode codepoints as different characters
+ unicode_normalized_text = unicodedata.normalize("NFC", text)
+ orig_tokens = whitespace_tokenize(unicode_normalized_text)
+ split_tokens = []
+ for token in orig_tokens:
+ if token not in never_split:
+ if self.do_lower_case:
+ token = token.lower()
+ if self.strip_accents is not False:
+ token = self._run_strip_accents(token)
+ elif self.strip_accents:
+ token = self._run_strip_accents(token)
+ split_tokens.extend(self._run_split_on_punc(token, never_split))
+
+ output_tokens = whitespace_tokenize(" ".join(split_tokens))
+ return output_tokens
+
+ def _run_strip_accents(self, text):
+ """Strips accents from a piece of text."""
+ text = unicodedata.normalize("NFD", text)
+ output = []
+ for char in text:
+ cat = unicodedata.category(char)
+ if cat == "Mn":
+ continue
+ output.append(char)
+ return "".join(output)
+
+ def _run_split_on_punc(self, text, never_split=None):
+ """Splits punctuation on a piece of text."""
+ if not self.do_split_on_punc or (never_split is not None and text in never_split):
+ return [text]
+ chars = list(text)
+ i = 0
+ start_new_word = True
+ output = []
+ while i < len(chars):
+ char = chars[i]
+ if _is_punctuation(char):
+ output.append([char])
+ start_new_word = True
+ else:
+ if start_new_word:
+ output.append([])
+ start_new_word = False
+ output[-1].append(char)
+ i += 1
+
+ return ["".join(x) for x in output]
+
+ def _tokenize_chinese_chars(self, text):
+ """Adds whitespace around any CJK character."""
+ output = []
+ for char in text:
+ cp = ord(char)
+ if self._is_chinese_char(cp):
+ output.append(" ")
+ output.append(char)
+ output.append(" ")
+ else:
+ output.append(char)
+ return "".join(output)
+
+ def _is_chinese_char(self, cp):
+ """Checks whether CP is the codepoint of a CJK character."""
+ # This defines a "chinese character" as anything in the CJK Unicode block:
+ # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
+ #
+ # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
+ # despite its name. The modern Korean Hangul alphabet is a different block,
+ # as is Japanese Hiragana and Katakana. Those alphabets are used to write
+ # space-separated words, so they are not treated specially and handled
+ # like the all of the other languages.
+ if (
+ (cp >= 0x4E00 and cp <= 0x9FFF)
+ or (cp >= 0x3400 and cp <= 0x4DBF) #
+ or (cp >= 0x20000 and cp <= 0x2A6DF) #
+ or (cp >= 0x2A700 and cp <= 0x2B73F) #
+ or (cp >= 0x2B740 and cp <= 0x2B81F) #
+ or (cp >= 0x2B820 and cp <= 0x2CEAF) #
+ or (cp >= 0xF900 and cp <= 0xFAFF)
+ or (cp >= 0x2F800 and cp <= 0x2FA1F) #
+ ): #
+ return True
+
+ return False
+
+ def _clean_text(self, text):
+ """Performs invalid character removal and whitespace cleanup on text."""
+ output = []
+ for char in text:
+ cp = ord(char)
+ if cp == 0 or cp == 0xFFFD or _is_control(char):
+ continue
+ if _is_whitespace(char):
+ output.append(" ")
+ else:
+ output.append(char)
+ return "".join(output)
+
+
+# Copied from transformers.models.bert.tokenization_bert.WordpieceTokenizer
+class WordpieceTokenizer:
+ """Runs WordPiece tokenization."""
+
+ def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
+ self.vocab = vocab
+ self.unk_token = unk_token
+ self.max_input_chars_per_word = max_input_chars_per_word
+
+ def tokenize(self, text):
+ """
+ Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform
+ tokenization using the given vocabulary.
+
+ For example, `input = "unaffable"` wil return as output `["un", "##aff", "##able"]`.
+
+ Args:
+ text: A single token or whitespace separated tokens. This should have
+ already been passed through *BasicTokenizer*.
+
+ Returns:
+ A list of wordpiece tokens.
+ """
+
+ output_tokens = []
+ for token in whitespace_tokenize(text):
+ chars = list(token)
+ if len(chars) > self.max_input_chars_per_word:
+ output_tokens.append(self.unk_token)
+ continue
+
+ is_bad = False
+ start = 0
+ sub_tokens = []
+ while start < len(chars):
+ end = len(chars)
+ cur_substr = None
+ while start < end:
+ substr = "".join(chars[start:end])
+ if start > 0:
+ substr = "##" + substr
+ if substr in self.vocab:
+ cur_substr = substr
+ break
+ end -= 1
+ if cur_substr is None:
+ is_bad = True
+ break
+ sub_tokens.append(cur_substr)
+ start = end
+
+ if is_bad:
+ output_tokens.append(self.unk_token)
+ else:
+ output_tokens.extend(sub_tokens)
+ return output_tokens
+
+
+__all__ = ["DistilBertTokenizer"]
diff --git a/docs/transformers/build/lib/transformers/models/distilbert/tokenization_distilbert_fast.py b/docs/transformers/build/lib/transformers/models/distilbert/tokenization_distilbert_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3829763d5e7ab8e2a338c53a0f7dd50c3e4b737
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/distilbert/tokenization_distilbert_fast.py
@@ -0,0 +1,179 @@
+# coding=utf-8
+# Copyright 2018 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization classes for DistilBERT."""
+
+import json
+from typing import List, Optional, Tuple
+
+from tokenizers import normalizers
+
+from ...tokenization_utils_fast import PreTrainedTokenizerFast
+from ...utils import logging
+from .tokenization_distilbert import DistilBertTokenizer
+
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.json"}
+
+
+class DistilBertTokenizerFast(PreTrainedTokenizerFast):
+ r"""
+ Construct a "fast" DistilBERT tokenizer (backed by HuggingFace's *tokenizers* library). Based on WordPiece.
+
+ This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
+ refer to this superclass for more information regarding those methods.
+
+ Args:
+ vocab_file (`str`):
+ File containing the vocabulary.
+ do_lower_case (`bool`, *optional*, defaults to `True`):
+ Whether or not to lowercase the input when tokenizing.
+ unk_token (`str`, *optional*, defaults to `"[UNK]"`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ sep_token (`str`, *optional*, defaults to `"[SEP]"`):
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
+ sequence classification or for a text and a question for question answering. It is also used as the last
+ token of a sequence built with special tokens.
+ pad_token (`str`, *optional*, defaults to `"[PAD]"`):
+ The token used for padding, for example when batching sequences of different lengths.
+ cls_token (`str`, *optional*, defaults to `"[CLS]"`):
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
+ mask_token (`str`, *optional*, defaults to `"[MASK]"`):
+ The token used for masking values. This is the token used when training this model with masked language
+ modeling. This is the token which the model will try to predict.
+ clean_text (`bool`, *optional*, defaults to `True`):
+ Whether or not to clean the text before tokenization by removing any control characters and replacing all
+ whitespaces by the classic one.
+ tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
+ Whether or not to tokenize Chinese characters. This should likely be deactivated for Japanese (see [this
+ issue](https://github.com/huggingface/transformers/issues/328)).
+ strip_accents (`bool`, *optional*):
+ Whether or not to strip all accents. If this option is not specified, then it will be determined by the
+ value for `lowercase` (as in the original BERT).
+ wordpieces_prefix (`str`, *optional*, defaults to `"##"`):
+ The prefix for subwords.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ model_input_names = ["input_ids", "attention_mask"]
+ slow_tokenizer_class = DistilBertTokenizer
+
+ def __init__(
+ self,
+ vocab_file=None,
+ tokenizer_file=None,
+ do_lower_case=True,
+ unk_token="[UNK]",
+ sep_token="[SEP]",
+ pad_token="[PAD]",
+ cls_token="[CLS]",
+ mask_token="[MASK]",
+ tokenize_chinese_chars=True,
+ strip_accents=None,
+ **kwargs,
+ ):
+ super().__init__(
+ vocab_file,
+ tokenizer_file=tokenizer_file,
+ do_lower_case=do_lower_case,
+ unk_token=unk_token,
+ sep_token=sep_token,
+ pad_token=pad_token,
+ cls_token=cls_token,
+ mask_token=mask_token,
+ tokenize_chinese_chars=tokenize_chinese_chars,
+ strip_accents=strip_accents,
+ **kwargs,
+ )
+
+ normalizer_state = json.loads(self.backend_tokenizer.normalizer.__getstate__())
+ if (
+ normalizer_state.get("lowercase", do_lower_case) != do_lower_case
+ or normalizer_state.get("strip_accents", strip_accents) != strip_accents
+ or normalizer_state.get("handle_chinese_chars", tokenize_chinese_chars) != tokenize_chinese_chars
+ ):
+ normalizer_class = getattr(normalizers, normalizer_state.pop("type"))
+ normalizer_state["lowercase"] = do_lower_case
+ normalizer_state["strip_accents"] = strip_accents
+ normalizer_state["handle_chinese_chars"] = tokenize_chinese_chars
+ self.backend_tokenizer.normalizer = normalizer_class(**normalizer_state)
+
+ self.do_lower_case = do_lower_case
+
+ # Copied from transformers.models.bert.tokenization_bert_fast.BertTokenizerFast.build_inputs_with_special_tokens
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
+ """
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
+ adding special tokens. A BERT sequence has the following format:
+
+ - single sequence: `[CLS] X [SEP]`
+ - pair of sequences: `[CLS] A [SEP] B [SEP]`
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs to which the special tokens will be added.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
+ """
+ output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
+
+ if token_ids_1 is not None:
+ output += token_ids_1 + [self.sep_token_id]
+
+ return output
+
+ # Copied from transformers.models.bert.tokenization_bert_fast.BertTokenizerFast.create_token_type_ids_from_sequences
+ def create_token_type_ids_from_sequences(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence
+ pair mask has the following format:
+
+ ```
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
+ | first sequence | second sequence |
+ ```
+
+ If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s).
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
+ """
+ sep = [self.sep_token_id]
+ cls = [self.cls_token_id]
+ if token_ids_1 is None:
+ return len(cls + token_ids_0 + sep) * [0]
+ return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
+
+ # Copied from transformers.models.bert.tokenization_bert_fast.BertTokenizerFast.save_vocabulary
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
+ files = self._tokenizer.model.save(save_directory, name=filename_prefix)
+ return tuple(files)
+
+
+__all__ = ["DistilBertTokenizerFast"]
diff --git a/docs/transformers/build/lib/transformers/models/dit/__init__.py b/docs/transformers/build/lib/transformers/models/dit/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/docs/transformers/build/lib/transformers/models/dit/convert_dit_unilm_to_pytorch.py b/docs/transformers/build/lib/transformers/models/dit/convert_dit_unilm_to_pytorch.py
new file mode 100644
index 0000000000000000000000000000000000000000..40c5b22e3b9a2dd2037660902febd8069ca41a7d
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/dit/convert_dit_unilm_to_pytorch.py
@@ -0,0 +1,230 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Convert DiT checkpoints from the unilm repository."""
+
+import argparse
+import json
+from pathlib import Path
+
+import requests
+import torch
+from huggingface_hub import hf_hub_download
+from PIL import Image
+
+from transformers import BeitConfig, BeitForImageClassification, BeitForMaskedImageModeling, BeitImageProcessor
+from transformers.image_utils import PILImageResampling
+from transformers.utils import logging
+
+
+logging.set_verbosity_info()
+logger = logging.get_logger(__name__)
+
+
+# here we list all keys to be renamed (original name on the left, our name on the right)
+def create_rename_keys(config, has_lm_head=False, is_semantic=False):
+ prefix = "backbone." if is_semantic else ""
+
+ rename_keys = []
+ for i in range(config.num_hidden_layers):
+ # encoder layers: output projection, 2 feedforward neural networks and 2 layernorms
+ rename_keys.append((f"{prefix}blocks.{i}.norm1.weight", f"beit.encoder.layer.{i}.layernorm_before.weight"))
+ rename_keys.append((f"{prefix}blocks.{i}.norm1.bias", f"beit.encoder.layer.{i}.layernorm_before.bias"))
+ rename_keys.append(
+ (f"{prefix}blocks.{i}.attn.proj.weight", f"beit.encoder.layer.{i}.attention.output.dense.weight")
+ )
+ rename_keys.append(
+ (f"{prefix}blocks.{i}.attn.proj.bias", f"beit.encoder.layer.{i}.attention.output.dense.bias")
+ )
+ rename_keys.append((f"{prefix}blocks.{i}.norm2.weight", f"beit.encoder.layer.{i}.layernorm_after.weight"))
+ rename_keys.append((f"{prefix}blocks.{i}.norm2.bias", f"beit.encoder.layer.{i}.layernorm_after.bias"))
+ rename_keys.append((f"{prefix}blocks.{i}.mlp.fc1.weight", f"beit.encoder.layer.{i}.intermediate.dense.weight"))
+ rename_keys.append((f"{prefix}blocks.{i}.mlp.fc1.bias", f"beit.encoder.layer.{i}.intermediate.dense.bias"))
+ rename_keys.append((f"{prefix}blocks.{i}.mlp.fc2.weight", f"beit.encoder.layer.{i}.output.dense.weight"))
+ rename_keys.append((f"{prefix}blocks.{i}.mlp.fc2.bias", f"beit.encoder.layer.{i}.output.dense.bias"))
+
+ # projection layer + position embeddings
+ rename_keys.extend(
+ [
+ (f"{prefix}cls_token", "beit.embeddings.cls_token"),
+ (f"{prefix}patch_embed.proj.weight", "beit.embeddings.patch_embeddings.projection.weight"),
+ (f"{prefix}patch_embed.proj.bias", "beit.embeddings.patch_embeddings.projection.bias"),
+ (f"{prefix}pos_embed", "beit.embeddings.position_embeddings"),
+ ]
+ )
+
+ if has_lm_head:
+ # mask token + layernorm
+ rename_keys.extend(
+ [
+ ("mask_token", "beit.embeddings.mask_token"),
+ ("norm.weight", "layernorm.weight"),
+ ("norm.bias", "layernorm.bias"),
+ ]
+ )
+ else:
+ # layernorm + classification head
+ rename_keys.extend(
+ [
+ ("fc_norm.weight", "beit.pooler.layernorm.weight"),
+ ("fc_norm.bias", "beit.pooler.layernorm.bias"),
+ ("head.weight", "classifier.weight"),
+ ("head.bias", "classifier.bias"),
+ ]
+ )
+
+ return rename_keys
+
+
+# we split up the matrix of each encoder layer into queries, keys and values
+def read_in_q_k_v(state_dict, config, has_lm_head=False, is_semantic=False):
+ for i in range(config.num_hidden_layers):
+ prefix = "backbone." if is_semantic else ""
+ # queries, keys and values
+ in_proj_weight = state_dict.pop(f"{prefix}blocks.{i}.attn.qkv.weight")
+ q_bias = state_dict.pop(f"{prefix}blocks.{i}.attn.q_bias")
+ v_bias = state_dict.pop(f"{prefix}blocks.{i}.attn.v_bias")
+
+ state_dict[f"beit.encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[
+ : config.hidden_size, :
+ ]
+ state_dict[f"beit.encoder.layer.{i}.attention.attention.query.bias"] = q_bias
+ state_dict[f"beit.encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[
+ config.hidden_size : config.hidden_size * 2, :
+ ]
+ state_dict[f"beit.encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[
+ -config.hidden_size :, :
+ ]
+ state_dict[f"beit.encoder.layer.{i}.attention.attention.value.bias"] = v_bias
+
+ # gamma_1 and gamma_2
+ # we call them lambda because otherwise they are renamed when using .from_pretrained
+ gamma_1 = state_dict.pop(f"{prefix}blocks.{i}.gamma_1")
+ gamma_2 = state_dict.pop(f"{prefix}blocks.{i}.gamma_2")
+
+ state_dict[f"beit.encoder.layer.{i}.lambda_1"] = gamma_1
+ state_dict[f"beit.encoder.layer.{i}.lambda_2"] = gamma_2
+
+
+def rename_key(dct, old, new):
+ val = dct.pop(old)
+ dct[new] = val
+
+
+# We will verify our results on an image of cute cats
+def prepare_img():
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ im = Image.open(requests.get(url, stream=True).raw)
+ return im
+
+
+@torch.no_grad()
+def convert_dit_checkpoint(checkpoint_url, pytorch_dump_folder_path, push_to_hub=False):
+ """
+ Copy/paste/tweak model's weights to our BEiT structure.
+ """
+
+ # define default BEiT configuration
+ has_lm_head = False if "rvlcdip" in checkpoint_url else True
+ config = BeitConfig(use_absolute_position_embeddings=True, use_mask_token=has_lm_head)
+
+ # size of the architecture
+ if "large" in checkpoint_url or "dit-l" in checkpoint_url:
+ config.hidden_size = 1024
+ config.intermediate_size = 4096
+ config.num_hidden_layers = 24
+ config.num_attention_heads = 16
+
+ # labels
+ if "rvlcdip" in checkpoint_url:
+ config.num_labels = 16
+ repo_id = "huggingface/label-files"
+ filename = "rvlcdip-id2label.json"
+ id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
+ id2label = {int(k): v for k, v in id2label.items()}
+ config.id2label = id2label
+ config.label2id = {v: k for k, v in id2label.items()}
+
+ # load state_dict of original model, remove and rename some keys
+ state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu")["model"]
+
+ rename_keys = create_rename_keys(config, has_lm_head=has_lm_head)
+ for src, dest in rename_keys:
+ rename_key(state_dict, src, dest)
+ read_in_q_k_v(state_dict, config, has_lm_head=has_lm_head)
+
+ # load HuggingFace model
+ model = BeitForMaskedImageModeling(config) if has_lm_head else BeitForImageClassification(config)
+ model.eval()
+ model.load_state_dict(state_dict)
+
+ # Check outputs on an image
+ image_processor = BeitImageProcessor(
+ size=config.image_size, resample=PILImageResampling.BILINEAR, do_center_crop=False
+ )
+ image = prepare_img()
+
+ encoding = image_processor(images=image, return_tensors="pt")
+ pixel_values = encoding["pixel_values"]
+
+ outputs = model(pixel_values)
+ logits = outputs.logits
+
+ # verify logits
+ expected_shape = [1, 16] if "rvlcdip" in checkpoint_url else [1, 196, 8192]
+ assert logits.shape == torch.Size(expected_shape), "Shape of logits not as expected"
+
+ Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
+ print(f"Saving model to {pytorch_dump_folder_path}")
+ model.save_pretrained(pytorch_dump_folder_path)
+ print(f"Saving image processor to {pytorch_dump_folder_path}")
+ image_processor.save_pretrained(pytorch_dump_folder_path)
+
+ if push_to_hub:
+ if has_lm_head:
+ model_name = "dit-base" if "base" in checkpoint_url else "dit-large"
+ else:
+ model_name = "dit-base-finetuned-rvlcdip" if "dit-b" in checkpoint_url else "dit-large-finetuned-rvlcdip"
+ image_processor.push_to_hub(
+ repo_path_or_name=Path(pytorch_dump_folder_path, model_name),
+ organization="nielsr",
+ commit_message="Add image processor",
+ use_temp_dir=True,
+ )
+ model.push_to_hub(
+ repo_path_or_name=Path(pytorch_dump_folder_path, model_name),
+ organization="nielsr",
+ commit_message="Add model",
+ use_temp_dir=True,
+ )
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--checkpoint_url",
+ default="https://layoutlm.blob.core.windows.net/dit/dit-pts/dit-base-224-p16-500k-62d53a.pth",
+ type=str,
+ help="URL to the original PyTorch checkpoint (.pth file).",
+ )
+ parser.add_argument(
+ "--pytorch_dump_folder_path", default=None, type=str, help="Path to the folder to output PyTorch model."
+ )
+ parser.add_argument(
+ "--push_to_hub",
+ action="store_true",
+ )
+ args = parser.parse_args()
+ convert_dit_checkpoint(args.checkpoint_url, args.pytorch_dump_folder_path, args.push_to_hub)
diff --git a/docs/transformers/build/lib/transformers/models/donut/__init__.py b/docs/transformers/build/lib/transformers/models/donut/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..834c451f78fa0d4c5fe91f59719b6505c4c4e4e5
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/donut/__init__.py
@@ -0,0 +1,31 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_donut_swin import *
+ from .feature_extraction_donut import *
+ from .image_processing_donut import *
+ from .image_processing_donut_fast import *
+ from .modeling_donut_swin import *
+ from .processing_donut import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/docs/transformers/build/lib/transformers/models/donut/configuration_donut_swin.py b/docs/transformers/build/lib/transformers/models/donut/configuration_donut_swin.py
new file mode 100644
index 0000000000000000000000000000000000000000..9aac07dace7688273be0bdc57da0a12663c2fb5b
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/donut/configuration_donut_swin.py
@@ -0,0 +1,135 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Donut Swin Transformer model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class DonutSwinConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`DonutSwinModel`]. It is used to instantiate a
+ Donut model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the Donut
+ [naver-clova-ix/donut-base](https://huggingface.co/naver-clova-ix/donut-base) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ image_size (`int`, *optional*, defaults to 224):
+ The size (resolution) of each image.
+ patch_size (`int`, *optional*, defaults to 4):
+ The size (resolution) of each patch.
+ num_channels (`int`, *optional*, defaults to 3):
+ The number of input channels.
+ embed_dim (`int`, *optional*, defaults to 96):
+ Dimensionality of patch embedding.
+ depths (`list(int)`, *optional*, defaults to `[2, 2, 6, 2]`):
+ Depth of each layer in the Transformer encoder.
+ num_heads (`list(int)`, *optional*, defaults to `[3, 6, 12, 24]`):
+ Number of attention heads in each layer of the Transformer encoder.
+ window_size (`int`, *optional*, defaults to 7):
+ Size of windows.
+ mlp_ratio (`float`, *optional*, defaults to 4.0):
+ Ratio of MLP hidden dimensionality to embedding dimensionality.
+ qkv_bias (`bool`, *optional*, defaults to `True`):
+ Whether or not a learnable bias should be added to the queries, keys and values.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
+ The dropout probability for all fully connected layers in the embeddings and encoder.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ drop_path_rate (`float`, *optional*, defaults to 0.1):
+ Stochastic depth rate.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder. If string, `"gelu"`, `"relu"`,
+ `"selu"` and `"gelu_new"` are supported.
+ use_absolute_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether or not to add absolute position embeddings to the patch embeddings.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the layer normalization layers.
+
+ Example:
+
+ ```python
+ >>> from transformers import DonutSwinConfig, DonutSwinModel
+
+ >>> # Initializing a Donut naver-clova-ix/donut-base style configuration
+ >>> configuration = DonutSwinConfig()
+
+ >>> # Randomly initializing a model from the naver-clova-ix/donut-base style configuration
+ >>> model = DonutSwinModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "donut-swin"
+
+ attribute_map = {
+ "num_attention_heads": "num_heads",
+ "num_hidden_layers": "num_layers",
+ }
+
+ def __init__(
+ self,
+ image_size=224,
+ patch_size=4,
+ num_channels=3,
+ embed_dim=96,
+ depths=[2, 2, 6, 2],
+ num_heads=[3, 6, 12, 24],
+ window_size=7,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ hidden_dropout_prob=0.0,
+ attention_probs_dropout_prob=0.0,
+ drop_path_rate=0.1,
+ hidden_act="gelu",
+ use_absolute_embeddings=False,
+ initializer_range=0.02,
+ layer_norm_eps=1e-5,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.embed_dim = embed_dim
+ self.depths = depths
+ self.num_layers = len(depths)
+ self.num_heads = num_heads
+ self.window_size = window_size
+ self.mlp_ratio = mlp_ratio
+ self.qkv_bias = qkv_bias
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.drop_path_rate = drop_path_rate
+ self.hidden_act = hidden_act
+ self.use_absolute_embeddings = use_absolute_embeddings
+ self.layer_norm_eps = layer_norm_eps
+ self.initializer_range = initializer_range
+ # we set the hidden_size attribute in order to make Swin work with VisionEncoderDecoderModel
+ # this indicates the channel dimension after the last stage of the model
+ self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
+
+
+__all__ = ["DonutSwinConfig"]
diff --git a/docs/transformers/build/lib/transformers/models/donut/convert_donut_to_pytorch.py b/docs/transformers/build/lib/transformers/models/donut/convert_donut_to_pytorch.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6f14f6d08e31037389f448815242b388545fd15
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/donut/convert_donut_to_pytorch.py
@@ -0,0 +1,234 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Convert Donut checkpoints using the original `donut-python` library. URL: https://github.com/clovaai/donut"""
+
+import argparse
+
+import torch
+from datasets import load_dataset
+from donut import DonutModel
+
+from transformers import (
+ DonutImageProcessor,
+ DonutProcessor,
+ DonutSwinConfig,
+ DonutSwinModel,
+ MBartConfig,
+ MBartForCausalLM,
+ VisionEncoderDecoderModel,
+ XLMRobertaTokenizerFast,
+)
+
+
+def get_configs(model):
+ original_config = model.config
+
+ encoder_config = DonutSwinConfig(
+ image_size=original_config.input_size,
+ patch_size=4,
+ depths=original_config.encoder_layer,
+ num_heads=[4, 8, 16, 32],
+ window_size=original_config.window_size,
+ embed_dim=128,
+ )
+ decoder_config = MBartConfig(
+ is_decoder=True,
+ is_encoder_decoder=False,
+ add_cross_attention=True,
+ decoder_layers=original_config.decoder_layer,
+ max_position_embeddings=original_config.max_position_embeddings,
+ vocab_size=len(
+ model.decoder.tokenizer
+ ), # several special tokens are added to the vocab of XLMRobertaTokenizer, see repo on the hub (added_tokens.json)
+ scale_embedding=True,
+ add_final_layer_norm=True,
+ )
+
+ return encoder_config, decoder_config
+
+
+def rename_key(name):
+ if "encoder.model" in name:
+ name = name.replace("encoder.model", "encoder")
+ if "decoder.model" in name:
+ name = name.replace("decoder.model", "decoder")
+ if "patch_embed.proj" in name:
+ name = name.replace("patch_embed.proj", "embeddings.patch_embeddings.projection")
+ if "patch_embed.norm" in name:
+ name = name.replace("patch_embed.norm", "embeddings.norm")
+ if name.startswith("encoder"):
+ if "layers" in name:
+ name = "encoder." + name
+ if "attn.proj" in name:
+ name = name.replace("attn.proj", "attention.output.dense")
+ if "attn" in name and "mask" not in name:
+ name = name.replace("attn", "attention.self")
+ if "norm1" in name:
+ name = name.replace("norm1", "layernorm_before")
+ if "norm2" in name:
+ name = name.replace("norm2", "layernorm_after")
+ if "mlp.fc1" in name:
+ name = name.replace("mlp.fc1", "intermediate.dense")
+ if "mlp.fc2" in name:
+ name = name.replace("mlp.fc2", "output.dense")
+
+ if name == "encoder.norm.weight":
+ name = "encoder.layernorm.weight"
+ if name == "encoder.norm.bias":
+ name = "encoder.layernorm.bias"
+
+ return name
+
+
+def convert_state_dict(orig_state_dict, model):
+ for key in orig_state_dict.copy().keys():
+ val = orig_state_dict.pop(key)
+
+ if "qkv" in key:
+ key_split = key.split(".")
+ layer_num = int(key_split[3])
+ block_num = int(key_split[5])
+ dim = model.encoder.encoder.layers[layer_num].blocks[block_num].attention.self.all_head_size
+
+ if "weight" in key:
+ orig_state_dict[
+ f"encoder.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.query.weight"
+ ] = val[:dim, :]
+ orig_state_dict[f"encoder.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.key.weight"] = (
+ val[dim : dim * 2, :]
+ )
+ orig_state_dict[
+ f"encoder.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.value.weight"
+ ] = val[-dim:, :]
+ else:
+ orig_state_dict[f"encoder.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.query.bias"] = (
+ val[:dim]
+ )
+ orig_state_dict[f"encoder.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.key.bias"] = (
+ val[dim : dim * 2]
+ )
+ orig_state_dict[f"encoder.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.value.bias"] = (
+ val[-dim:]
+ )
+ elif "attn_mask" in key or key in ["encoder.model.norm.weight", "encoder.model.norm.bias"]:
+ # HuggingFace implementation doesn't use attn_mask buffer
+ # and model doesn't use final LayerNorms for the encoder
+ pass
+ else:
+ orig_state_dict[rename_key(key)] = val
+
+ return orig_state_dict
+
+
+def convert_donut_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_hub=False):
+ # load original model
+ original_model = DonutModel.from_pretrained(model_name).eval()
+
+ # load HuggingFace model
+ encoder_config, decoder_config = get_configs(original_model)
+ encoder = DonutSwinModel(encoder_config)
+ decoder = MBartForCausalLM(decoder_config)
+ model = VisionEncoderDecoderModel(encoder=encoder, decoder=decoder)
+ model.eval()
+
+ state_dict = original_model.state_dict()
+ new_state_dict = convert_state_dict(state_dict, model)
+ model.load_state_dict(new_state_dict)
+
+ # verify results on scanned document
+ dataset = load_dataset("hf-internal-testing/example-documents") # no-script
+ image = dataset["test"][0]["image"].convert("RGB")
+
+ tokenizer = XLMRobertaTokenizerFast.from_pretrained(model_name, from_slow=True)
+ image_processor = DonutImageProcessor(
+ do_align_long_axis=original_model.config.align_long_axis, size=original_model.config.input_size[::-1]
+ )
+ processor = DonutProcessor(image_processor, tokenizer)
+ pixel_values = processor(image, return_tensors="pt").pixel_values
+
+ if model_name == "naver-clova-ix/donut-base-finetuned-docvqa":
+ task_prompt = "{user_input}"
+ question = "When is the coffee break?"
+ task_prompt = task_prompt.replace("{user_input}", question)
+ elif model_name == "naver-clova-ix/donut-base-finetuned-rvlcdip":
+ task_prompt = ""
+ elif model_name in [
+ "naver-clova-ix/donut-base-finetuned-cord-v1",
+ "naver-clova-ix/donut-base-finetuned-cord-v1-2560",
+ ]:
+ task_prompt = ""
+ elif model_name == "naver-clova-ix/donut-base-finetuned-cord-v2":
+ task_prompt = "s_cord-v2>"
+ elif model_name == "naver-clova-ix/donut-base-finetuned-zhtrainticket":
+ task_prompt = ""
+ elif model_name in ["naver-clova-ix/donut-proto", "naver-clova-ix/donut-base"]:
+ # use a random prompt
+ task_prompt = "hello world"
+ else:
+ raise ValueError("Model name not supported")
+ prompt_tensors = original_model.decoder.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt")[
+ "input_ids"
+ ]
+
+ original_patch_embed = original_model.encoder.model.patch_embed(pixel_values)
+ patch_embeddings, _ = model.encoder.embeddings(pixel_values)
+ assert torch.allclose(original_patch_embed, patch_embeddings, atol=1e-3)
+
+ # verify encoder hidden states
+ original_last_hidden_state = original_model.encoder(pixel_values)
+ last_hidden_state = model.encoder(pixel_values).last_hidden_state
+ assert torch.allclose(original_last_hidden_state, last_hidden_state, atol=1e-2)
+
+ # verify decoder hidden states
+ original_logits = original_model(pixel_values, prompt_tensors, None).logits
+ logits = model(pixel_values, decoder_input_ids=prompt_tensors).logits
+ assert torch.allclose(original_logits, logits, atol=1e-3)
+ print("Looks ok!")
+
+ if pytorch_dump_folder_path is not None:
+ print(f"Saving model and processor to {pytorch_dump_folder_path}")
+ model.save_pretrained(pytorch_dump_folder_path)
+ processor.save_pretrained(pytorch_dump_folder_path)
+
+ if push_to_hub:
+ model.push_to_hub("nielsr/" + model_name.split("/")[-1], commit_message="Update model")
+ processor.push_to_hub("nielsr/" + model_name.split("/")[-1], commit_message="Update model")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ # Required parameters
+ parser.add_argument(
+ "--model_name",
+ default="naver-clova-ix/donut-base-finetuned-docvqa",
+ required=False,
+ type=str,
+ help="Name of the original model you'd like to convert.",
+ )
+ parser.add_argument(
+ "--pytorch_dump_folder_path",
+ default=None,
+ required=False,
+ type=str,
+ help="Path to the output PyTorch model directory.",
+ )
+ parser.add_argument(
+ "--push_to_hub",
+ action="store_true",
+ help="Whether or not to push the converted model and processor to the 🤗 hub.",
+ )
+
+ args = parser.parse_args()
+ convert_donut_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub)
diff --git a/docs/transformers/build/lib/transformers/models/donut/feature_extraction_donut.py b/docs/transformers/build/lib/transformers/models/donut/feature_extraction_donut.py
new file mode 100644
index 0000000000000000000000000000000000000000..e37a58ddd3055e040c6c29cbd5f5cc3c34270cbe
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/donut/feature_extraction_donut.py
@@ -0,0 +1,38 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Feature extractor class for Donut."""
+
+import warnings
+
+from ...utils import logging
+from ...utils.import_utils import requires
+from .image_processing_donut import DonutImageProcessor
+
+
+logger = logging.get_logger(__name__)
+
+
+@requires(backends=("vision",))
+class DonutFeatureExtractor(DonutImageProcessor):
+ def __init__(self, *args, **kwargs) -> None:
+ warnings.warn(
+ "The class DonutFeatureExtractor is deprecated and will be removed in version 5 of Transformers. Please"
+ " use DonutImageProcessor instead.",
+ FutureWarning,
+ )
+ super().__init__(*args, **kwargs)
+
+
+__all__ = ["DonutFeatureExtractor"]
diff --git a/docs/transformers/build/lib/transformers/models/donut/image_processing_donut.py b/docs/transformers/build/lib/transformers/models/donut/image_processing_donut.py
new file mode 100644
index 0000000000000000000000000000000000000000..72d051859a70d2cd83c1fc5c538361534aa82e7d
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/donut/image_processing_donut.py
@@ -0,0 +1,477 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Image processor class for Donut."""
+
+from typing import Dict, List, Optional, Union
+
+import numpy as np
+
+from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
+from ...image_transforms import (
+ convert_to_rgb,
+ get_resize_output_image_size,
+ pad,
+ resize,
+ to_channel_dimension_format,
+)
+from ...image_utils import (
+ IMAGENET_STANDARD_MEAN,
+ IMAGENET_STANDARD_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ get_image_size,
+ infer_channel_dimension_format,
+ is_scaled_image,
+ make_list_of_images,
+ to_numpy_array,
+ valid_images,
+ validate_preprocess_arguments,
+)
+from ...utils import TensorType, filter_out_non_signature_kwargs, logging
+from ...utils.import_utils import is_vision_available, requires
+
+
+logger = logging.get_logger(__name__)
+
+
+if is_vision_available():
+ import PIL
+
+
+@requires(backends=("vision",))
+class DonutImageProcessor(BaseImageProcessor):
+ r"""
+ Constructs a Donut image processor.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
+ `do_resize` in the `preprocess` method.
+ size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 224}`):
+ Size of the image after resizing. The shortest edge of the image is resized to size["shortest_edge"], with
+ the longest edge resized to keep the input aspect ratio. Can be overridden by `size` in the `preprocess`
+ method.
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
+ Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
+ do_thumbnail (`bool`, *optional*, defaults to `True`):
+ Whether to resize the image using thumbnail method.
+ do_align_long_axis (`bool`, *optional*, defaults to `False`):
+ Whether to align the long axis of the image with the long axis of `size` by rotating by 90 degrees.
+ do_pad (`bool`, *optional*, defaults to `True`):
+ Whether to pad the image. If `random_padding` is set to `True` in `preprocess`, each image is padded with a
+ random amont of padding on each size, up to the largest image size in the batch. Otherwise, all images are
+ padded to the largest image size in the batch.
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
+ the `preprocess` method.
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
+ Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
+ method.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
+ image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
+ Image standard deviation.
+ """
+
+ model_input_names = ["pixel_values"]
+
+ def __init__(
+ self,
+ do_resize: bool = True,
+ size: Dict[str, int] = None,
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
+ do_thumbnail: bool = True,
+ do_align_long_axis: bool = False,
+ do_pad: bool = True,
+ do_rescale: bool = True,
+ rescale_factor: Union[int, float] = 1 / 255,
+ do_normalize: bool = True,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+
+ size = size if size is not None else {"height": 2560, "width": 1920}
+ if isinstance(size, (tuple, list)):
+ # The previous feature extractor size parameter was in (width, height) format
+ size = size[::-1]
+ size = get_size_dict(size)
+
+ self.do_resize = do_resize
+ self.size = size
+ self.resample = resample
+ self.do_thumbnail = do_thumbnail
+ self.do_align_long_axis = do_align_long_axis
+ self.do_pad = do_pad
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
+ self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
+
+ def align_long_axis(
+ self,
+ image: np.ndarray,
+ size: Dict[str, int],
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> np.ndarray:
+ """
+ Align the long axis of the image to the longest axis of the specified size.
+
+ Args:
+ image (`np.ndarray`):
+ The image to be aligned.
+ size (`Dict[str, int]`):
+ The size `{"height": h, "width": w}` to align the long axis to.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The data format of the output image. If unset, the same format as the input image is used.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred.
+
+ Returns:
+ `np.ndarray`: The aligned image.
+ """
+ input_height, input_width = get_image_size(image, channel_dim=input_data_format)
+ output_height, output_width = size["height"], size["width"]
+
+ if input_data_format is None:
+ # We assume that all images have the same channel dimension format.
+ input_data_format = infer_channel_dimension_format(image)
+
+ if input_data_format == ChannelDimension.LAST:
+ rot_axes = (0, 1)
+ elif input_data_format == ChannelDimension.FIRST:
+ rot_axes = (1, 2)
+ else:
+ raise ValueError(f"Unsupported data format: {input_data_format}")
+
+ if (output_width < output_height and input_width > input_height) or (
+ output_width > output_height and input_width < input_height
+ ):
+ image = np.rot90(image, 3, axes=rot_axes)
+
+ if data_format is not None:
+ image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
+
+ return image
+
+ def pad_image(
+ self,
+ image: np.ndarray,
+ size: Dict[str, int],
+ random_padding: bool = False,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> np.ndarray:
+ """
+ Pad the image to the specified size.
+
+ Args:
+ image (`np.ndarray`):
+ The image to be padded.
+ size (`Dict[str, int]`):
+ The size `{"height": h, "width": w}` to pad the image to.
+ random_padding (`bool`, *optional*, defaults to `False`):
+ Whether to use random padding or not.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The data format of the output image. If unset, the same format as the input image is used.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred.
+ """
+ output_height, output_width = size["height"], size["width"]
+ input_height, input_width = get_image_size(image, channel_dim=input_data_format)
+
+ delta_width = output_width - input_width
+ delta_height = output_height - input_height
+
+ if random_padding:
+ pad_top = np.random.randint(low=0, high=delta_height + 1)
+ pad_left = np.random.randint(low=0, high=delta_width + 1)
+ else:
+ pad_top = delta_height // 2
+ pad_left = delta_width // 2
+
+ pad_bottom = delta_height - pad_top
+ pad_right = delta_width - pad_left
+
+ padding = ((pad_top, pad_bottom), (pad_left, pad_right))
+ return pad(image, padding, data_format=data_format, input_data_format=input_data_format)
+
+ def pad(self, *args, **kwargs):
+ logger.info("pad is deprecated and will be removed in version 4.27. Please use pad_image instead.")
+ return self.pad_image(*args, **kwargs)
+
+ def thumbnail(
+ self,
+ image: np.ndarray,
+ size: Dict[str, int],
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ) -> np.ndarray:
+ """
+ Resize the image to make a thumbnail. The image is resized so that no dimension is larger than any
+ corresponding dimension of the specified size.
+
+ Args:
+ image (`np.ndarray`):
+ The image to be resized.
+ size (`Dict[str, int]`):
+ The size `{"height": h, "width": w}` to resize the image to.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
+ The resampling filter to use.
+ data_format (`Optional[Union[str, ChannelDimension]]`, *optional*):
+ The data format of the output image. If unset, the same format as the input image is used.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred.
+ """
+ input_height, input_width = get_image_size(image, channel_dim=input_data_format)
+ output_height, output_width = size["height"], size["width"]
+
+ # We always resize to the smallest of either the input or output size.
+ height = min(input_height, output_height)
+ width = min(input_width, output_width)
+
+ if height == input_height and width == input_width:
+ return image
+
+ if input_height > input_width:
+ width = int(input_width * height / input_height)
+ elif input_width > input_height:
+ height = int(input_height * width / input_width)
+
+ return resize(
+ image,
+ size=(height, width),
+ resample=resample,
+ reducing_gap=2.0,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ **kwargs,
+ )
+
+ def resize(
+ self,
+ image: np.ndarray,
+ size: Dict[str, int],
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ) -> np.ndarray:
+ """
+ Resizes `image` to `(height, width)` specified by `size` using the PIL library.
+
+ Args:
+ image (`np.ndarray`):
+ Image to resize.
+ size (`Dict[str, int]`):
+ Size of the output image.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
+ Resampling filter to use when resiizing the image.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred.
+ """
+ size = get_size_dict(size)
+ shortest_edge = min(size["height"], size["width"])
+ output_size = get_resize_output_image_size(
+ image, size=shortest_edge, default_to_square=False, input_data_format=input_data_format
+ )
+ resized_image = resize(
+ image,
+ size=output_size,
+ resample=resample,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ **kwargs,
+ )
+ return resized_image
+
+ @filter_out_non_signature_kwargs()
+ def preprocess(
+ self,
+ images: ImageInput,
+ do_resize: Optional[bool] = None,
+ size: Dict[str, int] = None,
+ resample: PILImageResampling = None,
+ do_thumbnail: Optional[bool] = None,
+ do_align_long_axis: Optional[bool] = None,
+ do_pad: Optional[bool] = None,
+ random_padding: bool = False,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> PIL.Image.Image:
+ """
+ Preprocess an image or batch of images.
+
+ Args:
+ images (`ImageInput`):
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
+ Size of the image after resizing. Shortest edge of the image is resized to min(size["height"],
+ size["width"]) with the longest edge resized to keep the input aspect ratio.
+ resample (`int`, *optional*, defaults to `self.resample`):
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
+ has an effect if `do_resize` is set to `True`.
+ do_thumbnail (`bool`, *optional*, defaults to `self.do_thumbnail`):
+ Whether to resize the image using thumbnail method.
+ do_align_long_axis (`bool`, *optional*, defaults to `self.do_align_long_axis`):
+ Whether to align the long axis of the image with the long axis of `size` by rotating by 90 degrees.
+ do_pad (`bool`, *optional*, defaults to `self.do_pad`):
+ Whether to pad the image. If `random_padding` is set to `True`, each image is padded with a random
+ amont of padding on each size, up to the largest image size in the batch. Otherwise, all images are
+ padded to the largest image size in the batch.
+ random_padding (`bool`, *optional*, defaults to `self.random_padding`):
+ Whether to use random padding when padding the image. If `True`, each image in the batch with be padded
+ with a random amount of padding on each side up to the size of the largest image in the batch.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image pixel values.
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
+ Image mean to use for normalization.
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
+ Image standard deviation to use for normalization.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - Unset: defaults to the channel dimension format of the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ """
+ do_resize = do_resize if do_resize is not None else self.do_resize
+ size = size if size is not None else self.size
+ if isinstance(size, (tuple, list)):
+ # Previous feature extractor had size in (width, height) format
+ size = size[::-1]
+ size = get_size_dict(size)
+ resample = resample if resample is not None else self.resample
+ do_thumbnail = do_thumbnail if do_thumbnail is not None else self.do_thumbnail
+ do_align_long_axis = do_align_long_axis if do_align_long_axis is not None else self.do_align_long_axis
+ do_pad = do_pad if do_pad is not None else self.do_pad
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+
+ images = make_list_of_images(images)
+
+ if not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+ validate_preprocess_arguments(
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ do_pad=do_pad,
+ size_divisibility=size, # There is no pad divisibility in this processor, but pad requires the size arg.
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ )
+
+ images = [convert_to_rgb(image) for image in images]
+
+ # All transformations expect numpy arrays.
+ images = [to_numpy_array(image) for image in images]
+
+ if do_rescale and is_scaled_image(images[0]):
+ logger.warning_once(
+ "It looks like you are trying to rescale already rescaled images. If the input"
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
+ )
+
+ if input_data_format is None:
+ # We assume that all images have the same channel dimension format.
+ input_data_format = infer_channel_dimension_format(images[0])
+
+ if do_align_long_axis:
+ images = [self.align_long_axis(image, size=size, input_data_format=input_data_format) for image in images]
+
+ if do_resize:
+ images = [
+ self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
+ for image in images
+ ]
+
+ if do_thumbnail:
+ images = [self.thumbnail(image=image, size=size, input_data_format=input_data_format) for image in images]
+
+ if do_pad:
+ images = [
+ self.pad_image(
+ image=image, size=size, random_padding=random_padding, input_data_format=input_data_format
+ )
+ for image in images
+ ]
+
+ if do_rescale:
+ images = [
+ self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
+ for image in images
+ ]
+
+ if do_normalize:
+ images = [
+ self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
+ for image in images
+ ]
+
+ images = [
+ to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
+ ]
+
+ data = {"pixel_values": images}
+ return BatchFeature(data=data, tensor_type=return_tensors)
+
+
+__all__ = ["DonutImageProcessor"]
diff --git a/docs/transformers/build/lib/transformers/models/donut/image_processing_donut_fast.py b/docs/transformers/build/lib/transformers/models/donut/image_processing_donut_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..be83b5cc5c6bd1b83d23ff1df9a853fbe023a74e
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/donut/image_processing_donut_fast.py
@@ -0,0 +1,289 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Fast Image processor class for Donut."""
+
+from typing import Optional, Union
+
+from ...image_processing_utils_fast import (
+ BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
+ BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
+ BaseImageProcessorFast,
+ BatchFeature,
+ DefaultFastImageProcessorKwargs,
+)
+from ...image_transforms import group_images_by_shape, reorder_images
+from ...image_utils import (
+ IMAGENET_STANDARD_MEAN,
+ IMAGENET_STANDARD_STD,
+ ImageInput,
+ PILImageResampling,
+ SizeDict,
+)
+from ...processing_utils import Unpack
+from ...utils import (
+ TensorType,
+ add_start_docstrings,
+ is_torch_available,
+ is_torchvision_available,
+ is_torchvision_v2_available,
+ logging,
+)
+
+
+logger = logging.get_logger(__name__)
+
+if is_torch_available():
+ import torch
+
+if is_torchvision_available():
+ if is_torchvision_v2_available():
+ from torchvision.transforms.v2 import functional as F
+ else:
+ from torchvision.transforms import functional as F
+
+
+class DonutFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
+ do_thumbnail: Optional[bool]
+ do_align_long_axis: Optional[bool]
+ do_pad: Optional[bool]
+
+
+@add_start_docstrings(
+ "Constructs a fast Donut image processor.",
+ BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
+ """
+ do_thumbnail (`bool`, *optional*, defaults to `self.do_thumbnail`):
+ Whether to resize the image using thumbnail method.
+ do_align_long_axis (`bool`, *optional*, defaults to `self.do_align_long_axis`):
+ Whether to align the long axis of the image with the long axis of `size` by rotating by 90 degrees.
+ do_pad (`bool`, *optional*, defaults to `self.do_pad`):
+ Whether to pad the image. If `random_padding` is set to `True`, each image is padded with a random
+ amount of padding on each size, up to the largest image size in the batch. Otherwise, all images are
+ padded to the largest image size in the batch.
+ """,
+)
+class DonutImageProcessorFast(BaseImageProcessorFast):
+ resample = PILImageResampling.BILINEAR
+ image_mean = IMAGENET_STANDARD_MEAN
+ image_std = IMAGENET_STANDARD_STD
+ size = {"height": 2560, "width": 1920}
+ do_resize = True
+ do_rescale = True
+ do_normalize = True
+ do_thumbnail = True
+ do_align_long_axis = False
+ do_pad = True
+ valid_kwargs = DonutFastImageProcessorKwargs
+
+ def __init__(self, **kwargs: Unpack[DonutFastImageProcessorKwargs]):
+ size = kwargs.pop("size", None)
+ if isinstance(size, (tuple, list)):
+ size = size[::-1]
+ kwargs["size"] = size
+ super().__init__(**kwargs)
+
+ @add_start_docstrings(
+ BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
+ """
+ do_thumbnail (`bool`, *optional*, defaults to `self.do_thumbnail`):
+ Whether to resize the image using thumbnail method.
+ do_align_long_axis (`bool`, *optional*, defaults to `self.do_align_long_axis`):
+ Whether to align the long axis of the image with the long axis of `size` by rotating by 90 degrees.
+ do_pad (`bool`, *optional*, defaults to `self.do_pad`):
+ Whether to pad the image. If `random_padding` is set to `True`, each image is padded with a random
+ amount of padding on each size, up to the largest image size in the batch. Otherwise, all images are
+ padded to the largest image size in the batch.
+ """,
+ )
+ def preprocess(self, images: ImageInput, **kwargs: Unpack[DonutFastImageProcessorKwargs]) -> BatchFeature:
+ if "size" in kwargs:
+ size = kwargs.pop("size")
+ if isinstance(size, (tuple, list)):
+ size = size[::-1]
+ kwargs["size"] = size
+ return super().preprocess(images, **kwargs)
+
+ def align_long_axis(
+ self,
+ image: "torch.Tensor",
+ size: SizeDict,
+ ) -> "torch.Tensor":
+ """
+ Align the long axis of the image to the longest axis of the specified size.
+
+ Args:
+ image (`torch.Tensor`):
+ The image to be aligned.
+ size (`Dict[str, int]`):
+ The size `{"height": h, "width": w}` to align the long axis to.
+
+ Returns:
+ `torch.Tensor`: The aligned image.
+ """
+ input_height, input_width = image.shape[-2:]
+ output_height, output_width = size.height, size.width
+
+ if (output_width < output_height and input_width > input_height) or (
+ output_width > output_height and input_width < input_height
+ ):
+ height_dim, width_dim = image.dim() - 2, image.dim() - 1
+ image = torch.rot90(image, 3, dims=[height_dim, width_dim])
+
+ return image
+
+ def pad_image(
+ self,
+ image: "torch.Tensor",
+ size: SizeDict,
+ random_padding: bool = False,
+ ) -> "torch.Tensor":
+ """
+ Pad the image to the specified size.
+
+ Args:
+ image (`torch.Tensor`):
+ The image to be padded.
+ size (`Dict[str, int]`):
+ The size `{"height": h, "width": w}` to pad the image to.
+ random_padding (`bool`, *optional*, defaults to `False`):
+ Whether to use random padding or not.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The data format of the output image. If unset, the same format as the input image is used.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred.
+ """
+ output_height, output_width = size.height, size.width
+ input_height, input_width = image.shape[-2:]
+
+ delta_width = output_width - input_width
+ delta_height = output_height - input_height
+
+ if random_padding:
+ pad_top = torch.random.randint(low=0, high=delta_height + 1)
+ pad_left = torch.random.randint(low=0, high=delta_width + 1)
+ else:
+ pad_top = delta_height // 2
+ pad_left = delta_width // 2
+
+ pad_bottom = delta_height - pad_top
+ pad_right = delta_width - pad_left
+
+ padding = (pad_left, pad_top, pad_right, pad_bottom)
+ return F.pad(image, padding)
+
+ def pad(self, *args, **kwargs):
+ logger.info("pad is deprecated and will be removed in version 4.27. Please use pad_image instead.")
+ return self.pad_image(*args, **kwargs)
+
+ def thumbnail(
+ self,
+ image: "torch.Tensor",
+ size: SizeDict,
+ ) -> "torch.Tensor":
+ """
+ Resize the image to make a thumbnail. The image is resized so that no dimension is larger than any
+ corresponding dimension of the specified size.
+
+ Args:
+ image (`torch.Tensor`):
+ The image to be resized.
+ size (`Dict[str, int]`):
+ The size `{"height": h, "width": w}` to resize the image to.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
+ The resampling filter to use.
+ data_format (`Optional[Union[str, ChannelDimension]]`, *optional*):
+ The data format of the output image. If unset, the same format as the input image is used.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred.
+ """
+ input_height, input_width = image.shape[-2:]
+ output_height, output_width = size.height, size.width
+
+ # We always resize to the smallest of either the input or output size.
+ height = min(input_height, output_height)
+ width = min(input_width, output_width)
+
+ if height == input_height and width == input_width:
+ return image
+
+ if input_height > input_width:
+ width = int(input_width * height / input_height)
+ elif input_width > input_height:
+ height = int(input_height * width / input_width)
+
+ return self.resize(
+ image,
+ size=SizeDict(width=width, height=height),
+ interpolation=F.InterpolationMode.BICUBIC,
+ )
+
+ def _preprocess(
+ self,
+ images: list["torch.Tensor"],
+ do_resize: bool,
+ do_thumbnail: bool,
+ do_align_long_axis: bool,
+ do_pad: bool,
+ size: SizeDict,
+ interpolation: Optional["F.InterpolationMode"],
+ do_center_crop: bool,
+ crop_size: SizeDict,
+ do_rescale: bool,
+ rescale_factor: float,
+ do_normalize: bool,
+ image_mean: Optional[Union[float, list[float]]],
+ image_std: Optional[Union[float, list[float]]],
+ return_tensors: Optional[Union[str, TensorType]],
+ **kwargs,
+ ) -> BatchFeature:
+ # Group images by size for batched resizing
+ grouped_images, grouped_images_index = group_images_by_shape(images)
+ resized_images_grouped = {}
+ for shape, stacked_images in grouped_images.items():
+ if do_align_long_axis:
+ stacked_images = self.align_long_axis(image=stacked_images, size=size)
+ if do_resize:
+ shortest_edge = min(size.height, size.width)
+ stacked_images = self.resize(
+ image=stacked_images, size=SizeDict(shortest_edge=shortest_edge), interpolation=interpolation
+ )
+ if do_thumbnail:
+ stacked_images = self.thumbnail(image=stacked_images, size=size)
+ if do_pad:
+ stacked_images = self.pad_image(image=stacked_images, size=size, random_padding=False)
+
+ resized_images_grouped[shape] = stacked_images
+ resized_images = reorder_images(resized_images_grouped, grouped_images_index)
+
+ # Group images by size for further processing
+ # Needed in case do_resize is False, or resize returns images with different sizes
+ grouped_images, grouped_images_index = group_images_by_shape(resized_images)
+ processed_images_grouped = {}
+ for shape, stacked_images in grouped_images.items():
+ if do_center_crop:
+ stacked_images = self.center_crop(stacked_images, crop_size)
+ # Fused rescale and normalize
+ stacked_images = self.rescale_and_normalize(
+ stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
+ )
+ processed_images_grouped[shape] = stacked_images
+
+ processed_images = reorder_images(processed_images_grouped, grouped_images_index)
+ processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
+
+ return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
+
+
+__all__ = ["DonutImageProcessorFast"]
diff --git a/docs/transformers/build/lib/transformers/models/donut/modeling_donut_swin.py b/docs/transformers/build/lib/transformers/models/donut/modeling_donut_swin.py
new file mode 100644
index 0000000000000000000000000000000000000000..43d8f3f10798be368f95c285d932d937c7e39178
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/donut/modeling_donut_swin.py
@@ -0,0 +1,1145 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Donut Swin Transformer model.
+
+This implementation is identical to a regular Swin Transformer, without final layer norm on top of the final hidden
+states."""
+
+import collections.abc
+import math
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+
+from ...activations import ACT2FN
+from ...modeling_utils import PreTrainedModel
+from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer
+from ...utils import (
+ ModelOutput,
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ torch_int,
+)
+from .configuration_donut_swin import DonutSwinConfig
+
+
+logger = logging.get_logger(__name__)
+
+# General docstring
+_CONFIG_FOR_DOC = "DonutSwinConfig"
+
+# Base docstring
+_CHECKPOINT_FOR_DOC = "https://huggingface.co/naver-clova-ix/donut-base"
+_EXPECTED_OUTPUT_SHAPE = [1, 49, 768]
+
+# Image classification docstring
+_IMAGE_CLASS_CHECKPOINT = "eljandoubi/donut-base-encoder"
+_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
+
+
+@dataclass
+# Copied from transformers.models.swin.modeling_swin.SwinEncoderOutput with Swin->DonutSwin
+class DonutSwinEncoderOutput(ModelOutput):
+ """
+ DonutSwin encoder's outputs, with potential hidden states and attentions.
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
+ shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
+ shape `(batch_size, hidden_size, height, width)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
+ include the spatial dimensions.
+ """
+
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+ reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+# Copied from transformers.models.swin.modeling_swin.SwinModelOutput with Swin->DonutSwin
+class DonutSwinModelOutput(ModelOutput):
+ """
+ DonutSwin model's outputs that also contains a pooling of the last hidden states.
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*, returned when `add_pooling_layer=True` is passed):
+ Average pooling of the last layer hidden-state.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
+ shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
+ shape `(batch_size, hidden_size, height, width)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
+ include the spatial dimensions.
+ """
+
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ pooler_output: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+ reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+# Copied from transformers.models.swin.modeling_swin.SwinImageClassifierOutput with Swin->DonutSwin
+class DonutSwinImageClassifierOutput(ModelOutput):
+ """
+ DonutSwin outputs for image classification.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Classification (or regression if config.num_labels==1) loss.
+ logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
+ Classification (or regression if config.num_labels==1) scores (before SoftMax).
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
+ shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
+ shape `(batch_size, hidden_size, height, width)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
+ include the spatial dimensions.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+ reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+
+
+# Copied from transformers.models.swin.modeling_swin.window_partition
+def window_partition(input_feature, window_size):
+ """
+ Partitions the given input into windows.
+ """
+ batch_size, height, width, num_channels = input_feature.shape
+ input_feature = input_feature.view(
+ batch_size, height // window_size, window_size, width // window_size, window_size, num_channels
+ )
+ windows = input_feature.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels)
+ return windows
+
+
+# Copied from transformers.models.swin.modeling_swin.window_reverse
+def window_reverse(windows, window_size, height, width):
+ """
+ Merges windows to produce higher resolution features.
+ """
+ num_channels = windows.shape[-1]
+ windows = windows.view(-1, height // window_size, width // window_size, window_size, window_size, num_channels)
+ windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, height, width, num_channels)
+ return windows
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinEmbeddings with Swin->DonutSwin
+class DonutSwinEmbeddings(nn.Module):
+ """
+ Construct the patch and position embeddings. Optionally, also the mask token.
+ """
+
+ def __init__(self, config, use_mask_token=False):
+ super().__init__()
+
+ self.patch_embeddings = DonutSwinPatchEmbeddings(config)
+ num_patches = self.patch_embeddings.num_patches
+ self.patch_grid = self.patch_embeddings.grid_size
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, config.embed_dim)) if use_mask_token else None
+
+ if config.use_absolute_embeddings:
+ self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.embed_dim))
+ else:
+ self.position_embeddings = None
+
+ self.norm = nn.LayerNorm(config.embed_dim)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ self.patch_size = config.patch_size
+ self.config = config
+
+ # Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
+ """
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
+ images. This method is also adapted to support torch.jit tracing.
+
+ Adapted from:
+ - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
+ - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
+ """
+
+ num_patches = embeddings.shape[1] - 1
+ num_positions = self.position_embeddings.shape[1] - 1
+
+ # always interpolate when tracing to ensure the exported model works for dynamic input shapes
+ if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
+ return self.position_embeddings
+
+ class_pos_embed = self.position_embeddings[:, :1]
+ patch_pos_embed = self.position_embeddings[:, 1:]
+
+ dim = embeddings.shape[-1]
+
+ new_height = height // self.patch_size
+ new_width = width // self.patch_size
+
+ sqrt_num_positions = torch_int(num_positions**0.5)
+ patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
+
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed,
+ size=(new_height, new_width),
+ mode="bicubic",
+ align_corners=False,
+ )
+
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+
+ return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
+
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor],
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
+ interpolate_pos_encoding: bool = False,
+ ) -> Tuple[torch.Tensor]:
+ _, num_channels, height, width = pixel_values.shape
+ embeddings, output_dimensions = self.patch_embeddings(pixel_values)
+ embeddings = self.norm(embeddings)
+ batch_size, seq_len, _ = embeddings.size()
+
+ if bool_masked_pos is not None:
+ mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
+ # replace the masked visual tokens by mask_tokens
+ mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
+ embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
+
+ if self.position_embeddings is not None:
+ if interpolate_pos_encoding:
+ embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
+ else:
+ embeddings = embeddings + self.position_embeddings
+
+ embeddings = self.dropout(embeddings)
+
+ return embeddings, output_dimensions
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinPatchEmbeddings with Swin->DonutSwin
+class DonutSwinPatchEmbeddings(nn.Module):
+ """
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
+ Transformer.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ image_size, patch_size = config.image_size, config.patch_size
+ num_channels, hidden_size = config.num_channels, config.embed_dim
+ image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.num_patches = num_patches
+ self.grid_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
+
+ self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
+
+ def maybe_pad(self, pixel_values, height, width):
+ if width % self.patch_size[1] != 0:
+ pad_values = (0, self.patch_size[1] - width % self.patch_size[1])
+ pixel_values = nn.functional.pad(pixel_values, pad_values)
+ if height % self.patch_size[0] != 0:
+ pad_values = (0, 0, 0, self.patch_size[0] - height % self.patch_size[0])
+ pixel_values = nn.functional.pad(pixel_values, pad_values)
+ return pixel_values
+
+ def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]:
+ _, num_channels, height, width = pixel_values.shape
+ # pad the input to be divisible by self.patch_size, if needed
+ pixel_values = self.maybe_pad(pixel_values, height, width)
+ embeddings = self.projection(pixel_values)
+ _, _, height, width = embeddings.shape
+ output_dimensions = (height, width)
+ embeddings = embeddings.flatten(2).transpose(1, 2)
+
+ return embeddings, output_dimensions
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinPatchMerging
+class DonutSwinPatchMerging(nn.Module):
+ """
+ Patch Merging Layer.
+
+ Args:
+ input_resolution (`Tuple[int]`):
+ Resolution of input feature.
+ dim (`int`):
+ Number of input channels.
+ norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`):
+ Normalization layer class.
+ """
+
+ def __init__(self, input_resolution: Tuple[int], dim: int, norm_layer: nn.Module = nn.LayerNorm) -> None:
+ super().__init__()
+ self.input_resolution = input_resolution
+ self.dim = dim
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
+ self.norm = norm_layer(4 * dim)
+
+ def maybe_pad(self, input_feature, height, width):
+ should_pad = (height % 2 == 1) or (width % 2 == 1)
+ if should_pad:
+ pad_values = (0, 0, 0, width % 2, 0, height % 2)
+ input_feature = nn.functional.pad(input_feature, pad_values)
+
+ return input_feature
+
+ def forward(self, input_feature: torch.Tensor, input_dimensions: Tuple[int, int]) -> torch.Tensor:
+ height, width = input_dimensions
+ # `dim` is height * width
+ batch_size, dim, num_channels = input_feature.shape
+
+ input_feature = input_feature.view(batch_size, height, width, num_channels)
+ # pad input to be disible by width and height, if needed
+ input_feature = self.maybe_pad(input_feature, height, width)
+ # [batch_size, height/2, width/2, num_channels]
+ input_feature_0 = input_feature[:, 0::2, 0::2, :]
+ # [batch_size, height/2, width/2, num_channels]
+ input_feature_1 = input_feature[:, 1::2, 0::2, :]
+ # [batch_size, height/2, width/2, num_channels]
+ input_feature_2 = input_feature[:, 0::2, 1::2, :]
+ # [batch_size, height/2, width/2, num_channels]
+ input_feature_3 = input_feature[:, 1::2, 1::2, :]
+ # batch_size height/2 width/2 4*num_channels
+ input_feature = torch.cat([input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1)
+ input_feature = input_feature.view(batch_size, -1, 4 * num_channels) # batch_size height/2*width/2 4*C
+
+ input_feature = self.norm(input_feature)
+ input_feature = self.reduction(input_feature)
+
+ return input_feature
+
+
+# Copied from transformers.models.beit.modeling_beit.drop_path
+def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
+ """
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+ Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
+ however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
+ layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
+ argument.
+ """
+ if drop_prob == 0.0 or not training:
+ return input
+ keep_prob = 1 - drop_prob
+ shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
+ random_tensor.floor_() # binarize
+ output = input.div(keep_prob) * random_tensor
+ return output
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinDropPath
+class DonutSwinDropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob: Optional[float] = None) -> None:
+ super().__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ return drop_path(hidden_states, self.drop_prob, self.training)
+
+ def extra_repr(self) -> str:
+ return "p={}".format(self.drop_prob)
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinSelfAttention with Swin->DonutSwin
+class DonutSwinSelfAttention(nn.Module):
+ def __init__(self, config, dim, num_heads, window_size):
+ super().__init__()
+ if dim % num_heads != 0:
+ raise ValueError(
+ f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})"
+ )
+
+ self.num_attention_heads = num_heads
+ self.attention_head_size = int(dim / num_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+ self.window_size = (
+ window_size if isinstance(window_size, collections.abc.Iterable) else (window_size, window_size)
+ )
+
+ self.relative_position_bias_table = nn.Parameter(
+ torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads)
+ )
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(self.window_size[0])
+ coords_w = torch.arange(self.window_size[1])
+ coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij"))
+ coords_flatten = torch.flatten(coords, 1)
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous()
+ relative_coords[:, :, 0] += self.window_size[0] - 1
+ relative_coords[:, :, 1] += self.window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
+ relative_position_index = relative_coords.sum(-1)
+ self.register_buffer("relative_position_index", relative_position_index)
+
+ self.query = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
+ self.key = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
+ self.value = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+
+ def transpose_for_scores(self, x):
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
+ x = x.view(new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor]:
+ batch_size, dim, num_channels = hidden_states.shape
+ mixed_query_layer = self.query(hidden_states)
+
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)]
+ relative_position_bias = relative_position_bias.view(
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
+ )
+
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
+ attention_scores = attention_scores + relative_position_bias.unsqueeze(0)
+
+ if attention_mask is not None:
+ # Apply the attention mask is (precomputed for all layers in DonutSwinModel forward() function)
+ mask_shape = attention_mask.shape[0]
+ attention_scores = attention_scores.view(
+ batch_size // mask_shape, mask_shape, self.num_attention_heads, dim, dim
+ )
+ attention_scores = attention_scores + attention_mask.unsqueeze(1).unsqueeze(0)
+ attention_scores = attention_scores.view(-1, self.num_attention_heads, dim, dim)
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs = attention_probs * head_mask
+
+ context_layer = torch.matmul(attention_probs, value_layer)
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(new_context_layer_shape)
+
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+ return outputs
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinSelfOutput
+class DonutSwinSelfOutput(nn.Module):
+ def __init__(self, config, dim):
+ super().__init__()
+ self.dense = nn.Linear(dim, dim)
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+
+ return hidden_states
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinAttention with Swin->DonutSwin
+class DonutSwinAttention(nn.Module):
+ def __init__(self, config, dim, num_heads, window_size):
+ super().__init__()
+ self.self = DonutSwinSelfAttention(config, dim, num_heads, window_size)
+ self.output = DonutSwinSelfOutput(config, dim)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
+ )
+
+ # Prune linear layers
+ self.self.query = prune_linear_layer(self.self.query, index)
+ self.self.key = prune_linear_layer(self.self.key, index)
+ self.self.value = prune_linear_layer(self.self.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor]:
+ self_outputs = self.self(hidden_states, attention_mask, head_mask, output_attentions)
+ attention_output = self.output(self_outputs[0], hidden_states)
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+ return outputs
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinIntermediate
+class DonutSwinIntermediate(nn.Module):
+ def __init__(self, config, dim):
+ super().__init__()
+ self.dense = nn.Linear(dim, int(config.mlp_ratio * dim))
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinOutput
+class DonutSwinOutput(nn.Module):
+ def __init__(self, config, dim):
+ super().__init__()
+ self.dense = nn.Linear(int(config.mlp_ratio * dim), dim)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinLayer with Swin->DonutSwin
+class DonutSwinLayer(nn.Module):
+ def __init__(self, config, dim, input_resolution, num_heads, drop_path_rate=0.0, shift_size=0):
+ super().__init__()
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.shift_size = shift_size
+ self.window_size = config.window_size
+ self.input_resolution = input_resolution
+ self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps)
+ self.attention = DonutSwinAttention(config, dim, num_heads, window_size=self.window_size)
+ self.drop_path = DonutSwinDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
+ self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps)
+ self.intermediate = DonutSwinIntermediate(config, dim)
+ self.output = DonutSwinOutput(config, dim)
+
+ def set_shift_and_window_size(self, input_resolution):
+ if min(input_resolution) <= self.window_size:
+ # if window size is larger than input resolution, we don't partition windows
+ self.shift_size = torch_int(0)
+ self.window_size = (
+ torch.min(torch.tensor(input_resolution)) if torch.jit.is_tracing() else min(input_resolution)
+ )
+
+ def get_attn_mask(self, height, width, dtype, device):
+ if self.shift_size > 0:
+ # calculate attention mask for SW-MSA
+ img_mask = torch.zeros((1, height, width, 1), dtype=dtype, device=device)
+ height_slices = (
+ slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None),
+ )
+ width_slices = (
+ slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None),
+ )
+ count = 0
+ for height_slice in height_slices:
+ for width_slice in width_slices:
+ img_mask[:, height_slice, width_slice, :] = count
+ count += 1
+
+ mask_windows = window_partition(img_mask, self.window_size)
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
+ else:
+ attn_mask = None
+ return attn_mask
+
+ def maybe_pad(self, hidden_states, height, width):
+ pad_right = (self.window_size - width % self.window_size) % self.window_size
+ pad_bottom = (self.window_size - height % self.window_size) % self.window_size
+ pad_values = (0, 0, 0, pad_right, 0, pad_bottom)
+ hidden_states = nn.functional.pad(hidden_states, pad_values)
+ return hidden_states, pad_values
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ input_dimensions: Tuple[int, int],
+ head_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = False,
+ always_partition: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ if not always_partition:
+ self.set_shift_and_window_size(input_dimensions)
+ else:
+ pass
+ height, width = input_dimensions
+ batch_size, _, channels = hidden_states.size()
+ shortcut = hidden_states
+
+ hidden_states = self.layernorm_before(hidden_states)
+
+ hidden_states = hidden_states.view(batch_size, height, width, channels)
+
+ # pad hidden_states to multiples of window size
+ hidden_states, pad_values = self.maybe_pad(hidden_states, height, width)
+
+ _, height_pad, width_pad, _ = hidden_states.shape
+ # cyclic shift
+ if self.shift_size > 0:
+ shifted_hidden_states = torch.roll(hidden_states, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
+ else:
+ shifted_hidden_states = hidden_states
+
+ # partition windows
+ hidden_states_windows = window_partition(shifted_hidden_states, self.window_size)
+ hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels)
+ attn_mask = self.get_attn_mask(
+ height_pad, width_pad, dtype=hidden_states.dtype, device=hidden_states_windows.device
+ )
+
+ attention_outputs = self.attention(
+ hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions
+ )
+
+ attention_output = attention_outputs[0]
+
+ attention_windows = attention_output.view(-1, self.window_size, self.window_size, channels)
+ shifted_windows = window_reverse(attention_windows, self.window_size, height_pad, width_pad)
+
+ # reverse cyclic shift
+ if self.shift_size > 0:
+ attention_windows = torch.roll(shifted_windows, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
+ else:
+ attention_windows = shifted_windows
+
+ was_padded = pad_values[3] > 0 or pad_values[5] > 0
+ if was_padded:
+ attention_windows = attention_windows[:, :height, :width, :].contiguous()
+
+ attention_windows = attention_windows.view(batch_size, height * width, channels)
+
+ hidden_states = shortcut + self.drop_path(attention_windows)
+
+ layer_output = self.layernorm_after(hidden_states)
+ layer_output = self.intermediate(layer_output)
+ layer_output = hidden_states + self.output(layer_output)
+
+ layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,)
+ return layer_outputs
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinStage with Swin->DonutSwin
+class DonutSwinStage(nn.Module):
+ def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, downsample):
+ super().__init__()
+ self.config = config
+ self.dim = dim
+ self.blocks = nn.ModuleList(
+ [
+ DonutSwinLayer(
+ config=config,
+ dim=dim,
+ input_resolution=input_resolution,
+ num_heads=num_heads,
+ drop_path_rate=drop_path[i],
+ shift_size=0 if (i % 2 == 0) else config.window_size // 2,
+ )
+ for i in range(depth)
+ ]
+ )
+
+ # patch merging layer
+ if downsample is not None:
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=nn.LayerNorm)
+ else:
+ self.downsample = None
+
+ self.pointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ input_dimensions: Tuple[int, int],
+ head_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = False,
+ always_partition: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor]:
+ height, width = input_dimensions
+ for i, layer_module in enumerate(self.blocks):
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+
+ layer_outputs = layer_module(
+ hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition
+ )
+
+ hidden_states = layer_outputs[0]
+
+ hidden_states_before_downsampling = hidden_states
+ if self.downsample is not None:
+ height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2
+ output_dimensions = (height, width, height_downsampled, width_downsampled)
+ hidden_states = self.downsample(hidden_states_before_downsampling, input_dimensions)
+ else:
+ output_dimensions = (height, width, height, width)
+
+ stage_outputs = (hidden_states, hidden_states_before_downsampling, output_dimensions)
+
+ if output_attentions:
+ stage_outputs += layer_outputs[1:]
+ return stage_outputs
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinEncoder with Swin->DonutSwin
+class DonutSwinEncoder(nn.Module):
+ def __init__(self, config, grid_size):
+ super().__init__()
+ self.num_layers = len(config.depths)
+ self.config = config
+ dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")]
+ self.layers = nn.ModuleList(
+ [
+ DonutSwinStage(
+ config=config,
+ dim=int(config.embed_dim * 2**i_layer),
+ input_resolution=(grid_size[0] // (2**i_layer), grid_size[1] // (2**i_layer)),
+ depth=config.depths[i_layer],
+ num_heads=config.num_heads[i_layer],
+ drop_path=dpr[sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])],
+ downsample=DonutSwinPatchMerging if (i_layer < self.num_layers - 1) else None,
+ )
+ for i_layer in range(self.num_layers)
+ ]
+ )
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ input_dimensions: Tuple[int, int],
+ head_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = False,
+ output_hidden_states: Optional[bool] = False,
+ output_hidden_states_before_downsampling: Optional[bool] = False,
+ always_partition: Optional[bool] = False,
+ return_dict: Optional[bool] = True,
+ ) -> Union[Tuple, DonutSwinEncoderOutput]:
+ all_hidden_states = () if output_hidden_states else None
+ all_reshaped_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+
+ if output_hidden_states:
+ batch_size, _, hidden_size = hidden_states.shape
+ # rearrange b (h w) c -> b c h w
+ reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
+ reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
+ all_hidden_states += (hidden_states,)
+ all_reshaped_hidden_states += (reshaped_hidden_state,)
+
+ for i, layer_module in enumerate(self.layers):
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ layer_module.__call__,
+ hidden_states,
+ input_dimensions,
+ layer_head_mask,
+ output_attentions,
+ always_partition,
+ )
+ else:
+ layer_outputs = layer_module(
+ hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition
+ )
+
+ hidden_states = layer_outputs[0]
+ hidden_states_before_downsampling = layer_outputs[1]
+ output_dimensions = layer_outputs[2]
+
+ input_dimensions = (output_dimensions[-2], output_dimensions[-1])
+
+ if output_hidden_states and output_hidden_states_before_downsampling:
+ batch_size, _, hidden_size = hidden_states_before_downsampling.shape
+ # rearrange b (h w) c -> b c h w
+ # here we use the original (not downsampled) height and width
+ reshaped_hidden_state = hidden_states_before_downsampling.view(
+ batch_size, *(output_dimensions[0], output_dimensions[1]), hidden_size
+ )
+ reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
+ all_hidden_states += (hidden_states_before_downsampling,)
+ all_reshaped_hidden_states += (reshaped_hidden_state,)
+ elif output_hidden_states and not output_hidden_states_before_downsampling:
+ batch_size, _, hidden_size = hidden_states.shape
+ # rearrange b (h w) c -> b c h w
+ reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
+ reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
+ all_hidden_states += (hidden_states,)
+ all_reshaped_hidden_states += (reshaped_hidden_state,)
+
+ if output_attentions:
+ all_self_attentions += layer_outputs[3:]
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
+
+ return DonutSwinEncoderOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ reshaped_hidden_states=all_reshaped_hidden_states,
+ )
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinPreTrainedModel with Swin->DonutSwin,swin->donut
+class DonutSwinPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = DonutSwinConfig
+ base_model_prefix = "donut"
+ main_input_name = "pixel_values"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["DonutSwinStage"]
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, DonutSwinEmbeddings):
+ if module.mask_token is not None:
+ module.mask_token.data.zero_()
+ if module.position_embeddings is not None:
+ module.position_embeddings.data.zero_()
+ elif isinstance(module, DonutSwinSelfAttention):
+ module.relative_position_bias_table.data.zero_()
+
+
+SWIN_START_DOCSTRING = r"""
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
+ it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+ behavior.
+
+ Parameters:
+ config ([`DonutSwinConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+SWIN_INPUTS_DOCSTRING = r"""
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
+ [`DonutImageProcessor.__call__`] for details.
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
+ Whether to interpolate the pre-trained position encodings.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare Donut Swin Model transformer outputting raw hidden-states without any specific head on top.",
+ SWIN_START_DOCSTRING,
+)
+class DonutSwinModel(DonutSwinPreTrainedModel):
+ def __init__(self, config, add_pooling_layer=True, use_mask_token=False):
+ super().__init__(config)
+ self.config = config
+ self.num_layers = len(config.depths)
+ self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1))
+
+ self.embeddings = DonutSwinEmbeddings(config, use_mask_token=use_mask_token)
+ self.encoder = DonutSwinEncoder(config, self.embeddings.patch_grid)
+
+ self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embeddings.patch_embeddings
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ @add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=DonutSwinModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ modality="vision",
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
+ )
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ interpolate_pos_encoding: bool = False,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, DonutSwinModelOutput]:
+ r"""
+ bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
+ Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, len(self.config.depths))
+
+ embedding_output, input_dimensions = self.embeddings(
+ pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
+ )
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ input_dimensions,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = encoder_outputs[0]
+
+ pooled_output = None
+ if self.pooler is not None:
+ pooled_output = self.pooler(sequence_output.transpose(1, 2))
+ pooled_output = torch.flatten(pooled_output, 1)
+
+ if not return_dict:
+ output = (sequence_output, pooled_output) + encoder_outputs[1:]
+
+ return output
+
+ return DonutSwinModelOutput(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ reshaped_hidden_states=encoder_outputs.reshaped_hidden_states,
+ )
+
+
+@add_start_docstrings(
+ """
+ DonutSwin Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
+ the [CLS] token) e.g. for ImageNet.
+
+
+
+ Note that it's possible to fine-tune DonutSwin on higher resolution images than the ones it has been trained on, by
+ setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
+ position embeddings to the higher resolution.
+
+
+ """,
+ SWIN_START_DOCSTRING,
+)
+# Copied from transformers.models.swin.modeling_swin.SwinForImageClassification with Swin->DonutSwin,swin->donut
+class DonutSwinForImageClassification(DonutSwinPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.num_labels = config.num_labels
+ self.donut = DonutSwinModel(config)
+
+ # Classifier head
+ self.classifier = (
+ nn.Linear(self.donut.num_features, config.num_labels) if config.num_labels > 0 else nn.Identity()
+ )
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_IMAGE_CLASS_CHECKPOINT,
+ output_type=DonutSwinImageClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
+ )
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ interpolate_pos_encoding: bool = False,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, DonutSwinImageClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.donut(
+ pixel_values,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ return_dict=return_dict,
+ )
+
+ pooled_output = outputs[1]
+
+ logits = self.classifier(pooled_output)
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits=logits, labels=labels, pooled_logits=logits, config=self.config)
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return DonutSwinImageClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ reshaped_hidden_states=outputs.reshaped_hidden_states,
+ )
+
+
+__all__ = ["DonutSwinModel", "DonutSwinPreTrainedModel", "DonutSwinForImageClassification"]
diff --git a/docs/transformers/build/lib/transformers/models/donut/processing_donut.py b/docs/transformers/build/lib/transformers/models/donut/processing_donut.py
new file mode 100644
index 0000000000000000000000000000000000000000..689aa5122f8feca2862f71f02eaad4519d6d7cbf
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/donut/processing_donut.py
@@ -0,0 +1,220 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Processor class for Donut.
+"""
+
+import re
+import warnings
+from contextlib import contextmanager
+from typing import List, Optional, Union
+
+from ...image_utils import ImageInput
+from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
+from ...tokenization_utils_base import PreTokenizedInput, TextInput
+from ...utils import logging
+
+
+class DonutProcessorKwargs(ProcessingKwargs, total=False):
+ _defaults = {}
+
+
+logger = logging.get_logger(__name__)
+
+
+class DonutProcessor(ProcessorMixin):
+ r"""
+ Constructs a Donut processor which wraps a Donut image processor and an XLMRoBERTa tokenizer into a single
+ processor.
+
+ [`DonutProcessor`] offers all the functionalities of [`DonutImageProcessor`] and
+ [`XLMRobertaTokenizer`/`XLMRobertaTokenizerFast`]. See the [`~DonutProcessor.__call__`] and
+ [`~DonutProcessor.decode`] for more information.
+
+ Args:
+ image_processor ([`DonutImageProcessor`], *optional*):
+ An instance of [`DonutImageProcessor`]. The image processor is a required input.
+ tokenizer ([`XLMRobertaTokenizer`/`XLMRobertaTokenizerFast`], *optional*):
+ An instance of [`XLMRobertaTokenizer`/`XLMRobertaTokenizerFast`]. The tokenizer is a required input.
+ """
+
+ attributes = ["image_processor", "tokenizer"]
+ image_processor_class = "AutoImageProcessor"
+ tokenizer_class = "AutoTokenizer"
+
+ def __init__(self, image_processor=None, tokenizer=None, **kwargs):
+ feature_extractor = None
+ if "feature_extractor" in kwargs:
+ warnings.warn(
+ "The `feature_extractor` argument is deprecated and will be removed in v5, use `image_processor`"
+ " instead.",
+ FutureWarning,
+ )
+ feature_extractor = kwargs.pop("feature_extractor")
+
+ image_processor = image_processor if image_processor is not None else feature_extractor
+ if image_processor is None:
+ raise ValueError("You need to specify an `image_processor`.")
+ if tokenizer is None:
+ raise ValueError("You need to specify a `tokenizer`.")
+
+ super().__init__(image_processor, tokenizer)
+ self.current_processor = self.image_processor
+ self._in_target_context_manager = False
+
+ def __call__(
+ self,
+ images: ImageInput = None,
+ text: Optional[Union[str, List[str], TextInput, PreTokenizedInput]] = None,
+ audio=None,
+ videos=None,
+ **kwargs: Unpack[DonutProcessorKwargs],
+ ):
+ """
+ When used in normal mode, this method forwards all its arguments to AutoImageProcessor's
+ [`~AutoImageProcessor.__call__`] and returns its output. If used in the context
+ [`~DonutProcessor.as_target_processor`] this method forwards all its arguments to DonutTokenizer's
+ [`~DonutTokenizer.__call__`]. Please refer to the docstring of the above two methods for more information.
+ """
+ if self._in_target_context_manager:
+ return self.current_processor(images, text, **kwargs)
+
+ if images is None and text is None:
+ raise ValueError("You need to specify either an `images` or `text` input to process.")
+
+ output_kwargs = self._merge_kwargs(
+ DonutProcessorKwargs,
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
+ **kwargs,
+ )
+
+ if images is not None:
+ inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
+ if text is not None:
+ if images is not None:
+ output_kwargs["text_kwargs"].setdefault("add_special_tokens", False)
+ encodings = self.tokenizer(text, **output_kwargs["text_kwargs"])
+
+ if text is None:
+ return inputs
+ elif images is None:
+ return encodings
+ else:
+ inputs["labels"] = encodings["input_ids"] # for BC
+ inputs["input_ids"] = encodings["input_ids"]
+ return inputs
+
+ def batch_decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to DonutTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please refer
+ to the docstring of this method for more information.
+ """
+ return self.tokenizer.batch_decode(*args, **kwargs)
+
+ def decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to DonutTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to the
+ docstring of this method for more information.
+ """
+ return self.tokenizer.decode(*args, **kwargs)
+
+ @contextmanager
+ def as_target_processor(self):
+ """
+ Temporarily sets the tokenizer for processing the input. Useful for encoding the labels when fine-tuning TrOCR.
+ """
+ warnings.warn(
+ "`as_target_processor` is deprecated and will be removed in v5 of Transformers. You can process your "
+ "labels by using the argument `text` of the regular `__call__` method (either in the same call as "
+ "your images inputs, or in a separate call."
+ )
+ self._in_target_context_manager = True
+ self.current_processor = self.tokenizer
+ yield
+ self.current_processor = self.image_processor
+ self._in_target_context_manager = False
+
+ def token2json(self, tokens, is_inner_value=False, added_vocab=None):
+ """
+ Convert a (generated) token sequence into an ordered JSON format.
+ """
+ if added_vocab is None:
+ added_vocab = self.tokenizer.get_added_vocab()
+
+ output = {}
+
+ while tokens:
+ start_token = re.search(r"", tokens, re.IGNORECASE)
+ if start_token is None:
+ break
+ key = start_token.group(1)
+ key_escaped = re.escape(key)
+
+ end_token = re.search(rf"", tokens, re.IGNORECASE)
+ start_token = start_token.group()
+ if end_token is None:
+ tokens = tokens.replace(start_token, "")
+ else:
+ end_token = end_token.group()
+ start_token_escaped = re.escape(start_token)
+ end_token_escaped = re.escape(end_token)
+ content = re.search(
+ f"{start_token_escaped}(.*?){end_token_escaped}", tokens, re.IGNORECASE | re.DOTALL
+ )
+ if content is not None:
+ content = content.group(1).strip()
+ if r""):
+ leaf = leaf.strip()
+ if leaf in added_vocab and leaf[0] == "<" and leaf[-2:] == "/>":
+ leaf = leaf[1:-2] # for categorical special tokens
+ output[key].append(leaf)
+ if len(output[key]) == 1:
+ output[key] = output[key][0]
+
+ tokens = tokens[tokens.find(end_token) + len(end_token) :].strip()
+ if tokens[:6] == r"": # non-leaf nodes
+ return [output] + self.token2json(tokens[6:], is_inner_value=True, added_vocab=added_vocab)
+
+ if len(output):
+ return [output] if is_inner_value else output
+ else:
+ return [] if is_inner_value else {"text_sequence": tokens}
+
+ @property
+ def feature_extractor_class(self):
+ warnings.warn(
+ "`feature_extractor_class` is deprecated and will be removed in v5. Use `image_processor_class` instead.",
+ FutureWarning,
+ )
+ return self.image_processor_class
+
+ @property
+ def feature_extractor(self):
+ warnings.warn(
+ "`feature_extractor` is deprecated and will be removed in v5. Use `image_processor` instead.",
+ FutureWarning,
+ )
+ return self.image_processor
+
+
+__all__ = ["DonutProcessor"]
diff --git a/docs/transformers/build/lib/transformers/models/dpr/__init__.py b/docs/transformers/build/lib/transformers/models/dpr/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9aeadbeaf416575570c280a3e15a52422a007103
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/dpr/__init__.py
@@ -0,0 +1,30 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_dpr import *
+ from .modeling_dpr import *
+ from .modeling_tf_dpr import *
+ from .tokenization_dpr import *
+ from .tokenization_dpr_fast import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/docs/transformers/build/lib/transformers/models/dpr/configuration_dpr.py b/docs/transformers/build/lib/transformers/models/dpr/configuration_dpr.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e4b97c97a4f7f1acedb0f2e23be4fb5dd770a99
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/dpr/configuration_dpr.py
@@ -0,0 +1,131 @@
+# coding=utf-8
+# Copyright 2010, DPR authors, The Hugging Face Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""DPR model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class DPRConfig(PretrainedConfig):
+ r"""
+ [`DPRConfig`] is the configuration class to store the configuration of a *DPRModel*.
+
+ This is the configuration class to store the configuration of a [`DPRContextEncoder`], [`DPRQuestionEncoder`], or a
+ [`DPRReader`]. It is used to instantiate the components of the DPR model according to the specified arguments,
+ defining the model component architectures. Instantiating a configuration with the defaults will yield a similar
+ configuration to that of the DPRContextEncoder
+ [facebook/dpr-ctx_encoder-single-nq-base](https://huggingface.co/facebook/dpr-ctx_encoder-single-nq-base)
+ architecture.
+
+ This class is a subclass of [`BertConfig`]. Please check the superclass for the documentation of all kwargs.
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 30522):
+ Vocabulary size of the DPR model. Defines the different tokens that can be represented by the *inputs_ids*
+ passed to the forward method of [`BertModel`].
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ intermediate_size (`int`, *optional*, defaults to 3072):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for the attention probabilities.
+ max_position_embeddings (`int`, *optional*, defaults to 512):
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
+ just in case (e.g., 512 or 1024 or 2048).
+ type_vocab_size (`int`, *optional*, defaults to 2):
+ The vocabulary size of the *token_type_ids* passed into [`BertModel`].
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
+ The epsilon used by the layer normalization layers.
+ pad_token_id (`int`, *optional*, defaults to 0):
+ Padding token id.
+ position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
+ Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
+ positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
+ [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155).
+ For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
+ with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658).
+ projection_dim (`int`, *optional*, defaults to 0):
+ Dimension of the projection for the context and question encoders. If it is set to zero (default), then no
+ projection is done.
+
+ Example:
+
+ ```python
+ >>> from transformers import DPRConfig, DPRContextEncoder
+
+ >>> # Initializing a DPR facebook/dpr-ctx_encoder-single-nq-base style configuration
+ >>> configuration = DPRConfig()
+
+ >>> # Initializing a model (with random weights) from the facebook/dpr-ctx_encoder-single-nq-base style configuration
+ >>> model = DPRContextEncoder(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "dpr"
+
+ def __init__(
+ self,
+ vocab_size=30522,
+ hidden_size=768,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ intermediate_size=3072,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ max_position_embeddings=512,
+ type_vocab_size=2,
+ initializer_range=0.02,
+ layer_norm_eps=1e-12,
+ pad_token_id=0,
+ position_embedding_type="absolute",
+ projection_dim: int = 0,
+ **kwargs,
+ ):
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
+
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.hidden_act = hidden_act
+ self.intermediate_size = intermediate_size
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.max_position_embeddings = max_position_embeddings
+ self.type_vocab_size = type_vocab_size
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.projection_dim = projection_dim
+ self.position_embedding_type = position_embedding_type
+
+
+__all__ = ["DPRConfig"]
diff --git a/docs/transformers/build/lib/transformers/models/dpr/convert_dpr_original_checkpoint_to_pytorch.py b/docs/transformers/build/lib/transformers/models/dpr/convert_dpr_original_checkpoint_to_pytorch.py
new file mode 100644
index 0000000000000000000000000000000000000000..5151c0972a7ed72c47d125400b918aba3a0d3c0d
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/dpr/convert_dpr_original_checkpoint_to_pytorch.py
@@ -0,0 +1,145 @@
+# Copyright 2020 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import collections
+from pathlib import Path
+
+import torch
+from torch.serialization import default_restore_location
+
+from transformers import BertConfig, DPRConfig, DPRContextEncoder, DPRQuestionEncoder, DPRReader
+
+
+CheckpointState = collections.namedtuple(
+ "CheckpointState", ["model_dict", "optimizer_dict", "scheduler_dict", "offset", "epoch", "encoder_params"]
+)
+
+
+def load_states_from_checkpoint(model_file: str) -> CheckpointState:
+ print(f"Reading saved model from {model_file}")
+ state_dict = torch.load(
+ model_file, map_location=lambda s, l: default_restore_location(s, "cpu"), weights_only=True
+ )
+ return CheckpointState(**state_dict)
+
+
+class DPRState:
+ def __init__(self, src_file: Path):
+ self.src_file = src_file
+
+ def load_dpr_model(self):
+ raise NotImplementedError
+
+ @staticmethod
+ def from_type(comp_type: str, *args, **kwargs) -> "DPRState":
+ if comp_type.startswith("c"):
+ return DPRContextEncoderState(*args, **kwargs)
+ if comp_type.startswith("q"):
+ return DPRQuestionEncoderState(*args, **kwargs)
+ if comp_type.startswith("r"):
+ return DPRReaderState(*args, **kwargs)
+ else:
+ raise ValueError("Component type must be either 'ctx_encoder', 'question_encoder' or 'reader'.")
+
+
+class DPRContextEncoderState(DPRState):
+ def load_dpr_model(self):
+ model = DPRContextEncoder(DPRConfig(**BertConfig.get_config_dict("google-bert/bert-base-uncased")[0]))
+ print(f"Loading DPR biencoder from {self.src_file}")
+ saved_state = load_states_from_checkpoint(self.src_file)
+ encoder, prefix = model.ctx_encoder, "ctx_model."
+ # Fix changes from https://github.com/huggingface/transformers/commit/614fef1691edb806de976756d4948ecbcd0c0ca3
+ state_dict = {"bert_model.embeddings.position_ids": model.ctx_encoder.bert_model.embeddings.position_ids}
+ for key, value in saved_state.model_dict.items():
+ if key.startswith(prefix):
+ key = key[len(prefix) :]
+ if not key.startswith("encode_proj."):
+ key = "bert_model." + key
+ state_dict[key] = value
+ encoder.load_state_dict(state_dict)
+ return model
+
+
+class DPRQuestionEncoderState(DPRState):
+ def load_dpr_model(self):
+ model = DPRQuestionEncoder(DPRConfig(**BertConfig.get_config_dict("google-bert/bert-base-uncased")[0]))
+ print(f"Loading DPR biencoder from {self.src_file}")
+ saved_state = load_states_from_checkpoint(self.src_file)
+ encoder, prefix = model.question_encoder, "question_model."
+ # Fix changes from https://github.com/huggingface/transformers/commit/614fef1691edb806de976756d4948ecbcd0c0ca3
+ state_dict = {"bert_model.embeddings.position_ids": model.question_encoder.bert_model.embeddings.position_ids}
+ for key, value in saved_state.model_dict.items():
+ if key.startswith(prefix):
+ key = key[len(prefix) :]
+ if not key.startswith("encode_proj."):
+ key = "bert_model." + key
+ state_dict[key] = value
+ encoder.load_state_dict(state_dict)
+ return model
+
+
+class DPRReaderState(DPRState):
+ def load_dpr_model(self):
+ model = DPRReader(DPRConfig(**BertConfig.get_config_dict("google-bert/bert-base-uncased")[0]))
+ print(f"Loading DPR reader from {self.src_file}")
+ saved_state = load_states_from_checkpoint(self.src_file)
+ # Fix changes from https://github.com/huggingface/transformers/commit/614fef1691edb806de976756d4948ecbcd0c0ca3
+ state_dict = {
+ "encoder.bert_model.embeddings.position_ids": model.span_predictor.encoder.bert_model.embeddings.position_ids
+ }
+ for key, value in saved_state.model_dict.items():
+ if key.startswith("encoder.") and not key.startswith("encoder.encode_proj"):
+ key = "encoder.bert_model." + key[len("encoder.") :]
+ state_dict[key] = value
+ model.span_predictor.load_state_dict(state_dict)
+ return model
+
+
+def convert(comp_type: str, src_file: Path, dest_dir: Path):
+ dest_dir = Path(dest_dir)
+ dest_dir.mkdir(exist_ok=True)
+
+ dpr_state = DPRState.from_type(comp_type, src_file=src_file)
+ model = dpr_state.load_dpr_model()
+ model.save_pretrained(dest_dir)
+ model.from_pretrained(dest_dir) # sanity check
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ # Required parameters
+ parser.add_argument(
+ "--type", type=str, help="Type of the component to convert: 'ctx_encoder', 'question_encoder' or 'reader'."
+ )
+ parser.add_argument(
+ "--src",
+ type=str,
+ help=(
+ "Path to the dpr checkpoint file. They can be downloaded from the official DPR repo"
+ " https://github.com/facebookresearch/DPR. Note that in the official repo, both encoders are stored in the"
+ " 'retriever' checkpoints."
+ ),
+ )
+ parser.add_argument("--dest", type=str, default=None, help="Path to the output PyTorch model directory.")
+ args = parser.parse_args()
+
+ src_file = Path(args.src)
+ dest_dir = f"converted-{src_file.name}" if args.dest is None else args.dest
+ dest_dir = Path(dest_dir)
+ assert src_file.exists()
+ assert args.type is not None, (
+ "Please specify the component type of the DPR model to convert: 'ctx_encoder', 'question_encoder' or 'reader'."
+ )
+ convert(args.type, src_file, dest_dir)
diff --git a/docs/transformers/build/lib/transformers/models/dpr/modeling_dpr.py b/docs/transformers/build/lib/transformers/models/dpr/modeling_dpr.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ff4aa11523e730dee6dc014aef58479f0c2775a
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/dpr/modeling_dpr.py
@@ -0,0 +1,668 @@
+# coding=utf-8
+# Copyright 2018 DPR Authors, The Hugging Face Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch DPR model for Open Domain Question Answering."""
+
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import torch
+from torch import Tensor, nn
+
+from ...modeling_outputs import BaseModelOutputWithPooling
+from ...modeling_utils import PreTrainedModel
+from ...utils import (
+ ModelOutput,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from ..bert.modeling_bert import BertModel
+from .configuration_dpr import DPRConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "DPRConfig"
+_CHECKPOINT_FOR_DOC = "facebook/dpr-ctx_encoder-single-nq-base"
+
+
+##########
+# Outputs
+##########
+
+
+@dataclass
+class DPRContextEncoderOutput(ModelOutput):
+ """
+ Class for outputs of [`DPRQuestionEncoder`].
+
+ Args:
+ pooler_output (`torch.FloatTensor` of shape `(batch_size, embeddings_size)`):
+ The DPR encoder outputs the *pooler_output* that corresponds to the context representation. Last layer
+ hidden-state of the first token of the sequence (classification token) further processed by a Linear layer.
+ This output is to be used to embed contexts for nearest neighbors queries with questions embeddings.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
+ shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ pooler_output: torch.FloatTensor
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class DPRQuestionEncoderOutput(ModelOutput):
+ """
+ Class for outputs of [`DPRQuestionEncoder`].
+
+ Args:
+ pooler_output (`torch.FloatTensor` of shape `(batch_size, embeddings_size)`):
+ The DPR encoder outputs the *pooler_output* that corresponds to the question representation. Last layer
+ hidden-state of the first token of the sequence (classification token) further processed by a Linear layer.
+ This output is to be used to embed questions for nearest neighbors queries with context embeddings.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
+ shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ pooler_output: torch.FloatTensor
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class DPRReaderOutput(ModelOutput):
+ """
+ Class for outputs of [`DPRQuestionEncoder`].
+
+ Args:
+ start_logits (`torch.FloatTensor` of shape `(n_passages, sequence_length)`):
+ Logits of the start index of the span for each passage.
+ end_logits (`torch.FloatTensor` of shape `(n_passages, sequence_length)`):
+ Logits of the end index of the span for each passage.
+ relevance_logits (`torch.FloatTensor` of shape `(n_passages, )`):
+ Outputs of the QA classifier of the DPRReader that corresponds to the scores of each passage to answer the
+ question, compared to all the other passages.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
+ shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ start_logits: torch.FloatTensor
+ end_logits: Optional[torch.FloatTensor] = None
+ relevance_logits: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+
+
+class DPRPreTrainedModel(PreTrainedModel):
+ _supports_sdpa = True
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, nn.Linear):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+
+class DPREncoder(DPRPreTrainedModel):
+ base_model_prefix = "bert_model"
+
+ def __init__(self, config: DPRConfig):
+ super().__init__(config)
+ self.bert_model = BertModel(config, add_pooling_layer=False)
+ if self.bert_model.config.hidden_size <= 0:
+ raise ValueError("Encoder hidden_size can't be zero")
+ self.projection_dim = config.projection_dim
+ if self.projection_dim > 0:
+ self.encode_proj = nn.Linear(self.bert_model.config.hidden_size, config.projection_dim)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def forward(
+ self,
+ input_ids: Tensor,
+ attention_mask: Optional[Tensor] = None,
+ token_type_ids: Optional[Tensor] = None,
+ inputs_embeds: Optional[Tensor] = None,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = False,
+ ) -> Union[BaseModelOutputWithPooling, Tuple[Tensor, ...]]:
+ outputs = self.bert_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = outputs[0]
+ pooled_output = sequence_output[:, 0, :]
+
+ if self.projection_dim > 0:
+ pooled_output = self.encode_proj(pooled_output)
+
+ if not return_dict:
+ return (sequence_output, pooled_output) + outputs[2:]
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ @property
+ def embeddings_size(self) -> int:
+ if self.projection_dim > 0:
+ return self.encode_proj.out_features
+ return self.bert_model.config.hidden_size
+
+
+class DPRSpanPredictor(DPRPreTrainedModel):
+ base_model_prefix = "encoder"
+
+ def __init__(self, config: DPRConfig):
+ super().__init__(config)
+ self.encoder = DPREncoder(config)
+ self.qa_outputs = nn.Linear(self.encoder.embeddings_size, 2)
+ self.qa_classifier = nn.Linear(self.encoder.embeddings_size, 1)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def forward(
+ self,
+ input_ids: Tensor,
+ attention_mask: Tensor,
+ inputs_embeds: Optional[Tensor] = None,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = False,
+ ) -> Union[DPRReaderOutput, Tuple[Tensor, ...]]:
+ # notations: N - number of questions in a batch, M - number of passages per questions, L - sequence length
+ n_passages, sequence_length = input_ids.size() if input_ids is not None else inputs_embeds.size()[:2]
+ # feed encoder
+ outputs = self.encoder(
+ input_ids,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = outputs[0]
+
+ # compute logits
+ logits = self.qa_outputs(sequence_output)
+ start_logits, end_logits = logits.split(1, dim=-1)
+ start_logits = start_logits.squeeze(-1).contiguous()
+ end_logits = end_logits.squeeze(-1).contiguous()
+ relevance_logits = self.qa_classifier(sequence_output[:, 0, :])
+
+ # resize
+ start_logits = start_logits.view(n_passages, sequence_length)
+ end_logits = end_logits.view(n_passages, sequence_length)
+ relevance_logits = relevance_logits.view(n_passages)
+
+ if not return_dict:
+ return (start_logits, end_logits, relevance_logits) + outputs[2:]
+
+ return DPRReaderOutput(
+ start_logits=start_logits,
+ end_logits=end_logits,
+ relevance_logits=relevance_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+##################
+# PreTrainedModel
+##################
+
+
+class DPRPretrainedContextEncoder(DPRPreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = DPRConfig
+ load_tf_weights = None
+ base_model_prefix = "ctx_encoder"
+
+
+class DPRPretrainedQuestionEncoder(DPRPreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = DPRConfig
+ load_tf_weights = None
+ base_model_prefix = "question_encoder"
+
+
+class DPRPretrainedReader(DPRPreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = DPRConfig
+ load_tf_weights = None
+ base_model_prefix = "span_predictor"
+
+
+###############
+# Actual Models
+###############
+
+
+DPR_START_DOCSTRING = r"""
+
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`DPRConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+DPR_ENCODERS_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. To match pretraining, DPR input sequence should be
+ formatted with [CLS] and [SEP] tokens as follows:
+
+ (a) For sequence pairs (for a pair title+text for example):
+
+ ```
+ tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
+ token_type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
+ ```
+
+ (b) For single sequences (for a question for example):
+
+ ```
+ tokens: [CLS] the dog is hairy . [SEP]
+ token_type_ids: 0 0 0 0 0 0 0
+ ```
+
+ DPR is a model with absolute position embeddings so it's usually advised to pad the inputs on the right
+ rather than the left.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
+ 1]`:
+
+ - 0 corresponds to a *sentence A* token,
+ - 1 corresponds to a *sentence B* token.
+
+ [What are token type IDs?](../glossary#token-type-ids)
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+DPR_READER_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`Tuple[torch.LongTensor]` of shapes `(n_passages, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. It has to be a sequence triplet with 1) the question
+ and 2) the passages titles and 3) the passages texts To match pretraining, DPR `input_ids` sequence should
+ be formatted with [CLS] and [SEP] with the format:
+
+ `[CLS] [SEP] [SEP] `
+
+ DPR is a model with absolute position embeddings so it's usually advised to pad the inputs on the right
+ rather than the left.
+
+ Indices can be obtained using [`DPRReaderTokenizer`]. See this class documentation for more details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.FloatTensor` of shape `(n_passages, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ inputs_embeds (`torch.FloatTensor` of shape `(n_passages, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare DPRContextEncoder transformer outputting pooler outputs as context representations.",
+ DPR_START_DOCSTRING,
+)
+class DPRContextEncoder(DPRPretrainedContextEncoder):
+ def __init__(self, config: DPRConfig):
+ super().__init__(config)
+ self.config = config
+ self.ctx_encoder = DPREncoder(config)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(DPR_ENCODERS_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=DPRContextEncoderOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: Optional[Tensor] = None,
+ attention_mask: Optional[Tensor] = None,
+ token_type_ids: Optional[Tensor] = None,
+ inputs_embeds: Optional[Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[DPRContextEncoderOutput, Tuple[Tensor, ...]]:
+ r"""
+ Return:
+
+ Examples:
+
+ ```python
+ >>> from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
+
+ >>> tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
+ >>> model = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
+ >>> input_ids = tokenizer("Hello, is my dog cute ?", return_tensors="pt")["input_ids"]
+ >>> embeddings = model(input_ids).pooler_output
+ ```"""
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+
+ if attention_mask is None:
+ attention_mask = (
+ torch.ones(input_shape, device=device)
+ if input_ids is None
+ else (input_ids != self.config.pad_token_id)
+ )
+ if token_type_ids is None:
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
+
+ outputs = self.ctx_encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ if not return_dict:
+ return outputs[1:]
+ return DPRContextEncoderOutput(
+ pooler_output=outputs.pooler_output, hidden_states=outputs.hidden_states, attentions=outputs.attentions
+ )
+
+
+@add_start_docstrings(
+ "The bare DPRQuestionEncoder transformer outputting pooler outputs as question representations.",
+ DPR_START_DOCSTRING,
+)
+class DPRQuestionEncoder(DPRPretrainedQuestionEncoder):
+ def __init__(self, config: DPRConfig):
+ super().__init__(config)
+ self.config = config
+ self.question_encoder = DPREncoder(config)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(DPR_ENCODERS_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=DPRQuestionEncoderOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: Optional[Tensor] = None,
+ attention_mask: Optional[Tensor] = None,
+ token_type_ids: Optional[Tensor] = None,
+ inputs_embeds: Optional[Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[DPRQuestionEncoderOutput, Tuple[Tensor, ...]]:
+ r"""
+ Return:
+
+ Examples:
+
+ ```python
+ >>> from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer
+
+ >>> tokenizer = DPRQuestionEncoderTokenizer.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
+ >>> model = DPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
+ >>> input_ids = tokenizer("Hello, is my dog cute ?", return_tensors="pt")["input_ids"]
+ >>> embeddings = model(input_ids).pooler_output
+ ```
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
+ input_shape = input_ids.size()
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+
+ if attention_mask is None:
+ attention_mask = (
+ torch.ones(input_shape, device=device)
+ if input_ids is None
+ else (input_ids != self.config.pad_token_id)
+ )
+ if token_type_ids is None:
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
+
+ outputs = self.question_encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ if not return_dict:
+ return outputs[1:]
+ return DPRQuestionEncoderOutput(
+ pooler_output=outputs.pooler_output, hidden_states=outputs.hidden_states, attentions=outputs.attentions
+ )
+
+
+@add_start_docstrings(
+ "The bare DPRReader transformer outputting span predictions.",
+ DPR_START_DOCSTRING,
+)
+class DPRReader(DPRPretrainedReader):
+ def __init__(self, config: DPRConfig):
+ super().__init__(config)
+ self.config = config
+ self.span_predictor = DPRSpanPredictor(config)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(DPR_READER_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=DPRReaderOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: Optional[Tensor] = None,
+ attention_mask: Optional[Tensor] = None,
+ inputs_embeds: Optional[Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[DPRReaderOutput, Tuple[Tensor, ...]]:
+ r"""
+ Return:
+
+ Examples:
+
+ ```python
+ >>> from transformers import DPRReader, DPRReaderTokenizer
+
+ >>> tokenizer = DPRReaderTokenizer.from_pretrained("facebook/dpr-reader-single-nq-base")
+ >>> model = DPRReader.from_pretrained("facebook/dpr-reader-single-nq-base")
+ >>> encoded_inputs = tokenizer(
+ ... questions=["What is love ?"],
+ ... titles=["Haddaway"],
+ ... texts=["'What Is Love' is a song recorded by the artist Haddaway"],
+ ... return_tensors="pt",
+ ... )
+ >>> outputs = model(**encoded_inputs)
+ >>> start_logits = outputs.start_logits
+ >>> end_logits = outputs.end_logits
+ >>> relevance_logits = outputs.relevance_logits
+ ```
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
+ input_shape = input_ids.size()
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+
+ if attention_mask is None:
+ attention_mask = torch.ones(input_shape, device=device)
+
+ return self.span_predictor(
+ input_ids,
+ attention_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+
+__all__ = [
+ "DPRContextEncoder",
+ "DPRPretrainedContextEncoder",
+ "DPRPreTrainedModel",
+ "DPRPretrainedQuestionEncoder",
+ "DPRPretrainedReader",
+ "DPRQuestionEncoder",
+ "DPRReader",
+]
diff --git a/docs/transformers/build/lib/transformers/models/dpr/modeling_tf_dpr.py b/docs/transformers/build/lib/transformers/models/dpr/modeling_tf_dpr.py
new file mode 100644
index 0000000000000000000000000000000000000000..303b03ec244d61c8698d8115cf8237cdbe1f960b
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/dpr/modeling_tf_dpr.py
@@ -0,0 +1,800 @@
+# coding=utf-8
+# Copyright 2018 DPR Authors, The Hugging Face Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""TensorFlow DPR model for Open Domain Question Answering."""
+
+from __future__ import annotations
+
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import tensorflow as tf
+
+from ...modeling_tf_outputs import TFBaseModelOutputWithPooling
+from ...modeling_tf_utils import TFModelInputType, TFPreTrainedModel, get_initializer, keras, shape_list, unpack_inputs
+from ...utils import (
+ ModelOutput,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from ..bert.modeling_tf_bert import TFBertMainLayer
+from .configuration_dpr import DPRConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "DPRConfig"
+
+
+##########
+# Outputs
+##########
+
+
+@dataclass
+class TFDPRContextEncoderOutput(ModelOutput):
+ r"""
+ Class for outputs of [`TFDPRContextEncoder`].
+
+ Args:
+ pooler_output (`tf.Tensor` of shape `(batch_size, embeddings_size)`):
+ The DPR encoder outputs the *pooler_output* that corresponds to the context representation. Last layer
+ hidden-state of the first token of the sequence (classification token) further processed by a Linear layer.
+ This output is to be used to embed contexts for nearest neighbors queries with questions embeddings.
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ pooler_output: Optional[tf.Tensor] = None
+ hidden_states: Tuple[tf.Tensor, ...] | None = None
+ attentions: Tuple[tf.Tensor, ...] | None = None
+
+
+@dataclass
+class TFDPRQuestionEncoderOutput(ModelOutput):
+ """
+ Class for outputs of [`TFDPRQuestionEncoder`].
+
+ Args:
+ pooler_output (`tf.Tensor` of shape `(batch_size, embeddings_size)`):
+ The DPR encoder outputs the *pooler_output* that corresponds to the question representation. Last layer
+ hidden-state of the first token of the sequence (classification token) further processed by a Linear layer.
+ This output is to be used to embed questions for nearest neighbors queries with context embeddings.
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ pooler_output: Optional[tf.Tensor] = None
+ hidden_states: Tuple[tf.Tensor, ...] | None = None
+ attentions: Tuple[tf.Tensor, ...] | None = None
+
+
+@dataclass
+class TFDPRReaderOutput(ModelOutput):
+ """
+ Class for outputs of [`TFDPRReaderEncoder`].
+
+ Args:
+ start_logits (`tf.Tensor` of shape `(n_passages, sequence_length)`):
+ Logits of the start index of the span for each passage.
+ end_logits (`tf.Tensor` of shape `(n_passages, sequence_length)`):
+ Logits of the end index of the span for each passage.
+ relevance_logits (`tf.Tensor` of shape `(n_passages, )`):
+ Outputs of the QA classifier of the DPRReader that corresponds to the scores of each passage to answer the
+ question, compared to all the other passages.
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ start_logits: Optional[tf.Tensor] = None
+ end_logits: Optional[tf.Tensor] = None
+ relevance_logits: Optional[tf.Tensor] = None
+ hidden_states: Tuple[tf.Tensor, ...] | None = None
+ attentions: Tuple[tf.Tensor, ...] | None = None
+
+
+class TFDPREncoderLayer(keras.layers.Layer):
+ base_model_prefix = "bert_model"
+
+ def __init__(self, config: DPRConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ # resolve name conflict with TFBertMainLayer instead of TFBertModel
+ self.bert_model = TFBertMainLayer(config, add_pooling_layer=False, name="bert_model")
+ self.config = config
+
+ if self.config.hidden_size <= 0:
+ raise ValueError("Encoder hidden_size can't be zero")
+ self.projection_dim = config.projection_dim
+ if self.projection_dim > 0:
+ self.encode_proj = keras.layers.Dense(
+ config.projection_dim, kernel_initializer=get_initializer(config.initializer_range), name="encode_proj"
+ )
+
+ @unpack_inputs
+ def call(
+ self,
+ input_ids: Optional[tf.Tensor] = None,
+ attention_mask: tf.Tensor | None = None,
+ token_type_ids: tf.Tensor | None = None,
+ inputs_embeds: tf.Tensor | None = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ training: bool = False,
+ ) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor, ...]]:
+ outputs = self.bert_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ sequence_output = outputs[0]
+ pooled_output = sequence_output[:, 0, :]
+ if self.projection_dim > 0:
+ pooled_output = self.encode_proj(pooled_output)
+
+ if not return_dict:
+ return (sequence_output, pooled_output) + outputs[1:]
+
+ return TFBaseModelOutputWithPooling(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ @property
+ def embeddings_size(self) -> int:
+ if self.projection_dim > 0:
+ return self.projection_dim
+ return self.bert_model.config.hidden_size
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "bert_model", None) is not None:
+ with tf.name_scope(self.bert_model.name):
+ self.bert_model.build(None)
+ if getattr(self, "encode_proj", None) is not None:
+ with tf.name_scope(self.encode_proj.name):
+ self.encode_proj.build(None)
+
+
+class TFDPRSpanPredictorLayer(keras.layers.Layer):
+ base_model_prefix = "encoder"
+
+ def __init__(self, config: DPRConfig, **kwargs):
+ super().__init__(**kwargs)
+ self.config = config
+ self.encoder = TFDPREncoderLayer(config, name="encoder")
+
+ self.qa_outputs = keras.layers.Dense(
+ 2, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
+ )
+ self.qa_classifier = keras.layers.Dense(
+ 1, kernel_initializer=get_initializer(config.initializer_range), name="qa_classifier"
+ )
+
+ @unpack_inputs
+ def call(
+ self,
+ input_ids: Optional[tf.Tensor] = None,
+ attention_mask: tf.Tensor | None = None,
+ inputs_embeds: tf.Tensor | None = None,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = False,
+ training: bool = False,
+ ) -> Union[TFDPRReaderOutput, Tuple[tf.Tensor, ...]]:
+ # notations: N - number of questions in a batch, M - number of passages per questions, L - sequence length
+ n_passages, sequence_length = shape_list(input_ids) if input_ids is not None else shape_list(inputs_embeds)[:2]
+ # feed encoder
+ outputs = self.encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ sequence_output = outputs[0]
+
+ # compute logits
+ logits = self.qa_outputs(sequence_output)
+ start_logits, end_logits = tf.split(logits, 2, axis=-1)
+ start_logits = tf.squeeze(start_logits, axis=-1)
+ end_logits = tf.squeeze(end_logits, axis=-1)
+ relevance_logits = self.qa_classifier(sequence_output[:, 0, :])
+
+ # resize
+ start_logits = tf.reshape(start_logits, [n_passages, sequence_length])
+ end_logits = tf.reshape(end_logits, [n_passages, sequence_length])
+ relevance_logits = tf.reshape(relevance_logits, [n_passages])
+
+ if not return_dict:
+ return (start_logits, end_logits, relevance_logits) + outputs[2:]
+
+ return TFDPRReaderOutput(
+ start_logits=start_logits,
+ end_logits=end_logits,
+ relevance_logits=relevance_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "encoder", None) is not None:
+ with tf.name_scope(self.encoder.name):
+ self.encoder.build(None)
+ if getattr(self, "qa_outputs", None) is not None:
+ with tf.name_scope(self.qa_outputs.name):
+ self.qa_outputs.build([None, None, self.encoder.embeddings_size])
+ if getattr(self, "qa_classifier", None) is not None:
+ with tf.name_scope(self.qa_classifier.name):
+ self.qa_classifier.build([None, None, self.encoder.embeddings_size])
+
+
+class TFDPRSpanPredictor(TFPreTrainedModel):
+ base_model_prefix = "encoder"
+
+ def __init__(self, config: DPRConfig, **kwargs):
+ super().__init__(config, **kwargs)
+ self.encoder = TFDPRSpanPredictorLayer(config)
+
+ @unpack_inputs
+ def call(
+ self,
+ input_ids: Optional[tf.Tensor] = None,
+ attention_mask: tf.Tensor | None = None,
+ token_type_ids: tf.Tensor | None = None,
+ inputs_embeds: tf.Tensor | None = None,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = False,
+ training: bool = False,
+ ) -> Union[TFDPRReaderOutput, Tuple[tf.Tensor, ...]]:
+ outputs = self.encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ return outputs
+
+
+class TFDPREncoder(TFPreTrainedModel):
+ base_model_prefix = "encoder"
+
+ def __init__(self, config: DPRConfig, **kwargs):
+ super().__init__(config, **kwargs)
+
+ self.encoder = TFDPREncoderLayer(config)
+
+ @unpack_inputs
+ def call(
+ self,
+ input_ids: Optional[tf.Tensor] = None,
+ attention_mask: tf.Tensor | None = None,
+ token_type_ids: tf.Tensor | None = None,
+ inputs_embeds: tf.Tensor | None = None,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = False,
+ training: bool = False,
+ ) -> Union[TFDPRReaderOutput, Tuple[tf.Tensor, ...]]:
+ outputs = self.encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ return outputs
+
+
+##################
+# PreTrainedModel
+##################
+
+
+class TFDPRPretrainedContextEncoder(TFPreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = DPRConfig
+ base_model_prefix = "ctx_encoder"
+
+
+class TFDPRPretrainedQuestionEncoder(TFPreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = DPRConfig
+ base_model_prefix = "question_encoder"
+
+
+class TFDPRPretrainedReader(TFPreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = DPRConfig
+ base_model_prefix = "reader"
+
+
+###############
+# Actual Models
+###############
+
+
+TF_DPR_START_DOCSTRING = r"""
+
+ This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a Tensorflow [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model)
+ subclass. Use it as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to
+ general usage and behavior.
+
+
+
+ TensorFlow models and layers in `transformers` accept two formats as input:
+
+ - having all inputs as keyword arguments (like PyTorch models), or
+ - having all inputs as a list, tuple or dict in the first positional argument.
+
+ The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
+ and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
+ pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
+ format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
+ the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
+ positional argument:
+
+ - a single Tensor with `input_ids` only and nothing else: `model(input_ids)`
+ - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
+ `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
+ - a dictionary with one or several input Tensors associated to the input names given in the docstring:
+ `model({"input_ids": input_ids, "token_type_ids": token_type_ids})`
+
+ Note that when creating models and layers with
+ [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
+ about any of this, as you can just pass inputs like you would to any other Python function!
+
+
+
+ Parameters:
+ config ([`DPRConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+TF_DPR_ENCODERS_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`Numpy array` or `tf.Tensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. To match pretraining, DPR input sequence should be
+ formatted with [CLS] and [SEP] tokens as follows:
+
+ (a) For sequence pairs (for a pair title+text for example):
+
+ ```
+ tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
+ token_type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
+ ```
+
+ (b) For single sequences (for a question for example):
+
+ ```
+ tokens: [CLS] the dog is hairy . [SEP]
+ token_type_ids: 0 0 0 0 0 0 0
+ ```
+
+ DPR is a model with absolute position embeddings so it's usually advised to pad the inputs on the right
+ rather than the left.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`Numpy array` or `tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ token_type_ids (`Numpy array` or `tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
+ 1]`:
+
+ - 0 corresponds to a *sentence A* token,
+ - 1 corresponds to a *sentence B* token.
+
+ [What are token type IDs?](../glossary#token-type-ids)
+ inputs_embeds (`Numpy array` or `tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
+ config will be used instead.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
+ used instead.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
+ eager mode, in graph mode the value will always be set to True.
+ training (`bool`, *optional*, defaults to `False`):
+ Whether or not to use the model in training mode (some modules like dropout modules have different
+ behaviors between training and evaluation).
+"""
+
+TF_DPR_READER_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`Numpy array` or `tf.Tensor` of shapes `(n_passages, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. It has to be a sequence triplet with 1) the question
+ and 2) the passages titles and 3) the passages texts To match pretraining, DPR `input_ids` sequence should
+ be formatted with [CLS] and [SEP] with the format:
+
+ `[CLS] [SEP] [SEP] `
+
+ DPR is a model with absolute position embeddings so it's usually advised to pad the inputs on the right
+ rather than the left.
+
+ Indices can be obtained using [`DPRReaderTokenizer`]. See this class documentation for more details.
+ attention_mask (`Numpy array` or `tf.Tensor` of shape `(n_passages, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ inputs_embeds (`Numpy array` or `tf.Tensor` of shape `(n_passages, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
+ used instead.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
+ eager mode, in graph mode the value will always be set to True.
+ training (`bool`, *optional*, defaults to `False`):
+ Whether or not to use the model in training mode (some modules like dropout modules have different
+ behaviors between training and evaluation).
+"""
+
+
+@add_start_docstrings(
+ "The bare DPRContextEncoder transformer outputting pooler outputs as context representations.",
+ TF_DPR_START_DOCSTRING,
+)
+class TFDPRContextEncoder(TFDPRPretrainedContextEncoder):
+ def __init__(self, config: DPRConfig, *args, **kwargs):
+ super().__init__(config, *args, **kwargs)
+ self.ctx_encoder = TFDPREncoderLayer(config, name="ctx_encoder")
+
+ def get_input_embeddings(self):
+ try:
+ return self.ctx_encoder.bert_model.get_input_embeddings()
+ except AttributeError:
+ self.build()
+ return self.ctx_encoder.bert_model.get_input_embeddings()
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(TF_DPR_ENCODERS_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=TFDPRContextEncoderOutput, config_class=_CONFIG_FOR_DOC)
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: tf.Tensor | None = None,
+ token_type_ids: tf.Tensor | None = None,
+ inputs_embeds: tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ training: bool = False,
+ ) -> TFDPRContextEncoderOutput | Tuple[tf.Tensor, ...]:
+ r"""
+ Return:
+
+ Examples:
+
+ ```python
+ >>> from transformers import TFDPRContextEncoder, DPRContextEncoderTokenizer
+
+ >>> tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
+ >>> model = TFDPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base", from_pt=True)
+ >>> input_ids = tokenizer("Hello, is my dog cute ?", return_tensors="tf")["input_ids"]
+ >>> embeddings = model(input_ids).pooler_output
+ ```
+ """
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = shape_list(input_ids)
+ elif inputs_embeds is not None:
+ input_shape = shape_list(inputs_embeds)[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ if attention_mask is None:
+ attention_mask = (
+ tf.ones(input_shape, dtype=tf.dtypes.int32)
+ if input_ids is None
+ else (input_ids != self.config.pad_token_id)
+ )
+ if token_type_ids is None:
+ token_type_ids = tf.zeros(input_shape, dtype=tf.dtypes.int32)
+
+ outputs = self.ctx_encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ if not return_dict:
+ return outputs[1:]
+
+ return TFDPRContextEncoderOutput(
+ pooler_output=outputs.pooler_output, hidden_states=outputs.hidden_states, attentions=outputs.attentions
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "ctx_encoder", None) is not None:
+ with tf.name_scope(self.ctx_encoder.name):
+ self.ctx_encoder.build(None)
+
+
+@add_start_docstrings(
+ "The bare DPRQuestionEncoder transformer outputting pooler outputs as question representations.",
+ TF_DPR_START_DOCSTRING,
+)
+class TFDPRQuestionEncoder(TFDPRPretrainedQuestionEncoder):
+ def __init__(self, config: DPRConfig, *args, **kwargs):
+ super().__init__(config, *args, **kwargs)
+ self.question_encoder = TFDPREncoderLayer(config, name="question_encoder")
+
+ def get_input_embeddings(self):
+ try:
+ return self.question_encoder.bert_model.get_input_embeddings()
+ except AttributeError:
+ self.build()
+ return self.question_encoder.bert_model.get_input_embeddings()
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(TF_DPR_ENCODERS_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=TFDPRQuestionEncoderOutput, config_class=_CONFIG_FOR_DOC)
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: tf.Tensor | None = None,
+ token_type_ids: tf.Tensor | None = None,
+ inputs_embeds: tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ training: bool = False,
+ ) -> TFDPRQuestionEncoderOutput | Tuple[tf.Tensor, ...]:
+ r"""
+ Return:
+
+ Examples:
+
+ ```python
+ >>> from transformers import TFDPRQuestionEncoder, DPRQuestionEncoderTokenizer
+
+ >>> tokenizer = DPRQuestionEncoderTokenizer.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
+ >>> model = TFDPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-single-nq-base", from_pt=True)
+ >>> input_ids = tokenizer("Hello, is my dog cute ?", return_tensors="tf")["input_ids"]
+ >>> embeddings = model(input_ids).pooler_output
+ ```
+ """
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = shape_list(input_ids)
+ elif inputs_embeds is not None:
+ input_shape = shape_list(inputs_embeds)[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ if attention_mask is None:
+ attention_mask = (
+ tf.ones(input_shape, dtype=tf.dtypes.int32)
+ if input_ids is None
+ else (input_ids != self.config.pad_token_id)
+ )
+ if token_type_ids is None:
+ token_type_ids = tf.zeros(input_shape, dtype=tf.dtypes.int32)
+
+ outputs = self.question_encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ if not return_dict:
+ return outputs[1:]
+ return TFDPRQuestionEncoderOutput(
+ pooler_output=outputs.pooler_output, hidden_states=outputs.hidden_states, attentions=outputs.attentions
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "question_encoder", None) is not None:
+ with tf.name_scope(self.question_encoder.name):
+ self.question_encoder.build(None)
+
+
+@add_start_docstrings(
+ "The bare DPRReader transformer outputting span predictions.",
+ TF_DPR_START_DOCSTRING,
+)
+class TFDPRReader(TFDPRPretrainedReader):
+ def __init__(self, config: DPRConfig, *args, **kwargs):
+ super().__init__(config, *args, **kwargs)
+ self.span_predictor = TFDPRSpanPredictorLayer(config, name="span_predictor")
+
+ def get_input_embeddings(self):
+ try:
+ return self.span_predictor.encoder.bert_model.get_input_embeddings()
+ except AttributeError:
+ self.build()
+ return self.span_predictor.encoder.bert_model.get_input_embeddings()
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(TF_DPR_READER_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=TFDPRReaderOutput, config_class=_CONFIG_FOR_DOC)
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: tf.Tensor | None = None,
+ inputs_embeds: tf.Tensor | None = None,
+ output_attentions: bool | None = None,
+ output_hidden_states: bool | None = None,
+ return_dict: bool | None = None,
+ training: bool = False,
+ ) -> TFDPRReaderOutput | Tuple[tf.Tensor, ...]:
+ r"""
+ Return:
+
+ Examples:
+
+ ```python
+ >>> from transformers import TFDPRReader, DPRReaderTokenizer
+
+ >>> tokenizer = DPRReaderTokenizer.from_pretrained("facebook/dpr-reader-single-nq-base")
+ >>> model = TFDPRReader.from_pretrained("facebook/dpr-reader-single-nq-base", from_pt=True)
+ >>> encoded_inputs = tokenizer(
+ ... questions=["What is love ?"],
+ ... titles=["Haddaway"],
+ ... texts=["'What Is Love' is a song recorded by the artist Haddaway"],
+ ... return_tensors="tf",
+ ... )
+ >>> outputs = model(encoded_inputs)
+ >>> start_logits = outputs.start_logits
+ >>> end_logits = outputs.end_logits
+ >>> relevance_logits = outputs.relevance_logits
+ ```
+ """
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = shape_list(input_ids)
+ elif inputs_embeds is not None:
+ input_shape = shape_list(inputs_embeds)[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ if attention_mask is None:
+ attention_mask = tf.ones(input_shape, dtype=tf.dtypes.int32)
+
+ return self.span_predictor(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "span_predictor", None) is not None:
+ with tf.name_scope(self.span_predictor.name):
+ self.span_predictor.build(None)
+
+
+__all__ = [
+ "TFDPRContextEncoder",
+ "TFDPRPretrainedContextEncoder",
+ "TFDPRPretrainedQuestionEncoder",
+ "TFDPRPretrainedReader",
+ "TFDPRQuestionEncoder",
+ "TFDPRReader",
+]
diff --git a/docs/transformers/build/lib/transformers/models/dpr/tokenization_dpr.py b/docs/transformers/build/lib/transformers/models/dpr/tokenization_dpr.py
new file mode 100644
index 0000000000000000000000000000000000000000..00b8dedfa7e4b7d190ad4fa3b9f68597996b038f
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/dpr/tokenization_dpr.py
@@ -0,0 +1,321 @@
+# coding=utf-8
+# Copyright 2018 The HuggingFace Inc. team, The Hugging Face Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization classes for DPR."""
+
+import collections
+from typing import List, Optional, Union
+
+from ...tokenization_utils_base import BatchEncoding
+from ...utils import TensorType, add_end_docstrings, add_start_docstrings, logging
+from ..bert.tokenization_bert import BertTokenizer
+
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.json"}
+
+
+class DPRContextEncoderTokenizer(BertTokenizer):
+ r"""
+ Construct a DPRContextEncoder tokenizer.
+
+ [`DPRContextEncoderTokenizer`] is identical to [`BertTokenizer`] and runs end-to-end tokenization: punctuation
+ splitting and wordpiece.
+
+ Refer to superclass [`BertTokenizer`] for usage examples and documentation concerning parameters.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+
+
+class DPRQuestionEncoderTokenizer(BertTokenizer):
+ r"""
+ Constructs a DPRQuestionEncoder tokenizer.
+
+ [`DPRQuestionEncoderTokenizer`] is identical to [`BertTokenizer`] and runs end-to-end tokenization: punctuation
+ splitting and wordpiece.
+
+ Refer to superclass [`BertTokenizer`] for usage examples and documentation concerning parameters.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+
+
+DPRSpanPrediction = collections.namedtuple(
+ "DPRSpanPrediction", ["span_score", "relevance_score", "doc_id", "start_index", "end_index", "text"]
+)
+
+DPRReaderOutput = collections.namedtuple("DPRReaderOutput", ["start_logits", "end_logits", "relevance_logits"])
+
+
+CUSTOM_DPR_READER_DOCSTRING = r"""
+ Return a dictionary with the token ids of the input strings and other information to give to `.decode_best_spans`.
+ It converts the strings of a question and different passages (title and text) in a sequence of IDs (integers),
+ using the tokenizer and vocabulary. The resulting `input_ids` is a matrix of size `(n_passages, sequence_length)`
+ with the format:
+
+ ```
+ [CLS] [SEP] [SEP]
+ ```
+
+ Args:
+ questions (`str` or `List[str]`):
+ The questions to be encoded. You can specify one question for many passages. In this case, the question
+ will be duplicated like `[questions] * n_passages`. Otherwise you have to specify as many questions as in
+ `titles` or `texts`.
+ titles (`str` or `List[str]`):
+ The passages titles to be encoded. This can be a string or a list of strings if there are several passages.
+ texts (`str` or `List[str]`):
+ The passages texts to be encoded. This can be a string or a list of strings if there are several passages.
+ padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
+ Activates and controls padding. Accepts the following values:
+
+ - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single sequence
+ if provided).
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
+ acceptable input length for the model if that argument is not provided.
+ - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
+ lengths).
+ truncation (`bool`, `str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*, defaults to `False`):
+ Activates and controls truncation. Accepts the following values:
+
+ - `True` or `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or to
+ the maximum acceptable input length for the model if that argument is not provided. This will truncate
+ token by token, removing a token from the longest sequence in the pair if a pair of sequences (or a batch
+ of pairs) is provided.
+ - `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the maximum
+ acceptable input length for the model if that argument is not provided. This will only truncate the first
+ sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
+ - `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the maximum
+ acceptable input length for the model if that argument is not provided. This will only truncate the
+ second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
+ - `False` or `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths
+ greater than the model maximum admissible input size).
+ max_length (`int`, *optional*):
+ Controls the maximum length to use by one of the truncation/padding parameters.
+
+ If left unset or set to `None`, this will use the predefined model maximum length if a maximum length
+ is required by one of the truncation/padding parameters. If the model has no specific maximum input
+ length (like XLNet) truncation/padding to a maximum length will be deactivated.
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
+ If set, will return tensors instead of list of python integers. Acceptable values are:
+
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return Numpy `np.ndarray` objects.
+ return_attention_mask (`bool`, *optional*):
+ Whether or not to return the attention mask. If not set, will return the attention mask according to the
+ specific tokenizer's default, defined by the `return_outputs` attribute.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+ Returns:
+ `Dict[str, List[List[int]]]`: A dictionary with the following keys:
+
+ - `input_ids`: List of token ids to be fed to a model.
+ - `attention_mask`: List of indices specifying which tokens should be attended to by the model.
+ """
+
+
+@add_start_docstrings(CUSTOM_DPR_READER_DOCSTRING)
+class CustomDPRReaderTokenizerMixin:
+ def __call__(
+ self,
+ questions,
+ titles: Optional[str] = None,
+ texts: Optional[str] = None,
+ padding: Union[bool, str] = False,
+ truncation: Union[bool, str] = False,
+ max_length: Optional[int] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_attention_mask: Optional[bool] = None,
+ **kwargs,
+ ) -> BatchEncoding:
+ if titles is None and texts is None:
+ return super().__call__(
+ questions,
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ return_tensors=return_tensors,
+ return_attention_mask=return_attention_mask,
+ **kwargs,
+ )
+ elif titles is None or texts is None:
+ text_pair = titles if texts is None else texts
+ return super().__call__(
+ questions,
+ text_pair,
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ return_tensors=return_tensors,
+ return_attention_mask=return_attention_mask,
+ **kwargs,
+ )
+ titles = titles if not isinstance(titles, str) else [titles]
+ texts = texts if not isinstance(texts, str) else [texts]
+ n_passages = len(titles)
+ questions = questions if not isinstance(questions, str) else [questions] * n_passages
+ if len(titles) != len(texts):
+ raise ValueError(
+ f"There should be as many titles than texts but got {len(titles)} titles and {len(texts)} texts."
+ )
+ encoded_question_and_titles = super().__call__(questions, titles, padding=False, truncation=False)["input_ids"]
+ encoded_texts = super().__call__(texts, add_special_tokens=False, padding=False, truncation=False)["input_ids"]
+ encoded_inputs = {
+ "input_ids": [
+ (encoded_question_and_title + encoded_text)[:max_length]
+ if max_length is not None and truncation
+ else encoded_question_and_title + encoded_text
+ for encoded_question_and_title, encoded_text in zip(encoded_question_and_titles, encoded_texts)
+ ]
+ }
+ if return_attention_mask is not False:
+ attention_mask = []
+ for input_ids in encoded_inputs["input_ids"]:
+ attention_mask.append([int(input_id != self.pad_token_id) for input_id in input_ids])
+ encoded_inputs["attention_mask"] = attention_mask
+ return self.pad(encoded_inputs, padding=padding, max_length=max_length, return_tensors=return_tensors)
+
+ def decode_best_spans(
+ self,
+ reader_input: BatchEncoding,
+ reader_output: DPRReaderOutput,
+ num_spans: int = 16,
+ max_answer_length: int = 64,
+ num_spans_per_passage: int = 4,
+ ) -> List[DPRSpanPrediction]:
+ """
+ Get the span predictions for the extractive Q&A model.
+
+ Returns: *List* of *DPRReaderOutput* sorted by descending *(relevance_score, span_score)*. Each
+ *DPRReaderOutput* is a *Tuple* with:
+
+ - **span_score**: `float` that corresponds to the score given by the reader for this span compared to other
+ spans in the same passage. It corresponds to the sum of the start and end logits of the span.
+ - **relevance_score**: `float` that corresponds to the score of the each passage to answer the question,
+ compared to all the other passages. It corresponds to the output of the QA classifier of the DPRReader.
+ - **doc_id**: `int` the id of the passage. - **start_index**: `int` the start index of the span
+ (inclusive). - **end_index**: `int` the end index of the span (inclusive).
+
+ Examples:
+
+ ```python
+ >>> from transformers import DPRReader, DPRReaderTokenizer
+
+ >>> tokenizer = DPRReaderTokenizer.from_pretrained("facebook/dpr-reader-single-nq-base")
+ >>> model = DPRReader.from_pretrained("facebook/dpr-reader-single-nq-base")
+ >>> encoded_inputs = tokenizer(
+ ... questions=["What is love ?"],
+ ... titles=["Haddaway"],
+ ... texts=["'What Is Love' is a song recorded by the artist Haddaway"],
+ ... return_tensors="pt",
+ ... )
+ >>> outputs = model(**encoded_inputs)
+ >>> predicted_spans = tokenizer.decode_best_spans(encoded_inputs, outputs)
+ >>> print(predicted_spans[0].text) # best span
+ a song
+ ```"""
+ input_ids = reader_input["input_ids"]
+ start_logits, end_logits, relevance_logits = reader_output[:3]
+ n_passages = len(relevance_logits)
+ sorted_docs = sorted(range(n_passages), reverse=True, key=relevance_logits.__getitem__)
+ nbest_spans_predictions: List[DPRReaderOutput] = []
+ for doc_id in sorted_docs:
+ sequence_ids = list(input_ids[doc_id])
+ # assuming question & title information is at the beginning of the sequence
+ passage_offset = sequence_ids.index(self.sep_token_id, 2) + 1 # second sep id
+ if sequence_ids[-1] == self.pad_token_id:
+ sequence_len = sequence_ids.index(self.pad_token_id)
+ else:
+ sequence_len = len(sequence_ids)
+
+ best_spans = self._get_best_spans(
+ start_logits=start_logits[doc_id][passage_offset:sequence_len],
+ end_logits=end_logits[doc_id][passage_offset:sequence_len],
+ max_answer_length=max_answer_length,
+ top_spans=num_spans_per_passage,
+ )
+ for start_index, end_index in best_spans:
+ start_index += passage_offset
+ end_index += passage_offset
+ nbest_spans_predictions.append(
+ DPRSpanPrediction(
+ span_score=start_logits[doc_id][start_index] + end_logits[doc_id][end_index],
+ relevance_score=relevance_logits[doc_id],
+ doc_id=doc_id,
+ start_index=start_index,
+ end_index=end_index,
+ text=self.decode(sequence_ids[start_index : end_index + 1]),
+ )
+ )
+ if len(nbest_spans_predictions) >= num_spans:
+ break
+ return nbest_spans_predictions[:num_spans]
+
+ def _get_best_spans(
+ self,
+ start_logits: List[int],
+ end_logits: List[int],
+ max_answer_length: int,
+ top_spans: int,
+ ) -> List[DPRSpanPrediction]:
+ """
+ Finds the best answer span for the extractive Q&A model for one passage. It returns the best span by descending
+ `span_score` order and keeping max `top_spans` spans. Spans longer that `max_answer_length` are ignored.
+ """
+ scores = []
+ for start_index, start_score in enumerate(start_logits):
+ for answer_length, end_score in enumerate(end_logits[start_index : start_index + max_answer_length]):
+ scores.append(((start_index, start_index + answer_length), start_score + end_score))
+ scores = sorted(scores, key=lambda x: x[1], reverse=True)
+ chosen_span_intervals = []
+ for (start_index, end_index), score in scores:
+ if start_index > end_index:
+ raise ValueError(f"Wrong span indices: [{start_index}:{end_index}]")
+ length = end_index - start_index + 1
+ if length > max_answer_length:
+ raise ValueError(f"Span is too long: {length} > {max_answer_length}")
+ if any(
+ start_index <= prev_start_index <= prev_end_index <= end_index
+ or prev_start_index <= start_index <= end_index <= prev_end_index
+ for (prev_start_index, prev_end_index) in chosen_span_intervals
+ ):
+ continue
+ chosen_span_intervals.append((start_index, end_index))
+
+ if len(chosen_span_intervals) == top_spans:
+ break
+ return chosen_span_intervals
+
+
+@add_end_docstrings(CUSTOM_DPR_READER_DOCSTRING)
+class DPRReaderTokenizer(CustomDPRReaderTokenizerMixin, BertTokenizer):
+ r"""
+ Construct a DPRReader tokenizer.
+
+ [`DPRReaderTokenizer`] is almost identical to [`BertTokenizer`] and runs end-to-end tokenization: punctuation
+ splitting and wordpiece. The difference is that is has three inputs strings: question, titles and texts that are
+ combined to be fed to the [`DPRReader`] model.
+
+ Refer to superclass [`BertTokenizer`] for usage examples and documentation concerning parameters.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ model_input_names = ["input_ids", "attention_mask"]
+
+
+__all__ = ["DPRContextEncoderTokenizer", "DPRQuestionEncoderTokenizer", "DPRReaderOutput", "DPRReaderTokenizer"]
diff --git a/docs/transformers/build/lib/transformers/models/dpr/tokenization_dpr_fast.py b/docs/transformers/build/lib/transformers/models/dpr/tokenization_dpr_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4e7c0fdcdbf72aebb97216927dea52577b8bc4c
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/dpr/tokenization_dpr_fast.py
@@ -0,0 +1,321 @@
+# coding=utf-8
+# Copyright 2018 The HuggingFace Inc. team, The Hugging Face Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization classes for DPR."""
+
+import collections
+from typing import List, Optional, Union
+
+from ...tokenization_utils_base import BatchEncoding
+from ...utils import TensorType, add_end_docstrings, add_start_docstrings, logging
+from ..bert.tokenization_bert_fast import BertTokenizerFast
+from .tokenization_dpr import DPRContextEncoderTokenizer, DPRQuestionEncoderTokenizer, DPRReaderTokenizer
+
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.json"}
+
+
+class DPRContextEncoderTokenizerFast(BertTokenizerFast):
+ r"""
+ Construct a "fast" DPRContextEncoder tokenizer (backed by HuggingFace's *tokenizers* library).
+
+ [`DPRContextEncoderTokenizerFast`] is identical to [`BertTokenizerFast`] and runs end-to-end tokenization:
+ punctuation splitting and wordpiece.
+
+ Refer to superclass [`BertTokenizerFast`] for usage examples and documentation concerning parameters.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ slow_tokenizer_class = DPRContextEncoderTokenizer
+
+
+class DPRQuestionEncoderTokenizerFast(BertTokenizerFast):
+ r"""
+ Constructs a "fast" DPRQuestionEncoder tokenizer (backed by HuggingFace's *tokenizers* library).
+
+ [`DPRQuestionEncoderTokenizerFast`] is identical to [`BertTokenizerFast`] and runs end-to-end tokenization:
+ punctuation splitting and wordpiece.
+
+ Refer to superclass [`BertTokenizerFast`] for usage examples and documentation concerning parameters.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ slow_tokenizer_class = DPRQuestionEncoderTokenizer
+
+
+DPRSpanPrediction = collections.namedtuple(
+ "DPRSpanPrediction", ["span_score", "relevance_score", "doc_id", "start_index", "end_index", "text"]
+)
+
+DPRReaderOutput = collections.namedtuple("DPRReaderOutput", ["start_logits", "end_logits", "relevance_logits"])
+
+
+CUSTOM_DPR_READER_DOCSTRING = r"""
+ Return a dictionary with the token ids of the input strings and other information to give to `.decode_best_spans`.
+ It converts the strings of a question and different passages (title and text) in a sequence of IDs (integers),
+ using the tokenizer and vocabulary. The resulting `input_ids` is a matrix of size `(n_passages, sequence_length)`
+ with the format:
+
+ [CLS] [SEP] [SEP]
+
+ Args:
+ questions (`str` or `List[str]`):
+ The questions to be encoded. You can specify one question for many passages. In this case, the question
+ will be duplicated like `[questions] * n_passages`. Otherwise you have to specify as many questions as in
+ `titles` or `texts`.
+ titles (`str` or `List[str]`):
+ The passages titles to be encoded. This can be a string or a list of strings if there are several passages.
+ texts (`str` or `List[str]`):
+ The passages texts to be encoded. This can be a string or a list of strings if there are several passages.
+ padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
+ Activates and controls padding. Accepts the following values:
+
+ - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single sequence
+ if provided).
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
+ acceptable input length for the model if that argument is not provided.
+ - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
+ lengths).
+ truncation (`bool`, `str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*, defaults to `False`):
+ Activates and controls truncation. Accepts the following values:
+
+ - `True` or `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or to
+ the maximum acceptable input length for the model if that argument is not provided. This will truncate
+ token by token, removing a token from the longest sequence in the pair if a pair of sequences (or a batch
+ of pairs) is provided.
+ - `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the maximum
+ acceptable input length for the model if that argument is not provided. This will only truncate the first
+ sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
+ - `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the maximum
+ acceptable input length for the model if that argument is not provided. This will only truncate the
+ second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
+ - `False` or `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths
+ greater than the model maximum admissible input size).
+ max_length (`int`, *optional*):
+ Controls the maximum length to use by one of the truncation/padding parameters.
+
+ If left unset or set to `None`, this will use the predefined model maximum length if a maximum length
+ is required by one of the truncation/padding parameters. If the model has no specific maximum input
+ length (like XLNet) truncation/padding to a maximum length will be deactivated.
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
+ If set, will return tensors instead of list of python integers. Acceptable values are:
+
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return Numpy `np.ndarray` objects.
+ return_attention_mask (`bool`, *optional*):
+ Whether or not to return the attention mask. If not set, will return the attention mask according to the
+ specific tokenizer's default, defined by the `return_outputs` attribute.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+ Return:
+ `Dict[str, List[List[int]]]`: A dictionary with the following keys:
+
+ - `input_ids`: List of token ids to be fed to a model.
+ - `attention_mask`: List of indices specifying which tokens should be attended to by the model.
+ """
+
+
+@add_start_docstrings(CUSTOM_DPR_READER_DOCSTRING)
+class CustomDPRReaderTokenizerMixin:
+ def __call__(
+ self,
+ questions,
+ titles: Optional[str] = None,
+ texts: Optional[str] = None,
+ padding: Union[bool, str] = False,
+ truncation: Union[bool, str] = False,
+ max_length: Optional[int] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ return_attention_mask: Optional[bool] = None,
+ **kwargs,
+ ) -> BatchEncoding:
+ if titles is None and texts is None:
+ return super().__call__(
+ questions,
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ return_tensors=return_tensors,
+ return_attention_mask=return_attention_mask,
+ **kwargs,
+ )
+ elif titles is None or texts is None:
+ text_pair = titles if texts is None else texts
+ return super().__call__(
+ questions,
+ text_pair,
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ return_tensors=return_tensors,
+ return_attention_mask=return_attention_mask,
+ **kwargs,
+ )
+ titles = titles if not isinstance(titles, str) else [titles]
+ texts = texts if not isinstance(texts, str) else [texts]
+ n_passages = len(titles)
+ questions = questions if not isinstance(questions, str) else [questions] * n_passages
+ assert len(titles) == len(texts), (
+ f"There should be as many titles than texts but got {len(titles)} titles and {len(texts)} texts."
+ )
+ encoded_question_and_titles = super().__call__(questions, titles, padding=False, truncation=False)["input_ids"]
+ encoded_texts = super().__call__(texts, add_special_tokens=False, padding=False, truncation=False)["input_ids"]
+ encoded_inputs = {
+ "input_ids": [
+ (encoded_question_and_title + encoded_text)[:max_length]
+ if max_length is not None and truncation
+ else encoded_question_and_title + encoded_text
+ for encoded_question_and_title, encoded_text in zip(encoded_question_and_titles, encoded_texts)
+ ]
+ }
+ if return_attention_mask is not False:
+ attention_mask = []
+ for input_ids in encoded_inputs["input_ids"]:
+ attention_mask.append([int(input_id != self.pad_token_id) for input_id in input_ids])
+ encoded_inputs["attention_mask"] = attention_mask
+ return self.pad(encoded_inputs, padding=padding, max_length=max_length, return_tensors=return_tensors)
+
+ def decode_best_spans(
+ self,
+ reader_input: BatchEncoding,
+ reader_output: DPRReaderOutput,
+ num_spans: int = 16,
+ max_answer_length: int = 64,
+ num_spans_per_passage: int = 4,
+ ) -> List[DPRSpanPrediction]:
+ """
+ Get the span predictions for the extractive Q&A model.
+
+ Returns: *List* of *DPRReaderOutput* sorted by descending *(relevance_score, span_score)*. Each
+ *DPRReaderOutput* is a *Tuple* with:
+
+ - **span_score**: `float` that corresponds to the score given by the reader for this span compared to other
+ spans in the same passage. It corresponds to the sum of the start and end logits of the span.
+ - **relevance_score**: `float` that corresponds to the score of the each passage to answer the question,
+ compared to all the other passages. It corresponds to the output of the QA classifier of the DPRReader.
+ - **doc_id**: `int` the id of the passage. - ***start_index**: `int` the start index of the span
+ (inclusive). - **end_index**: `int` the end index of the span (inclusive).
+
+ Examples:
+
+ ```python
+ >>> from transformers import DPRReader, DPRReaderTokenizer
+
+ >>> tokenizer = DPRReaderTokenizer.from_pretrained("facebook/dpr-reader-single-nq-base")
+ >>> model = DPRReader.from_pretrained("facebook/dpr-reader-single-nq-base")
+ >>> encoded_inputs = tokenizer(
+ ... questions=["What is love ?"],
+ ... titles=["Haddaway"],
+ ... texts=["'What Is Love' is a song recorded by the artist Haddaway"],
+ ... return_tensors="pt",
+ ... )
+ >>> outputs = model(**encoded_inputs)
+ >>> predicted_spans = tokenizer.decode_best_spans(encoded_inputs, outputs)
+ >>> print(predicted_spans[0].text) # best span
+ a song
+ ```"""
+ input_ids = reader_input["input_ids"]
+ start_logits, end_logits, relevance_logits = reader_output[:3]
+ n_passages = len(relevance_logits)
+ sorted_docs = sorted(range(n_passages), reverse=True, key=relevance_logits.__getitem__)
+ nbest_spans_predictions: List[DPRReaderOutput] = []
+ for doc_id in sorted_docs:
+ sequence_ids = list(input_ids[doc_id])
+ # assuming question & title information is at the beginning of the sequence
+ passage_offset = sequence_ids.index(self.sep_token_id, 2) + 1 # second sep id
+ if sequence_ids[-1] == self.pad_token_id:
+ sequence_len = sequence_ids.index(self.pad_token_id)
+ else:
+ sequence_len = len(sequence_ids)
+
+ best_spans = self._get_best_spans(
+ start_logits=start_logits[doc_id][passage_offset:sequence_len],
+ end_logits=end_logits[doc_id][passage_offset:sequence_len],
+ max_answer_length=max_answer_length,
+ top_spans=num_spans_per_passage,
+ )
+ for start_index, end_index in best_spans:
+ start_index += passage_offset
+ end_index += passage_offset
+ nbest_spans_predictions.append(
+ DPRSpanPrediction(
+ span_score=start_logits[doc_id][start_index] + end_logits[doc_id][end_index],
+ relevance_score=relevance_logits[doc_id],
+ doc_id=doc_id,
+ start_index=start_index,
+ end_index=end_index,
+ text=self.decode(sequence_ids[start_index : end_index + 1]),
+ )
+ )
+ if len(nbest_spans_predictions) >= num_spans:
+ break
+ return nbest_spans_predictions[:num_spans]
+
+ def _get_best_spans(
+ self,
+ start_logits: List[int],
+ end_logits: List[int],
+ max_answer_length: int,
+ top_spans: int,
+ ) -> List[DPRSpanPrediction]:
+ """
+ Finds the best answer span for the extractive Q&A model for one passage. It returns the best span by descending
+ `span_score` order and keeping max `top_spans` spans. Spans longer that `max_answer_length` are ignored.
+ """
+ scores = []
+ for start_index, start_score in enumerate(start_logits):
+ for answer_length, end_score in enumerate(end_logits[start_index : start_index + max_answer_length]):
+ scores.append(((start_index, start_index + answer_length), start_score + end_score))
+ scores = sorted(scores, key=lambda x: x[1], reverse=True)
+ chosen_span_intervals = []
+ for (start_index, end_index), score in scores:
+ assert start_index <= end_index, f"Wrong span indices: [{start_index}:{end_index}]"
+ length = end_index - start_index + 1
+ assert length <= max_answer_length, f"Span is too long: {length} > {max_answer_length}"
+ if any(
+ start_index <= prev_start_index <= prev_end_index <= end_index
+ or prev_start_index <= start_index <= end_index <= prev_end_index
+ for (prev_start_index, prev_end_index) in chosen_span_intervals
+ ):
+ continue
+ chosen_span_intervals.append((start_index, end_index))
+
+ if len(chosen_span_intervals) == top_spans:
+ break
+ return chosen_span_intervals
+
+
+@add_end_docstrings(CUSTOM_DPR_READER_DOCSTRING)
+class DPRReaderTokenizerFast(CustomDPRReaderTokenizerMixin, BertTokenizerFast):
+ r"""
+ Constructs a "fast" DPRReader tokenizer (backed by HuggingFace's *tokenizers* library).
+
+ [`DPRReaderTokenizerFast`] is almost identical to [`BertTokenizerFast`] and runs end-to-end tokenization:
+ punctuation splitting and wordpiece. The difference is that is has three inputs strings: question, titles and texts
+ that are combined to be fed to the [`DPRReader`] model.
+
+ Refer to superclass [`BertTokenizerFast`] for usage examples and documentation concerning parameters.
+
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ model_input_names = ["input_ids", "attention_mask"]
+ slow_tokenizer_class = DPRReaderTokenizer
+
+
+__all__ = ["DPRContextEncoderTokenizerFast", "DPRQuestionEncoderTokenizerFast", "DPRReaderTokenizerFast"]
diff --git a/docs/transformers/build/lib/transformers/models/dpt/__init__.py b/docs/transformers/build/lib/transformers/models/dpt/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..086750423dbd93ecd6f23bf50adb8e0955c1771d
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/dpt/__init__.py
@@ -0,0 +1,29 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_dpt import *
+ from .feature_extraction_dpt import *
+ from .image_processing_dpt import *
+ from .modeling_dpt import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/docs/transformers/build/lib/transformers/models/dpt/configuration_dpt.py b/docs/transformers/build/lib/transformers/models/dpt/configuration_dpt.py
new file mode 100644
index 0000000000000000000000000000000000000000..22f25e18423f142cd9847a3441db723b3e005025
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/dpt/configuration_dpt.py
@@ -0,0 +1,300 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""DPT model configuration"""
+
+import copy
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+from ...utils.backbone_utils import verify_backbone_config_arguments
+from ..auto.configuration_auto import CONFIG_MAPPING
+from ..bit import BitConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+class DPTConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`DPTModel`]. It is used to instantiate an DPT
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the DPT
+ [Intel/dpt-large](https://huggingface.co/Intel/dpt-large) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ intermediate_size (`int`, *optional*, defaults to 3072):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
+ The epsilon used by the layer normalization layers.
+ image_size (`int`, *optional*, defaults to 384):
+ The size (resolution) of each image.
+ patch_size (`int`, *optional*, defaults to 16):
+ The size (resolution) of each patch.
+ num_channels (`int`, *optional*, defaults to 3):
+ The number of input channels.
+ is_hybrid (`bool`, *optional*, defaults to `False`):
+ Whether to use a hybrid backbone. Useful in the context of loading DPT-Hybrid models.
+ qkv_bias (`bool`, *optional*, defaults to `True`):
+ Whether to add a bias to the queries, keys and values.
+ backbone_out_indices (`List[int]`, *optional*, defaults to `[2, 5, 8, 11]`):
+ Indices of the intermediate hidden states to use from backbone.
+ readout_type (`str`, *optional*, defaults to `"project"`):
+ The readout type to use when processing the readout token (CLS token) of the intermediate hidden states of
+ the ViT backbone. Can be one of [`"ignore"`, `"add"`, `"project"`].
+
+ - "ignore" simply ignores the CLS token.
+ - "add" passes the information from the CLS token to all other tokens by adding the representations.
+ - "project" passes information to the other tokens by concatenating the readout to all other tokens before
+ projecting the
+ representation to the original feature dimension D using a linear layer followed by a GELU non-linearity.
+ reassemble_factors (`List[int]`, *optional*, defaults to `[4, 2, 1, 0.5]`):
+ The up/downsampling factors of the reassemble layers.
+ neck_hidden_sizes (`List[str]`, *optional*, defaults to `[96, 192, 384, 768]`):
+ The hidden sizes to project to for the feature maps of the backbone.
+ fusion_hidden_size (`int`, *optional*, defaults to 256):
+ The number of channels before fusion.
+ head_in_index (`int`, *optional*, defaults to -1):
+ The index of the features to use in the heads.
+ use_batch_norm_in_fusion_residual (`bool`, *optional*, defaults to `False`):
+ Whether to use batch normalization in the pre-activate residual units of the fusion blocks.
+ use_bias_in_fusion_residual (`bool`, *optional*, defaults to `True`):
+ Whether to use bias in the pre-activate residual units of the fusion blocks.
+ add_projection (`bool`, *optional*, defaults to `False`):
+ Whether to add a projection layer before the depth estimation head.
+ use_auxiliary_head (`bool`, *optional*, defaults to `True`):
+ Whether to use an auxiliary head during training.
+ auxiliary_loss_weight (`float`, *optional*, defaults to 0.4):
+ Weight of the cross-entropy loss of the auxiliary head.
+ semantic_loss_ignore_index (`int`, *optional*, defaults to 255):
+ The index that is ignored by the loss function of the semantic segmentation model.
+ semantic_classifier_dropout (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for the semantic classification head.
+ backbone_featmap_shape (`List[int]`, *optional*, defaults to `[1, 1024, 24, 24]`):
+ Used only for the `hybrid` embedding type. The shape of the feature maps of the backbone.
+ neck_ignore_stages (`List[int]`, *optional*, defaults to `[0, 1]`):
+ Used only for the `hybrid` embedding type. The stages of the readout layers to ignore.
+ backbone_config (`Union[Dict[str, Any], PretrainedConfig]`, *optional*):
+ The configuration of the backbone model. Only used in case `is_hybrid` is `True` or in case you want to
+ leverage the [`AutoBackbone`] API.
+ backbone (`str`, *optional*):
+ Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
+ will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
+ is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
+ use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
+ Whether to use pretrained weights for the backbone.
+ use_timm_backbone (`bool`, *optional*, defaults to `False`):
+ Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
+ library.
+ backbone_kwargs (`dict`, *optional*):
+ Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
+ e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
+ pooler_output_size (`int`, *optional*):
+ Dimensionality of the pooler layer. If None, defaults to `hidden_size`.
+ pooler_act (`str`, *optional*, defaults to `"tanh"`):
+ The activation function to be used by the pooler. Keys of ACT2FN are supported for Flax and
+ Pytorch, and elements of https://www.tensorflow.org/api_docs/python/tf/keras/activations are
+ supported for Tensorflow.
+
+ Example:
+
+ ```python
+ >>> from transformers import DPTModel, DPTConfig
+
+ >>> # Initializing a DPT dpt-large style configuration
+ >>> configuration = DPTConfig()
+
+ >>> # Initializing a model from the dpt-large style configuration
+ >>> model = DPTModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "dpt"
+
+ def __init__(
+ self,
+ hidden_size=768,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ intermediate_size=3072,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.0,
+ attention_probs_dropout_prob=0.0,
+ initializer_range=0.02,
+ layer_norm_eps=1e-12,
+ image_size=384,
+ patch_size=16,
+ num_channels=3,
+ is_hybrid=False,
+ qkv_bias=True,
+ backbone_out_indices=[2, 5, 8, 11],
+ readout_type="project",
+ reassemble_factors=[4, 2, 1, 0.5],
+ neck_hidden_sizes=[96, 192, 384, 768],
+ fusion_hidden_size=256,
+ head_in_index=-1,
+ use_batch_norm_in_fusion_residual=False,
+ use_bias_in_fusion_residual=None,
+ add_projection=False,
+ use_auxiliary_head=True,
+ auxiliary_loss_weight=0.4,
+ semantic_loss_ignore_index=255,
+ semantic_classifier_dropout=0.1,
+ backbone_featmap_shape=[1, 1024, 24, 24],
+ neck_ignore_stages=[0, 1],
+ backbone_config=None,
+ backbone=None,
+ use_pretrained_backbone=False,
+ use_timm_backbone=False,
+ backbone_kwargs=None,
+ pooler_output_size=None,
+ pooler_act="tanh",
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.hidden_size = hidden_size
+ self.is_hybrid = is_hybrid
+
+ use_autobackbone = False
+ if self.is_hybrid:
+ if backbone_config is None:
+ backbone_config = {
+ "global_padding": "same",
+ "layer_type": "bottleneck",
+ "depths": [3, 4, 9],
+ "out_features": ["stage1", "stage2", "stage3"],
+ "embedding_dynamic_padding": True,
+ }
+
+ if isinstance(backbone_config, dict):
+ logger.info("Initializing the config with a `BiT` backbone.")
+ backbone_config = BitConfig(**backbone_config)
+ elif isinstance(backbone_config, PretrainedConfig):
+ backbone_config = backbone_config
+ else:
+ raise ValueError(
+ f"backbone_config must be a dictionary or a `PretrainedConfig`, got {backbone_config.__class__}."
+ )
+ self.backbone_config = backbone_config
+ self.backbone_featmap_shape = backbone_featmap_shape
+ self.neck_ignore_stages = neck_ignore_stages
+
+ if readout_type != "project":
+ raise ValueError("Readout type must be 'project' when using `DPT-hybrid` mode.")
+
+ elif backbone is not None or backbone_config is not None:
+ use_autobackbone = True
+ if isinstance(backbone_config, dict):
+ backbone_model_type = backbone_config.get("model_type")
+ config_class = CONFIG_MAPPING[backbone_model_type]
+ backbone_config = config_class.from_dict(backbone_config)
+
+ self.backbone_config = backbone_config
+ self.backbone_featmap_shape = None
+ self.neck_ignore_stages = []
+
+ # We only use load_backbone when config.is_hydrid is False
+ verify_backbone_config_arguments(
+ use_timm_backbone=use_timm_backbone,
+ use_pretrained_backbone=use_pretrained_backbone,
+ backbone=backbone,
+ backbone_config=backbone_config,
+ backbone_kwargs=backbone_kwargs,
+ )
+ else:
+ self.backbone_config = None
+ self.backbone_featmap_shape = None
+ self.neck_ignore_stages = []
+
+ self.backbone = backbone
+ self.use_pretrained_backbone = use_pretrained_backbone
+ self.use_timm_backbone = use_timm_backbone
+ self.backbone_kwargs = backbone_kwargs
+
+ # ViT parameters used if not using a hybrid backbone
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.layer_norm_eps = layer_norm_eps
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.qkv_bias = qkv_bias
+ self.use_autobackbone = use_autobackbone
+ self.backbone_out_indices = None if use_autobackbone else backbone_out_indices
+
+ if readout_type not in ["ignore", "add", "project"]:
+ raise ValueError("Readout_type must be one of ['ignore', 'add', 'project']")
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.readout_type = readout_type
+ self.reassemble_factors = reassemble_factors
+ self.neck_hidden_sizes = neck_hidden_sizes
+ self.fusion_hidden_size = fusion_hidden_size
+ self.head_in_index = head_in_index
+ self.use_batch_norm_in_fusion_residual = use_batch_norm_in_fusion_residual
+ self.use_bias_in_fusion_residual = use_bias_in_fusion_residual
+ self.add_projection = add_projection
+
+ # auxiliary head attributes (semantic segmentation)
+ self.use_auxiliary_head = use_auxiliary_head
+ self.auxiliary_loss_weight = auxiliary_loss_weight
+ self.semantic_loss_ignore_index = semantic_loss_ignore_index
+ self.semantic_classifier_dropout = semantic_classifier_dropout
+ self.pooler_output_size = pooler_output_size if pooler_output_size else hidden_size
+ self.pooler_act = pooler_act
+
+ def to_dict(self):
+ """
+ Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. Returns:
+ `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
+ """
+ output = copy.deepcopy(self.__dict__)
+
+ if output["backbone_config"] is not None:
+ output["backbone_config"] = self.backbone_config.to_dict()
+
+ output["model_type"] = self.__class__.model_type
+ return output
+
+ @property
+ def sub_configs(self):
+ return {"backbone_config": type(self.backbone_config)} if self.backbone_config is not None else {}
+
+
+__all__ = ["DPTConfig"]
diff --git a/docs/transformers/build/lib/transformers/models/dpt/convert_dinov2_depth_to_hf.py b/docs/transformers/build/lib/transformers/models/dpt/convert_dinov2_depth_to_hf.py
new file mode 100644
index 0000000000000000000000000000000000000000..367aff7f90e18bab33dc81bfc82148550d73b944
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/dpt/convert_dinov2_depth_to_hf.py
@@ -0,0 +1,383 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Convert DINOv2 + DPT checkpoints from the original repository. URL:
+https://github.com/facebookresearch/dinov2/tree/main"""
+
+import argparse
+import itertools
+import math
+from pathlib import Path
+
+import requests
+import torch
+from PIL import Image
+from torchvision import transforms
+
+from transformers import Dinov2Config, DPTConfig, DPTForDepthEstimation, DPTImageProcessor
+from transformers.utils import logging
+
+
+logging.set_verbosity_info()
+logger = logging.get_logger(__name__)
+
+
+def get_dpt_config(model_name):
+ if "small" in model_name:
+ # equivalent to stage 3, stage 6, stage 9, stage 12
+ backbone_config = Dinov2Config.from_pretrained(
+ "facebook/dinov2-small", out_indices=[3, 6, 9, 12], apply_layernorm=False, reshape_hidden_states=False
+ )
+ neck_hidden_sizes = [48, 96, 192, 384]
+ elif "base" in model_name:
+ backbone_config = Dinov2Config.from_pretrained(
+ "facebook/dinov2-base", out_indices=[3, 6, 9, 12], apply_layernorm=False, reshape_hidden_states=False
+ )
+ neck_hidden_sizes = [96, 192, 384, 768]
+ elif "large" in model_name:
+ backbone_config = Dinov2Config.from_pretrained(
+ "facebook/dinov2-large", out_indices=[5, 12, 18, 24], apply_layernorm=False, reshape_hidden_states=False
+ )
+ neck_hidden_sizes = [128, 256, 512, 1024]
+ elif "giant" in model_name:
+ backbone_config = Dinov2Config.from_pretrained(
+ "facebook/dinov2-giant", out_indices=[10, 20, 30, 40], apply_layernorm=False, reshape_hidden_states=False
+ )
+ neck_hidden_sizes = [192, 384, 768, 1536]
+ else:
+ raise NotImplementedError("To do")
+
+ config = DPTConfig(
+ backbone_config=backbone_config,
+ neck_hidden_sizes=neck_hidden_sizes,
+ use_bias_in_fusion_residual=False,
+ add_projection=True,
+ )
+
+ return config
+
+
+# here we list all DPT keys to be renamed (original name on the left, our name on the right)
+def create_rename_keys_dpt(config):
+ rename_keys = []
+
+ # fmt: off
+ # activation postprocessing (projections, readout projections + resize blocks)
+ for i in range(4):
+ rename_keys.append((f"decode_head.reassemble_blocks.projects.{i}.conv.weight", f"neck.reassemble_stage.layers.{i}.projection.weight"))
+ rename_keys.append((f"decode_head.reassemble_blocks.projects.{i}.conv.bias", f"neck.reassemble_stage.layers.{i}.projection.bias"))
+
+ rename_keys.append((f"decode_head.reassemble_blocks.readout_projects.{i}.0.weight", f"neck.reassemble_stage.readout_projects.{i}.0.weight"))
+ rename_keys.append((f"decode_head.reassemble_blocks.readout_projects.{i}.0.bias", f"neck.reassemble_stage.readout_projects.{i}.0.bias"))
+
+ if i != 2:
+ rename_keys.append((f"decode_head.reassemble_blocks.resize_layers.{i}.weight", f"neck.reassemble_stage.layers.{i}.resize.weight"))
+ rename_keys.append((f"decode_head.reassemble_blocks.resize_layers.{i}.bias", f"neck.reassemble_stage.layers.{i}.resize.bias"))
+
+ # fusion layers
+ for i in range(4):
+ rename_keys.append((f"decode_head.fusion_blocks.{i}.project.conv.weight", f"neck.fusion_stage.layers.{i}.projection.weight"))
+ rename_keys.append((f"decode_head.fusion_blocks.{i}.project.conv.bias", f"neck.fusion_stage.layers.{i}.projection.bias"))
+ if i != 0:
+ rename_keys.append((f"decode_head.fusion_blocks.{i}.res_conv_unit1.conv1.conv.weight", f"neck.fusion_stage.layers.{i}.residual_layer1.convolution1.weight"))
+ rename_keys.append((f"decode_head.fusion_blocks.{i}.res_conv_unit1.conv2.conv.weight", f"neck.fusion_stage.layers.{i}.residual_layer1.convolution2.weight"))
+ rename_keys.append((f"decode_head.fusion_blocks.{i}.res_conv_unit2.conv1.conv.weight", f"neck.fusion_stage.layers.{i}.residual_layer2.convolution1.weight"))
+ rename_keys.append((f"decode_head.fusion_blocks.{i}.res_conv_unit2.conv2.conv.weight", f"neck.fusion_stage.layers.{i}.residual_layer2.convolution2.weight"))
+
+ # neck convolutions
+ for i in range(4):
+ rename_keys.append((f"decode_head.convs.{i}.conv.weight", f"neck.convs.{i}.weight"))
+
+ # head
+ rename_keys.append(("decode_head.project.conv.weight", "head.projection.weight"))
+ rename_keys.append(("decode_head.project.conv.bias", "head.projection.bias"))
+
+ for i in range(0, 5, 2):
+ rename_keys.append((f"decode_head.conv_depth.head.{i}.weight", f"head.head.{i}.weight"))
+ rename_keys.append((f"decode_head.conv_depth.head.{i}.bias", f"head.head.{i}.bias"))
+ # fmt: on
+
+ return rename_keys
+
+
+# here we list all backbone keys to be renamed (original name on the left, our name on the right)
+def create_rename_keys_backbone(config):
+ rename_keys = []
+
+ # fmt: off
+ # patch embedding layer
+ rename_keys.append(("cls_token", "backbone.embeddings.cls_token"))
+ rename_keys.append(("mask_token", "backbone.embeddings.mask_token"))
+ rename_keys.append(("pos_embed", "backbone.embeddings.position_embeddings"))
+ rename_keys.append(("patch_embed.proj.weight", "backbone.embeddings.patch_embeddings.projection.weight"))
+ rename_keys.append(("patch_embed.proj.bias", "backbone.embeddings.patch_embeddings.projection.bias"))
+
+ # Transfomer encoder
+ for i in range(config.backbone_config.num_hidden_layers):
+ # layernorms
+ rename_keys.append((f"blocks.{i}.norm1.weight", f"backbone.encoder.layer.{i}.norm1.weight"))
+ rename_keys.append((f"blocks.{i}.norm1.bias", f"backbone.encoder.layer.{i}.norm1.bias"))
+ rename_keys.append((f"blocks.{i}.norm2.weight", f"backbone.encoder.layer.{i}.norm2.weight"))
+ rename_keys.append((f"blocks.{i}.norm2.bias", f"backbone.encoder.layer.{i}.norm2.bias"))
+ # MLP
+ if config.backbone_config.use_swiglu_ffn:
+ rename_keys.append((f"blocks.{i}.mlp.w12.weight", f"backbone.encoder.layer.{i}.mlp.w12.weight"))
+ rename_keys.append((f"blocks.{i}.mlp.w12.bias", f"backbone.encoder.layer.{i}.mlp.w12.bias"))
+ rename_keys.append((f"blocks.{i}.mlp.w3.weight", f"backbone.encoder.layer.{i}.mlp.w3.weight"))
+ rename_keys.append((f"blocks.{i}.mlp.w3.bias", f"backbone.encoder.layer.{i}.mlp.w3.bias"))
+ else:
+ rename_keys.append((f"blocks.{i}.mlp.fc1.weight", f"backbone.encoder.layer.{i}.mlp.fc1.weight"))
+ rename_keys.append((f"blocks.{i}.mlp.fc1.bias", f"backbone.encoder.layer.{i}.mlp.fc1.bias"))
+ rename_keys.append((f"blocks.{i}.mlp.fc2.weight", f"backbone.encoder.layer.{i}.mlp.fc2.weight"))
+ rename_keys.append((f"blocks.{i}.mlp.fc2.bias", f"backbone.encoder.layer.{i}.mlp.fc2.bias"))
+ # layerscale
+ rename_keys.append((f"blocks.{i}.ls1.gamma", f"backbone.encoder.layer.{i}.layer_scale1.lambda1"))
+ rename_keys.append((f"blocks.{i}.ls2.gamma", f"backbone.encoder.layer.{i}.layer_scale2.lambda1"))
+ # attention projection layer
+ rename_keys.append((f"blocks.{i}.attn.proj.weight", f"backbone.encoder.layer.{i}.attention.output.dense.weight"))
+ rename_keys.append((f"blocks.{i}.attn.proj.bias", f"backbone.encoder.layer.{i}.attention.output.dense.bias"))
+ # fmt: on
+
+ rename_keys.append(("norm.weight", "backbone.layernorm.weight"))
+ rename_keys.append(("norm.bias", "backbone.layernorm.bias"))
+
+ return rename_keys
+
+
+# we split up the matrix of each encoder layer into queries, keys and values
+def read_in_q_k_v(state_dict, config):
+ for i in range(config.backbone_config.num_hidden_layers):
+ # read in weights + bias of input projection layer (in timm, this is a single matrix + bias)
+ in_proj_weight = state_dict.pop(f"blocks.{i}.attn.qkv.weight")
+ in_proj_bias = state_dict.pop(f"blocks.{i}.attn.qkv.bias")
+ hidden_size = config.backbone_config.hidden_size
+ # next, add query, keys and values (in that order) to the state dict
+ state_dict[f"backbone.encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[:hidden_size, :]
+ state_dict[f"backbone.encoder.layer.{i}.attention.attention.query.bias"] = in_proj_bias[:hidden_size]
+ state_dict[f"backbone.encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[
+ hidden_size : hidden_size * 2, :
+ ]
+ state_dict[f"backbone.encoder.layer.{i}.attention.attention.key.bias"] = in_proj_bias[
+ hidden_size : hidden_size * 2
+ ]
+ state_dict[f"backbone.encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[-hidden_size:, :]
+ state_dict[f"backbone.encoder.layer.{i}.attention.attention.value.bias"] = in_proj_bias[-hidden_size:]
+
+
+def rename_key(dct, old, new):
+ val = dct.pop(old)
+ dct[new] = val
+
+
+# We will verify our results on an image of cute cats
+def prepare_img():
+ url = "https://dl.fbaipublicfiles.com/dinov2/images/example.jpg"
+ im = Image.open(requests.get(url, stream=True).raw)
+ return im
+
+
+name_to_url = {
+ "dpt-dinov2-small-nyu": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_nyu_dpt_head.pth",
+ "dpt-dinov2-small-kitti": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_kitti_dpt_head.pth",
+ "dpt-dinov2-base-nyu": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_nyu_dpt_head.pth",
+ "dpt-dinov2-base-kitti": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_kitti_dpt_head.pth",
+ "dpt-dinov2-large-nyu": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_nyu_dpt_head.pth",
+ "dpt-dinov2-large-kitti": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_kitti_dpt_head.pth",
+ "dpt-dinov2-giant-nyu": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_nyu_dpt_head.pth",
+ "dpt-dinov2-giant-kitti": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_kitti_dpt_head.pth",
+}
+
+
+def get_original_pixel_values(image):
+ class CenterPadding:
+ def __init__(self, multiple):
+ super().__init__()
+ self.multiple = multiple
+
+ def _get_pad(self, size):
+ new_size = math.ceil(size / self.multiple) * self.multiple
+ pad_size = new_size - size
+ pad_size_left = pad_size // 2
+ pad_size_right = pad_size - pad_size_left
+ return pad_size_left, pad_size_right
+
+ def __call__(self, img):
+ pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in img.shape[-2:][::-1]))
+ output = torch.nn.functional.pad(img, pads)
+ return output
+
+ def __repr__(self):
+ return self.__class__.__name__ + "()"
+
+ def make_depth_transform() -> transforms.Compose:
+ return transforms.Compose(
+ [
+ transforms.ToTensor(),
+ lambda x: 255.0 * x[:3], # Discard alpha component and scale by 255
+ transforms.Normalize(
+ mean=(123.675, 116.28, 103.53),
+ std=(58.395, 57.12, 57.375),
+ ),
+ CenterPadding(multiple=14),
+ ]
+ )
+
+ transform = make_depth_transform()
+ original_pixel_values = transform(image).unsqueeze(0)
+
+ return original_pixel_values
+
+
+@torch.no_grad()
+def convert_dpt_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub, verify_logits):
+ """
+ Copy/paste/tweak model's weights to our DPT structure.
+ """
+
+ # define DPT configuration based on URL
+ checkpoint_url = name_to_url[model_name]
+ config = get_dpt_config(model_name)
+
+ # load original DPT state_dict from URL
+ print("URL:", checkpoint_url)
+ dpt_state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu")["state_dict"]
+ # rename keys
+ rename_keys = create_rename_keys_dpt(config)
+ for src, dest in rename_keys:
+ rename_key(dpt_state_dict, src, dest)
+
+ # load original backbone state_dict from URL
+ if "small" in model_name:
+ original_model = torch.hub.load("facebookresearch/dinov2", "dinov2_vits14")
+ elif "base" in model_name:
+ original_model = torch.hub.load("facebookresearch/dinov2", "dinov2_vitb14")
+ elif "large" in model_name:
+ original_model = torch.hub.load("facebookresearch/dinov2", "dinov2_vitl14")
+ elif "giant" in model_name:
+ original_model = torch.hub.load("facebookresearch/dinov2", "dinov2_vitg14")
+ else:
+ raise NotImplementedError("To do")
+ original_model.eval()
+ backbone_state_dict = original_model.state_dict()
+
+ # rename keys
+ rename_keys = create_rename_keys_backbone(config)
+ for src, dest in rename_keys:
+ rename_key(backbone_state_dict, src, dest)
+
+ # read in qkv matrices
+ read_in_q_k_v(backbone_state_dict, config)
+
+ for key, val in backbone_state_dict.copy().items():
+ val = backbone_state_dict.pop(key)
+ if "w12" in key:
+ key = key.replace("w12", "weights_in")
+ if "w3" in key:
+ key = key.replace("w3", "weights_out")
+ backbone_state_dict[key] = val
+
+ # merge state_dicts
+ state_dict = {**backbone_state_dict, **dpt_state_dict}
+
+ # load HuggingFace model
+ model = DPTForDepthEstimation(config)
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
+ print("Missing keys:", missing_keys)
+ print("Unexpected keys:", unexpected_keys)
+ assert missing_keys == [
+ "neck.fusion_stage.layers.0.residual_layer1.convolution1.weight",
+ "neck.fusion_stage.layers.0.residual_layer1.convolution2.weight",
+ ]
+ model.eval()
+
+ # Verify image processor
+ processor = DPTImageProcessor(
+ do_resize=False,
+ do_rescale=False,
+ do_pad=True,
+ size_divisor=14,
+ do_normalize=True,
+ image_mean=(123.675, 116.28, 103.53),
+ image_std=(58.395, 57.12, 57.375),
+ )
+
+ image = prepare_img()
+ pixel_values = processor(image, return_tensors="pt").pixel_values.float()
+ original_pixel_values = get_original_pixel_values(image)
+
+ assert torch.allclose(pixel_values, original_pixel_values)
+
+ # Verify forward pass
+ with torch.no_grad():
+ outputs = model(pixel_values)
+
+ predicted_depth = outputs.predicted_depth
+
+ print("Shape of predicted depth:", predicted_depth.shape)
+ print("First values of predicted depth:", predicted_depth[0, :3, :3])
+
+ # assert logits
+ if verify_logits:
+ if model_name == "dpt-dinov2-small-nyu":
+ expected_shape = torch.Size([1, 576, 736])
+ expected_slice = torch.tensor(
+ [[3.3576, 3.4741, 3.4345], [3.4324, 3.5012, 3.2775], [3.2560, 3.3563, 3.2354]]
+ )
+
+ assert predicted_depth.shape == torch.Size(expected_shape)
+ assert torch.allclose(predicted_depth[0, :3, :3], expected_slice, atol=1e-5)
+ print("Looks ok!")
+
+ if pytorch_dump_folder_path is not None:
+ Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
+ print(f"Saving model and processor to {pytorch_dump_folder_path}")
+ model.save_pretrained(pytorch_dump_folder_path)
+ processor.save_pretrained(pytorch_dump_folder_path)
+
+ if push_to_hub:
+ print("Pushing model and processor to hub...")
+ model.push_to_hub(repo_id=f"facebook/{model_name}")
+ processor.push_to_hub(repo_id=f"facebook/{model_name}")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ # Required parameters
+ parser.add_argument(
+ "--model_name",
+ default="dpt-dinov2-small-nyu",
+ type=str,
+ choices=name_to_url.keys(),
+ help="Name of the model you'd like to convert.",
+ )
+ parser.add_argument(
+ "--pytorch_dump_folder_path",
+ default=None,
+ type=str,
+ help="Path to the output PyTorch model directory.",
+ )
+ parser.add_argument(
+ "--push_to_hub",
+ action="store_true",
+ help="Whether to push the model to the hub after conversion.",
+ )
+ parser.add_argument(
+ "--verify_logits",
+ action="store_true",
+ required=False,
+ help="Path to the output PyTorch model directory.",
+ )
+
+ args = parser.parse_args()
+ convert_dpt_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub, args.verify_logits)
diff --git a/docs/transformers/build/lib/transformers/models/dpt/convert_dpt_beit_to_hf.py b/docs/transformers/build/lib/transformers/models/dpt/convert_dpt_beit_to_hf.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a576d772f577b0f690fb3300c3dd75203f700a6
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/dpt/convert_dpt_beit_to_hf.py
@@ -0,0 +1,305 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Convert DPT 3.1 checkpoints from the MiDaS repository. URL: https://github.com/isl-org/MiDaS"""
+
+import argparse
+from pathlib import Path
+
+import requests
+import torch
+from PIL import Image
+
+from transformers import BeitConfig, DPTConfig, DPTForDepthEstimation, DPTImageProcessor
+from transformers.utils import logging
+
+
+logging.set_verbosity_info()
+logger = logging.get_logger(__name__)
+
+
+def get_dpt_config(model_name):
+ hidden_size = 768
+ num_hidden_layers = 12
+ num_attention_heads = 12
+ intermediate_size = 3072
+ out_features = ["stage3", "stage6", "stage9", "stage12"] # beit-base-384 uses [2, 5, 8, 11]
+
+ if "large" in model_name:
+ hidden_size = 1024
+ num_hidden_layers = 24
+ num_attention_heads = 16
+ intermediate_size = 4096
+ out_features = ["stage6", "stage12", "stage18", "stage24"] # beit-large-512 uses [5, 11, 17, 23]
+
+ if "512" in model_name:
+ image_size = 512
+ elif "384" in model_name:
+ image_size = 384
+ else:
+ raise ValueError("Model not supported")
+
+ backbone_config = BeitConfig(
+ image_size=image_size,
+ num_hidden_layers=num_hidden_layers,
+ hidden_size=hidden_size,
+ intermediate_size=intermediate_size,
+ num_attention_heads=num_attention_heads,
+ use_relative_position_bias=True,
+ reshape_hidden_states=False,
+ out_features=out_features,
+ )
+
+ neck_hidden_sizes = [256, 512, 1024, 1024] if "large" in model_name else [96, 192, 384, 768]
+ config = DPTConfig(backbone_config=backbone_config, neck_hidden_sizes=neck_hidden_sizes)
+
+ return config, image_size
+
+
+# here we list all keys to be renamed (original name on the left, our name on the right)
+def create_rename_keys(config):
+ rename_keys = []
+
+ # fmt: off
+ # stem
+ rename_keys.append(("pretrained.model.cls_token", "backbone.embeddings.cls_token"))
+ rename_keys.append(("pretrained.model.patch_embed.proj.weight", "backbone.embeddings.patch_embeddings.projection.weight"))
+ rename_keys.append(("pretrained.model.patch_embed.proj.bias", "backbone.embeddings.patch_embeddings.projection.bias"))
+
+ # Transfomer encoder
+ for i in range(config.backbone_config.num_hidden_layers):
+ rename_keys.append((f"pretrained.model.blocks.{i}.gamma_1", f"backbone.encoder.layer.{i}.lambda_1"))
+ rename_keys.append((f"pretrained.model.blocks.{i}.gamma_2", f"backbone.encoder.layer.{i}.lambda_2"))
+ rename_keys.append((f"pretrained.model.blocks.{i}.norm1.weight", f"backbone.encoder.layer.{i}.layernorm_before.weight"))
+ rename_keys.append((f"pretrained.model.blocks.{i}.norm1.bias", f"backbone.encoder.layer.{i}.layernorm_before.bias"))
+ rename_keys.append((f"pretrained.model.blocks.{i}.norm2.weight", f"backbone.encoder.layer.{i}.layernorm_after.weight"))
+ rename_keys.append((f"pretrained.model.blocks.{i}.norm2.bias", f"backbone.encoder.layer.{i}.layernorm_after.bias"))
+ rename_keys.append((f"pretrained.model.blocks.{i}.mlp.fc1.weight", f"backbone.encoder.layer.{i}.intermediate.dense.weight"))
+ rename_keys.append((f"pretrained.model.blocks.{i}.mlp.fc1.bias", f"backbone.encoder.layer.{i}.intermediate.dense.bias"))
+ rename_keys.append((f"pretrained.model.blocks.{i}.mlp.fc2.weight", f"backbone.encoder.layer.{i}.output.dense.weight"))
+ rename_keys.append((f"pretrained.model.blocks.{i}.mlp.fc2.bias", f"backbone.encoder.layer.{i}.output.dense.bias"))
+ rename_keys.append((f"pretrained.model.blocks.{i}.attn.proj.weight", f"backbone.encoder.layer.{i}.attention.output.dense.weight"))
+ rename_keys.append((f"pretrained.model.blocks.{i}.attn.proj.bias", f"backbone.encoder.layer.{i}.attention.output.dense.bias"))
+ rename_keys.append((f"pretrained.model.blocks.{i}.attn.relative_position_bias_table", f"backbone.encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_bias_table"))
+ rename_keys.append((f"pretrained.model.blocks.{i}.attn.relative_position_index", f"backbone.encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_index"))
+
+ # activation postprocessing (readout projections + resize blocks)
+ for i in range(4):
+ rename_keys.append((f"pretrained.act_postprocess{i+1}.0.project.0.weight", f"neck.reassemble_stage.readout_projects.{i}.0.weight"))
+ rename_keys.append((f"pretrained.act_postprocess{i+1}.0.project.0.bias", f"neck.reassemble_stage.readout_projects.{i}.0.bias"))
+
+ rename_keys.append((f"pretrained.act_postprocess{i+1}.3.weight", f"neck.reassemble_stage.layers.{i}.projection.weight"))
+ rename_keys.append((f"pretrained.act_postprocess{i+1}.3.bias", f"neck.reassemble_stage.layers.{i}.projection.bias"))
+
+ if i != 2:
+ rename_keys.append((f"pretrained.act_postprocess{i+1}.4.weight", f"neck.reassemble_stage.layers.{i}.resize.weight"))
+ rename_keys.append((f"pretrained.act_postprocess{i+1}.4.bias", f"neck.reassemble_stage.layers.{i}.resize.bias"))
+
+ # refinenet (tricky here)
+ mapping = {1:3, 2:2, 3:1, 4:0}
+
+ for i in range(1, 5):
+ j = mapping[i]
+ rename_keys.append((f"scratch.refinenet{i}.out_conv.weight", f"neck.fusion_stage.layers.{j}.projection.weight"))
+ rename_keys.append((f"scratch.refinenet{i}.out_conv.bias", f"neck.fusion_stage.layers.{j}.projection.bias"))
+ rename_keys.append((f"scratch.refinenet{i}.resConfUnit1.conv1.weight", f"neck.fusion_stage.layers.{j}.residual_layer1.convolution1.weight"))
+ rename_keys.append((f"scratch.refinenet{i}.resConfUnit1.conv1.bias", f"neck.fusion_stage.layers.{j}.residual_layer1.convolution1.bias"))
+ rename_keys.append((f"scratch.refinenet{i}.resConfUnit1.conv2.weight", f"neck.fusion_stage.layers.{j}.residual_layer1.convolution2.weight"))
+ rename_keys.append((f"scratch.refinenet{i}.resConfUnit1.conv2.bias", f"neck.fusion_stage.layers.{j}.residual_layer1.convolution2.bias"))
+ rename_keys.append((f"scratch.refinenet{i}.resConfUnit2.conv1.weight", f"neck.fusion_stage.layers.{j}.residual_layer2.convolution1.weight"))
+ rename_keys.append((f"scratch.refinenet{i}.resConfUnit2.conv1.bias", f"neck.fusion_stage.layers.{j}.residual_layer2.convolution1.bias"))
+ rename_keys.append((f"scratch.refinenet{i}.resConfUnit2.conv2.weight", f"neck.fusion_stage.layers.{j}.residual_layer2.convolution2.weight"))
+ rename_keys.append((f"scratch.refinenet{i}.resConfUnit2.conv2.bias", f"neck.fusion_stage.layers.{j}.residual_layer2.convolution2.bias"))
+
+ # scratch convolutions
+ for i in range(4):
+ rename_keys.append((f"scratch.layer{i+1}_rn.weight", f"neck.convs.{i}.weight"))
+
+ # head
+ for i in range(0, 5, 2):
+ rename_keys.append((f"scratch.output_conv.{i}.weight", f"head.head.{i}.weight"))
+ rename_keys.append((f"scratch.output_conv.{i}.bias", f"head.head.{i}.bias"))
+
+ return rename_keys
+
+
+def remove_ignore_keys_(state_dict):
+ ignore_keys = ["pretrained.model.head.weight", "pretrained.model.head.bias"]
+ for k in ignore_keys:
+ state_dict.pop(k, None)
+
+
+# we split up the matrix of each encoder layer into queries, keys and values
+def read_in_q_k_v(state_dict, config):
+ hidden_size = config.backbone_config.hidden_size
+ for i in range(config.backbone_config.num_hidden_layers):
+ # read in weights + bias of input projection layer (in original implementation, this is a single matrix + bias)
+ in_proj_weight = state_dict.pop(f"pretrained.model.blocks.{i}.attn.qkv.weight")
+ q_bias = state_dict.pop(f"pretrained.model.blocks.{i}.attn.q_bias")
+ v_bias = state_dict.pop(f"pretrained.model.blocks.{i}.attn.v_bias")
+ # next, add query, keys and values (in that order) to the state dict
+ state_dict[f"backbone.encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[:hidden_size, :]
+ state_dict[f"backbone.encoder.layer.{i}.attention.attention.query.bias"] = q_bias
+ state_dict[f"backbone.encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[
+ hidden_size : hidden_size * 2, :
+ ]
+ state_dict[f"backbone.encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[-hidden_size:, :]
+ state_dict[f"backbone.encoder.layer.{i}.attention.attention.value.bias"] = v_bias
+
+
+def rename_key(dct, old, new):
+ val = dct.pop(old)
+ dct[new] = val
+
+
+# We will verify our results on an image of cute cats
+def prepare_img():
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ im = Image.open(requests.get(url, stream=True).raw)
+ return im
+
+
+@torch.no_grad()
+def convert_dpt_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub):
+ """
+ Copy/paste/tweak model's weights to our DPT structure.
+ """
+
+ name_to_url = {
+ "dpt-beit-large-512": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_512.pt",
+ "dpt-beit-large-384": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_384.pt",
+ "dpt-beit-base-384": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_base_384.pt",
+ }
+
+ # define DPT configuration based on URL
+ checkpoint_url = name_to_url[model_name]
+ config, image_size = get_dpt_config(model_name)
+ # load original state_dict from URL
+ state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu")
+ # remove certain keys
+ remove_ignore_keys_(state_dict)
+ # rename keys
+ rename_keys = create_rename_keys(config)
+ for src, dest in rename_keys:
+ rename_key(state_dict, src, dest)
+ # read in qkv matrices
+ read_in_q_k_v(state_dict, config)
+
+ # load HuggingFace model
+ model = DPTForDepthEstimation(config)
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
+ print("Missing keys:", missing_keys)
+ print("Unexpected keys:", unexpected_keys)
+ assert missing_keys == []
+ # assert unexpected_keys == ["pretrained.model.fc_norm.weight", "pretrained.model.fc_norm.bias"]
+ model.eval()
+
+ # Check outputs on an image
+ # We set `keep_aspect_ratio=False` as our current BEiT does not support arbitrary window sizes
+ processor = DPTImageProcessor(
+ size={"height": image_size, "width": image_size}, keep_aspect_ratio=False, ensure_multiple_of=32
+ )
+
+ image = prepare_img()
+ pixel_values = processor(image, return_tensors="pt").pixel_values
+
+ print("First values of pixel values:", pixel_values[0, 0, :3, :3])
+ print("Mean of pixel values:", pixel_values.mean().item())
+ print("Shape of pixel values:", pixel_values.shape)
+
+ import requests
+ from PIL import Image
+ from torchvision import transforms
+
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ image = Image.open(requests.get(url, stream=True).raw)
+
+ transforms = transforms.Compose(
+ [
+ transforms.Resize((image_size, image_size)),
+ transforms.ToTensor(),
+ ]
+ )
+ pixel_values = transforms(image).unsqueeze(0)
+
+ # forward pass
+ with torch.no_grad():
+ outputs = model(pixel_values)
+
+ predicted_depth = outputs.predicted_depth
+
+ print("Shape of predicted depth:", predicted_depth.shape)
+ print("First values of predicted depth:", predicted_depth[0, :3, :3])
+
+ # assert logits
+ # TODO there's still a small difference with the original logits
+ if model_name == "dpt-beit-large-512":
+ # OK, checked
+ expected_shape = torch.Size([1, 512, 512])
+ expected_slice = torch.tensor(
+ [[2804.6260, 2792.5708, 2812.9263], [2772.0288, 2780.1118, 2796.2529], [2748.1094, 2766.6558, 2766.9834]]
+ )
+ elif model_name == "dpt-beit-large-384":
+ # OK, checked
+ expected_shape = torch.Size([1, 384, 384])
+ expected_slice = torch.tensor(
+ [[1783.2273, 1780.5729, 1792.6453], [1759.9817, 1765.5359, 1778.5002], [1739.1633, 1754.7903, 1757.1990]],
+ )
+ elif model_name == "dpt-beit-base-384":
+ # OK, checked
+ expected_shape = torch.Size([1, 384, 384])
+ expected_slice = torch.tensor(
+ [[2898.4482, 2891.3750, 2904.8079], [2858.6685, 2877.2615, 2894.4507], [2842.1235, 2854.1023, 2861.6328]],
+ )
+
+ assert predicted_depth.shape == torch.Size(expected_shape)
+ assert torch.allclose(predicted_depth[0, :3, :3], expected_slice)
+ print("Looks ok!")
+
+ if pytorch_dump_folder_path is not None:
+ Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
+ print(f"Saving model and processor to {pytorch_dump_folder_path}")
+ model.save_pretrained(pytorch_dump_folder_path)
+ processor.save_pretrained(pytorch_dump_folder_path)
+
+ if push_to_hub:
+ print("Pushing model and processor to hub...")
+ model.push_to_hub(repo_id=f"nielsr/{model_name}")
+ processor.push_to_hub(repo_id=f"nielsr/{model_name}")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ # Required parameters
+ parser.add_argument(
+ "--model_name",
+ default="dpt-beit-large-512",
+ type=str,
+ choices=["dpt-beit-large-512", "dpt-beit-large-384", "dpt-beit-base-384"],
+ help="Name of the model you'd like to convert.",
+ )
+ parser.add_argument(
+ "--pytorch_dump_folder_path",
+ default=None,
+ type=str,
+ help="Path to the output PyTorch model directory.",
+ )
+ parser.add_argument(
+ "--push_to_hub",
+ action="store_true",
+ help="Whether to push the model to the hub after conversion.",
+ )
+
+ args = parser.parse_args()
+ convert_dpt_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub)
diff --git a/docs/transformers/build/lib/transformers/models/dpt/convert_dpt_hybrid_to_pytorch.py b/docs/transformers/build/lib/transformers/models/dpt/convert_dpt_hybrid_to_pytorch.py
new file mode 100644
index 0000000000000000000000000000000000000000..ceae9b84711463261f20e70b7cd174437a1666cd
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/dpt/convert_dpt_hybrid_to_pytorch.py
@@ -0,0 +1,315 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Convert DPT checkpoints from the original repository. URL: https://github.com/isl-org/DPT"""
+
+import argparse
+import json
+from pathlib import Path
+
+import requests
+import torch
+from huggingface_hub import hf_hub_download
+from PIL import Image
+
+from transformers import DPTConfig, DPTForDepthEstimation, DPTForSemanticSegmentation, DPTImageProcessor
+from transformers.utils import logging
+
+
+logging.set_verbosity_info()
+logger = logging.get_logger(__name__)
+
+
+def get_dpt_config(checkpoint_url):
+ config = DPTConfig(embedding_type="hybrid")
+
+ if "large" in checkpoint_url:
+ config.hidden_size = 1024
+ config.intermediate_size = 4096
+ config.num_hidden_layers = 24
+ config.num_attention_heads = 16
+ config.backbone_out_indices = [5, 11, 17, 23]
+ config.neck_hidden_sizes = [256, 512, 1024, 1024]
+ expected_shape = (1, 384, 384)
+
+ if "nyu" in checkpoint_url or "midas" in checkpoint_url:
+ config.hidden_size = 768
+ config.reassemble_factors = [1, 1, 1, 0.5]
+ config.neck_hidden_sizes = [256, 512, 768, 768]
+ config.num_labels = 150
+ config.patch_size = 16
+ expected_shape = (1, 384, 384)
+ config.use_batch_norm_in_fusion_residual = False
+ config.readout_type = "project"
+
+ if "ade" in checkpoint_url:
+ config.use_batch_norm_in_fusion_residual = True
+ config.hidden_size = 768
+ config.reassemble_stage = [1, 1, 1, 0.5]
+ config.num_labels = 150
+ config.patch_size = 16
+ repo_id = "huggingface/label-files"
+ filename = "ade20k-id2label.json"
+ id2label = json.loads(Path(hf_hub_download(repo_id, filename, repo_type="dataset")).read_text())
+ id2label = {int(k): v for k, v in id2label.items()}
+ config.id2label = id2label
+ config.label2id = {v: k for k, v in id2label.items()}
+ expected_shape = [1, 150, 480, 480]
+
+ return config, expected_shape
+
+
+def remove_ignore_keys_(state_dict):
+ ignore_keys = ["pretrained.model.head.weight", "pretrained.model.head.bias"]
+ for k in ignore_keys:
+ state_dict.pop(k, None)
+
+
+def rename_key(name):
+ if (
+ "pretrained.model" in name
+ and "cls_token" not in name
+ and "pos_embed" not in name
+ and "patch_embed" not in name
+ ):
+ name = name.replace("pretrained.model", "dpt.encoder")
+ if "pretrained.model" in name:
+ name = name.replace("pretrained.model", "dpt.embeddings")
+ if "patch_embed" in name:
+ name = name.replace("patch_embed", "")
+ if "pos_embed" in name:
+ name = name.replace("pos_embed", "position_embeddings")
+ if "attn.proj" in name:
+ name = name.replace("attn.proj", "attention.output.dense")
+ if "proj" in name and "project" not in name:
+ name = name.replace("proj", "projection")
+ if "blocks" in name:
+ name = name.replace("blocks", "layer")
+ if "mlp.fc1" in name:
+ name = name.replace("mlp.fc1", "intermediate.dense")
+ if "mlp.fc2" in name:
+ name = name.replace("mlp.fc2", "output.dense")
+ if "norm1" in name and "backbone" not in name:
+ name = name.replace("norm1", "layernorm_before")
+ if "norm2" in name and "backbone" not in name:
+ name = name.replace("norm2", "layernorm_after")
+ if "scratch.output_conv" in name:
+ name = name.replace("scratch.output_conv", "head")
+ if "scratch" in name:
+ name = name.replace("scratch", "neck")
+ if "layer1_rn" in name:
+ name = name.replace("layer1_rn", "convs.0")
+ if "layer2_rn" in name:
+ name = name.replace("layer2_rn", "convs.1")
+ if "layer3_rn" in name:
+ name = name.replace("layer3_rn", "convs.2")
+ if "layer4_rn" in name:
+ name = name.replace("layer4_rn", "convs.3")
+ if "refinenet" in name:
+ layer_idx = int(name[len("neck.refinenet") : len("neck.refinenet") + 1])
+ # tricky here: we need to map 4 to 0, 3 to 1, 2 to 2 and 1 to 3
+ name = name.replace(f"refinenet{layer_idx}", f"fusion_stage.layers.{abs(layer_idx - 4)}")
+ if "out_conv" in name:
+ name = name.replace("out_conv", "projection")
+ if "resConfUnit1" in name:
+ name = name.replace("resConfUnit1", "residual_layer1")
+ if "resConfUnit2" in name:
+ name = name.replace("resConfUnit2", "residual_layer2")
+ if "conv1" in name:
+ name = name.replace("conv1", "convolution1")
+ if "conv2" in name:
+ name = name.replace("conv2", "convolution2")
+ # readout blocks
+ if "pretrained.act_postprocess1.0.project.0" in name:
+ name = name.replace("pretrained.act_postprocess1.0.project.0", "neck.reassemble_stage.readout_projects.0.0")
+ if "pretrained.act_postprocess2.0.project.0" in name:
+ name = name.replace("pretrained.act_postprocess2.0.project.0", "neck.reassemble_stage.readout_projects.1.0")
+ if "pretrained.act_postprocess3.0.project.0" in name:
+ name = name.replace("pretrained.act_postprocess3.0.project.0", "neck.reassemble_stage.readout_projects.2.0")
+ if "pretrained.act_postprocess4.0.project.0" in name:
+ name = name.replace("pretrained.act_postprocess4.0.project.0", "neck.reassemble_stage.readout_projects.3.0")
+
+ # resize blocks
+ if "pretrained.act_postprocess1.3" in name:
+ name = name.replace("pretrained.act_postprocess1.3", "neck.reassemble_stage.layers.0.projection")
+ if "pretrained.act_postprocess1.4" in name:
+ name = name.replace("pretrained.act_postprocess1.4", "neck.reassemble_stage.layers.0.resize")
+ if "pretrained.act_postprocess2.3" in name:
+ name = name.replace("pretrained.act_postprocess2.3", "neck.reassemble_stage.layers.1.projection")
+ if "pretrained.act_postprocess2.4" in name:
+ name = name.replace("pretrained.act_postprocess2.4", "neck.reassemble_stage.layers.1.resize")
+ if "pretrained.act_postprocess3.3" in name:
+ name = name.replace("pretrained.act_postprocess3.3", "neck.reassemble_stage.layers.2.projection")
+ if "pretrained.act_postprocess4.3" in name:
+ name = name.replace("pretrained.act_postprocess4.3", "neck.reassemble_stage.layers.3.projection")
+ if "pretrained.act_postprocess4.4" in name:
+ name = name.replace("pretrained.act_postprocess4.4", "neck.reassemble_stage.layers.3.resize")
+ if "pretrained" in name:
+ name = name.replace("pretrained", "dpt")
+ if "bn" in name:
+ name = name.replace("bn", "batch_norm")
+ if "head" in name:
+ name = name.replace("head", "head.head")
+ if "encoder.norm" in name:
+ name = name.replace("encoder.norm", "layernorm")
+ if "auxlayer" in name:
+ name = name.replace("auxlayer", "auxiliary_head.head")
+ if "backbone" in name:
+ name = name.replace("backbone", "backbone.bit.encoder")
+
+ if ".." in name:
+ name = name.replace("..", ".")
+
+ if "stem.conv" in name:
+ name = name.replace("stem.conv", "bit.embedder.convolution")
+ if "blocks" in name:
+ name = name.replace("blocks", "layers")
+ if "convolution" in name and "backbone" in name:
+ name = name.replace("convolution", "conv")
+ if "layer" in name and "backbone" in name:
+ name = name.replace("layer", "layers")
+ if "backbone.bit.encoder.bit" in name:
+ name = name.replace("backbone.bit.encoder.bit", "backbone.bit")
+ if "embedder.conv" in name:
+ name = name.replace("embedder.conv", "embedder.convolution")
+ if "backbone.bit.encoder.stem.norm" in name:
+ name = name.replace("backbone.bit.encoder.stem.norm", "backbone.bit.embedder.norm")
+ return name
+
+
+# we split up the matrix of each encoder layer into queries, keys and values
+def read_in_q_k_v(state_dict, config):
+ for i in range(config.num_hidden_layers):
+ # read in weights + bias of input projection layer (in timm, this is a single matrix + bias)
+ in_proj_weight = state_dict.pop(f"dpt.encoder.layer.{i}.attn.qkv.weight")
+ in_proj_bias = state_dict.pop(f"dpt.encoder.layer.{i}.attn.qkv.bias")
+ # next, add query, keys and values (in that order) to the state dict
+ state_dict[f"dpt.encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[: config.hidden_size, :]
+ state_dict[f"dpt.encoder.layer.{i}.attention.attention.query.bias"] = in_proj_bias[: config.hidden_size]
+ state_dict[f"dpt.encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[
+ config.hidden_size : config.hidden_size * 2, :
+ ]
+ state_dict[f"dpt.encoder.layer.{i}.attention.attention.key.bias"] = in_proj_bias[
+ config.hidden_size : config.hidden_size * 2
+ ]
+ state_dict[f"dpt.encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[
+ -config.hidden_size :, :
+ ]
+ state_dict[f"dpt.encoder.layer.{i}.attention.attention.value.bias"] = in_proj_bias[-config.hidden_size :]
+
+
+# We will verify our results on an image of cute cats
+def prepare_img():
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ im = Image.open(requests.get(url, stream=True).raw)
+ return im
+
+
+@torch.no_grad()
+def convert_dpt_checkpoint(checkpoint_url, pytorch_dump_folder_path, push_to_hub, model_name, show_prediction):
+ """
+ Copy/paste/tweak model's weights to our DPT structure.
+ """
+
+ # define DPT configuration based on URL
+ config, expected_shape = get_dpt_config(checkpoint_url)
+ # load original state_dict from URL
+ # state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu")
+ state_dict = torch.load(checkpoint_url, map_location="cpu", weights_only=True)
+ # remove certain keys
+ remove_ignore_keys_(state_dict)
+ # rename keys
+ for key in state_dict.copy().keys():
+ val = state_dict.pop(key)
+ state_dict[rename_key(key)] = val
+ # read in qkv matrices
+ read_in_q_k_v(state_dict, config)
+
+ # load HuggingFace model
+ model = DPTForSemanticSegmentation(config) if "ade" in checkpoint_url else DPTForDepthEstimation(config)
+ model.load_state_dict(state_dict)
+ model.eval()
+
+ # Check outputs on an image
+ size = 480 if "ade" in checkpoint_url else 384
+ image_processor = DPTImageProcessor(size=size)
+
+ image = prepare_img()
+ encoding = image_processor(image, return_tensors="pt")
+
+ # forward pass
+ outputs = model(**encoding).logits if "ade" in checkpoint_url else model(**encoding).predicted_depth
+
+ if show_prediction:
+ prediction = (
+ torch.nn.functional.interpolate(
+ outputs.unsqueeze(1),
+ size=(image.size[1], image.size[0]),
+ mode="bicubic",
+ align_corners=False,
+ )
+ .squeeze()
+ .cpu()
+ .numpy()
+ )
+
+ Image.fromarray((prediction / prediction.max()) * 255).show()
+
+ if pytorch_dump_folder_path is not None:
+ Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
+ print(f"Saving model to {pytorch_dump_folder_path}")
+ model.save_pretrained(pytorch_dump_folder_path)
+ print(f"Saving image processor to {pytorch_dump_folder_path}")
+ image_processor.save_pretrained(pytorch_dump_folder_path)
+
+ if push_to_hub:
+ model.push_to_hub("ybelkada/dpt-hybrid-midas")
+ image_processor.push_to_hub("ybelkada/dpt-hybrid-midas")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ # Required parameters
+ parser.add_argument(
+ "--checkpoint_url",
+ default="https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt",
+ type=str,
+ help="URL of the original DPT checkpoint you'd like to convert.",
+ )
+ parser.add_argument(
+ "--pytorch_dump_folder_path",
+ default=None,
+ type=str,
+ required=False,
+ help="Path to the output PyTorch model directory.",
+ )
+ parser.add_argument(
+ "--push_to_hub",
+ action="store_true",
+ )
+ parser.add_argument(
+ "--model_name",
+ default="dpt-large",
+ type=str,
+ help="Name of the model, in case you're pushing to the hub.",
+ )
+ parser.add_argument(
+ "--show_prediction",
+ action="store_true",
+ )
+
+ args = parser.parse_args()
+ convert_dpt_checkpoint(
+ args.checkpoint_url, args.pytorch_dump_folder_path, args.push_to_hub, args.model_name, args.show_prediction
+ )
diff --git a/docs/transformers/build/lib/transformers/models/dpt/convert_dpt_swinv2_to_hf.py b/docs/transformers/build/lib/transformers/models/dpt/convert_dpt_swinv2_to_hf.py
new file mode 100644
index 0000000000000000000000000000000000000000..0feebe72d47419b1bce08a25f85fda81c3822210
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/dpt/convert_dpt_swinv2_to_hf.py
@@ -0,0 +1,321 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Convert DPT 3.1 checkpoints from the MiDaS repository. URL: https://github.com/isl-org/MiDaS"""
+
+import argparse
+from pathlib import Path
+
+import requests
+import torch
+from PIL import Image
+
+from transformers import DPTConfig, DPTForDepthEstimation, DPTImageProcessor, Swinv2Config
+from transformers.utils import logging
+
+
+logging.set_verbosity_info()
+logger = logging.get_logger(__name__)
+
+
+def get_dpt_config(model_name):
+ if "tiny" in model_name:
+ embed_dim = 96
+ depths = (2, 2, 6, 2)
+ num_heads = (3, 6, 12, 24)
+ window_size = 16
+ # note: for Swinv2-tiny authors used the window_size = 16 variant
+ # as seen here: https://github.com/isl-org/MiDaS/blob/bdc4ed64c095e026dc0a2f17cabb14d58263decb/midas/backbones/swin2.py#L26
+ pretrained_window_sizes = (0, 0, 0, 0)
+ elif "base" in model_name:
+ embed_dim = 128
+ depths = (2, 2, 18, 2)
+ num_heads = (4, 8, 16, 32)
+ window_size = 24
+ pretrained_window_sizes = (12, 12, 12, 6)
+ elif "large" in model_name:
+ embed_dim = 192
+ depths = (2, 2, 18, 2)
+ num_heads = (6, 12, 24, 48)
+ window_size = 24
+ pretrained_window_sizes = (12, 12, 12, 6)
+
+ if "384" in model_name:
+ image_size = 384
+ elif "256" in model_name:
+ image_size = 256
+ else:
+ raise ValueError("Model not supported, to do")
+
+ backbone_config = Swinv2Config(
+ image_size=image_size,
+ embed_dim=embed_dim,
+ depths=depths,
+ window_size=window_size,
+ pretrained_window_sizes=pretrained_window_sizes,
+ num_heads=num_heads,
+ out_features=["stage1", "stage2", "stage3", "stage4"],
+ )
+
+ if model_name == "dpt-swinv2-tiny-256":
+ neck_hidden_sizes = [96, 192, 384, 768]
+ elif model_name == "dpt-swinv2-base-384":
+ neck_hidden_sizes = [128, 256, 512, 1024]
+ elif model_name == "dpt-swinv2-large-384":
+ neck_hidden_sizes = [192, 384, 768, 1536]
+
+ config = DPTConfig(backbone_config=backbone_config, neck_hidden_sizes=neck_hidden_sizes)
+
+ return config, image_size
+
+
+# here we list all keys to be renamed (original name on the left, our name on the right)
+def create_rename_keys(config):
+ rename_keys = []
+
+ # fmt: off
+ # stem
+ rename_keys.append(("pretrained.model.patch_embed.proj.weight", "backbone.embeddings.patch_embeddings.projection.weight"))
+ rename_keys.append(("pretrained.model.patch_embed.proj.bias", "backbone.embeddings.patch_embeddings.projection.bias"))
+ rename_keys.append(("pretrained.model.patch_embed.norm.weight", "backbone.embeddings.norm.weight"))
+ rename_keys.append(("pretrained.model.patch_embed.norm.bias", "backbone.embeddings.norm.bias"))
+
+ # transformer encoder
+ for i in range(len(config.backbone_config.depths)):
+ for j in range(config.backbone_config.depths[i]):
+ rename_keys.append((f"pretrained.model.layers.{i}.blocks.{j}.attn.logit_scale", f"backbone.encoder.layers.{i}.blocks.{j}.attention.self.logit_scale"))
+ rename_keys.append((f"pretrained.model.layers.{i}.blocks.{j}.attn.cpb_mlp.0.weight", f"backbone.encoder.layers.{i}.blocks.{j}.attention.self.continuous_position_bias_mlp.0.weight"))
+ rename_keys.append((f"pretrained.model.layers.{i}.blocks.{j}.attn.cpb_mlp.0.bias", f"backbone.encoder.layers.{i}.blocks.{j}.attention.self.continuous_position_bias_mlp.0.bias"))
+ rename_keys.append((f"pretrained.model.layers.{i}.blocks.{j}.attn.cpb_mlp.2.weight", f"backbone.encoder.layers.{i}.blocks.{j}.attention.self.continuous_position_bias_mlp.2.weight"))
+ rename_keys.append((f"pretrained.model.layers.{i}.blocks.{j}.attn.q_bias", f"backbone.encoder.layers.{i}.blocks.{j}.attention.self.query.bias"))
+ rename_keys.append((f"pretrained.model.layers.{i}.blocks.{j}.attn.v_bias", f"backbone.encoder.layers.{i}.blocks.{j}.attention.self.value.bias"))
+ rename_keys.append((f"pretrained.model.layers.{i}.blocks.{j}.attn.proj.weight", f"backbone.encoder.layers.{i}.blocks.{j}.attention.output.dense.weight"))
+ rename_keys.append((f"pretrained.model.layers.{i}.blocks.{j}.attn.proj.bias", f"backbone.encoder.layers.{i}.blocks.{j}.attention.output.dense.bias"))
+ rename_keys.append((f"pretrained.model.layers.{i}.blocks.{j}.norm1.weight", f"backbone.encoder.layers.{i}.blocks.{j}.layernorm_before.weight"))
+ rename_keys.append((f"pretrained.model.layers.{i}.blocks.{j}.norm1.bias", f"backbone.encoder.layers.{i}.blocks.{j}.layernorm_before.bias"))
+ rename_keys.append((f"pretrained.model.layers.{i}.blocks.{j}.mlp.fc1.weight", f"backbone.encoder.layers.{i}.blocks.{j}.intermediate.dense.weight"))
+ rename_keys.append((f"pretrained.model.layers.{i}.blocks.{j}.mlp.fc1.bias", f"backbone.encoder.layers.{i}.blocks.{j}.intermediate.dense.bias"))
+ rename_keys.append((f"pretrained.model.layers.{i}.blocks.{j}.mlp.fc2.weight", f"backbone.encoder.layers.{i}.blocks.{j}.output.dense.weight"))
+ rename_keys.append((f"pretrained.model.layers.{i}.blocks.{j}.mlp.fc2.bias", f"backbone.encoder.layers.{i}.blocks.{j}.output.dense.bias"))
+ rename_keys.append((f"pretrained.model.layers.{i}.blocks.{j}.norm2.weight", f"backbone.encoder.layers.{i}.blocks.{j}.layernorm_after.weight"))
+ rename_keys.append((f"pretrained.model.layers.{i}.blocks.{j}.norm2.bias", f"backbone.encoder.layers.{i}.blocks.{j}.layernorm_after.bias"))
+
+ # downsample parameters
+ if i in [0,1,2]:
+ rename_keys.append((f"pretrained.model.layers.{i}.downsample.reduction.weight", f"backbone.encoder.layers.{i}.downsample.reduction.weight"))
+ rename_keys.append((f"pretrained.model.layers.{i}.downsample.norm.weight", f"backbone.encoder.layers.{i}.downsample.norm.weight"))
+ rename_keys.append((f"pretrained.model.layers.{i}.downsample.norm.bias", f"backbone.encoder.layers.{i}.downsample.norm.bias"))
+
+ # note: non-Transformer backbones like Swinv2, LeViT et al don't require activation postprocessing (readout projections + resize blocks)
+
+ # refinenet (tricky here)
+ mapping = {1:3, 2:2, 3:1, 4:0}
+
+ for i in range(1, 5):
+ j = mapping[i]
+ rename_keys.append((f"scratch.refinenet{i}.out_conv.weight", f"neck.fusion_stage.layers.{j}.projection.weight"))
+ rename_keys.append((f"scratch.refinenet{i}.out_conv.bias", f"neck.fusion_stage.layers.{j}.projection.bias"))
+ rename_keys.append((f"scratch.refinenet{i}.resConfUnit1.conv1.weight", f"neck.fusion_stage.layers.{j}.residual_layer1.convolution1.weight"))
+ rename_keys.append((f"scratch.refinenet{i}.resConfUnit1.conv1.bias", f"neck.fusion_stage.layers.{j}.residual_layer1.convolution1.bias"))
+ rename_keys.append((f"scratch.refinenet{i}.resConfUnit1.conv2.weight", f"neck.fusion_stage.layers.{j}.residual_layer1.convolution2.weight"))
+ rename_keys.append((f"scratch.refinenet{i}.resConfUnit1.conv2.bias", f"neck.fusion_stage.layers.{j}.residual_layer1.convolution2.bias"))
+ rename_keys.append((f"scratch.refinenet{i}.resConfUnit2.conv1.weight", f"neck.fusion_stage.layers.{j}.residual_layer2.convolution1.weight"))
+ rename_keys.append((f"scratch.refinenet{i}.resConfUnit2.conv1.bias", f"neck.fusion_stage.layers.{j}.residual_layer2.convolution1.bias"))
+ rename_keys.append((f"scratch.refinenet{i}.resConfUnit2.conv2.weight", f"neck.fusion_stage.layers.{j}.residual_layer2.convolution2.weight"))
+ rename_keys.append((f"scratch.refinenet{i}.resConfUnit2.conv2.bias", f"neck.fusion_stage.layers.{j}.residual_layer2.convolution2.bias"))
+
+ # scratch convolutions
+ for i in range(4):
+ rename_keys.append((f"scratch.layer{i+1}_rn.weight", f"neck.convs.{i}.weight"))
+
+ # head
+ for i in range(0, 5, 2):
+ rename_keys.append((f"scratch.output_conv.{i}.weight", f"head.head.{i}.weight"))
+ rename_keys.append((f"scratch.output_conv.{i}.bias", f"head.head.{i}.bias"))
+
+ return rename_keys
+
+
+def remove_ignore_keys_(state_dict):
+ ignore_keys = ["pretrained.model.head.weight", "pretrained.model.head.bias"]
+ for k in ignore_keys:
+ state_dict.pop(k, None)
+
+
+# we split up the matrix of each encoder layer into queries, keys and values
+def read_in_q_k_v(state_dict, config, model):
+ for i in range(len(config.backbone_config.depths)):
+ for j in range(config.backbone_config.depths[i]):
+ dim = model.backbone.encoder.layers[i].blocks[j].attention.self.all_head_size
+ # read in weights + bias of input projection layer (in original implementation, this is a single matrix + bias)
+ in_proj_weight = state_dict.pop(f"pretrained.model.layers.{i}.blocks.{j}.attn.qkv.weight")
+ # next, add query, keys and values (in that order) to the state dict
+ state_dict[f"backbone.encoder.layers.{i}.blocks.{j}.attention.self.query.weight"] = in_proj_weight[:dim, :]
+ state_dict[f"backbone.encoder.layers.{i}.blocks.{j}.attention.self.key.weight"] = in_proj_weight[
+ dim : dim * 2, :
+ ]
+ state_dict[f"backbone.encoder.layers.{i}.blocks.{j}.attention.self.value.weight"] = in_proj_weight[
+ -dim:, :
+ ]
+
+
+def rename_key(dct, old, new):
+ val = dct.pop(old)
+ dct[new] = val
+
+
+# We will verify our results on an image of cute cats
+def prepare_img():
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ im = Image.open(requests.get(url, stream=True).raw)
+ return im
+
+
+@torch.no_grad()
+def convert_dpt_checkpoint(model_name, pytorch_dump_folder_path, verify_logits, push_to_hub):
+ """
+ Copy/paste/tweak model's weights to our DPT structure.
+ """
+
+ name_to_url = {
+ "dpt-swinv2-tiny-256": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_tiny_256.pt",
+ "dpt-swinv2-base-384": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_base_384.pt",
+ "dpt-swinv2-large-384": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_large_384.pt",
+ }
+
+ # define DPT configuration based on URL
+ checkpoint_url = name_to_url[model_name]
+ config, image_size = get_dpt_config(model_name)
+ # load original state_dict from URL
+ state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu")
+
+ # load HuggingFace model
+ model = DPTForDepthEstimation(config)
+
+ # remove certain keys
+ remove_ignore_keys_(state_dict)
+ # rename keys
+ rename_keys = create_rename_keys(config)
+ for src, dest in rename_keys:
+ rename_key(state_dict, src, dest)
+ # read in qkv matrices
+ read_in_q_k_v(state_dict, config, model)
+
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
+ print("Missing keys:", missing_keys)
+ print("Unexpected keys:", unexpected_keys)
+ model.eval()
+
+ # Check outputs on an image
+ processor = DPTImageProcessor(size={"height": image_size, "width": image_size})
+
+ image = prepare_img()
+ processor(image, return_tensors="pt")
+
+ if verify_logits:
+ from torchvision import transforms
+
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ image = Image.open(requests.get(url, stream=True).raw)
+
+ transforms = transforms.Compose(
+ [
+ transforms.Resize((image_size, image_size)),
+ transforms.ToTensor(),
+ ]
+ )
+ pixel_values = transforms(image).unsqueeze(0)
+
+ # forward pass
+ with torch.no_grad():
+ outputs = model(pixel_values)
+
+ predicted_depth = outputs.predicted_depth
+
+ print("Shape of predicted depth:", predicted_depth.shape)
+ print("First values of predicted depth:", predicted_depth[0, :3, :3])
+
+ # assert logits
+ if model_name == "dpt-swinv2-base-384":
+ # OK, checked
+ expected_shape = torch.Size([1, 384, 384])
+ expected_slice = torch.tensor(
+ [
+ [1998.5575, 1997.3887, 2009.2981],
+ [1952.8607, 1979.6488, 2001.0854],
+ [1953.7697, 1961.7711, 1968.8904],
+ ],
+ )
+ elif model_name == "dpt-swinv2-tiny-256":
+ # OK, checked
+ expected_shape = torch.Size([1, 256, 256])
+ expected_slice = torch.tensor(
+ [[978.9163, 976.5215, 978.5349], [974.1859, 971.7249, 975.8046], [971.3419, 970.3118, 971.6830]],
+ )
+ elif model_name == "dpt-swinv2-large-384":
+ # OK, checked
+ expected_shape = torch.Size([1, 384, 384])
+ expected_slice = torch.tensor(
+ [
+ [1203.7206, 1200.1495, 1197.8234],
+ [1196.2484, 1183.5033, 1186.4640],
+ [1178.8131, 1182.3260, 1174.3975],
+ ],
+ )
+
+ assert predicted_depth.shape == torch.Size(expected_shape)
+ assert torch.allclose(predicted_depth[0, :3, :3], expected_slice)
+ print("Looks ok!")
+
+ if pytorch_dump_folder_path is not None:
+ Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
+ print(f"Saving model and processor to {pytorch_dump_folder_path}")
+ model.save_pretrained(pytorch_dump_folder_path)
+ processor.save_pretrained(pytorch_dump_folder_path)
+
+ if push_to_hub:
+ print("Pushing model and processor to hub...")
+ model.push_to_hub(repo_id=f"Intel/{model_name}")
+ processor.push_to_hub(repo_id=f"Intel/{model_name}")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ # Required parameters
+ parser.add_argument(
+ "--model_name",
+ default="dpt-swinv2-base-384",
+ type=str,
+ choices=["dpt-swinv2-tiny-256", "dpt-swinv2-base-384", "dpt-swinv2-large-384"],
+ help="Name of the model you'd like to convert.",
+ )
+ parser.add_argument(
+ "--pytorch_dump_folder_path",
+ default=None,
+ type=str,
+ help="Path to the output PyTorch model directory.",
+ )
+ parser.add_argument(
+ "--verify_logits",
+ action="store_true",
+ help="Whether to verify logits after conversion.",
+ )
+ parser.add_argument(
+ "--push_to_hub",
+ action="store_true",
+ help="Whether to push the model to the hub after conversion.",
+ )
+
+ args = parser.parse_args()
+ convert_dpt_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.verify_logits, args.push_to_hub)
diff --git a/docs/transformers/build/lib/transformers/models/dpt/convert_dpt_to_pytorch.py b/docs/transformers/build/lib/transformers/models/dpt/convert_dpt_to_pytorch.py
new file mode 100644
index 0000000000000000000000000000000000000000..55e0a444e857b0d386e685845bb4842c9f8794d7
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/dpt/convert_dpt_to_pytorch.py
@@ -0,0 +1,285 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Convert DPT checkpoints from the original repository. URL: https://github.com/isl-org/DPT"""
+
+import argparse
+import json
+from pathlib import Path
+
+import requests
+import torch
+from huggingface_hub import hf_hub_download
+from PIL import Image
+
+from transformers import DPTConfig, DPTForDepthEstimation, DPTForSemanticSegmentation, DPTImageProcessor
+from transformers.utils import logging
+
+
+logging.set_verbosity_info()
+logger = logging.get_logger(__name__)
+
+
+def get_dpt_config(checkpoint_url):
+ config = DPTConfig()
+
+ if "large" in checkpoint_url:
+ config.hidden_size = 1024
+ config.intermediate_size = 4096
+ config.num_hidden_layers = 24
+ config.num_attention_heads = 16
+ config.backbone_out_indices = [5, 11, 17, 23]
+ config.neck_hidden_sizes = [256, 512, 1024, 1024]
+ expected_shape = (1, 384, 384)
+
+ if "ade" in checkpoint_url:
+ config.use_batch_norm_in_fusion_residual = True
+
+ config.num_labels = 150
+ repo_id = "huggingface/label-files"
+ filename = "ade20k-id2label.json"
+ id2label = json.loads(Path(hf_hub_download(repo_id, filename, repo_type="dataset")).read_text())
+ id2label = {int(k): v for k, v in id2label.items()}
+ config.id2label = id2label
+ config.label2id = {v: k for k, v in id2label.items()}
+ expected_shape = [1, 150, 480, 480]
+
+ return config, expected_shape
+
+
+def remove_ignore_keys_(state_dict):
+ ignore_keys = ["pretrained.model.head.weight", "pretrained.model.head.bias"]
+ for k in ignore_keys:
+ state_dict.pop(k, None)
+
+
+def rename_key(name):
+ if (
+ "pretrained.model" in name
+ and "cls_token" not in name
+ and "pos_embed" not in name
+ and "patch_embed" not in name
+ ):
+ name = name.replace("pretrained.model", "dpt.encoder")
+ if "pretrained.model" in name:
+ name = name.replace("pretrained.model", "dpt.embeddings")
+ if "patch_embed" in name:
+ name = name.replace("patch_embed", "patch_embeddings")
+ if "pos_embed" in name:
+ name = name.replace("pos_embed", "position_embeddings")
+ if "attn.proj" in name:
+ name = name.replace("attn.proj", "attention.output.dense")
+ if "proj" in name and "project" not in name:
+ name = name.replace("proj", "projection")
+ if "blocks" in name:
+ name = name.replace("blocks", "layer")
+ if "mlp.fc1" in name:
+ name = name.replace("mlp.fc1", "intermediate.dense")
+ if "mlp.fc2" in name:
+ name = name.replace("mlp.fc2", "output.dense")
+ if "norm1" in name:
+ name = name.replace("norm1", "layernorm_before")
+ if "norm2" in name:
+ name = name.replace("norm2", "layernorm_after")
+ if "scratch.output_conv" in name:
+ name = name.replace("scratch.output_conv", "head")
+ if "scratch" in name:
+ name = name.replace("scratch", "neck")
+ if "layer1_rn" in name:
+ name = name.replace("layer1_rn", "convs.0")
+ if "layer2_rn" in name:
+ name = name.replace("layer2_rn", "convs.1")
+ if "layer3_rn" in name:
+ name = name.replace("layer3_rn", "convs.2")
+ if "layer4_rn" in name:
+ name = name.replace("layer4_rn", "convs.3")
+ if "refinenet" in name:
+ layer_idx = int(name[len("neck.refinenet") : len("neck.refinenet") + 1])
+ # tricky here: we need to map 4 to 0, 3 to 1, 2 to 2 and 1 to 3
+ name = name.replace(f"refinenet{layer_idx}", f"fusion_stage.layers.{abs(layer_idx - 4)}")
+ if "out_conv" in name:
+ name = name.replace("out_conv", "projection")
+ if "resConfUnit1" in name:
+ name = name.replace("resConfUnit1", "residual_layer1")
+ if "resConfUnit2" in name:
+ name = name.replace("resConfUnit2", "residual_layer2")
+ if "conv1" in name:
+ name = name.replace("conv1", "convolution1")
+ if "conv2" in name:
+ name = name.replace("conv2", "convolution2")
+ # readout blocks
+ if "pretrained.act_postprocess1.0.project.0" in name:
+ name = name.replace("pretrained.act_postprocess1.0.project.0", "neck.reassemble_stage.readout_projects.0.0")
+ if "pretrained.act_postprocess2.0.project.0" in name:
+ name = name.replace("pretrained.act_postprocess2.0.project.0", "neck.reassemble_stage.readout_projects.1.0")
+ if "pretrained.act_postprocess3.0.project.0" in name:
+ name = name.replace("pretrained.act_postprocess3.0.project.0", "neck.reassemble_stage.readout_projects.2.0")
+ if "pretrained.act_postprocess4.0.project.0" in name:
+ name = name.replace("pretrained.act_postprocess4.0.project.0", "neck.reassemble_stage.readout_projects.3.0")
+ # resize blocks
+ if "pretrained.act_postprocess1.3" in name:
+ name = name.replace("pretrained.act_postprocess1.3", "neck.reassemble_stage.layers.0.projection")
+ if "pretrained.act_postprocess1.4" in name:
+ name = name.replace("pretrained.act_postprocess1.4", "neck.reassemble_stage.layers.0.resize")
+ if "pretrained.act_postprocess2.3" in name:
+ name = name.replace("pretrained.act_postprocess2.3", "neck.reassemble_stage.layers.1.projection")
+ if "pretrained.act_postprocess2.4" in name:
+ name = name.replace("pretrained.act_postprocess2.4", "neck.reassemble_stage.layers.1.resize")
+ if "pretrained.act_postprocess3.3" in name:
+ name = name.replace("pretrained.act_postprocess3.3", "neck.reassemble_stage.layers.2.projection")
+ if "pretrained.act_postprocess4.3" in name:
+ name = name.replace("pretrained.act_postprocess4.3", "neck.reassemble_stage.layers.3.projection")
+ if "pretrained.act_postprocess4.4" in name:
+ name = name.replace("pretrained.act_postprocess4.4", "neck.reassemble_stage.layers.3.resize")
+ if "pretrained" in name:
+ name = name.replace("pretrained", "dpt")
+ if "bn" in name:
+ name = name.replace("bn", "batch_norm")
+ if "head" in name:
+ name = name.replace("head", "head.head")
+ if "encoder.norm" in name:
+ name = name.replace("encoder.norm", "layernorm")
+ if "auxlayer" in name:
+ name = name.replace("auxlayer", "auxiliary_head.head")
+
+ return name
+
+
+# we split up the matrix of each encoder layer into queries, keys and values
+def read_in_q_k_v(state_dict, config):
+ for i in range(config.num_hidden_layers):
+ # read in weights + bias of input projection layer (in timm, this is a single matrix + bias)
+ in_proj_weight = state_dict.pop(f"dpt.encoder.layer.{i}.attn.qkv.weight")
+ in_proj_bias = state_dict.pop(f"dpt.encoder.layer.{i}.attn.qkv.bias")
+ # next, add query, keys and values (in that order) to the state dict
+ state_dict[f"dpt.encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[: config.hidden_size, :]
+ state_dict[f"dpt.encoder.layer.{i}.attention.attention.query.bias"] = in_proj_bias[: config.hidden_size]
+ state_dict[f"dpt.encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[
+ config.hidden_size : config.hidden_size * 2, :
+ ]
+ state_dict[f"dpt.encoder.layer.{i}.attention.attention.key.bias"] = in_proj_bias[
+ config.hidden_size : config.hidden_size * 2
+ ]
+ state_dict[f"dpt.encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[
+ -config.hidden_size :, :
+ ]
+ state_dict[f"dpt.encoder.layer.{i}.attention.attention.value.bias"] = in_proj_bias[-config.hidden_size :]
+
+
+# We will verify our results on an image of cute cats
+def prepare_img():
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ im = Image.open(requests.get(url, stream=True).raw)
+ return im
+
+
+@torch.no_grad()
+def convert_dpt_checkpoint(checkpoint_url, pytorch_dump_folder_path, push_to_hub, model_name):
+ """
+ Copy/paste/tweak model's weights to our DPT structure.
+ """
+
+ # define DPT configuration based on URL
+ config, expected_shape = get_dpt_config(checkpoint_url)
+ # load original state_dict from URL
+ state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu")
+ # remove certain keys
+ remove_ignore_keys_(state_dict)
+ # rename keys
+ for key in state_dict.copy().keys():
+ val = state_dict.pop(key)
+ state_dict[rename_key(key)] = val
+ # read in qkv matrices
+ read_in_q_k_v(state_dict, config)
+
+ # load HuggingFace model
+ model = DPTForSemanticSegmentation(config) if "ade" in checkpoint_url else DPTForDepthEstimation(config)
+ model.load_state_dict(state_dict)
+ model.eval()
+
+ # Check outputs on an image
+ size = 480 if "ade" in checkpoint_url else 384
+ image_processor = DPTImageProcessor(size=size)
+
+ image = prepare_img()
+ encoding = image_processor(image, return_tensors="pt")
+
+ # forward pass
+ outputs = model(**encoding).logits if "ade" in checkpoint_url else model(**encoding).predicted_depth
+
+ # Assert logits
+ expected_slice = torch.tensor([[6.3199, 6.3629, 6.4148], [6.3850, 6.3615, 6.4166], [6.3519, 6.3176, 6.3575]])
+ if "ade" in checkpoint_url:
+ expected_slice = torch.tensor([[4.0480, 4.2420, 4.4360], [4.3124, 4.5693, 4.8261], [4.5768, 4.8965, 5.2163]])
+ assert outputs.shape == torch.Size(expected_shape)
+ assert (
+ torch.allclose(outputs[0, 0, :3, :3], expected_slice, atol=1e-4)
+ if "ade" in checkpoint_url
+ else torch.allclose(outputs[0, :3, :3], expected_slice)
+ )
+ print("Looks ok!")
+
+ if pytorch_dump_folder_path is not None:
+ Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
+ print(f"Saving model to {pytorch_dump_folder_path}")
+ model.save_pretrained(pytorch_dump_folder_path)
+ print(f"Saving image processor to {pytorch_dump_folder_path}")
+ image_processor.save_pretrained(pytorch_dump_folder_path)
+
+ if push_to_hub:
+ print("Pushing model to hub...")
+ model.push_to_hub(
+ repo_path_or_name=Path(pytorch_dump_folder_path, model_name),
+ organization="nielsr",
+ commit_message="Add model",
+ use_temp_dir=True,
+ )
+ image_processor.push_to_hub(
+ repo_path_or_name=Path(pytorch_dump_folder_path, model_name),
+ organization="nielsr",
+ commit_message="Add image processor",
+ use_temp_dir=True,
+ )
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ # Required parameters
+ parser.add_argument(
+ "--checkpoint_url",
+ default="https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt",
+ type=str,
+ help="URL of the original DPT checkpoint you'd like to convert.",
+ )
+ parser.add_argument(
+ "--pytorch_dump_folder_path",
+ default=None,
+ type=str,
+ required=False,
+ help="Path to the output PyTorch model directory.",
+ )
+ parser.add_argument(
+ "--push_to_hub",
+ action="store_true",
+ )
+ parser.add_argument(
+ "--model_name",
+ default="dpt-large",
+ type=str,
+ required=False,
+ help="Name of the model, in case you're pushing to the hub.",
+ )
+
+ args = parser.parse_args()
+ convert_dpt_checkpoint(args.checkpoint_url, args.pytorch_dump_folder_path, args.push_to_hub, args.model_name)
diff --git a/docs/transformers/build/lib/transformers/models/dpt/feature_extraction_dpt.py b/docs/transformers/build/lib/transformers/models/dpt/feature_extraction_dpt.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6ab8ccbed8d33b1e5b15d429b6cb057ff781f78
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/dpt/feature_extraction_dpt.py
@@ -0,0 +1,38 @@
+# coding=utf-8
+# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Feature extractor class for DPT."""
+
+import warnings
+
+from ...utils import logging
+from ...utils.import_utils import requires
+from .image_processing_dpt import DPTImageProcessor
+
+
+logger = logging.get_logger(__name__)
+
+
+@requires(backends=("vision",))
+class DPTFeatureExtractor(DPTImageProcessor):
+ def __init__(self, *args, **kwargs) -> None:
+ warnings.warn(
+ "The class DPTFeatureExtractor is deprecated and will be removed in version 5 of Transformers. Please"
+ " use DPTImageProcessor instead.",
+ FutureWarning,
+ )
+ super().__init__(*args, **kwargs)
+
+
+__all__ = ["DPTFeatureExtractor"]
diff --git a/docs/transformers/build/lib/transformers/models/dpt/image_processing_dpt.py b/docs/transformers/build/lib/transformers/models/dpt/image_processing_dpt.py
new file mode 100644
index 0000000000000000000000000000000000000000..095cd1a48b4f08843c80b48dbff9768f98d17bb7
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/dpt/image_processing_dpt.py
@@ -0,0 +1,680 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Image processor class for DPT."""
+
+import math
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union
+
+from ...utils.import_utils import requires
+
+
+if TYPE_CHECKING:
+ from ...modeling_outputs import DepthEstimatorOutput
+
+import numpy as np
+
+from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
+from ...image_transforms import pad, resize, to_channel_dimension_format
+from ...image_utils import (
+ IMAGENET_STANDARD_MEAN,
+ IMAGENET_STANDARD_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ get_image_size,
+ infer_channel_dimension_format,
+ is_scaled_image,
+ is_torch_available,
+ is_torch_tensor,
+ make_list_of_images,
+ to_numpy_array,
+ valid_images,
+ validate_preprocess_arguments,
+)
+from ...utils import (
+ TensorType,
+ filter_out_non_signature_kwargs,
+ is_vision_available,
+ logging,
+ requires_backends,
+)
+
+
+if is_torch_available():
+ import torch
+
+if is_vision_available():
+ import PIL
+
+
+logger = logging.get_logger(__name__)
+
+
+def get_resize_output_image_size(
+ input_image: np.ndarray,
+ output_size: Union[int, Iterable[int]],
+ keep_aspect_ratio: bool,
+ multiple: int,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+) -> Tuple[int, int]:
+ def constrain_to_multiple_of(val, multiple, min_val=0, max_val=None):
+ x = round(val / multiple) * multiple
+
+ if max_val is not None and x > max_val:
+ x = math.floor(val / multiple) * multiple
+
+ if x < min_val:
+ x = math.ceil(val / multiple) * multiple
+
+ return x
+
+ output_size = (output_size, output_size) if isinstance(output_size, int) else output_size
+
+ input_height, input_width = get_image_size(input_image, input_data_format)
+ output_height, output_width = output_size
+
+ # determine new height and width
+ scale_height = output_height / input_height
+ scale_width = output_width / input_width
+
+ if keep_aspect_ratio:
+ # scale as little as possible
+ if abs(1 - scale_width) < abs(1 - scale_height):
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+
+ new_height = constrain_to_multiple_of(scale_height * input_height, multiple=multiple)
+ new_width = constrain_to_multiple_of(scale_width * input_width, multiple=multiple)
+
+ return (new_height, new_width)
+
+
+@requires(backends=("vision",))
+class DPTImageProcessor(BaseImageProcessor):
+ r"""
+ Constructs a DPT image processor.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the image's (height, width) dimensions. Can be overidden by `do_resize` in `preprocess`.
+ size (`Dict[str, int]` *optional*, defaults to `{"height": 384, "width": 384}`):
+ Size of the image after resizing. Can be overidden by `size` in `preprocess`.
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
+ Defines the resampling filter to use if resizing the image. Can be overidden by `resample` in `preprocess`.
+ keep_aspect_ratio (`bool`, *optional*, defaults to `False`):
+ If `True`, the image is resized to the largest possible size such that the aspect ratio is preserved. Can
+ be overidden by `keep_aspect_ratio` in `preprocess`.
+ ensure_multiple_of (`int`, *optional*, defaults to 1):
+ If `do_resize` is `True`, the image is resized to a size that is a multiple of this value. Can be overidden
+ by `ensure_multiple_of` in `preprocess`.
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overidden by `do_rescale` in
+ `preprocess`.
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
+ Scale factor to use if rescaling the image. Can be overidden by `rescale_factor` in `preprocess`.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
+ method.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
+ image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
+ do_pad (`bool`, *optional*, defaults to `False`):
+ Whether to apply center padding. This was introduced in the DINOv2 paper, which uses the model in
+ combination with DPT.
+ size_divisor (`int`, *optional*):
+ If `do_pad` is `True`, pads the image dimensions to be divisible by this value. This was introduced in the
+ DINOv2 paper, which uses the model in combination with DPT.
+ do_reduce_labels (`bool`, *optional*, defaults to `False`):
+ Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 is
+ used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k). The
+ background label will be replaced by 255. Can be overridden by the `do_reduce_labels` parameter in the
+ `preprocess` method.
+ """
+
+ model_input_names = ["pixel_values"]
+
+ def __init__(
+ self,
+ do_resize: bool = True,
+ size: Dict[str, int] = None,
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ keep_aspect_ratio: bool = False,
+ ensure_multiple_of: int = 1,
+ do_rescale: bool = True,
+ rescale_factor: Union[int, float] = 1 / 255,
+ do_normalize: bool = True,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ do_pad: bool = False,
+ size_divisor: Optional[int] = None,
+ do_reduce_labels: bool = False,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ size = size if size is not None else {"height": 384, "width": 384}
+ size = get_size_dict(size)
+ self.do_resize = do_resize
+ self.size = size
+ self.keep_aspect_ratio = keep_aspect_ratio
+ self.ensure_multiple_of = ensure_multiple_of
+ self.resample = resample
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
+ self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
+ self.do_pad = do_pad
+ self.size_divisor = size_divisor
+ self.do_reduce_labels = do_reduce_labels
+
+ def resize(
+ self,
+ image: np.ndarray,
+ size: Dict[str, int],
+ keep_aspect_ratio: bool = False,
+ ensure_multiple_of: int = 1,
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ) -> np.ndarray:
+ """
+ Resize an image to target size `(size["height"], size["width"])`. If `keep_aspect_ratio` is `True`, the image
+ is resized to the largest possible size such that the aspect ratio is preserved. If `ensure_multiple_of` is
+ set, the image is resized to a size that is a multiple of this value.
+
+ Args:
+ image (`np.ndarray`):
+ Image to resize.
+ size (`Dict[str, int]`):
+ Target size of the output image.
+ keep_aspect_ratio (`bool`, *optional*, defaults to `False`):
+ If `True`, the image is resized to the largest possible size such that the aspect ratio is preserved.
+ ensure_multiple_of (`int`, *optional*, defaults to 1):
+ The image is resized to a size that is a multiple of this value.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
+ Defines the resampling filter to use if resizing the image. Otherwise, the image is resized to size
+ specified in `size`.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
+ Resampling filter to use when resiizing the image.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ input_data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred.
+ """
+ size = get_size_dict(size)
+ if "height" not in size or "width" not in size:
+ raise ValueError(f"The size dictionary must contain the keys 'height' and 'width'. Got {size.keys()}")
+
+ output_size = get_resize_output_image_size(
+ image,
+ output_size=(size["height"], size["width"]),
+ keep_aspect_ratio=keep_aspect_ratio,
+ multiple=ensure_multiple_of,
+ input_data_format=input_data_format,
+ )
+ return resize(
+ image,
+ size=output_size,
+ resample=resample,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ **kwargs,
+ )
+
+ def pad_image(
+ self,
+ image: np.array,
+ size_divisor: int,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ):
+ """
+ Center pad an image to be a multiple of `multiple`.
+
+ Args:
+ image (`np.ndarray`):
+ Image to pad.
+ size_divisor (`int`):
+ The width and height of the image will be padded to a multiple of this number.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - Unset: Use the channel dimension format of the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ """
+
+ def _get_pad(size, size_divisor):
+ new_size = math.ceil(size / size_divisor) * size_divisor
+ pad_size = new_size - size
+ pad_size_left = pad_size // 2
+ pad_size_right = pad_size - pad_size_left
+ return pad_size_left, pad_size_right
+
+ if input_data_format is None:
+ input_data_format = infer_channel_dimension_format(image)
+
+ height, width = get_image_size(image, input_data_format)
+
+ pad_size_left, pad_size_right = _get_pad(height, size_divisor)
+ pad_size_top, pad_size_bottom = _get_pad(width, size_divisor)
+
+ return pad(image, ((pad_size_left, pad_size_right), (pad_size_top, pad_size_bottom)), data_format=data_format)
+
+ # Copied from transformers.models.beit.image_processing_beit.BeitImageProcessor.reduce_label
+ def reduce_label(self, label: ImageInput) -> np.ndarray:
+ label = to_numpy_array(label)
+ # Avoid using underflow conversion
+ label[label == 0] = 255
+ label = label - 1
+ label[label == 254] = 255
+ return label
+
+ def _preprocess(
+ self,
+ image: ImageInput,
+ do_reduce_labels: Optional[bool] = None,
+ do_resize: Optional[bool] = None,
+ size: Dict[str, int] = None,
+ resample: PILImageResampling = None,
+ keep_aspect_ratio: Optional[bool] = None,
+ ensure_multiple_of: Optional[int] = None,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ do_pad: Optional[bool] = None,
+ size_divisor: Optional[int] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ):
+ if do_reduce_labels:
+ image = self.reduce_label(image)
+
+ if do_resize:
+ image = self.resize(
+ image=image,
+ size=size,
+ resample=resample,
+ keep_aspect_ratio=keep_aspect_ratio,
+ ensure_multiple_of=ensure_multiple_of,
+ input_data_format=input_data_format,
+ )
+
+ if do_rescale:
+ image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
+
+ if do_normalize:
+ image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
+
+ if do_pad:
+ image = self.pad_image(image=image, size_divisor=size_divisor, input_data_format=input_data_format)
+
+ return image
+
+ def _preprocess_image(
+ self,
+ image: ImageInput,
+ do_resize: Optional[bool] = None,
+ size: Dict[str, int] = None,
+ resample: PILImageResampling = None,
+ keep_aspect_ratio: Optional[bool] = None,
+ ensure_multiple_of: Optional[int] = None,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ do_pad: Optional[bool] = None,
+ size_divisor: Optional[int] = None,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> np.ndarray:
+ """Preprocesses a single image."""
+ # All transformations expect numpy arrays.
+ image = to_numpy_array(image)
+ if do_rescale and is_scaled_image(image):
+ logger.warning_once(
+ "It looks like you are trying to rescale already rescaled images. If the input"
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
+ )
+ if input_data_format is None:
+ # We assume that all images have the same channel dimension format.
+ input_data_format = infer_channel_dimension_format(image)
+
+ image = self._preprocess(
+ image,
+ do_reduce_labels=False,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ keep_aspect_ratio=keep_aspect_ratio,
+ ensure_multiple_of=ensure_multiple_of,
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ do_pad=do_pad,
+ size_divisor=size_divisor,
+ input_data_format=input_data_format,
+ )
+ if data_format is not None:
+ image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
+ return image
+
+ def _preprocess_segmentation_map(
+ self,
+ segmentation_map: ImageInput,
+ do_resize: Optional[bool] = None,
+ size: Dict[str, int] = None,
+ resample: PILImageResampling = None,
+ keep_aspect_ratio: Optional[bool] = None,
+ ensure_multiple_of: Optional[int] = None,
+ do_reduce_labels: Optional[bool] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ):
+ """Preprocesses a single segmentation map."""
+ # All transformations expect numpy arrays.
+ segmentation_map = to_numpy_array(segmentation_map)
+ # Add an axis to the segmentation maps for transformations.
+ if segmentation_map.ndim == 2:
+ segmentation_map = segmentation_map[None, ...]
+ added_dimension = True
+ input_data_format = ChannelDimension.FIRST
+ else:
+ added_dimension = False
+ if input_data_format is None:
+ input_data_format = infer_channel_dimension_format(segmentation_map, num_channels=1)
+ segmentation_map = self._preprocess(
+ image=segmentation_map,
+ do_reduce_labels=do_reduce_labels,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ keep_aspect_ratio=keep_aspect_ratio,
+ ensure_multiple_of=ensure_multiple_of,
+ do_normalize=False,
+ do_rescale=False,
+ input_data_format=input_data_format,
+ )
+ # Remove extra axis if added
+ if added_dimension:
+ segmentation_map = np.squeeze(segmentation_map, axis=0)
+ segmentation_map = segmentation_map.astype(np.int64)
+ return segmentation_map
+
+ # Copied from transformers.models.beit.image_processing_beit.BeitImageProcessor.__call__
+ def __call__(self, images, segmentation_maps=None, **kwargs):
+ # Overrides the `__call__` method of the `Preprocessor` class such that the images and segmentation maps can both
+ # be passed in as positional arguments.
+ return super().__call__(images, segmentation_maps=segmentation_maps, **kwargs)
+
+ @filter_out_non_signature_kwargs()
+ def preprocess(
+ self,
+ images: ImageInput,
+ segmentation_maps: Optional[ImageInput] = None,
+ do_resize: Optional[bool] = None,
+ size: Optional[int] = None,
+ keep_aspect_ratio: Optional[bool] = None,
+ ensure_multiple_of: Optional[int] = None,
+ resample: PILImageResampling = None,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ do_pad: Optional[bool] = None,
+ size_divisor: Optional[int] = None,
+ do_reduce_labels: Optional[bool] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: ChannelDimension = ChannelDimension.FIRST,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> PIL.Image.Image:
+ """
+ Preprocess an image or batch of images.
+
+ Args:
+ images (`ImageInput`):
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
+ segmentation_maps (`ImageInput`, *optional*):
+ Segmentation map to preprocess.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
+ Size of the image after reszing. If `keep_aspect_ratio` is `True`, the image is resized to the largest
+ possible size such that the aspect ratio is preserved. If `ensure_multiple_of` is set, the image is
+ resized to a size that is a multiple of this value.
+ keep_aspect_ratio (`bool`, *optional*, defaults to `self.keep_aspect_ratio`):
+ Whether to keep the aspect ratio of the image. If False, the image will be resized to (size, size). If
+ True, the image will be resized to keep the aspect ratio and the size will be the maximum possible.
+ ensure_multiple_of (`int`, *optional*, defaults to `self.ensure_multiple_of`):
+ Ensure that the image size is a multiple of this value.
+ resample (`int`, *optional*, defaults to `self.resample`):
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only
+ has an effect if `do_resize` is set to `True`.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image values between [0 - 1].
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
+ Image mean.
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
+ Image standard deviation.
+ do_reduce_labels (`bool`, *optional*, defaults to `self.do_reduce_labels`):
+ Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0
+ is used for background, and background itself is not included in all classes of a dataset (e.g.
+ ADE20k). The background label will be replaced by 255.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ """
+ do_resize = do_resize if do_resize is not None else self.do_resize
+ size = size if size is not None else self.size
+ size = get_size_dict(size)
+ keep_aspect_ratio = keep_aspect_ratio if keep_aspect_ratio is not None else self.keep_aspect_ratio
+ ensure_multiple_of = ensure_multiple_of if ensure_multiple_of is not None else self.ensure_multiple_of
+ resample = resample if resample is not None else self.resample
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+ do_pad = do_pad if do_pad is not None else self.do_pad
+ size_divisor = size_divisor if size_divisor is not None else self.size_divisor
+ do_reduce_labels = do_reduce_labels if do_reduce_labels is not None else self.do_reduce_labels
+
+ images = make_list_of_images(images)
+
+ if segmentation_maps is not None:
+ segmentation_maps = make_list_of_images(segmentation_maps, expected_ndims=2)
+
+ if not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+ validate_preprocess_arguments(
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ do_pad=do_pad,
+ size_divisibility=size_divisor,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ )
+
+ images = [
+ self._preprocess_image(
+ image=img,
+ do_resize=do_resize,
+ do_rescale=do_rescale,
+ do_normalize=do_normalize,
+ do_pad=do_pad,
+ size=size,
+ resample=resample,
+ keep_aspect_ratio=keep_aspect_ratio,
+ ensure_multiple_of=ensure_multiple_of,
+ rescale_factor=rescale_factor,
+ image_mean=image_mean,
+ image_std=image_std,
+ size_divisor=size_divisor,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ )
+ for img in images
+ ]
+
+ data = {"pixel_values": images}
+
+ if segmentation_maps is not None:
+ segmentation_maps = [
+ self._preprocess_segmentation_map(
+ segmentation_map=segmentation_map,
+ do_reduce_labels=do_reduce_labels,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ keep_aspect_ratio=keep_aspect_ratio,
+ ensure_multiple_of=ensure_multiple_of,
+ input_data_format=input_data_format,
+ )
+ for segmentation_map in segmentation_maps
+ ]
+
+ data["labels"] = segmentation_maps
+
+ return BatchFeature(data=data, tensor_type=return_tensors)
+
+ # Copied from transformers.models.beit.image_processing_beit.BeitImageProcessor.post_process_semantic_segmentation with Beit->DPT
+ def post_process_semantic_segmentation(self, outputs, target_sizes: List[Tuple] = None):
+ """
+ Converts the output of [`DPTForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch.
+
+ Args:
+ outputs ([`DPTForSemanticSegmentation`]):
+ Raw outputs of the model.
+ target_sizes (`List[Tuple]` of length `batch_size`, *optional*):
+ List of tuples corresponding to the requested final size (height, width) of each prediction. If unset,
+ predictions will not be resized.
+
+ Returns:
+ semantic_segmentation: `List[torch.Tensor]` of length `batch_size`, where each item is a semantic
+ segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is
+ specified). Each entry of each `torch.Tensor` correspond to a semantic class id.
+ """
+ # TODO: add support for other frameworks
+ logits = outputs.logits
+
+ # Resize logits and compute semantic segmentation maps
+ if target_sizes is not None:
+ if len(logits) != len(target_sizes):
+ raise ValueError(
+ "Make sure that you pass in as many target sizes as the batch dimension of the logits"
+ )
+
+ if is_torch_tensor(target_sizes):
+ target_sizes = target_sizes.numpy()
+
+ semantic_segmentation = []
+
+ for idx in range(len(logits)):
+ resized_logits = torch.nn.functional.interpolate(
+ logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
+ )
+ semantic_map = resized_logits[0].argmax(dim=0)
+ semantic_segmentation.append(semantic_map)
+ else:
+ semantic_segmentation = logits.argmax(dim=1)
+ semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
+
+ return semantic_segmentation
+
+ def post_process_depth_estimation(
+ self,
+ outputs: "DepthEstimatorOutput",
+ target_sizes: Optional[Union[TensorType, List[Tuple[int, int]], None]] = None,
+ ) -> List[Dict[str, TensorType]]:
+ """
+ Converts the raw output of [`DepthEstimatorOutput`] into final depth predictions and depth PIL images.
+ Only supports PyTorch.
+
+ Args:
+ outputs ([`DepthEstimatorOutput`]):
+ Raw outputs of the model.
+ target_sizes (`TensorType` or `List[Tuple[int, int]]`, *optional*):
+ Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
+ (height, width) of each image in the batch. If left to None, predictions will not be resized.
+
+ Returns:
+ `List[Dict[str, TensorType]]`: A list of dictionaries of tensors representing the processed depth
+ predictions.
+ """
+ requires_backends(self, "torch")
+
+ predicted_depth = outputs.predicted_depth
+
+ if (target_sizes is not None) and (len(predicted_depth) != len(target_sizes)):
+ raise ValueError(
+ "Make sure that you pass in as many target sizes as the batch dimension of the predicted depth"
+ )
+
+ results = []
+ target_sizes = [None] * len(predicted_depth) if target_sizes is None else target_sizes
+ for depth, target_size in zip(predicted_depth, target_sizes):
+ if target_size is not None:
+ depth = torch.nn.functional.interpolate(
+ depth.unsqueeze(0).unsqueeze(1), size=target_size, mode="bicubic", align_corners=False
+ ).squeeze()
+
+ results.append({"predicted_depth": depth})
+
+ return results
+
+
+__all__ = ["DPTImageProcessor"]
diff --git a/docs/transformers/build/lib/transformers/models/dpt/modeling_dpt.py b/docs/transformers/build/lib/transformers/models/dpt/modeling_dpt.py
new file mode 100644
index 0000000000000000000000000000000000000000..0dc1dcf2f87e072ed72b0892f0d88a9175203bd0
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/dpt/modeling_dpt.py
@@ -0,0 +1,1413 @@
+# coding=utf-8
+# Copyright 2022 Intel Labs, OpenMMLab and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch DPT (Dense Prediction Transformers) model.
+
+This implementation is heavily inspired by OpenMMLab's implementation, found here:
+https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/decode_heads/dpt_head.py.
+
+"""
+
+import collections.abc
+from dataclasses import dataclass
+from typing import Callable, List, Optional, Set, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import CrossEntropyLoss
+
+from ...activations import ACT2FN
+from ...file_utils import (
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ replace_return_docstrings,
+)
+from ...modeling_outputs import BaseModelOutput, DepthEstimatorOutput, SemanticSegmenterOutput
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
+from ...utils import ModelOutput, logging, torch_int
+from ...utils.backbone_utils import load_backbone
+from .configuration_dpt import DPTConfig
+
+
+logger = logging.get_logger(__name__)
+
+# General docstring
+_CONFIG_FOR_DOC = "DPTConfig"
+
+# Base docstring
+_CHECKPOINT_FOR_DOC = "Intel/dpt-large"
+_EXPECTED_OUTPUT_SHAPE = [1, 577, 1024]
+
+
+@dataclass
+class BaseModelOutputWithIntermediateActivations(ModelOutput):
+ """
+ Base class for model's outputs that also contains intermediate activations that can be used at later stages. Useful
+ in the context of Vision models.:
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ intermediate_activations (`tuple(torch.FloatTensor)`, *optional*):
+ Intermediate activations that can be used to compute hidden states of the model at various layers.
+ """
+
+ last_hidden_states: Optional[torch.FloatTensor] = None
+ intermediate_activations: Optional[Tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class BaseModelOutputWithPoolingAndIntermediateActivations(ModelOutput):
+ """
+ Base class for model's outputs that also contains a pooling of the last hidden states as well as intermediate
+ activations that can be used by the model at later stages.
+
+ Args:
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
+ Last layer hidden-state of the first token of the sequence (classification token) after further processing
+ through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns
+ the classification token after processing through a linear layer and a tanh activation function. The linear
+ layer weights are trained from the next sentence prediction (classification) objective during pretraining.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ intermediate_activations (`tuple(torch.FloatTensor)`, *optional*):
+ Intermediate activations that can be used to compute hidden states of the model at various layers.
+ """
+
+ last_hidden_state: Optional[torch.FloatTensor] = None
+ pooler_output: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+ intermediate_activations: Optional[Tuple[torch.FloatTensor, ...]] = None
+
+
+class DPTViTHybridEmbeddings(nn.Module):
+ """
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
+ Transformer.
+ """
+
+ def __init__(self, config, feature_size=None):
+ super().__init__()
+ image_size, patch_size = config.image_size, config.patch_size
+ num_channels, hidden_size = config.num_channels, config.hidden_size
+
+ image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+
+ self.backbone = load_backbone(config)
+ feature_dim = self.backbone.channels[-1]
+ if len(self.backbone.channels) != 3:
+ raise ValueError(f"Expected backbone to have 3 output features, got {len(self.backbone.channels)}")
+ self.residual_feature_map_index = [0, 1] # Always take the output of the first and second backbone stage
+
+ if feature_size is None:
+ feat_map_shape = config.backbone_featmap_shape
+ feature_size = feat_map_shape[-2:]
+ feature_dim = feat_map_shape[1]
+ else:
+ feature_size = (
+ feature_size if isinstance(feature_size, collections.abc.Iterable) else (feature_size, feature_size)
+ )
+ feature_dim = self.backbone.channels[-1]
+
+ self.image_size = image_size
+ self.patch_size = patch_size[0]
+ self.num_channels = num_channels
+
+ self.projection = nn.Conv2d(feature_dim, hidden_size, kernel_size=1)
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
+ self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
+
+ def _resize_pos_embed(self, posemb, grid_size_height, grid_size_width, start_index=1):
+ posemb_tok = posemb[:, :start_index]
+ posemb_grid = posemb[0, start_index:]
+
+ old_grid_size = torch_int(len(posemb_grid) ** 0.5)
+
+ posemb_grid = posemb_grid.reshape(1, old_grid_size, old_grid_size, -1).permute(0, 3, 1, 2)
+ posemb_grid = nn.functional.interpolate(posemb_grid, size=(grid_size_height, grid_size_width), mode="bilinear")
+ posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, grid_size_height * grid_size_width, -1)
+
+ posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
+
+ return posemb
+
+ def forward(
+ self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False, return_dict: bool = False
+ ) -> torch.Tensor:
+ batch_size, num_channels, height, width = pixel_values.shape
+ if num_channels != self.num_channels:
+ raise ValueError(
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
+ )
+ if not interpolate_pos_encoding:
+ if height != self.image_size[0] or width != self.image_size[1]:
+ raise ValueError(
+ f"Input image size ({height}*{width}) doesn't match model"
+ f" ({self.image_size[0]}*{self.image_size[1]})."
+ )
+
+ position_embeddings = self._resize_pos_embed(
+ self.position_embeddings, height // self.patch_size, width // self.patch_size
+ )
+
+ backbone_output = self.backbone(pixel_values)
+
+ features = backbone_output.feature_maps[-1]
+
+ # Retrieve also the intermediate activations to use them at later stages
+ output_hidden_states = [backbone_output.feature_maps[index] for index in self.residual_feature_map_index]
+
+ embeddings = self.projection(features).flatten(2).transpose(1, 2)
+
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
+ embeddings = torch.cat((cls_tokens, embeddings), dim=1)
+
+ # add positional encoding to each token
+ embeddings = embeddings + position_embeddings
+
+ if not return_dict:
+ return (embeddings, output_hidden_states)
+
+ # Return hidden states and intermediate activations
+ return BaseModelOutputWithIntermediateActivations(
+ last_hidden_states=embeddings,
+ intermediate_activations=output_hidden_states,
+ )
+
+
+class DPTViTEmbeddings(nn.Module):
+ """
+ Construct the CLS token, position and patch embeddings.
+
+ """
+
+ def __init__(self, config):
+ super().__init__()
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
+ self.patch_embeddings = DPTViTPatchEmbeddings(config)
+ num_patches = self.patch_embeddings.num_patches
+ self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ self.config = config
+
+ def _resize_pos_embed(self, posemb, grid_size_height, grid_size_width, start_index=1):
+ posemb_tok = posemb[:, :start_index]
+ posemb_grid = posemb[0, start_index:]
+
+ old_grid_size = torch_int(posemb_grid.size(0) ** 0.5)
+
+ posemb_grid = posemb_grid.reshape(1, old_grid_size, old_grid_size, -1).permute(0, 3, 1, 2)
+ posemb_grid = nn.functional.interpolate(posemb_grid, size=(grid_size_height, grid_size_width), mode="bilinear")
+ posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, grid_size_height * grid_size_width, -1)
+
+ posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
+
+ return posemb
+
+ def forward(self, pixel_values, return_dict=False):
+ batch_size, num_channels, height, width = pixel_values.shape
+
+ # possibly interpolate position encodings to handle varying image sizes
+ patch_size = self.config.patch_size
+ position_embeddings = self._resize_pos_embed(
+ self.position_embeddings, height // patch_size, width // patch_size
+ )
+
+ embeddings = self.patch_embeddings(pixel_values)
+
+ batch_size, seq_len, _ = embeddings.size()
+
+ # add the [CLS] token to the embedded patch tokens
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
+ embeddings = torch.cat((cls_tokens, embeddings), dim=1)
+
+ # add positional encoding to each token
+ embeddings = embeddings + position_embeddings
+
+ embeddings = self.dropout(embeddings)
+
+ if not return_dict:
+ return (embeddings,)
+
+ return BaseModelOutputWithIntermediateActivations(last_hidden_states=embeddings)
+
+
+class DPTViTPatchEmbeddings(nn.Module):
+ """
+ Image to Patch Embedding.
+
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ image_size, patch_size = config.image_size, config.patch_size
+ num_channels, hidden_size = config.num_channels, config.hidden_size
+
+ image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.num_patches = num_patches
+
+ self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
+
+ def forward(self, pixel_values):
+ batch_size, num_channels, height, width = pixel_values.shape
+ if num_channels != self.num_channels:
+ raise ValueError(
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
+ )
+ embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
+ return embeddings
+
+
+# Copied from transformers.models.vit.modeling_vit.eager_attention_forward
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs,
+):
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
+
+ # Normalize the attention scores to probabilities.
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+
+ # Mask heads if we want to
+ if attention_mask is not None:
+ attn_weights = attn_weights * attention_mask
+
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->DPT
+class DPTSelfAttention(nn.Module):
+ def __init__(self, config: DPTConfig) -> None:
+ super().__init__()
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+ raise ValueError(
+ f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
+ f"heads {config.num_attention_heads}."
+ )
+
+ self.config = config
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+ self.dropout_prob = config.attention_probs_dropout_prob
+ self.scaling = self.attention_head_size**-0.5
+ self.is_causal = False
+
+ self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+ self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
+ x = x.view(new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward(
+ self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+ query_layer = self.transpose_for_scores(self.query(hidden_states))
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ if self.config._attn_implementation == "sdpa" and output_attentions:
+ logger.warning_once(
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ else:
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ context_layer, attention_probs = attention_interface(
+ self,
+ query_layer,
+ key_layer,
+ value_layer,
+ head_mask,
+ is_causal=self.is_causal,
+ scaling=self.scaling,
+ dropout=0.0 if not self.training else self.dropout_prob,
+ )
+
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.reshape(new_context_layer_shape)
+
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+ return outputs
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->DPT
+class DPTViTSelfOutput(nn.Module):
+ """
+ The residual connection is defined in DPTLayer instead of here (as is the case with other models), due to the
+ layernorm applied before each block.
+ """
+
+ def __init__(self, config: DPTConfig) -> None:
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+
+ return hidden_states
+
+
+class DPTViTAttention(nn.Module):
+ def __init__(self, config: DPTConfig) -> None:
+ super().__init__()
+ self.attention = DPTSelfAttention(config)
+ self.output = DPTViTSelfOutput(config)
+ self.pruned_heads = set()
+
+ # Copied from transformers.models.vit.modeling_vit.ViTAttention.prune_heads
+ def prune_heads(self, heads: Set[int]) -> None:
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
+ )
+
+ # Prune linear layers
+ self.attention.query = prune_linear_layer(self.attention.query, index)
+ self.attention.key = prune_linear_layer(self.attention.key, index)
+ self.attention.value = prune_linear_layer(self.attention.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
+ self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ # Copied from transformers.models.vit.modeling_vit.ViTAttention.forward
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
+ self_outputs = self.attention(hidden_states, head_mask, output_attentions)
+
+ attention_output = self.output(self_outputs[0], hidden_states)
+
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+ return outputs
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->DPT
+class DPTViTIntermediate(nn.Module):
+ def __init__(self, config: DPTConfig) -> None:
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+
+ return hidden_states
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTOutput with ViT->DPT
+class DPTViTOutput(nn.Module):
+ def __init__(self, config: DPTConfig) -> None:
+ super().__init__()
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+
+ hidden_states = hidden_states + input_tensor
+
+ return hidden_states
+
+
+# copied from transformers.models.vit.modeling_vit.ViTLayer with ViTConfig->DPTConfig, ViTAttention->DPTViTAttention, ViTIntermediate->DPTViTIntermediate, ViTOutput->DPTViTOutput
+class DPTViTLayer(nn.Module):
+ """This corresponds to the Block class in the timm implementation."""
+
+ def __init__(self, config: DPTConfig) -> None:
+ super().__init__()
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.seq_len_dim = 1
+ self.attention = DPTViTAttention(config)
+ self.intermediate = DPTViTIntermediate(config)
+ self.output = DPTViTOutput(config)
+ self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
+ self_attention_outputs = self.attention(
+ self.layernorm_before(hidden_states), # in ViT, layernorm is applied before self-attention
+ head_mask,
+ output_attentions=output_attentions,
+ )
+ attention_output = self_attention_outputs[0]
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
+
+ # first residual connection
+ hidden_states = attention_output + hidden_states
+
+ # in ViT, layernorm is also applied after self-attention
+ layer_output = self.layernorm_after(hidden_states)
+ layer_output = self.intermediate(layer_output)
+
+ # second residual connection is done here
+ layer_output = self.output(layer_output, hidden_states)
+
+ outputs = (layer_output,) + outputs
+
+ return outputs
+
+
+# copied from transformers.models.vit.modeling_vit.ViTEncoder with ViTConfig -> DPTConfig, ViTLayer->DPTViTLayer
+class DPTViTEncoder(nn.Module):
+ def __init__(self, config: DPTConfig) -> None:
+ super().__init__()
+ self.config = config
+ self.layer = nn.ModuleList([DPTViTLayer(config) for _ in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ) -> Union[tuple, BaseModelOutput]:
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+
+ for i, layer_module in enumerate(self.layer):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ layer_module.__call__,
+ hidden_states,
+ layer_head_mask,
+ output_attentions,
+ )
+ else:
+ layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+
+class DPTReassembleStage(nn.Module):
+ """
+ This class reassembles the hidden states of the backbone into image-like feature representations at various
+ resolutions.
+
+ This happens in 3 stages:
+ 1. Map the N + 1 tokens to a set of N tokens, by taking into account the readout ([CLS]) token according to
+ `config.readout_type`.
+ 2. Project the channel dimension of the hidden states according to `config.neck_hidden_sizes`.
+ 3. Resizing the spatial dimensions (height, width).
+
+ Args:
+ config (`[DPTConfig]`):
+ Model configuration class defining the model architecture.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+
+ self.config = config
+ self.layers = nn.ModuleList()
+ if config.is_hybrid:
+ self._init_reassemble_dpt_hybrid(config)
+ else:
+ self._init_reassemble_dpt(config)
+
+ self.neck_ignore_stages = config.neck_ignore_stages
+
+ def _init_reassemble_dpt_hybrid(self, config):
+ r""" "
+ For DPT-Hybrid the first 2 reassemble layers are set to `nn.Identity()`, please check the official
+ implementation: https://github.com/isl-org/DPT/blob/f43ef9e08d70a752195028a51be5e1aff227b913/dpt/vit.py#L438
+ for more details.
+ """
+ for i, factor in zip(range(len(config.neck_hidden_sizes)), config.reassemble_factors):
+ if i <= 1:
+ self.layers.append(nn.Identity())
+ elif i > 1:
+ self.layers.append(DPTReassembleLayer(config, channels=config.neck_hidden_sizes[i], factor=factor))
+
+ if config.readout_type != "project":
+ raise ValueError(f"Readout type {config.readout_type} is not supported for DPT-Hybrid.")
+
+ # When using DPT-Hybrid the readout type is set to "project". The sanity check is done on the config file
+ self.readout_projects = nn.ModuleList()
+ hidden_size = _get_backbone_hidden_size(config)
+ for i in range(len(config.neck_hidden_sizes)):
+ if i <= 1:
+ self.readout_projects.append(nn.Sequential(nn.Identity()))
+ elif i > 1:
+ self.readout_projects.append(
+ nn.Sequential(nn.Linear(2 * hidden_size, hidden_size), ACT2FN[config.hidden_act])
+ )
+
+ def _init_reassemble_dpt(self, config):
+ for i, factor in zip(range(len(config.neck_hidden_sizes)), config.reassemble_factors):
+ self.layers.append(DPTReassembleLayer(config, channels=config.neck_hidden_sizes[i], factor=factor))
+
+ if config.readout_type == "project":
+ self.readout_projects = nn.ModuleList()
+ hidden_size = _get_backbone_hidden_size(config)
+ for _ in range(len(config.neck_hidden_sizes)):
+ self.readout_projects.append(
+ nn.Sequential(nn.Linear(2 * hidden_size, hidden_size), ACT2FN[config.hidden_act])
+ )
+
+ def forward(self, hidden_states: List[torch.Tensor], patch_height=None, patch_width=None) -> List[torch.Tensor]:
+ """
+ Args:
+ hidden_states (`List[torch.FloatTensor]`, each of shape `(batch_size, sequence_length + 1, hidden_size)`):
+ List of hidden states from the backbone.
+ """
+ out = []
+
+ for i, hidden_state in enumerate(hidden_states):
+ if i not in self.neck_ignore_stages:
+ # reshape to (batch_size, num_channels, height, width)
+ cls_token, hidden_state = hidden_state[:, 0], hidden_state[:, 1:]
+ batch_size, sequence_length, num_channels = hidden_state.shape
+ if patch_height is not None and patch_width is not None:
+ hidden_state = hidden_state.reshape(batch_size, patch_height, patch_width, num_channels)
+ else:
+ size = torch_int(sequence_length**0.5)
+ hidden_state = hidden_state.reshape(batch_size, size, size, num_channels)
+ hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
+
+ feature_shape = hidden_state.shape
+ if self.config.readout_type == "project":
+ # reshape to (batch_size, height*width, num_channels)
+ hidden_state = hidden_state.flatten(2).permute((0, 2, 1))
+ readout = cls_token.unsqueeze(1).expand_as(hidden_state)
+ # concatenate the readout token to the hidden states and project
+ hidden_state = self.readout_projects[i](torch.cat((hidden_state, readout), -1))
+ # reshape back to (batch_size, num_channels, height, width)
+ hidden_state = hidden_state.permute(0, 2, 1).reshape(feature_shape)
+ elif self.config.readout_type == "add":
+ hidden_state = hidden_state.flatten(2) + cls_token.unsqueeze(-1)
+ hidden_state = hidden_state.reshape(feature_shape)
+ hidden_state = self.layers[i](hidden_state)
+ out.append(hidden_state)
+
+ return out
+
+
+def _get_backbone_hidden_size(config):
+ if config.backbone_config is not None and config.is_hybrid is False:
+ return config.backbone_config.hidden_size
+ else:
+ return config.hidden_size
+
+
+class DPTReassembleLayer(nn.Module):
+ def __init__(self, config, channels, factor):
+ super().__init__()
+ # projection
+ hidden_size = _get_backbone_hidden_size(config)
+ self.projection = nn.Conv2d(in_channels=hidden_size, out_channels=channels, kernel_size=1)
+
+ # up/down sampling depending on factor
+ if factor > 1:
+ self.resize = nn.ConvTranspose2d(channels, channels, kernel_size=factor, stride=factor, padding=0)
+ elif factor == 1:
+ self.resize = nn.Identity()
+ elif factor < 1:
+ # so should downsample
+ self.resize = nn.Conv2d(channels, channels, kernel_size=3, stride=int(1 / factor), padding=1)
+
+ def forward(self, hidden_state):
+ hidden_state = self.projection(hidden_state)
+ hidden_state = self.resize(hidden_state)
+ return hidden_state
+
+
+class DPTFeatureFusionStage(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.layers = nn.ModuleList()
+ for _ in range(len(config.neck_hidden_sizes)):
+ self.layers.append(DPTFeatureFusionLayer(config))
+
+ def forward(self, hidden_states):
+ # reversing the hidden_states, we start from the last
+ hidden_states = hidden_states[::-1]
+
+ fused_hidden_states = []
+ fused_hidden_state = None
+ for hidden_state, layer in zip(hidden_states, self.layers):
+ if fused_hidden_state is None:
+ # first layer only uses the last hidden_state
+ fused_hidden_state = layer(hidden_state)
+ else:
+ fused_hidden_state = layer(fused_hidden_state, hidden_state)
+ fused_hidden_states.append(fused_hidden_state)
+
+ return fused_hidden_states
+
+
+class DPTPreActResidualLayer(nn.Module):
+ """
+ ResidualConvUnit, pre-activate residual unit.
+
+ Args:
+ config (`[DPTConfig]`):
+ Model configuration class defining the model architecture.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+
+ self.use_batch_norm = config.use_batch_norm_in_fusion_residual
+ use_bias_in_fusion_residual = (
+ config.use_bias_in_fusion_residual
+ if config.use_bias_in_fusion_residual is not None
+ else not self.use_batch_norm
+ )
+
+ self.activation1 = nn.ReLU()
+ self.convolution1 = nn.Conv2d(
+ config.fusion_hidden_size,
+ config.fusion_hidden_size,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=use_bias_in_fusion_residual,
+ )
+
+ self.activation2 = nn.ReLU()
+ self.convolution2 = nn.Conv2d(
+ config.fusion_hidden_size,
+ config.fusion_hidden_size,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=use_bias_in_fusion_residual,
+ )
+
+ if self.use_batch_norm:
+ self.batch_norm1 = nn.BatchNorm2d(config.fusion_hidden_size)
+ self.batch_norm2 = nn.BatchNorm2d(config.fusion_hidden_size)
+
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
+ residual = hidden_state
+ hidden_state = self.activation1(hidden_state)
+
+ hidden_state = self.convolution1(hidden_state)
+
+ if self.use_batch_norm:
+ hidden_state = self.batch_norm1(hidden_state)
+
+ hidden_state = self.activation2(hidden_state)
+ hidden_state = self.convolution2(hidden_state)
+
+ if self.use_batch_norm:
+ hidden_state = self.batch_norm2(hidden_state)
+
+ return hidden_state + residual
+
+
+class DPTFeatureFusionLayer(nn.Module):
+ """Feature fusion layer, merges feature maps from different stages.
+
+ Args:
+ config (`[DPTConfig]`):
+ Model configuration class defining the model architecture.
+ align_corners (`bool`, *optional*, defaults to `True`):
+ The align_corner setting for bilinear upsample.
+ """
+
+ def __init__(self, config, align_corners=True):
+ super().__init__()
+
+ self.align_corners = align_corners
+
+ self.projection = nn.Conv2d(config.fusion_hidden_size, config.fusion_hidden_size, kernel_size=1, bias=True)
+
+ self.residual_layer1 = DPTPreActResidualLayer(config)
+ self.residual_layer2 = DPTPreActResidualLayer(config)
+
+ def forward(self, hidden_state, residual=None):
+ if residual is not None:
+ if hidden_state.shape != residual.shape:
+ residual = nn.functional.interpolate(
+ residual, size=(hidden_state.shape[2], hidden_state.shape[3]), mode="bilinear", align_corners=False
+ )
+ hidden_state = hidden_state + self.residual_layer1(residual)
+
+ hidden_state = self.residual_layer2(hidden_state)
+ hidden_state = nn.functional.interpolate(
+ hidden_state, scale_factor=2, mode="bilinear", align_corners=self.align_corners
+ )
+ hidden_state = self.projection(hidden_state)
+
+ return hidden_state
+
+
+class DPTPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = DPTConfig
+ base_model_prefix = "dpt"
+ main_input_name = "pixel_values"
+ supports_gradient_checkpointing = True
+ _supports_sdpa = True
+ _supports_flash_attn_2 = True
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d)):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ if isinstance(module, (DPTViTEmbeddings, DPTViTHybridEmbeddings)):
+ module.cls_token.data.zero_()
+ module.position_embeddings.data.zero_()
+
+
+DPT_START_DOCSTRING = r"""
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+ behavior.
+
+ Parameters:
+ config ([`ViTConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+DPT_INPUTS_DOCSTRING = r"""
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`DPTImageProcessor.__call__`]
+ for details.
+
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare DPT Model transformer outputting raw hidden-states without any specific head on top.",
+ DPT_START_DOCSTRING,
+)
+class DPTModel(DPTPreTrainedModel):
+ def __init__(self, config, add_pooling_layer=True):
+ super().__init__(config)
+ self.config = config
+
+ # vit encoder
+ if config.is_hybrid:
+ self.embeddings = DPTViTHybridEmbeddings(config)
+ else:
+ self.embeddings = DPTViTEmbeddings(config)
+ self.encoder = DPTViTEncoder(config)
+
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.pooler = DPTViTPooler(config) if add_pooling_layer else None
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ if self.config.is_hybrid:
+ return self.embeddings
+ else:
+ return self.embeddings.patch_embeddings
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ @add_start_docstrings_to_model_forward(DPT_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=BaseModelOutputWithPoolingAndIntermediateActivations,
+ config_class=_CONFIG_FOR_DOC,
+ modality="vision",
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
+ )
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor,
+ head_mask: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPoolingAndIntermediateActivations]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ embedding_output = self.embeddings(pixel_values, return_dict=return_dict)
+
+ embedding_last_hidden_states = embedding_output[0] if not return_dict else embedding_output.last_hidden_states
+
+ encoder_outputs = self.encoder(
+ embedding_last_hidden_states,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = encoder_outputs[0]
+ sequence_output = self.layernorm(sequence_output)
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
+
+ if not return_dict:
+ head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
+ return head_outputs + encoder_outputs[1:] + embedding_output[1:]
+
+ return BaseModelOutputWithPoolingAndIntermediateActivations(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ intermediate_activations=embedding_output.intermediate_activations,
+ )
+
+
+# Copied from transformers.models.vit.modeling_vit.ViTPooler with ViT->DPT
+class DPTViTPooler(nn.Module):
+ def __init__(self, config: DPTConfig):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.pooler_output_size)
+ self.activation = ACT2FN[config.pooler_act]
+
+ def forward(self, hidden_states):
+ # We "pool" the model by simply taking the hidden state corresponding
+ # to the first token.
+ first_token_tensor = hidden_states[:, 0]
+ pooled_output = self.dense(first_token_tensor)
+ pooled_output = self.activation(pooled_output)
+ return pooled_output
+
+
+class DPTNeck(nn.Module):
+ """
+ DPTNeck. A neck is a module that is normally used between the backbone and the head. It takes a list of tensors as
+ input and produces another list of tensors as output. For DPT, it includes 2 stages:
+
+ * DPTReassembleStage
+ * DPTFeatureFusionStage.
+
+ Args:
+ config (dict): config dict.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+
+ # postprocessing: only required in case of a non-hierarchical backbone (e.g. ViT, BEiT)
+ if config.backbone_config is not None and config.backbone_config.model_type in ["swinv2"]:
+ self.reassemble_stage = None
+ else:
+ self.reassemble_stage = DPTReassembleStage(config)
+
+ self.convs = nn.ModuleList()
+ for channel in config.neck_hidden_sizes:
+ self.convs.append(nn.Conv2d(channel, config.fusion_hidden_size, kernel_size=3, padding=1, bias=False))
+
+ # fusion
+ self.fusion_stage = DPTFeatureFusionStage(config)
+
+ def forward(self, hidden_states: List[torch.Tensor], patch_height=None, patch_width=None) -> List[torch.Tensor]:
+ """
+ Args:
+ hidden_states (`List[torch.FloatTensor]`, each of shape `(batch_size, sequence_length, hidden_size)` or `(batch_size, hidden_size, height, width)`):
+ List of hidden states from the backbone.
+ """
+ if not isinstance(hidden_states, (tuple, list)):
+ raise TypeError("hidden_states should be a tuple or list of tensors")
+
+ if len(hidden_states) != len(self.config.neck_hidden_sizes):
+ raise ValueError("The number of hidden states should be equal to the number of neck hidden sizes.")
+
+ # postprocess hidden states
+ if self.reassemble_stage is not None:
+ hidden_states = self.reassemble_stage(hidden_states, patch_height, patch_width)
+
+ features = [self.convs[i](feature) for i, feature in enumerate(hidden_states)]
+
+ # fusion blocks
+ output = self.fusion_stage(features)
+
+ return output
+
+
+class DPTDepthEstimationHead(nn.Module):
+ """
+ Output head consisting of 3 convolutional layers. It progressively halves the feature dimension and upsamples
+ the predictions to the input resolution after the first convolutional layer (details can be found in the paper's
+ supplementary material).
+ """
+
+ def __init__(self, config):
+ super().__init__()
+
+ self.config = config
+
+ self.projection = None
+ if config.add_projection:
+ self.projection = nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
+
+ features = config.fusion_hidden_size
+ self.head = nn.Sequential(
+ nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
+ nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
+ nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(),
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
+ nn.ReLU(),
+ )
+
+ def forward(self, hidden_states: List[torch.Tensor]) -> torch.Tensor:
+ # use last features
+ hidden_states = hidden_states[self.config.head_in_index]
+
+ if self.projection is not None:
+ hidden_states = self.projection(hidden_states)
+ hidden_states = nn.ReLU()(hidden_states)
+
+ predicted_depth = self.head(hidden_states)
+
+ predicted_depth = predicted_depth.squeeze(dim=1)
+
+ return predicted_depth
+
+
+@add_start_docstrings(
+ """
+ DPT Model with a depth estimation head on top (consisting of 3 convolutional layers) e.g. for KITTI, NYUv2.
+ """,
+ DPT_START_DOCSTRING,
+)
+class DPTForDepthEstimation(DPTPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.backbone = None
+ if config.is_hybrid is False and (config.backbone_config is not None or config.backbone is not None):
+ self.backbone = load_backbone(config)
+ else:
+ self.dpt = DPTModel(config, add_pooling_layer=False)
+
+ # Neck
+ self.neck = DPTNeck(config)
+
+ # Depth estimation head
+ self.head = DPTDepthEstimationHead(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(DPT_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=DepthEstimatorOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor,
+ head_mask: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], DepthEstimatorOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
+ Ground truth depth estimation maps for computing the loss.
+
+ Returns:
+
+ Examples:
+ ```python
+ >>> from transformers import AutoImageProcessor, DPTForDepthEstimation
+ >>> import torch
+ >>> import numpy as np
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> image_processor = AutoImageProcessor.from_pretrained("Intel/dpt-large")
+ >>> model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large")
+
+ >>> # prepare image for the model
+ >>> inputs = image_processor(images=image, return_tensors="pt")
+
+ >>> with torch.no_grad():
+ ... outputs = model(**inputs)
+
+ >>> # interpolate to original size
+ >>> post_processed_output = image_processor.post_process_depth_estimation(
+ ... outputs,
+ ... target_sizes=[(image.height, image.width)],
+ ... )
+
+ >>> # visualize the prediction
+ >>> predicted_depth = post_processed_output[0]["predicted_depth"]
+ >>> depth = predicted_depth * 255 / predicted_depth.max()
+ >>> depth = depth.detach().cpu().numpy()
+ >>> depth = Image.fromarray(depth.astype("uint8"))
+ ```"""
+ loss = None
+ if labels is not None:
+ raise NotImplementedError("Training is not implemented yet")
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+
+ if self.backbone is not None:
+ outputs = self.backbone.forward_with_filtered_kwargs(
+ pixel_values, output_hidden_states=output_hidden_states, output_attentions=output_attentions
+ )
+ hidden_states = outputs.feature_maps
+ else:
+ outputs = self.dpt(
+ pixel_values,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=True, # we need the intermediate hidden states
+ return_dict=return_dict,
+ )
+ hidden_states = outputs.hidden_states if return_dict else outputs[1]
+ # only keep certain features based on config.backbone_out_indices
+ # note that the hidden_states also include the initial embeddings
+ if not self.config.is_hybrid:
+ hidden_states = [
+ feature for idx, feature in enumerate(hidden_states[1:]) if idx in self.config.backbone_out_indices
+ ]
+ else:
+ backbone_hidden_states = outputs.intermediate_activations if return_dict else list(outputs[-1])
+ backbone_hidden_states.extend(
+ feature
+ for idx, feature in enumerate(hidden_states[1:])
+ if idx in self.config.backbone_out_indices[2:]
+ )
+
+ hidden_states = backbone_hidden_states
+
+ patch_height, patch_width = None, None
+ if self.config.backbone_config is not None and self.config.is_hybrid is False:
+ _, _, height, width = pixel_values.shape
+ patch_size = self.config.backbone_config.patch_size
+ patch_height = height // patch_size
+ patch_width = width // patch_size
+
+ hidden_states = self.neck(hidden_states, patch_height, patch_width)
+
+ predicted_depth = self.head(hidden_states)
+
+ if not return_dict:
+ if output_hidden_states:
+ output = (predicted_depth,) + outputs[1:]
+ else:
+ output = (predicted_depth,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return DepthEstimatorOutput(
+ loss=loss,
+ predicted_depth=predicted_depth,
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
+ attentions=outputs.attentions,
+ )
+
+
+class DPTSemanticSegmentationHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+
+ self.config = config
+
+ features = config.fusion_hidden_size
+ self.head = nn.Sequential(
+ nn.Conv2d(features, features, kernel_size=3, padding=1, bias=False),
+ nn.BatchNorm2d(features),
+ nn.ReLU(),
+ nn.Dropout(config.semantic_classifier_dropout),
+ nn.Conv2d(features, config.num_labels, kernel_size=1),
+ nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
+ )
+
+ def forward(self, hidden_states: List[torch.Tensor]) -> torch.Tensor:
+ # use last features
+ hidden_states = hidden_states[self.config.head_in_index]
+
+ logits = self.head(hidden_states)
+
+ return logits
+
+
+class DPTAuxiliaryHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+
+ features = config.fusion_hidden_size
+ self.head = nn.Sequential(
+ nn.Conv2d(features, features, kernel_size=3, padding=1, bias=False),
+ nn.BatchNorm2d(features),
+ nn.ReLU(),
+ nn.Dropout(0.1, False),
+ nn.Conv2d(features, config.num_labels, kernel_size=1),
+ )
+
+ def forward(self, hidden_states):
+ logits = self.head(hidden_states)
+
+ return logits
+
+
+@add_start_docstrings(
+ """
+ DPT Model with a semantic segmentation head on top e.g. for ADE20k, CityScapes.
+ """,
+ DPT_START_DOCSTRING,
+)
+class DPTForSemanticSegmentation(DPTPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.dpt = DPTModel(config, add_pooling_layer=False)
+
+ # Neck
+ self.neck = DPTNeck(config)
+
+ # Segmentation head(s)
+ self.head = DPTSemanticSegmentationHead(config)
+ self.auxiliary_head = DPTAuxiliaryHead(config) if config.use_auxiliary_head else None
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(DPT_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=SemanticSegmenterOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], SemanticSegmenterOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
+ Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy).
+
+ Returns:
+
+ Examples:
+ ```python
+ >>> from transformers import AutoImageProcessor, DPTForSemanticSegmentation
+ >>> from PIL import Image
+ >>> import requests
+
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> image_processor = AutoImageProcessor.from_pretrained("Intel/dpt-large-ade")
+ >>> model = DPTForSemanticSegmentation.from_pretrained("Intel/dpt-large-ade")
+
+ >>> inputs = image_processor(images=image, return_tensors="pt")
+
+ >>> outputs = model(**inputs)
+ >>> logits = outputs.logits
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ if labels is not None and self.config.num_labels == 1:
+ raise ValueError("The number of labels should be greater than one")
+
+ outputs = self.dpt(
+ pixel_values,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=True, # we need the intermediate hidden states
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs.hidden_states if return_dict else outputs[1]
+
+ # only keep certain features based on config.backbone_out_indices
+ # note that the hidden_states also include the initial embeddings
+ if not self.config.is_hybrid:
+ hidden_states = [
+ feature for idx, feature in enumerate(hidden_states[1:]) if idx in self.config.backbone_out_indices
+ ]
+ else:
+ backbone_hidden_states = outputs.intermediate_activations if return_dict else list(outputs[-1])
+ backbone_hidden_states.extend(
+ feature for idx, feature in enumerate(hidden_states[1:]) if idx in self.config.backbone_out_indices[2:]
+ )
+
+ hidden_states = backbone_hidden_states
+
+ hidden_states = self.neck(hidden_states=hidden_states)
+
+ logits = self.head(hidden_states)
+
+ auxiliary_logits = None
+ if self.auxiliary_head is not None:
+ auxiliary_logits = self.auxiliary_head(hidden_states[-1])
+
+ loss = None
+ if labels is not None:
+ # upsample logits to the images' original size
+ upsampled_logits = nn.functional.interpolate(
+ logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
+ )
+ if auxiliary_logits is not None:
+ upsampled_auxiliary_logits = nn.functional.interpolate(
+ auxiliary_logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
+ )
+ # compute weighted loss
+ loss_fct = CrossEntropyLoss(ignore_index=self.config.semantic_loss_ignore_index)
+ main_loss = loss_fct(upsampled_logits, labels)
+ auxiliary_loss = loss_fct(upsampled_auxiliary_logits, labels)
+ loss = main_loss + self.config.auxiliary_loss_weight * auxiliary_loss
+
+ if not return_dict:
+ if output_hidden_states:
+ output = (logits,) + outputs[1:]
+ else:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SemanticSegmenterOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
+ attentions=outputs.attentions,
+ )
+
+
+__all__ = ["DPTForDepthEstimation", "DPTForSemanticSegmentation", "DPTModel", "DPTPreTrainedModel"]
diff --git a/docs/transformers/build/lib/transformers/models/efficientnet/__init__.py b/docs/transformers/build/lib/transformers/models/efficientnet/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..24d58e81167ec8729d04fa52ce96ebc1737a5982
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/efficientnet/__init__.py
@@ -0,0 +1,29 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_efficientnet import *
+ from .image_processing_efficientnet import *
+ from .image_processing_efficientnet_fast import *
+ from .modeling_efficientnet import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/docs/transformers/build/lib/transformers/models/efficientnet/configuration_efficientnet.py b/docs/transformers/build/lib/transformers/models/efficientnet/configuration_efficientnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..29df2ce0e34a8291ee484b3f93a96081f658d1db
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/efficientnet/configuration_efficientnet.py
@@ -0,0 +1,169 @@
+# coding=utf-8
+# Copyright 2023 Google Research, Inc. and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""EfficientNet model configuration"""
+
+from collections import OrderedDict
+from typing import List, Mapping
+
+from packaging import version
+
+from ...configuration_utils import PretrainedConfig
+from ...onnx import OnnxConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class EfficientNetConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`EfficientNetModel`]. It is used to instantiate an
+ EfficientNet model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the EfficientNet
+ [google/efficientnet-b7](https://huggingface.co/google/efficientnet-b7) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ num_channels (`int`, *optional*, defaults to 3):
+ The number of input channels.
+ image_size (`int`, *optional*, defaults to 600):
+ The input image size.
+ width_coefficient (`float`, *optional*, defaults to 2.0):
+ Scaling coefficient for network width at each stage.
+ depth_coefficient (`float`, *optional*, defaults to 3.1):
+ Scaling coefficient for network depth at each stage.
+ depth_divisor `int`, *optional*, defaults to 8):
+ A unit of network width.
+ kernel_sizes (`List[int]`, *optional*, defaults to `[3, 3, 5, 3, 5, 5, 3]`):
+ List of kernel sizes to be used in each block.
+ in_channels (`List[int]`, *optional*, defaults to `[32, 16, 24, 40, 80, 112, 192]`):
+ List of input channel sizes to be used in each block for convolutional layers.
+ out_channels (`List[int]`, *optional*, defaults to `[16, 24, 40, 80, 112, 192, 320]`):
+ List of output channel sizes to be used in each block for convolutional layers.
+ depthwise_padding (`List[int]`, *optional*, defaults to `[]`):
+ List of block indices with square padding.
+ strides (`List[int]`, *optional*, defaults to `[1, 2, 2, 2, 1, 2, 1]`):
+ List of stride sizes to be used in each block for convolutional layers.
+ num_block_repeats (`List[int]`, *optional*, defaults to `[1, 2, 2, 3, 3, 4, 1]`):
+ List of the number of times each block is to repeated.
+ expand_ratios (`List[int]`, *optional*, defaults to `[1, 6, 6, 6, 6, 6, 6]`):
+ List of scaling coefficient of each block.
+ squeeze_expansion_ratio (`float`, *optional*, defaults to 0.25):
+ Squeeze expansion ratio.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in each block. If string, `"gelu"`, `"relu"`,
+ `"selu", `"gelu_new"`, `"silu"` and `"mish"` are supported.
+ hidden_dim (`int`, *optional*, defaults to 1280):
+ The hidden dimension of the layer before the classification head.
+ pooling_type (`str` or `function`, *optional*, defaults to `"mean"`):
+ Type of final pooling to be applied before the dense classification head. Available options are [`"mean"`,
+ `"max"`]
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ batch_norm_eps (`float`, *optional*, defaults to 1e-3):
+ The epsilon used by the batch normalization layers.
+ batch_norm_momentum (`float`, *optional*, defaults to 0.99):
+ The momentum used by the batch normalization layers.
+ dropout_rate (`float`, *optional*, defaults to 0.5):
+ The dropout rate to be applied before final classifier layer.
+ drop_connect_rate (`float`, *optional*, defaults to 0.2):
+ The drop rate for skip connections.
+
+ Example:
+ ```python
+ >>> from transformers import EfficientNetConfig, EfficientNetModel
+
+ >>> # Initializing a EfficientNet efficientnet-b7 style configuration
+ >>> configuration = EfficientNetConfig()
+
+ >>> # Initializing a model (with random weights) from the efficientnet-b7 style configuration
+ >>> model = EfficientNetModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "efficientnet"
+
+ def __init__(
+ self,
+ num_channels: int = 3,
+ image_size: int = 600,
+ width_coefficient: float = 2.0,
+ depth_coefficient: float = 3.1,
+ depth_divisor: int = 8,
+ kernel_sizes: List[int] = [3, 3, 5, 3, 5, 5, 3],
+ in_channels: List[int] = [32, 16, 24, 40, 80, 112, 192],
+ out_channels: List[int] = [16, 24, 40, 80, 112, 192, 320],
+ depthwise_padding: List[int] = [],
+ strides: List[int] = [1, 2, 2, 2, 1, 2, 1],
+ num_block_repeats: List[int] = [1, 2, 2, 3, 3, 4, 1],
+ expand_ratios: List[int] = [1, 6, 6, 6, 6, 6, 6],
+ squeeze_expansion_ratio: float = 0.25,
+ hidden_act: str = "swish",
+ hidden_dim: int = 2560,
+ pooling_type: str = "mean",
+ initializer_range: float = 0.02,
+ batch_norm_eps: float = 0.001,
+ batch_norm_momentum: float = 0.99,
+ dropout_rate: float = 0.5,
+ drop_connect_rate: float = 0.2,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.num_channels = num_channels
+ self.image_size = image_size
+ self.width_coefficient = width_coefficient
+ self.depth_coefficient = depth_coefficient
+ self.depth_divisor = depth_divisor
+ self.kernel_sizes = kernel_sizes
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.depthwise_padding = depthwise_padding
+ self.strides = strides
+ self.num_block_repeats = num_block_repeats
+ self.expand_ratios = expand_ratios
+ self.squeeze_expansion_ratio = squeeze_expansion_ratio
+ self.hidden_act = hidden_act
+ self.hidden_dim = hidden_dim
+ self.pooling_type = pooling_type
+ self.initializer_range = initializer_range
+ self.batch_norm_eps = batch_norm_eps
+ self.batch_norm_momentum = batch_norm_momentum
+ self.dropout_rate = dropout_rate
+ self.drop_connect_rate = drop_connect_rate
+ self.num_hidden_layers = sum(num_block_repeats) * 4
+
+
+class EfficientNetOnnxConfig(OnnxConfig):
+ torch_onnx_minimum_version = version.parse("1.11")
+
+ @property
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
+ return OrderedDict(
+ [
+ ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
+ ]
+ )
+
+ @property
+ def atol_for_validation(self) -> float:
+ return 1e-5
+
+
+__all__ = ["EfficientNetConfig", "EfficientNetOnnxConfig"]
diff --git a/docs/transformers/build/lib/transformers/models/efficientnet/convert_efficientnet_to_pytorch.py b/docs/transformers/build/lib/transformers/models/efficientnet/convert_efficientnet_to_pytorch.py
new file mode 100644
index 0000000000000000000000000000000000000000..e9988524aca04de2a1d600586ff01d9b9a3ea6c2
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/efficientnet/convert_efficientnet_to_pytorch.py
@@ -0,0 +1,339 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Convert EfficientNet checkpoints from the original repository.
+
+URL: https://github.com/keras-team/keras/blob/v2.11.0/keras/applications/efficientnet.py"""
+
+import argparse
+import json
+import os
+
+import numpy as np
+import PIL
+import requests
+import tensorflow.keras.applications.efficientnet as efficientnet
+import torch
+from huggingface_hub import hf_hub_download
+from PIL import Image
+from tensorflow.keras.preprocessing import image
+
+from transformers import (
+ EfficientNetConfig,
+ EfficientNetForImageClassification,
+ EfficientNetImageProcessor,
+)
+from transformers.utils import logging
+
+
+logging.set_verbosity_info()
+logger = logging.get_logger(__name__)
+
+model_classes = {
+ "b0": efficientnet.EfficientNetB0,
+ "b1": efficientnet.EfficientNetB1,
+ "b2": efficientnet.EfficientNetB2,
+ "b3": efficientnet.EfficientNetB3,
+ "b4": efficientnet.EfficientNetB4,
+ "b5": efficientnet.EfficientNetB5,
+ "b6": efficientnet.EfficientNetB6,
+ "b7": efficientnet.EfficientNetB7,
+}
+
+CONFIG_MAP = {
+ "b0": {
+ "hidden_dim": 1280,
+ "width_coef": 1.0,
+ "depth_coef": 1.0,
+ "image_size": 224,
+ "dropout_rate": 0.2,
+ "dw_padding": [],
+ },
+ "b1": {
+ "hidden_dim": 1280,
+ "width_coef": 1.0,
+ "depth_coef": 1.1,
+ "image_size": 240,
+ "dropout_rate": 0.2,
+ "dw_padding": [16],
+ },
+ "b2": {
+ "hidden_dim": 1408,
+ "width_coef": 1.1,
+ "depth_coef": 1.2,
+ "image_size": 260,
+ "dropout_rate": 0.3,
+ "dw_padding": [5, 8, 16],
+ },
+ "b3": {
+ "hidden_dim": 1536,
+ "width_coef": 1.2,
+ "depth_coef": 1.4,
+ "image_size": 300,
+ "dropout_rate": 0.3,
+ "dw_padding": [5, 18],
+ },
+ "b4": {
+ "hidden_dim": 1792,
+ "width_coef": 1.4,
+ "depth_coef": 1.8,
+ "image_size": 380,
+ "dropout_rate": 0.4,
+ "dw_padding": [6],
+ },
+ "b5": {
+ "hidden_dim": 2048,
+ "width_coef": 1.6,
+ "depth_coef": 2.2,
+ "image_size": 456,
+ "dropout_rate": 0.4,
+ "dw_padding": [13, 27],
+ },
+ "b6": {
+ "hidden_dim": 2304,
+ "width_coef": 1.8,
+ "depth_coef": 2.6,
+ "image_size": 528,
+ "dropout_rate": 0.5,
+ "dw_padding": [31],
+ },
+ "b7": {
+ "hidden_dim": 2560,
+ "width_coef": 2.0,
+ "depth_coef": 3.1,
+ "image_size": 600,
+ "dropout_rate": 0.5,
+ "dw_padding": [18],
+ },
+}
+
+
+def get_efficientnet_config(model_name):
+ config = EfficientNetConfig()
+ config.hidden_dim = CONFIG_MAP[model_name]["hidden_dim"]
+ config.width_coefficient = CONFIG_MAP[model_name]["width_coef"]
+ config.depth_coefficient = CONFIG_MAP[model_name]["depth_coef"]
+ config.image_size = CONFIG_MAP[model_name]["image_size"]
+ config.dropout_rate = CONFIG_MAP[model_name]["dropout_rate"]
+ config.depthwise_padding = CONFIG_MAP[model_name]["dw_padding"]
+
+ repo_id = "huggingface/label-files"
+ filename = "imagenet-1k-id2label.json"
+ config.num_labels = 1000
+ id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
+ id2label = {int(k): v for k, v in id2label.items()}
+
+ config.id2label = id2label
+ config.label2id = {v: k for k, v in id2label.items()}
+ return config
+
+
+# We will verify our results on an image of cute cats
+def prepare_img():
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ im = Image.open(requests.get(url, stream=True).raw)
+ return im
+
+
+def convert_image_processor(model_name):
+ size = CONFIG_MAP[model_name]["image_size"]
+ preprocessor = EfficientNetImageProcessor(
+ size={"height": size, "width": size},
+ image_mean=[0.485, 0.456, 0.406],
+ image_std=[0.47853944, 0.4732864, 0.47434163],
+ do_center_crop=False,
+ )
+ return preprocessor
+
+
+# here we list all keys to be renamed (original name on the left, our name on the right)
+def rename_keys(original_param_names):
+ block_names = [v.split("_")[0].split("block")[1] for v in original_param_names if v.startswith("block")]
+ block_names = sorted(set(block_names))
+ num_blocks = len(block_names)
+ block_name_mapping = {b: str(i) for b, i in zip(block_names, range(num_blocks))}
+
+ rename_keys = []
+ rename_keys.append(("stem_conv/kernel:0", "embeddings.convolution.weight"))
+ rename_keys.append(("stem_bn/gamma:0", "embeddings.batchnorm.weight"))
+ rename_keys.append(("stem_bn/beta:0", "embeddings.batchnorm.bias"))
+ rename_keys.append(("stem_bn/moving_mean:0", "embeddings.batchnorm.running_mean"))
+ rename_keys.append(("stem_bn/moving_variance:0", "embeddings.batchnorm.running_var"))
+
+ for b in block_names:
+ hf_b = block_name_mapping[b]
+ rename_keys.append((f"block{b}_expand_conv/kernel:0", f"encoder.blocks.{hf_b}.expansion.expand_conv.weight"))
+ rename_keys.append((f"block{b}_expand_bn/gamma:0", f"encoder.blocks.{hf_b}.expansion.expand_bn.weight"))
+ rename_keys.append((f"block{b}_expand_bn/beta:0", f"encoder.blocks.{hf_b}.expansion.expand_bn.bias"))
+ rename_keys.append(
+ (f"block{b}_expand_bn/moving_mean:0", f"encoder.blocks.{hf_b}.expansion.expand_bn.running_mean")
+ )
+ rename_keys.append(
+ (f"block{b}_expand_bn/moving_variance:0", f"encoder.blocks.{hf_b}.expansion.expand_bn.running_var")
+ )
+ rename_keys.append(
+ (f"block{b}_dwconv/depthwise_kernel:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_conv.weight")
+ )
+ rename_keys.append((f"block{b}_bn/gamma:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_norm.weight"))
+ rename_keys.append((f"block{b}_bn/beta:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_norm.bias"))
+ rename_keys.append(
+ (f"block{b}_bn/moving_mean:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_norm.running_mean")
+ )
+ rename_keys.append(
+ (f"block{b}_bn/moving_variance:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_norm.running_var")
+ )
+
+ rename_keys.append((f"block{b}_se_reduce/kernel:0", f"encoder.blocks.{hf_b}.squeeze_excite.reduce.weight"))
+ rename_keys.append((f"block{b}_se_reduce/bias:0", f"encoder.blocks.{hf_b}.squeeze_excite.reduce.bias"))
+ rename_keys.append((f"block{b}_se_expand/kernel:0", f"encoder.blocks.{hf_b}.squeeze_excite.expand.weight"))
+ rename_keys.append((f"block{b}_se_expand/bias:0", f"encoder.blocks.{hf_b}.squeeze_excite.expand.bias"))
+ rename_keys.append(
+ (f"block{b}_project_conv/kernel:0", f"encoder.blocks.{hf_b}.projection.project_conv.weight")
+ )
+ rename_keys.append((f"block{b}_project_bn/gamma:0", f"encoder.blocks.{hf_b}.projection.project_bn.weight"))
+ rename_keys.append((f"block{b}_project_bn/beta:0", f"encoder.blocks.{hf_b}.projection.project_bn.bias"))
+ rename_keys.append(
+ (f"block{b}_project_bn/moving_mean:0", f"encoder.blocks.{hf_b}.projection.project_bn.running_mean")
+ )
+ rename_keys.append(
+ (f"block{b}_project_bn/moving_variance:0", f"encoder.blocks.{hf_b}.projection.project_bn.running_var")
+ )
+
+ rename_keys.append(("top_conv/kernel:0", "encoder.top_conv.weight"))
+ rename_keys.append(("top_bn/gamma:0", "encoder.top_bn.weight"))
+ rename_keys.append(("top_bn/beta:0", "encoder.top_bn.bias"))
+ rename_keys.append(("top_bn/moving_mean:0", "encoder.top_bn.running_mean"))
+ rename_keys.append(("top_bn/moving_variance:0", "encoder.top_bn.running_var"))
+
+ key_mapping = {}
+ for item in rename_keys:
+ if item[0] in original_param_names:
+ key_mapping[item[0]] = "efficientnet." + item[1]
+
+ key_mapping["predictions/kernel:0"] = "classifier.weight"
+ key_mapping["predictions/bias:0"] = "classifier.bias"
+ return key_mapping
+
+
+def replace_params(hf_params, tf_params, key_mapping):
+ for key, value in tf_params.items():
+ if "normalization" in key:
+ continue
+
+ hf_key = key_mapping[key]
+ if "_conv" in key and "kernel" in key:
+ new_hf_value = torch.from_numpy(value).permute(3, 2, 0, 1)
+ elif "depthwise_kernel" in key:
+ new_hf_value = torch.from_numpy(value).permute(2, 3, 0, 1)
+ elif "kernel" in key:
+ new_hf_value = torch.from_numpy(np.transpose(value))
+ else:
+ new_hf_value = torch.from_numpy(value)
+
+ # Replace HF parameters with original TF model parameters
+ assert hf_params[hf_key].shape == new_hf_value.shape
+ hf_params[hf_key].copy_(new_hf_value)
+
+
+@torch.no_grad()
+def convert_efficientnet_checkpoint(model_name, pytorch_dump_folder_path, save_model, push_to_hub):
+ """
+ Copy/paste/tweak model's weights to our EfficientNet structure.
+ """
+ # Load original model
+ original_model = model_classes[model_name](
+ include_top=True,
+ weights="imagenet",
+ input_tensor=None,
+ input_shape=None,
+ pooling=None,
+ classes=1000,
+ classifier_activation="softmax",
+ )
+
+ tf_params = original_model.trainable_variables
+ tf_non_train_params = original_model.non_trainable_variables
+ tf_params = {param.name: param.numpy() for param in tf_params}
+ for param in tf_non_train_params:
+ tf_params[param.name] = param.numpy()
+ tf_param_names = list(tf_params.keys())
+
+ # Load HuggingFace model
+ config = get_efficientnet_config(model_name)
+ hf_model = EfficientNetForImageClassification(config).eval()
+ hf_params = hf_model.state_dict()
+
+ # Create src-to-dst parameter name mapping dictionary
+ print("Converting parameters...")
+ key_mapping = rename_keys(tf_param_names)
+ replace_params(hf_params, tf_params, key_mapping)
+
+ # Initialize preprocessor and preprocess input image
+ preprocessor = convert_image_processor(model_name)
+ inputs = preprocessor(images=prepare_img(), return_tensors="pt")
+
+ # HF model inference
+ hf_model.eval()
+ with torch.no_grad():
+ outputs = hf_model(**inputs)
+ hf_logits = outputs.logits.detach().numpy()
+
+ # Original model inference
+ original_model.trainable = False
+ image_size = CONFIG_MAP[model_name]["image_size"]
+ img = prepare_img().resize((image_size, image_size), resample=PIL.Image.NEAREST)
+ x = image.img_to_array(img)
+ x = np.expand_dims(x, axis=0)
+ original_logits = original_model.predict(x)
+
+ # Check whether original and HF model outputs match -> np.allclose
+ assert np.allclose(original_logits, hf_logits, atol=1e-3), "The predicted logits are not the same."
+ print("Model outputs match!")
+
+ if save_model:
+ # Create folder to save model
+ if not os.path.isdir(pytorch_dump_folder_path):
+ os.mkdir(pytorch_dump_folder_path)
+ # Save converted model and image processor
+ hf_model.save_pretrained(pytorch_dump_folder_path)
+ preprocessor.save_pretrained(pytorch_dump_folder_path)
+
+ if push_to_hub:
+ # Push model and image processor to hub
+ print(f"Pushing converted {model_name} to the hub...")
+ model_name = f"efficientnet-{model_name}"
+ preprocessor.push_to_hub(model_name)
+ hf_model.push_to_hub(model_name)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ # Required parameters
+ parser.add_argument(
+ "--model_name",
+ default="b0",
+ type=str,
+ help="Version name of the EfficientNet model you want to convert, select from [b0, b1, b2, b3, b4, b5, b6, b7].",
+ )
+ parser.add_argument(
+ "--pytorch_dump_folder_path",
+ default="hf_model",
+ type=str,
+ help="Path to the output PyTorch model directory.",
+ )
+ parser.add_argument("--save_model", action="store_true", help="Save model to local")
+ parser.add_argument("--push_to_hub", action="store_true", help="Push model and image processor to the hub")
+
+ args = parser.parse_args()
+ convert_efficientnet_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.save_model, args.push_to_hub)
diff --git a/docs/transformers/build/lib/transformers/models/efficientnet/image_processing_efficientnet.py b/docs/transformers/build/lib/transformers/models/efficientnet/image_processing_efficientnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..612ede7086ead59ff28a051faf957a9120fdfa61
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/efficientnet/image_processing_efficientnet.py
@@ -0,0 +1,369 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Image processor class for EfficientNet."""
+
+from typing import Dict, List, Optional, Union
+
+import numpy as np
+
+from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
+from ...image_transforms import rescale, resize, to_channel_dimension_format
+from ...image_utils import (
+ IMAGENET_STANDARD_MEAN,
+ IMAGENET_STANDARD_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ infer_channel_dimension_format,
+ is_scaled_image,
+ make_list_of_images,
+ to_numpy_array,
+ valid_images,
+ validate_preprocess_arguments,
+)
+from ...utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging
+
+
+if is_vision_available():
+ import PIL
+
+
+logger = logging.get_logger(__name__)
+
+
+class EfficientNetImageProcessor(BaseImageProcessor):
+ r"""
+ Constructs a EfficientNet image processor.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
+ `do_resize` in `preprocess`.
+ size (`Dict[str, int]` *optional*, defaults to `{"height": 346, "width": 346}`):
+ Size of the image after `resize`. Can be overridden by `size` in `preprocess`.
+ resample (`PILImageResampling` filter, *optional*, defaults to 0):
+ Resampling filter to use if resizing the image. Can be overridden by `resample` in `preprocess`.
+ do_center_crop (`bool`, *optional*, defaults to `False`):
+ Whether to center crop the image. If the input size is smaller than `crop_size` along any edge, the image
+ is padded with 0's and then center cropped. Can be overridden by `do_center_crop` in `preprocess`.
+ crop_size (`Dict[str, int]`, *optional*, defaults to `{"height": 289, "width": 289}`):
+ Desired output size when applying center-cropping. Can be overridden by `crop_size` in `preprocess`.
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
+ Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
+ `preprocess` method.
+ rescale_offset (`bool`, *optional*, defaults to `False`):
+ Whether to rescale the image between [-scale_range, scale_range] instead of [0, scale_range]. Can be
+ overridden by the `rescale_factor` parameter in the `preprocess` method.
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
+ parameter in the `preprocess` method.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
+ method.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
+ image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
+ include_top (`bool`, *optional*, defaults to `True`):
+ Whether to rescale the image again. Should be set to True if the inputs are used for image classification.
+ """
+
+ model_input_names = ["pixel_values"]
+
+ def __init__(
+ self,
+ do_resize: bool = True,
+ size: Dict[str, int] = None,
+ resample: PILImageResampling = PIL.Image.NEAREST,
+ do_center_crop: bool = False,
+ crop_size: Dict[str, int] = None,
+ rescale_factor: Union[int, float] = 1 / 255,
+ rescale_offset: bool = False,
+ do_rescale: bool = True,
+ do_normalize: bool = True,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ include_top: bool = True,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ size = size if size is not None else {"height": 346, "width": 346}
+ size = get_size_dict(size)
+ crop_size = crop_size if crop_size is not None else {"height": 289, "width": 289}
+ crop_size = get_size_dict(crop_size, param_name="crop_size")
+
+ self.do_resize = do_resize
+ self.size = size
+ self.resample = resample
+ self.do_center_crop = do_center_crop
+ self.crop_size = crop_size
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
+ self.rescale_offset = rescale_offset
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
+ self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
+ self.include_top = include_top
+
+ # Copied from transformers.models.vit.image_processing_vit.ViTImageProcessor.resize with PILImageResampling.BILINEAR->PILImageResampling.NEAREST
+ def resize(
+ self,
+ image: np.ndarray,
+ size: Dict[str, int],
+ resample: PILImageResampling = PILImageResampling.NEAREST,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ) -> np.ndarray:
+ """
+ Resize an image to `(size["height"], size["width"])`.
+
+ Args:
+ image (`np.ndarray`):
+ Image to resize.
+ size (`Dict[str, int]`):
+ Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.NEAREST`):
+ `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.NEAREST`.
+ data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
+ image is used. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+
+ Returns:
+ `np.ndarray`: The resized image.
+ """
+ size = get_size_dict(size)
+ if "height" not in size or "width" not in size:
+ raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
+ output_size = (size["height"], size["width"])
+ return resize(
+ image,
+ size=output_size,
+ resample=resample,
+ data_format=data_format,
+ input_data_format=input_data_format,
+ **kwargs,
+ )
+
+ def rescale(
+ self,
+ image: np.ndarray,
+ scale: Union[int, float],
+ offset: bool = True,
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ **kwargs,
+ ):
+ """
+ Rescale an image by a scale factor.
+
+ If `offset` is `True`, the image has its values rescaled by `scale` and then offset by 1. If `scale` is
+ 1/127.5, the image is rescaled between [-1, 1].
+ image = image * scale - 1
+
+ If `offset` is `False`, and `scale` is 1/255, the image is rescaled between [0, 1].
+ image = image * scale
+
+ Args:
+ image (`np.ndarray`):
+ Image to rescale.
+ scale (`int` or `float`):
+ Scale to apply to the image.
+ offset (`bool`, *optional*):
+ Whether to scale the image in both negative and positive directions.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format of the input image. If not provided, it will be inferred.
+ """
+ rescaled_image = rescale(
+ image, scale=scale, data_format=data_format, input_data_format=input_data_format, **kwargs
+ )
+
+ if offset:
+ rescaled_image = rescaled_image - 1
+
+ return rescaled_image
+
+ @filter_out_non_signature_kwargs()
+ def preprocess(
+ self,
+ images: ImageInput,
+ do_resize: Optional[bool] = None,
+ size: Dict[str, int] = None,
+ resample=None,
+ do_center_crop: Optional[bool] = None,
+ crop_size: Dict[str, int] = None,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ rescale_offset: Optional[bool] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ include_top: Optional[bool] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: ChannelDimension = ChannelDimension.FIRST,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> PIL.Image.Image:
+ """
+ Preprocess an image or batch of images.
+
+ Args:
+ images (`ImageInput`):
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
+ Size of the image after `resize`.
+ resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
+ PILImageResampling filter to use if resizing the image Only has an effect if `do_resize` is set to
+ `True`.
+ do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
+ Whether to center crop the image.
+ crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`):
+ Size of the image after center crop. If one edge the image is smaller than `crop_size`, it will be
+ padded with zeros and then cropped
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image values between [0 - 1].
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+ rescale_offset (`bool`, *optional*, defaults to `self.rescale_offset`):
+ Whether to rescale the image between [-scale_range, scale_range] instead of [0, scale_range].
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
+ Image mean.
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
+ Image standard deviation.
+ include_top (`bool`, *optional*, defaults to `self.include_top`):
+ Rescales the image again for image classification if set to True.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - `None`: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ """
+ do_resize = do_resize if do_resize is not None else self.do_resize
+ resample = resample if resample is not None else self.resample
+ do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ rescale_offset = rescale_offset if rescale_offset is not None else self.rescale_offset
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+ include_top = include_top if include_top is not None else self.include_top
+
+ size = size if size is not None else self.size
+ size = get_size_dict(size)
+ crop_size = crop_size if crop_size is not None else self.crop_size
+ crop_size = get_size_dict(crop_size, param_name="crop_size")
+
+ images = make_list_of_images(images)
+
+ if not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+ validate_preprocess_arguments(
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ do_center_crop=do_center_crop,
+ crop_size=crop_size,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ )
+ # All transformations expect numpy arrays.
+ images = [to_numpy_array(image) for image in images]
+
+ if do_rescale and is_scaled_image(images[0]):
+ logger.warning_once(
+ "It looks like you are trying to rescale already rescaled images. If the input"
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
+ )
+
+ if input_data_format is None:
+ # We assume that all images have the same channel dimension format.
+ input_data_format = infer_channel_dimension_format(images[0])
+
+ if do_resize:
+ images = [
+ self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
+ for image in images
+ ]
+
+ if do_center_crop:
+ images = [
+ self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) for image in images
+ ]
+
+ if do_rescale:
+ images = [
+ self.rescale(
+ image=image, scale=rescale_factor, offset=rescale_offset, input_data_format=input_data_format
+ )
+ for image in images
+ ]
+
+ if do_normalize:
+ images = [
+ self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
+ for image in images
+ ]
+
+ if include_top:
+ images = [
+ self.normalize(image=image, mean=0, std=image_std, input_data_format=input_data_format)
+ for image in images
+ ]
+
+ images = [
+ to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
+ ]
+
+ data = {"pixel_values": images}
+ return BatchFeature(data=data, tensor_type=return_tensors)
+
+
+__all__ = ["EfficientNetImageProcessor"]
diff --git a/docs/transformers/build/lib/transformers/models/efficientnet/image_processing_efficientnet_fast.py b/docs/transformers/build/lib/transformers/models/efficientnet/image_processing_efficientnet_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb639564014fa7eaa2c3114bce6a23e7e76ecf46
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/efficientnet/image_processing_efficientnet_fast.py
@@ -0,0 +1,226 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Fast Image processor class for EfficientNet."""
+
+from functools import lru_cache
+from typing import Optional, Union
+
+from ...image_processing_utils_fast import (
+ BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
+ BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
+ BaseImageProcessorFast,
+ BatchFeature,
+ DefaultFastImageProcessorKwargs,
+)
+from ...image_transforms import group_images_by_shape, reorder_images
+from ...image_utils import (
+ IMAGENET_STANDARD_MEAN,
+ IMAGENET_STANDARD_STD,
+ ImageInput,
+ PILImageResampling,
+ SizeDict,
+)
+from ...processing_utils import Unpack
+from ...utils import (
+ TensorType,
+ add_start_docstrings,
+ is_torch_available,
+ is_torchvision_available,
+ is_torchvision_v2_available,
+)
+
+
+if is_torch_available():
+ import torch
+
+if is_torchvision_available():
+ if is_torchvision_v2_available():
+ from torchvision.transforms.v2 import functional as F
+ else:
+ from torchvision.transforms import functional as F
+
+
+class EfficientNetFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
+ rescale_offset: bool
+ include_top: bool
+
+
+@add_start_docstrings(
+ "Constructs a fast EfficientNet image processor.",
+ BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
+)
+class EfficientNetImageProcessorFast(BaseImageProcessorFast):
+ resample = PILImageResampling.NEAREST
+ image_mean = IMAGENET_STANDARD_MEAN
+ image_std = IMAGENET_STANDARD_STD
+ size = {"height": 346, "width": 346}
+ crop_size = {"height": 289, "width": 289}
+ do_resize = True
+ do_center_crop = False
+ do_rescale = True
+ rescale_factor = 1 / 255
+ rescale_offset = False
+ do_normalize = True
+ include_top = True
+ valid_kwargs = EfficientNetFastImageProcessorKwargs
+
+ def __init__(self, **kwargs: Unpack[EfficientNetFastImageProcessorKwargs]):
+ super().__init__(**kwargs)
+
+ def rescale(
+ self,
+ image: "torch.Tensor",
+ scale: float,
+ offset: Optional[bool] = True,
+ **kwargs,
+ ) -> "torch.Tensor":
+ """
+ Rescale an image by a scale factor.
+
+ If `offset` is `True`, the image has its values rescaled by `scale` and then offset by 1. If `scale` is
+ 1/127.5, the image is rescaled between [-1, 1].
+ image = image * scale - 1
+
+ If `offset` is `False`, and `scale` is 1/255, the image is rescaled between [0, 1].
+ image = image * scale
+
+ Args:
+ image (`torch.Tensor`):
+ Image to rescale.
+ scale (`float`):
+ The scaling factor to rescale pixel values by.
+ offset (`bool`, *optional*):
+ Whether to scale the image in both negative and positive directions.
+
+ Returns:
+ `torch.Tensor`: The rescaled image.
+ """
+
+ rescaled_image = image * scale
+
+ if offset:
+ rescaled_image -= 1
+
+ return rescaled_image
+
+ @lru_cache(maxsize=10)
+ def _fuse_mean_std_and_rescale_factor(
+ self,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Union[float, list[float]]] = None,
+ image_std: Optional[Union[float, list[float]]] = None,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ device: Optional["torch.device"] = None,
+ rescale_offset: Optional[bool] = False,
+ ) -> tuple:
+ if do_rescale and do_normalize and not rescale_offset:
+ # Fused rescale and normalize
+ image_mean = torch.tensor(image_mean, device=device) * (1.0 / rescale_factor)
+ image_std = torch.tensor(image_std, device=device) * (1.0 / rescale_factor)
+ do_rescale = False
+ return image_mean, image_std, do_rescale
+
+ def rescale_and_normalize(
+ self,
+ images: "torch.Tensor",
+ do_rescale: bool,
+ rescale_factor: float,
+ do_normalize: bool,
+ image_mean: Union[float, list[float]],
+ image_std: Union[float, list[float]],
+ rescale_offset: bool = False,
+ ) -> "torch.Tensor":
+ """
+ Rescale and normalize images.
+ """
+ image_mean, image_std, do_rescale = self._fuse_mean_std_and_rescale_factor(
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ device=images.device,
+ rescale_offset=rescale_offset,
+ )
+ # if/elif as we use fused rescale and normalize if both are set to True
+ if do_rescale:
+ images = self.rescale(images, rescale_factor, rescale_offset)
+ if do_normalize:
+ images = self.normalize(images.to(dtype=torch.float32), image_mean, image_std)
+
+ return images
+
+ def _preprocess(
+ self,
+ images: list["torch.Tensor"],
+ do_resize: bool,
+ size: SizeDict,
+ interpolation: Optional["F.InterpolationMode"],
+ do_center_crop: bool,
+ crop_size: SizeDict,
+ do_rescale: bool,
+ rescale_factor: float,
+ rescale_offset: bool,
+ do_normalize: bool,
+ include_top: bool,
+ image_mean: Optional[Union[float, list[float]]],
+ image_std: Optional[Union[float, list[float]]],
+ return_tensors: Optional[Union[str, TensorType]],
+ **kwargs,
+ ) -> BatchFeature:
+ # Group images by size for batched resizing
+ grouped_images, grouped_images_index = group_images_by_shape(images)
+ resized_images_grouped = {}
+ for shape, stacked_images in grouped_images.items():
+ if do_resize:
+ stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation)
+ resized_images_grouped[shape] = stacked_images
+ resized_images = reorder_images(resized_images_grouped, grouped_images_index)
+
+ # Group images by size for further processing
+ # Needed in case do_resize is False, or resize returns images with different sizes
+ grouped_images, grouped_images_index = group_images_by_shape(resized_images)
+ processed_images_grouped = {}
+ for shape, stacked_images in grouped_images.items():
+ if do_center_crop:
+ stacked_images = self.center_crop(stacked_images, crop_size)
+ # Fused rescale and normalize
+ stacked_images = self.rescale_and_normalize(
+ stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std, rescale_offset
+ )
+ if include_top:
+ stacked_images = self.normalize(stacked_images, 0, image_std)
+ processed_images_grouped[shape] = stacked_images
+
+ processed_images = reorder_images(processed_images_grouped, grouped_images_index)
+ processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
+
+ return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
+
+ @add_start_docstrings(
+ BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
+ """
+ rescale_offset (`bool`, *optional*, defaults to `self.rescale_offset`):
+ Whether to rescale the image between [-max_range/2, scale_range/2] instead of [0, scale_range].
+ include_top (`bool`, *optional*, defaults to `self.include_top`):
+ Normalize the image again with the standard deviation only for image classification if set to True.
+ """,
+ )
+ def preprocess(self, images: ImageInput, **kwargs: Unpack[EfficientNetFastImageProcessorKwargs]) -> BatchFeature:
+ return super().preprocess(images, **kwargs)
+
+
+__all__ = ["EfficientNetImageProcessorFast"]
diff --git a/docs/transformers/build/lib/transformers/models/efficientnet/modeling_efficientnet.py b/docs/transformers/build/lib/transformers/models/efficientnet/modeling_efficientnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e0b89072921b47bd42a5febf9cdd576a31e33e1
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/efficientnet/modeling_efficientnet.py
@@ -0,0 +1,647 @@
+# coding=utf-8
+# Copyright 2023 Google Research, Inc. and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch EfficientNet model."""
+
+import math
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from ...activations import ACT2FN
+from ...modeling_outputs import (
+ BaseModelOutputWithNoAttention,
+ BaseModelOutputWithPoolingAndNoAttention,
+ ImageClassifierOutputWithNoAttention,
+)
+from ...modeling_utils import PreTrainedModel
+from ...utils import (
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+)
+from .configuration_efficientnet import EfficientNetConfig
+
+
+logger = logging.get_logger(__name__)
+
+# General docstring
+_CONFIG_FOR_DOC = "EfficientNetConfig"
+
+# Base docstring
+_CHECKPOINT_FOR_DOC = "google/efficientnet-b7"
+_EXPECTED_OUTPUT_SHAPE = [1, 768, 7, 7]
+
+# Image classification docstring
+_IMAGE_CLASS_CHECKPOINT = "google/efficientnet-b7"
+_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
+
+
+EFFICIENTNET_START_DOCSTRING = r"""
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+ behavior.
+
+ Parameters:
+ config ([`EfficientNetConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+EFFICIENTNET_INPUTS_DOCSTRING = r"""
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
+ [`AutoImageProcessor.__call__`] for details.
+
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+def round_filters(config: EfficientNetConfig, num_channels: int):
+ r"""
+ Round number of filters based on depth multiplier.
+ """
+ divisor = config.depth_divisor
+ num_channels *= config.width_coefficient
+ new_dim = max(divisor, int(num_channels + divisor / 2) // divisor * divisor)
+
+ # Make sure that round down does not go down by more than 10%.
+ if new_dim < 0.9 * num_channels:
+ new_dim += divisor
+
+ return int(new_dim)
+
+
+def correct_pad(kernel_size: Union[int, Tuple], adjust: bool = True):
+ r"""
+ Utility function to get the tuple padding value for the depthwise convolution.
+
+ Args:
+ kernel_size (`int` or `tuple`):
+ Kernel size of the convolution layers.
+ adjust (`bool`, *optional*, defaults to `True`):
+ Adjusts padding value to apply to right and bottom sides of the input.
+ """
+ if isinstance(kernel_size, int):
+ kernel_size = (kernel_size, kernel_size)
+
+ correct = (kernel_size[0] // 2, kernel_size[1] // 2)
+ if adjust:
+ return (correct[1] - 1, correct[1], correct[0] - 1, correct[0])
+ else:
+ return (correct[1], correct[1], correct[0], correct[0])
+
+
+class EfficientNetEmbeddings(nn.Module):
+ r"""
+ A module that corresponds to the stem module of the original work.
+ """
+
+ def __init__(self, config: EfficientNetConfig):
+ super().__init__()
+
+ self.out_dim = round_filters(config, 32)
+ self.padding = nn.ZeroPad2d(padding=(0, 1, 0, 1))
+ self.convolution = nn.Conv2d(
+ config.num_channels, self.out_dim, kernel_size=3, stride=2, padding="valid", bias=False
+ )
+ self.batchnorm = nn.BatchNorm2d(self.out_dim, eps=config.batch_norm_eps, momentum=config.batch_norm_momentum)
+ self.activation = ACT2FN[config.hidden_act]
+
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
+ features = self.padding(pixel_values)
+ features = self.convolution(features)
+ features = self.batchnorm(features)
+ features = self.activation(features)
+
+ return features
+
+
+class EfficientNetDepthwiseConv2d(nn.Conv2d):
+ def __init__(
+ self,
+ in_channels,
+ depth_multiplier=1,
+ kernel_size=3,
+ stride=1,
+ padding=0,
+ dilation=1,
+ bias=True,
+ padding_mode="zeros",
+ ):
+ out_channels = in_channels * depth_multiplier
+ super().__init__(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=in_channels,
+ bias=bias,
+ padding_mode=padding_mode,
+ )
+
+
+class EfficientNetExpansionLayer(nn.Module):
+ r"""
+ This corresponds to the expansion phase of each block in the original implementation.
+ """
+
+ def __init__(self, config: EfficientNetConfig, in_dim: int, out_dim: int, stride: int):
+ super().__init__()
+ self.expand_conv = nn.Conv2d(
+ in_channels=in_dim,
+ out_channels=out_dim,
+ kernel_size=1,
+ padding="same",
+ bias=False,
+ )
+ self.expand_bn = nn.BatchNorm2d(num_features=out_dim, eps=config.batch_norm_eps)
+ self.expand_act = ACT2FN[config.hidden_act]
+
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor:
+ # Expand phase
+ hidden_states = self.expand_conv(hidden_states)
+ hidden_states = self.expand_bn(hidden_states)
+ hidden_states = self.expand_act(hidden_states)
+
+ return hidden_states
+
+
+class EfficientNetDepthwiseLayer(nn.Module):
+ r"""
+ This corresponds to the depthwise convolution phase of each block in the original implementation.
+ """
+
+ def __init__(
+ self,
+ config: EfficientNetConfig,
+ in_dim: int,
+ stride: int,
+ kernel_size: int,
+ adjust_padding: bool,
+ ):
+ super().__init__()
+ self.stride = stride
+ conv_pad = "valid" if self.stride == 2 else "same"
+ padding = correct_pad(kernel_size, adjust=adjust_padding)
+
+ self.depthwise_conv_pad = nn.ZeroPad2d(padding=padding)
+ self.depthwise_conv = EfficientNetDepthwiseConv2d(
+ in_dim, kernel_size=kernel_size, stride=stride, padding=conv_pad, bias=False
+ )
+ self.depthwise_norm = nn.BatchNorm2d(
+ num_features=in_dim, eps=config.batch_norm_eps, momentum=config.batch_norm_momentum
+ )
+ self.depthwise_act = ACT2FN[config.hidden_act]
+
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor:
+ # Depthwise convolution
+ if self.stride == 2:
+ hidden_states = self.depthwise_conv_pad(hidden_states)
+
+ hidden_states = self.depthwise_conv(hidden_states)
+ hidden_states = self.depthwise_norm(hidden_states)
+ hidden_states = self.depthwise_act(hidden_states)
+
+ return hidden_states
+
+
+class EfficientNetSqueezeExciteLayer(nn.Module):
+ r"""
+ This corresponds to the Squeeze and Excitement phase of each block in the original implementation.
+ """
+
+ def __init__(self, config: EfficientNetConfig, in_dim: int, expand_dim: int, expand: bool = False):
+ super().__init__()
+ self.dim = expand_dim if expand else in_dim
+ self.dim_se = max(1, int(in_dim * config.squeeze_expansion_ratio))
+
+ self.squeeze = nn.AdaptiveAvgPool2d(output_size=1)
+ self.reduce = nn.Conv2d(
+ in_channels=self.dim,
+ out_channels=self.dim_se,
+ kernel_size=1,
+ padding="same",
+ )
+ self.expand = nn.Conv2d(
+ in_channels=self.dim_se,
+ out_channels=self.dim,
+ kernel_size=1,
+ padding="same",
+ )
+ self.act_reduce = ACT2FN[config.hidden_act]
+ self.act_expand = nn.Sigmoid()
+
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor:
+ inputs = hidden_states
+ hidden_states = self.squeeze(hidden_states)
+ hidden_states = self.reduce(hidden_states)
+ hidden_states = self.act_reduce(hidden_states)
+
+ hidden_states = self.expand(hidden_states)
+ hidden_states = self.act_expand(hidden_states)
+ hidden_states = torch.mul(inputs, hidden_states)
+
+ return hidden_states
+
+
+class EfficientNetFinalBlockLayer(nn.Module):
+ r"""
+ This corresponds to the final phase of each block in the original implementation.
+ """
+
+ def __init__(
+ self, config: EfficientNetConfig, in_dim: int, out_dim: int, stride: int, drop_rate: float, id_skip: bool
+ ):
+ super().__init__()
+ self.apply_dropout = stride == 1 and not id_skip
+ self.project_conv = nn.Conv2d(
+ in_channels=in_dim,
+ out_channels=out_dim,
+ kernel_size=1,
+ padding="same",
+ bias=False,
+ )
+ self.project_bn = nn.BatchNorm2d(
+ num_features=out_dim, eps=config.batch_norm_eps, momentum=config.batch_norm_momentum
+ )
+ self.dropout = nn.Dropout(p=drop_rate)
+
+ def forward(self, embeddings: torch.FloatTensor, hidden_states: torch.FloatTensor) -> torch.Tensor:
+ hidden_states = self.project_conv(hidden_states)
+ hidden_states = self.project_bn(hidden_states)
+
+ if self.apply_dropout:
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = hidden_states + embeddings
+
+ return hidden_states
+
+
+class EfficientNetBlock(nn.Module):
+ r"""
+ This corresponds to the expansion and depthwise convolution phase of each block in the original implementation.
+
+ Args:
+ config ([`EfficientNetConfig`]):
+ Model configuration class.
+ in_dim (`int`):
+ Number of input channels.
+ out_dim (`int`):
+ Number of output channels.
+ stride (`int`):
+ Stride size to be used in convolution layers.
+ expand_ratio (`int`):
+ Expand ratio to set the output dimensions for the expansion and squeeze-excite layers.
+ kernel_size (`int`):
+ Kernel size for the depthwise convolution layer.
+ drop_rate (`float`):
+ Dropout rate to be used in the final phase of each block.
+ id_skip (`bool`):
+ Whether to apply dropout and sum the final hidden states with the input embeddings during the final phase
+ of each block. Set to `True` for the first block of each stage.
+ adjust_padding (`bool`):
+ Whether to apply padding to only right and bottom side of the input kernel before the depthwise convolution
+ operation, set to `True` for inputs with odd input sizes.
+ """
+
+ def __init__(
+ self,
+ config: EfficientNetConfig,
+ in_dim: int,
+ out_dim: int,
+ stride: int,
+ expand_ratio: int,
+ kernel_size: int,
+ drop_rate: float,
+ id_skip: bool,
+ adjust_padding: bool,
+ ):
+ super().__init__()
+ self.expand_ratio = expand_ratio
+ self.expand = True if self.expand_ratio != 1 else False
+ expand_in_dim = in_dim * expand_ratio
+
+ if self.expand:
+ self.expansion = EfficientNetExpansionLayer(
+ config=config, in_dim=in_dim, out_dim=expand_in_dim, stride=stride
+ )
+
+ self.depthwise_conv = EfficientNetDepthwiseLayer(
+ config=config,
+ in_dim=expand_in_dim if self.expand else in_dim,
+ stride=stride,
+ kernel_size=kernel_size,
+ adjust_padding=adjust_padding,
+ )
+ self.squeeze_excite = EfficientNetSqueezeExciteLayer(
+ config=config, in_dim=in_dim, expand_dim=expand_in_dim, expand=self.expand
+ )
+ self.projection = EfficientNetFinalBlockLayer(
+ config=config,
+ in_dim=expand_in_dim if self.expand else in_dim,
+ out_dim=out_dim,
+ stride=stride,
+ drop_rate=drop_rate,
+ id_skip=id_skip,
+ )
+
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor:
+ embeddings = hidden_states
+ # Expansion and depthwise convolution phase
+ if self.expand_ratio != 1:
+ hidden_states = self.expansion(hidden_states)
+ hidden_states = self.depthwise_conv(hidden_states)
+
+ # Squeeze and excite phase
+ hidden_states = self.squeeze_excite(hidden_states)
+ hidden_states = self.projection(embeddings, hidden_states)
+ return hidden_states
+
+
+class EfficientNetEncoder(nn.Module):
+ r"""
+ Forward propogates the embeddings through each EfficientNet block.
+
+ Args:
+ config ([`EfficientNetConfig`]):
+ Model configuration class.
+ """
+
+ def __init__(self, config: EfficientNetConfig):
+ super().__init__()
+ self.config = config
+ self.depth_coefficient = config.depth_coefficient
+
+ def round_repeats(repeats):
+ # Round number of block repeats based on depth multiplier.
+ return int(math.ceil(self.depth_coefficient * repeats))
+
+ num_base_blocks = len(config.in_channels)
+ num_blocks = sum(round_repeats(n) for n in config.num_block_repeats)
+
+ curr_block_num = 0
+ blocks = []
+ for i in range(num_base_blocks):
+ in_dim = round_filters(config, config.in_channels[i])
+ out_dim = round_filters(config, config.out_channels[i])
+ stride = config.strides[i]
+ kernel_size = config.kernel_sizes[i]
+ expand_ratio = config.expand_ratios[i]
+
+ for j in range(round_repeats(config.num_block_repeats[i])):
+ id_skip = True if j == 0 else False
+ stride = 1 if j > 0 else stride
+ in_dim = out_dim if j > 0 else in_dim
+ adjust_padding = False if curr_block_num in config.depthwise_padding else True
+ drop_rate = config.drop_connect_rate * curr_block_num / num_blocks
+
+ block = EfficientNetBlock(
+ config=config,
+ in_dim=in_dim,
+ out_dim=out_dim,
+ stride=stride,
+ kernel_size=kernel_size,
+ expand_ratio=expand_ratio,
+ drop_rate=drop_rate,
+ id_skip=id_skip,
+ adjust_padding=adjust_padding,
+ )
+ blocks.append(block)
+ curr_block_num += 1
+
+ self.blocks = nn.ModuleList(blocks)
+ self.top_conv = nn.Conv2d(
+ in_channels=out_dim,
+ out_channels=round_filters(config, 1280),
+ kernel_size=1,
+ padding="same",
+ bias=False,
+ )
+ self.top_bn = nn.BatchNorm2d(
+ num_features=config.hidden_dim, eps=config.batch_norm_eps, momentum=config.batch_norm_momentum
+ )
+ self.top_activation = ACT2FN[config.hidden_act]
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ output_hidden_states: Optional[bool] = False,
+ return_dict: Optional[bool] = True,
+ ) -> BaseModelOutputWithNoAttention:
+ all_hidden_states = (hidden_states,) if output_hidden_states else None
+
+ for block in self.blocks:
+ hidden_states = block(hidden_states)
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ hidden_states = self.top_conv(hidden_states)
+ hidden_states = self.top_bn(hidden_states)
+ hidden_states = self.top_activation(hidden_states)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)
+
+ return BaseModelOutputWithNoAttention(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ )
+
+
+class EfficientNetPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = EfficientNetConfig
+ base_model_prefix = "efficientnet"
+ main_input_name = "pixel_values"
+ _no_split_modules = []
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+
+@add_start_docstrings(
+ "The bare EfficientNet model outputting raw features without any specific head on top.",
+ EFFICIENTNET_START_DOCSTRING,
+)
+class EfficientNetModel(EfficientNetPreTrainedModel):
+ def __init__(self, config: EfficientNetConfig):
+ super().__init__(config)
+ self.config = config
+ self.embeddings = EfficientNetEmbeddings(config)
+ self.encoder = EfficientNetEncoder(config)
+
+ # Final pooling layer
+ if config.pooling_type == "mean":
+ self.pooler = nn.AvgPool2d(config.hidden_dim, ceil_mode=True)
+ elif config.pooling_type == "max":
+ self.pooler = nn.MaxPool2d(config.hidden_dim, ceil_mode=True)
+ else:
+ raise ValueError(f"config.pooling must be one of ['mean', 'max'] got {config.pooling}")
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(EFFICIENTNET_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=BaseModelOutputWithPoolingAndNoAttention,
+ config_class=_CONFIG_FOR_DOC,
+ modality="vision",
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
+ )
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPoolingAndNoAttention]:
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ embedding_output = self.embeddings(pixel_values)
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ # Apply pooling
+ last_hidden_state = encoder_outputs[0]
+ pooled_output = self.pooler(last_hidden_state)
+ # Reshape (batch_size, 1280, 1 , 1) -> (batch_size, 1280)
+ pooled_output = pooled_output.reshape(pooled_output.shape[:2])
+
+ if not return_dict:
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPoolingAndNoAttention(
+ last_hidden_state=last_hidden_state,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ )
+
+
+@add_start_docstrings(
+ """
+ EfficientNet Model with an image classification head on top (a linear layer on top of the pooled features), e.g.
+ for ImageNet.
+ """,
+ EFFICIENTNET_START_DOCSTRING,
+)
+class EfficientNetForImageClassification(EfficientNetPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.config = config
+ self.efficientnet = EfficientNetModel(config)
+ # Classifier head
+ self.dropout = nn.Dropout(p=config.dropout_rate)
+ self.classifier = nn.Linear(config.hidden_dim, self.num_labels) if self.num_labels > 0 else nn.Identity()
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(EFFICIENTNET_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_IMAGE_CLASS_CHECKPOINT,
+ output_type=ImageClassifierOutputWithNoAttention,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
+ )
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, ImageClassifierOutputWithNoAttention]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.efficientnet(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict)
+
+ pooled_output = outputs.pooler_output if return_dict else outputs[1]
+ pooled_output = self.dropout(pooled_output)
+ logits = self.classifier(pooled_output)
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return ImageClassifierOutputWithNoAttention(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ )
+
+
+__all__ = ["EfficientNetForImageClassification", "EfficientNetModel", "EfficientNetPreTrainedModel"]
diff --git a/docs/transformers/build/lib/transformers/models/electra/__init__.py b/docs/transformers/build/lib/transformers/models/electra/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a78ed5c42aea51038335efabde5b03e333592ed6
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/electra/__init__.py
@@ -0,0 +1,31 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_electra import *
+ from .modeling_electra import *
+ from .modeling_flax_electra import *
+ from .modeling_tf_electra import *
+ from .tokenization_electra import *
+ from .tokenization_electra_fast import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/docs/transformers/build/lib/transformers/models/electra/configuration_electra.py b/docs/transformers/build/lib/transformers/models/electra/configuration_electra.py
new file mode 100644
index 0000000000000000000000000000000000000000..20b242c0f8d65fd5a4a3109fdf3b58a3ff7b9181
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/electra/configuration_electra.py
@@ -0,0 +1,187 @@
+# coding=utf-8
+# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""ELECTRA model configuration"""
+
+from collections import OrderedDict
+from typing import Mapping
+
+from ...configuration_utils import PretrainedConfig
+from ...onnx import OnnxConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class ElectraConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`ElectraModel`] or a [`TFElectraModel`]. It is
+ used to instantiate a ELECTRA model according to the specified arguments, defining the model architecture.
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the ELECTRA
+ [google/electra-small-discriminator](https://huggingface.co/google/electra-small-discriminator) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 30522):
+ Vocabulary size of the ELECTRA model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`ElectraModel`] or [`TFElectraModel`].
+ embedding_size (`int`, *optional*, defaults to 128):
+ Dimensionality of the encoder layers and the pooler layer.
+ hidden_size (`int`, *optional*, defaults to 256):
+ Dimensionality of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 4):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ intermediate_size (`int`, *optional*, defaults to 1024):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for the attention probabilities.
+ max_position_embeddings (`int`, *optional*, defaults to 512):
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
+ just in case (e.g., 512 or 1024 or 2048).
+ type_vocab_size (`int`, *optional*, defaults to 2):
+ The vocabulary size of the `token_type_ids` passed when calling [`ElectraModel`] or [`TFElectraModel`].
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
+ The epsilon used by the layer normalization layers.
+ summary_type (`str`, *optional*, defaults to `"first"`):
+ Argument used when doing sequence summary. Used in the sequence classification and multiple choice models.
+
+ Has to be one of the following options:
+
+ - `"last"`: Take the last token hidden state (like XLNet).
+ - `"first"`: Take the first token hidden state (like BERT).
+ - `"mean"`: Take the mean of all tokens hidden states.
+ - `"cls_index"`: Supply a Tensor of classification token position (like GPT/GPT-2).
+ - `"attn"`: Not implemented now, use multi-head attention.
+ summary_use_proj (`bool`, *optional*, defaults to `True`):
+ Argument used when doing sequence summary. Used in the sequence classification and multiple choice models.
+
+ Whether or not to add a projection after the vector extraction.
+ summary_activation (`str`, *optional*):
+ Argument used when doing sequence summary. Used in the sequence classification and multiple choice models.
+
+ Pass `"gelu"` for a gelu activation to the output, any other value will result in no activation.
+ summary_last_dropout (`float`, *optional*, defaults to 0.0):
+ Argument used when doing sequence summary. Used in the sequence classification and multiple choice models.
+
+ The dropout ratio to be used after the projection and activation.
+ position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
+ Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
+ positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
+ [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155).
+ For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
+ with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658).
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ classifier_dropout (`float`, *optional*):
+ The dropout ratio for the classification head.
+
+ Examples:
+
+ ```python
+ >>> from transformers import ElectraConfig, ElectraModel
+
+ >>> # Initializing a ELECTRA electra-base-uncased style configuration
+ >>> configuration = ElectraConfig()
+
+ >>> # Initializing a model (with random weights) from the electra-base-uncased style configuration
+ >>> model = ElectraModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "electra"
+
+ def __init__(
+ self,
+ vocab_size=30522,
+ embedding_size=128,
+ hidden_size=256,
+ num_hidden_layers=12,
+ num_attention_heads=4,
+ intermediate_size=1024,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ max_position_embeddings=512,
+ type_vocab_size=2,
+ initializer_range=0.02,
+ layer_norm_eps=1e-12,
+ summary_type="first",
+ summary_use_proj=True,
+ summary_activation="gelu",
+ summary_last_dropout=0.1,
+ pad_token_id=0,
+ position_embedding_type="absolute",
+ use_cache=True,
+ classifier_dropout=None,
+ **kwargs,
+ ):
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
+
+ self.vocab_size = vocab_size
+ self.embedding_size = embedding_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.max_position_embeddings = max_position_embeddings
+ self.type_vocab_size = type_vocab_size
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+
+ self.summary_type = summary_type
+ self.summary_use_proj = summary_use_proj
+ self.summary_activation = summary_activation
+ self.summary_last_dropout = summary_last_dropout
+ self.position_embedding_type = position_embedding_type
+ self.use_cache = use_cache
+ self.classifier_dropout = classifier_dropout
+
+
+class ElectraOnnxConfig(OnnxConfig):
+ @property
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
+ if self.task == "multiple-choice":
+ dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
+ else:
+ dynamic_axis = {0: "batch", 1: "sequence"}
+ return OrderedDict(
+ [
+ ("input_ids", dynamic_axis),
+ ("attention_mask", dynamic_axis),
+ ("token_type_ids", dynamic_axis),
+ ]
+ )
+
+
+__all__ = ["ElectraConfig", "ElectraOnnxConfig"]
diff --git a/docs/transformers/build/lib/transformers/models/electra/convert_electra_original_tf_checkpoint_to_pytorch.py b/docs/transformers/build/lib/transformers/models/electra/convert_electra_original_tf_checkpoint_to_pytorch.py
new file mode 100644
index 0000000000000000000000000000000000000000..b0abc30cd758743b243baabbf1298bcc2e1e595e
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/electra/convert_electra_original_tf_checkpoint_to_pytorch.py
@@ -0,0 +1,79 @@
+# coding=utf-8
+# Copyright 2018 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Convert ELECTRA checkpoint."""
+
+import argparse
+
+import torch
+
+from transformers import ElectraConfig, ElectraForMaskedLM, ElectraForPreTraining, load_tf_weights_in_electra
+from transformers.utils import logging
+
+
+logging.set_verbosity_info()
+
+
+def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path, discriminator_or_generator):
+ # Initialise PyTorch model
+ config = ElectraConfig.from_json_file(config_file)
+ print(f"Building PyTorch model from configuration: {config}")
+
+ if discriminator_or_generator == "discriminator":
+ model = ElectraForPreTraining(config)
+ elif discriminator_or_generator == "generator":
+ model = ElectraForMaskedLM(config)
+ else:
+ raise ValueError("The discriminator_or_generator argument should be either 'discriminator' or 'generator'")
+
+ # Load weights from tf checkpoint
+ load_tf_weights_in_electra(
+ model, config, tf_checkpoint_path, discriminator_or_generator=discriminator_or_generator
+ )
+
+ # Save pytorch-model
+ print(f"Save PyTorch model to {pytorch_dump_path}")
+ torch.save(model.state_dict(), pytorch_dump_path)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ # Required parameters
+ parser.add_argument(
+ "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
+ )
+ parser.add_argument(
+ "--config_file",
+ default=None,
+ type=str,
+ required=True,
+ help="The config json file corresponding to the pre-trained model. \nThis specifies the model architecture.",
+ )
+ parser.add_argument(
+ "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
+ )
+ parser.add_argument(
+ "--discriminator_or_generator",
+ default=None,
+ type=str,
+ required=True,
+ help=(
+ "Whether to export the generator or the discriminator. Should be a string, either 'discriminator' or "
+ "'generator'."
+ ),
+ )
+ args = parser.parse_args()
+ convert_tf_checkpoint_to_pytorch(
+ args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path, args.discriminator_or_generator
+ )
diff --git a/docs/transformers/build/lib/transformers/models/electra/modeling_electra.py b/docs/transformers/build/lib/transformers/models/electra/modeling_electra.py
new file mode 100644
index 0000000000000000000000000000000000000000..08cc3e530d6f4a1753dd1e178ec178993e6c951d
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/electra/modeling_electra.py
@@ -0,0 +1,1777 @@
+# coding=utf-8
+# Copyright 2019 The Google AI Language Team Authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch ELECTRA model."""
+
+import math
+import os
+from dataclasses import dataclass
+from typing import Callable, List, Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from ...activations import ACT2FN, get_activation
+from ...generation import GenerationMixin
+from ...modeling_outputs import (
+ BaseModelOutputWithCrossAttentions,
+ BaseModelOutputWithPastAndCrossAttentions,
+ CausalLMOutputWithCrossAttentions,
+ MaskedLMOutput,
+ MultipleChoiceModelOutput,
+ QuestionAnsweringModelOutput,
+ SequenceClassifierOutput,
+ TokenClassifierOutput,
+)
+from ...modeling_utils import PreTrainedModel
+from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
+from ...utils import (
+ ModelOutput,
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from .configuration_electra import ElectraConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "google/electra-small-discriminator"
+_CONFIG_FOR_DOC = "ElectraConfig"
+
+
+def load_tf_weights_in_electra(model, config, tf_checkpoint_path, discriminator_or_generator="discriminator"):
+ """Load tf checkpoints in a pytorch model."""
+ try:
+ import re
+
+ import numpy as np
+ import tensorflow as tf
+ except ImportError:
+ logger.error(
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
+ "https://www.tensorflow.org/install/ for installation instructions."
+ )
+ raise
+ tf_path = os.path.abspath(tf_checkpoint_path)
+ logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
+ # Load weights from TF model
+ init_vars = tf.train.list_variables(tf_path)
+ names = []
+ arrays = []
+ for name, shape in init_vars:
+ logger.info(f"Loading TF weight {name} with shape {shape}")
+ array = tf.train.load_variable(tf_path, name)
+ names.append(name)
+ arrays.append(array)
+ for name, array in zip(names, arrays):
+ original_name: str = name
+
+ try:
+ if isinstance(model, ElectraForMaskedLM):
+ name = name.replace("electra/embeddings/", "generator/embeddings/")
+
+ if discriminator_or_generator == "generator":
+ name = name.replace("electra/", "discriminator/")
+ name = name.replace("generator/", "electra/")
+
+ name = name.replace("dense_1", "dense_prediction")
+ name = name.replace("generator_predictions/output_bias", "generator_lm_head/bias")
+
+ name = name.split("/")
+ # print(original_name, name)
+ # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
+ # which are not required for using pretrained model
+ if any(n in ["global_step", "temperature"] for n in name):
+ logger.info(f"Skipping {original_name}")
+ continue
+ pointer = model
+ for m_name in name:
+ if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
+ scope_names = re.split(r"_(\d+)", m_name)
+ else:
+ scope_names = [m_name]
+ if scope_names[0] == "kernel" or scope_names[0] == "gamma":
+ pointer = getattr(pointer, "weight")
+ elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
+ pointer = getattr(pointer, "bias")
+ elif scope_names[0] == "output_weights":
+ pointer = getattr(pointer, "weight")
+ elif scope_names[0] == "squad":
+ pointer = getattr(pointer, "classifier")
+ else:
+ pointer = getattr(pointer, scope_names[0])
+ if len(scope_names) >= 2:
+ num = int(scope_names[1])
+ pointer = pointer[num]
+ if m_name.endswith("_embeddings"):
+ pointer = getattr(pointer, "weight")
+ elif m_name == "kernel":
+ array = np.transpose(array)
+ try:
+ if pointer.shape != array.shape:
+ raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
+ except ValueError as e:
+ e.args += (pointer.shape, array.shape)
+ raise
+ print(f"Initialize PyTorch weight {name}", original_name)
+ pointer.data = torch.from_numpy(array)
+ except AttributeError as e:
+ print(f"Skipping {original_name}", name, e)
+ continue
+ return model
+
+
+class ElectraEmbeddings(nn.Module):
+ """Construct the embeddings from word, position and token_type embeddings."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id)
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.embedding_size)
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.embedding_size)
+
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
+ # any TensorFlow checkpoint file
+ self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+ self.register_buffer(
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
+ )
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
+ self.register_buffer(
+ "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
+ )
+
+ # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ past_key_values_length: int = 0,
+ ) -> torch.Tensor:
+ if input_ids is not None:
+ input_shape = input_ids.size()
+ else:
+ input_shape = inputs_embeds.size()[:-1]
+
+ seq_length = input_shape[1]
+
+ if position_ids is None:
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
+
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
+ # issue #5664
+ if token_type_ids is None:
+ if hasattr(self, "token_type_ids"):
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
+ token_type_ids = buffered_token_type_ids_expanded
+ else:
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.word_embeddings(input_ids)
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
+
+ embeddings = inputs_embeds + token_type_embeddings
+ if self.position_embedding_type == "absolute":
+ position_embeddings = self.position_embeddings(position_ids)
+ embeddings += position_embeddings
+ embeddings = self.LayerNorm(embeddings)
+ embeddings = self.dropout(embeddings)
+ return embeddings
+
+
+# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->Electra
+class ElectraSelfAttention(nn.Module):
+ def __init__(self, config, position_embedding_type=None):
+ super().__init__()
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+ raise ValueError(
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
+ f"heads ({config.num_attention_heads})"
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+ self.position_embedding_type = position_embedding_type or getattr(
+ config, "position_embedding_type", "absolute"
+ )
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+ self.max_position_embeddings = config.max_position_embeddings
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
+
+ self.is_decoder = config.is_decoder
+
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
+ x = x.view(new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor]:
+ mixed_query_layer = self.query(hidden_states)
+
+ # If this is instantiated as a cross-attention module, the keys
+ # and values come from an encoder; the attention mask needs to be
+ # such that the encoder's padding tokens are not attended to.
+ is_cross_attention = encoder_hidden_states is not None
+
+ if is_cross_attention and past_key_value is not None:
+ # reuse k,v, cross_attentions
+ key_layer = past_key_value[0]
+ value_layer = past_key_value[1]
+ attention_mask = encoder_attention_mask
+ elif is_cross_attention:
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
+ attention_mask = encoder_attention_mask
+ elif past_key_value is not None:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
+ else:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+
+ use_cache = past_key_value is not None
+ if self.is_decoder:
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_layer, value_layer)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+ query_length, key_length = query_layer.shape[2], key_layer.shape[2]
+ if use_cache:
+ position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
+ -1, 1
+ )
+ else:
+ position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
+ position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
+ distance = position_ids_l - position_ids_r
+
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
+
+ if self.position_embedding_type == "relative_key":
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+ attention_scores = attention_scores + relative_position_scores
+ elif self.position_embedding_type == "relative_key_query":
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
+
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+ if attention_mask is not None:
+ # Apply the attention mask is (precomputed for all layers in ElectraModel forward() function)
+ attention_scores = attention_scores + attention_mask
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs = attention_probs * head_mask
+
+ context_layer = torch.matmul(attention_probs, value_layer)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(new_context_layer_shape)
+
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+ if self.is_decoder:
+ outputs = outputs + (past_key_value,)
+ return outputs
+
+
+# Copied from transformers.models.bert.modeling_bert.BertSelfOutput
+class ElectraSelfOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+ELECTRA_SELF_ATTENTION_CLASSES = {
+ "eager": ElectraSelfAttention,
+}
+
+
+# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Electra,BERT->ELECTRA
+class ElectraAttention(nn.Module):
+ def __init__(self, config, position_embedding_type=None):
+ super().__init__()
+ self.self = ELECTRA_SELF_ATTENTION_CLASSES[config._attn_implementation](
+ config, position_embedding_type=position_embedding_type
+ )
+ self.output = ElectraSelfOutput(config)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
+ )
+
+ # Prune linear layers
+ self.self.query = prune_linear_layer(self.self.query, index)
+ self.self.key = prune_linear_layer(self.self.key, index)
+ self.self.value = prune_linear_layer(self.self.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor]:
+ self_outputs = self.self(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ )
+ attention_output = self.output(self_outputs[0], hidden_states)
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+ return outputs
+
+
+# Copied from transformers.models.bert.modeling_bert.BertIntermediate
+class ElectraIntermediate(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.bert.modeling_bert.BertOutput
+class ElectraOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Electra
+class ElectraLayer(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.seq_len_dim = 1
+ self.attention = ElectraAttention(config)
+ self.is_decoder = config.is_decoder
+ self.add_cross_attention = config.add_cross_attention
+ if self.add_cross_attention:
+ if not self.is_decoder:
+ raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
+ self.crossattention = ElectraAttention(config, position_embedding_type="absolute")
+ self.intermediate = ElectraIntermediate(config)
+ self.output = ElectraOutput(config)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor]:
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
+ self_attention_outputs = self.attention(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ output_attentions=output_attentions,
+ past_key_value=self_attn_past_key_value,
+ )
+ attention_output = self_attention_outputs[0]
+
+ # if decoder, the last output is tuple of self-attn cache
+ if self.is_decoder:
+ outputs = self_attention_outputs[1:-1]
+ present_key_value = self_attention_outputs[-1]
+ else:
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
+
+ cross_attn_present_key_value = None
+ if self.is_decoder and encoder_hidden_states is not None:
+ if not hasattr(self, "crossattention"):
+ raise ValueError(
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
+ " by setting `config.add_cross_attention=True`"
+ )
+
+ # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
+ cross_attention_outputs = self.crossattention(
+ attention_output,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ cross_attn_past_key_value,
+ output_attentions,
+ )
+ attention_output = cross_attention_outputs[0]
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
+
+ # add cross-attn cache to positions 3,4 of present_key_value tuple
+ cross_attn_present_key_value = cross_attention_outputs[-1]
+ present_key_value = present_key_value + cross_attn_present_key_value
+
+ layer_output = apply_chunking_to_forward(
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
+ )
+ outputs = (layer_output,) + outputs
+
+ # if decoder, return the attn key/values as the last output
+ if self.is_decoder:
+ outputs = outputs + (present_key_value,)
+
+ return outputs
+
+ def feed_forward_chunk(self, attention_output):
+ intermediate_output = self.intermediate(attention_output)
+ layer_output = self.output(intermediate_output, attention_output)
+ return layer_output
+
+
+# Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->Electra
+class ElectraEncoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.layer = nn.ModuleList([ElectraLayer(config) for _ in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = False,
+ output_hidden_states: Optional[bool] = False,
+ return_dict: Optional[bool] = True,
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ next_decoder_cache = () if use_cache else None
+ for i, layer_module in enumerate(self.layer):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+ past_key_value = past_key_values[i] if past_key_values is not None else None
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ layer_module.__call__,
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ )
+ else:
+ layer_outputs = layer_module(
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ )
+
+ hidden_states = layer_outputs[0]
+ if use_cache:
+ next_decoder_cache += (layer_outputs[-1],)
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+ if self.config.add_cross_attention:
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [
+ hidden_states,
+ next_decoder_cache,
+ all_hidden_states,
+ all_self_attentions,
+ all_cross_attentions,
+ ]
+ if v is not None
+ )
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=next_decoder_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+class ElectraDiscriminatorPredictions(nn.Module):
+ """Prediction module for the discriminator, made up of two dense layers."""
+
+ def __init__(self, config):
+ super().__init__()
+
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.activation = get_activation(config.hidden_act)
+ self.dense_prediction = nn.Linear(config.hidden_size, 1)
+ self.config = config
+
+ def forward(self, discriminator_hidden_states):
+ hidden_states = self.dense(discriminator_hidden_states)
+ hidden_states = self.activation(hidden_states)
+ logits = self.dense_prediction(hidden_states).squeeze(-1)
+
+ return logits
+
+
+class ElectraGeneratorPredictions(nn.Module):
+ """Prediction module for the generator, made up of two dense layers."""
+
+ def __init__(self, config):
+ super().__init__()
+
+ self.activation = get_activation("gelu")
+ self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
+ self.dense = nn.Linear(config.hidden_size, config.embedding_size)
+
+ def forward(self, generator_hidden_states):
+ hidden_states = self.dense(generator_hidden_states)
+ hidden_states = self.activation(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states)
+
+ return hidden_states
+
+
+class ElectraPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = ElectraConfig
+ load_tf_weights = load_tf_weights_in_electra
+ base_model_prefix = "electra"
+ supports_gradient_checkpointing = True
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, nn.Linear):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+
+@dataclass
+class ElectraForPreTrainingOutput(ModelOutput):
+ """
+ Output type of [`ElectraForPreTraining`].
+
+ Args:
+ loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
+ Total loss of the ELECTRA objective.
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
+ Prediction scores of the head (scores for each token before SoftMax).
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
+ shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+
+
+ELECTRA_START_DOCSTRING = r"""
+
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`ElectraConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+ELECTRA_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `({0})`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
+ 1]`:
+
+ - 0 corresponds to a *sentence A* token,
+ - 1 corresponds to a *sentence B* token.
+
+ [What are token type IDs?](../glossary#token-type-ids)
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ encoder_hidden_states (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+ the model is configured as a decoder.
+ encoder_attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare Electra Model transformer outputting raw hidden-states without any specific head on top. Identical to "
+ "the BERT model except that it uses an additional linear layer between the embedding layer and the encoder if the "
+ "hidden size and embedding size are different. "
+ ""
+ "Both the generator and discriminator checkpoints may be loaded into this model.",
+ ELECTRA_START_DOCSTRING,
+)
+class ElectraModel(ElectraPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.embeddings = ElectraEmbeddings(config)
+
+ if config.embedding_size != config.hidden_size:
+ self.embeddings_project = nn.Linear(config.embedding_size, config.hidden_size)
+
+ self.encoder = ElectraEncoder(config)
+ self.config = config
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embeddings.word_embeddings
+
+ def set_input_embeddings(self, value):
+ self.embeddings.word_embeddings = value
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=BaseModelOutputWithCrossAttentions,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithCrossAttentions]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
+ input_shape = input_ids.size()
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ batch_size, seq_length = input_shape
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+
+ # past_key_values_length
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
+
+ if attention_mask is None:
+ attention_mask = torch.ones(input_shape, device=device)
+ if token_type_ids is None:
+ if hasattr(self.embeddings, "token_type_ids"):
+ buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
+ token_type_ids = buffered_token_type_ids_expanded
+ else:
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
+
+ extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
+
+ # If a 2D or 3D attention mask is provided for the cross-attention
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if self.config.is_decoder and encoder_hidden_states is not None:
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
+ if encoder_attention_mask is None:
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+ else:
+ encoder_extended_attention_mask = None
+
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ hidden_states = self.embeddings(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ token_type_ids=token_type_ids,
+ inputs_embeds=inputs_embeds,
+ past_key_values_length=past_key_values_length,
+ )
+
+ if hasattr(self, "embeddings_project"):
+ hidden_states = self.embeddings_project(hidden_states)
+
+ hidden_states = self.encoder(
+ hidden_states,
+ attention_mask=extended_attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_extended_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ return hidden_states
+
+
+class ElectraClassificationHead(nn.Module):
+ """Head for sentence-level classification tasks."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ classifier_dropout = (
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
+ )
+ self.activation = get_activation("gelu")
+ self.dropout = nn.Dropout(classifier_dropout)
+ self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
+
+ def forward(self, features, **kwargs):
+ x = features[:, 0, :] # take token (equiv. to [CLS])
+ x = self.dropout(x)
+ x = self.dense(x)
+ x = self.activation(x) # although BERT uses tanh here, it seems Electra authors used gelu here
+ x = self.dropout(x)
+ x = self.out_proj(x)
+ return x
+
+
+# Copied from transformers.models.xlm.modeling_xlm.XLMSequenceSummary with XLM->Electra
+class ElectraSequenceSummary(nn.Module):
+ r"""
+ Compute a single vector summary of a sequence hidden states.
+
+ Args:
+ config ([`ElectraConfig`]):
+ The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
+ config class of your model for the default values it uses):
+
+ - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are:
+
+ - `"last"` -- Take the last token hidden state (like XLNet)
+ - `"first"` -- Take the first token hidden state (like Bert)
+ - `"mean"` -- Take the mean of all tokens hidden states
+ - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
+ - `"attn"` -- Not implemented now, use multi-head attention
+
+ - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction.
+ - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes
+ (otherwise to `config.hidden_size`).
+ - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output,
+ another string or `None` will add no activation.
+ - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation.
+ - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation.
+ """
+
+ def __init__(self, config: ElectraConfig):
+ super().__init__()
+
+ self.summary_type = getattr(config, "summary_type", "last")
+ if self.summary_type == "attn":
+ # We should use a standard multi-head attention module with absolute positional embedding for that.
+ # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
+ # We can probably just use the multi-head attention module of PyTorch >=1.1.0
+ raise NotImplementedError
+
+ self.summary = nn.Identity()
+ if hasattr(config, "summary_use_proj") and config.summary_use_proj:
+ if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
+ num_classes = config.num_labels
+ else:
+ num_classes = config.hidden_size
+ self.summary = nn.Linear(config.hidden_size, num_classes)
+
+ activation_string = getattr(config, "summary_activation", None)
+ self.activation: Callable = get_activation(activation_string) if activation_string else nn.Identity()
+
+ self.first_dropout = nn.Identity()
+ if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
+ self.first_dropout = nn.Dropout(config.summary_first_dropout)
+
+ self.last_dropout = nn.Identity()
+ if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
+ self.last_dropout = nn.Dropout(config.summary_last_dropout)
+
+ def forward(
+ self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None
+ ) -> torch.FloatTensor:
+ """
+ Compute a single vector summary of a sequence hidden states.
+
+ Args:
+ hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`):
+ The hidden states of the last layer.
+ cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*):
+ Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token.
+
+ Returns:
+ `torch.FloatTensor`: The summary of the sequence hidden states.
+ """
+ if self.summary_type == "last":
+ output = hidden_states[:, -1]
+ elif self.summary_type == "first":
+ output = hidden_states[:, 0]
+ elif self.summary_type == "mean":
+ output = hidden_states.mean(dim=1)
+ elif self.summary_type == "cls_index":
+ if cls_index is None:
+ cls_index = torch.full_like(
+ hidden_states[..., :1, :],
+ hidden_states.shape[-2] - 1,
+ dtype=torch.long,
+ )
+ else:
+ cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
+ cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
+ # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
+ output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
+ elif self.summary_type == "attn":
+ raise NotImplementedError
+
+ output = self.first_dropout(output)
+ output = self.summary(output)
+ output = self.activation(output)
+ output = self.last_dropout(output)
+
+ return output
+
+
+@add_start_docstrings(
+ """
+ ELECTRA Model transformer with a sequence classification/regression head on top (a linear layer on top of the
+ pooled output) e.g. for GLUE tasks.
+ """,
+ ELECTRA_START_DOCSTRING,
+)
+class ElectraForSequenceClassification(ElectraPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.config = config
+ self.electra = ElectraModel(config)
+ self.classifier = ElectraClassificationHead(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint="bhadresh-savani/electra-base-emotion",
+ output_type=SequenceClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output="'joy'",
+ expected_loss=0.06,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ discriminator_hidden_states = self.electra(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = discriminator_hidden_states[0]
+ logits = self.classifier(sequence_output)
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+
+ if not return_dict:
+ output = (logits,) + discriminator_hidden_states[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=discriminator_hidden_states.hidden_states,
+ attentions=discriminator_hidden_states.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ Electra model with a binary classification head on top as used during pretraining for identifying generated tokens.
+
+ It is recommended to load the discriminator checkpoint into that model.
+ """,
+ ELECTRA_START_DOCSTRING,
+)
+class ElectraForPreTraining(ElectraPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.electra = ElectraModel(config)
+ self.discriminator_predictions = ElectraDiscriminatorPredictions(config)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @replace_return_docstrings(output_type=ElectraForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], ElectraForPreTrainingOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the ELECTRA loss. Input should be a sequence of tokens (see `input_ids` docstring)
+ Indices should be in `[0, 1]`:
+
+ - 0 indicates the token is an original token,
+ - 1 indicates the token was replaced.
+
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import ElectraForPreTraining, AutoTokenizer
+ >>> import torch
+
+ >>> discriminator = ElectraForPreTraining.from_pretrained("google/electra-base-discriminator")
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/electra-base-discriminator")
+
+ >>> sentence = "The quick brown fox jumps over the lazy dog"
+ >>> fake_sentence = "The quick brown fox fake over the lazy dog"
+
+ >>> fake_tokens = tokenizer.tokenize(fake_sentence, add_special_tokens=True)
+ >>> fake_inputs = tokenizer.encode(fake_sentence, return_tensors="pt")
+ >>> discriminator_outputs = discriminator(fake_inputs)
+ >>> predictions = torch.round((torch.sign(discriminator_outputs[0]) + 1) / 2)
+
+ >>> fake_tokens
+ ['[CLS]', 'the', 'quick', 'brown', 'fox', 'fake', 'over', 'the', 'lazy', 'dog', '[SEP]']
+
+ >>> predictions.squeeze().tolist()
+ [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0]
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ discriminator_hidden_states = self.electra(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ discriminator_sequence_output = discriminator_hidden_states[0]
+
+ logits = self.discriminator_predictions(discriminator_sequence_output)
+
+ loss = None
+ if labels is not None:
+ loss_fct = nn.BCEWithLogitsLoss()
+ if attention_mask is not None:
+ active_loss = attention_mask.view(-1, discriminator_sequence_output.shape[1]) == 1
+ active_logits = logits.view(-1, discriminator_sequence_output.shape[1])[active_loss]
+ active_labels = labels[active_loss]
+ loss = loss_fct(active_logits, active_labels.float())
+ else:
+ loss = loss_fct(logits.view(-1, discriminator_sequence_output.shape[1]), labels.float())
+
+ if not return_dict:
+ output = (logits,) + discriminator_hidden_states[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return ElectraForPreTrainingOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=discriminator_hidden_states.hidden_states,
+ attentions=discriminator_hidden_states.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ Electra model with a language modeling head on top.
+
+ Even though both the discriminator and generator may be loaded into this model, the generator is the only model of
+ the two to have been trained for the masked language modeling task.
+ """,
+ ELECTRA_START_DOCSTRING,
+)
+class ElectraForMaskedLM(ElectraPreTrainedModel):
+ _tied_weights_keys = ["generator_lm_head.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.electra = ElectraModel(config)
+ self.generator_predictions = ElectraGeneratorPredictions(config)
+
+ self.generator_lm_head = nn.Linear(config.embedding_size, config.vocab_size)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_output_embeddings(self):
+ return self.generator_lm_head
+
+ def set_output_embeddings(self, word_embeddings):
+ self.generator_lm_head = word_embeddings
+
+ @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint="google/electra-small-generator",
+ output_type=MaskedLMOutput,
+ config_class=_CONFIG_FOR_DOC,
+ mask="[MASK]",
+ expected_output="'paris'",
+ expected_loss=1.22,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ generator_hidden_states = self.electra(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ generator_sequence_output = generator_hidden_states[0]
+
+ prediction_scores = self.generator_predictions(generator_sequence_output)
+ prediction_scores = self.generator_lm_head(prediction_scores)
+
+ loss = None
+ # Masked language modeling softmax layer
+ if labels is not None:
+ loss_fct = nn.CrossEntropyLoss() # -100 index = padding token
+ loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
+
+ if not return_dict:
+ output = (prediction_scores,) + generator_hidden_states[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return MaskedLMOutput(
+ loss=loss,
+ logits=prediction_scores,
+ hidden_states=generator_hidden_states.hidden_states,
+ attentions=generator_hidden_states.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ Electra model with a token classification head on top.
+
+ Both the discriminator and generator may be loaded into this model.
+ """,
+ ELECTRA_START_DOCSTRING,
+)
+class ElectraForTokenClassification(ElectraPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.electra = ElectraModel(config)
+ classifier_dropout = (
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
+ )
+ self.dropout = nn.Dropout(classifier_dropout)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint="bhadresh-savani/electra-base-discriminator-finetuned-conll03-english",
+ output_type=TokenClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output="['B-LOC', 'B-ORG', 'O', 'O', 'O', 'O', 'O', 'B-LOC', 'O', 'B-LOC', 'I-LOC']",
+ expected_loss=0.11,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ discriminator_hidden_states = self.electra(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ discriminator_sequence_output = discriminator_hidden_states[0]
+
+ discriminator_sequence_output = self.dropout(discriminator_sequence_output)
+ logits = self.classifier(discriminator_sequence_output)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + discriminator_hidden_states[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=discriminator_hidden_states.hidden_states,
+ attentions=discriminator_hidden_states.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ ELECTRA Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
+ layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
+ """,
+ ELECTRA_START_DOCSTRING,
+)
+class ElectraForQuestionAnswering(ElectraPreTrainedModel):
+ config_class = ElectraConfig
+ base_model_prefix = "electra"
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.electra = ElectraModel(config)
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint="bhadresh-savani/electra-base-squad2",
+ output_type=QuestionAnsweringModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ qa_target_start_index=11,
+ qa_target_end_index=12,
+ expected_output="'a nice puppet'",
+ expected_loss=2.64,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ start_positions: Optional[torch.Tensor] = None,
+ end_positions: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
+ r"""
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
+ are not taken into account for computing the loss.
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
+ are not taken into account for computing the loss.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ discriminator_hidden_states = self.electra(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+
+ sequence_output = discriminator_hidden_states[0]
+
+ logits = self.qa_outputs(sequence_output)
+ start_logits, end_logits = logits.split(1, dim=-1)
+ start_logits = start_logits.squeeze(-1).contiguous()
+ end_logits = end_logits.squeeze(-1).contiguous()
+
+ total_loss = None
+ if start_positions is not None and end_positions is not None:
+ # If we are on multi-GPU, split add a dimension
+ if len(start_positions.size()) > 1:
+ start_positions = start_positions.squeeze(-1)
+ if len(end_positions.size()) > 1:
+ end_positions = end_positions.squeeze(-1)
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
+ ignored_index = start_logits.size(1)
+ start_positions = start_positions.clamp(0, ignored_index)
+ end_positions = end_positions.clamp(0, ignored_index)
+
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
+ start_loss = loss_fct(start_logits, start_positions)
+ end_loss = loss_fct(end_logits, end_positions)
+ total_loss = (start_loss + end_loss) / 2
+
+ if not return_dict:
+ output = (
+ start_logits,
+ end_logits,
+ ) + discriminator_hidden_states[1:]
+ return ((total_loss,) + output) if total_loss is not None else output
+
+ return QuestionAnsweringModelOutput(
+ loss=total_loss,
+ start_logits=start_logits,
+ end_logits=end_logits,
+ hidden_states=discriminator_hidden_states.hidden_states,
+ attentions=discriminator_hidden_states.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ ELECTRA Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
+ softmax) e.g. for RocStories/SWAG tasks.
+ """,
+ ELECTRA_START_DOCSTRING,
+)
+class ElectraForMultipleChoice(ElectraPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.electra = ElectraModel(config)
+ self.sequence_summary = ElectraSequenceSummary(config)
+ self.classifier = nn.Linear(config.hidden_size, 1)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=MultipleChoiceModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
+ num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
+ `input_ids` above)
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
+
+ input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
+ token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
+ position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
+ inputs_embeds = (
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
+ if inputs_embeds is not None
+ else None
+ )
+
+ discriminator_hidden_states = self.electra(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = discriminator_hidden_states[0]
+
+ pooled_output = self.sequence_summary(sequence_output)
+ logits = self.classifier(pooled_output)
+ reshaped_logits = logits.view(-1, num_choices)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(reshaped_logits, labels)
+
+ if not return_dict:
+ output = (reshaped_logits,) + discriminator_hidden_states[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return MultipleChoiceModelOutput(
+ loss=loss,
+ logits=reshaped_logits,
+ hidden_states=discriminator_hidden_states.hidden_states,
+ attentions=discriminator_hidden_states.attentions,
+ )
+
+
+@add_start_docstrings(
+ """ELECTRA Model with a `language modeling` head on top for CLM fine-tuning.""", ELECTRA_START_DOCSTRING
+)
+class ElectraForCausalLM(ElectraPreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["generator_lm_head.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ if not config.is_decoder:
+ logger.warning("If you want to use `ElectraForCausalLM` as a standalone, add `is_decoder=True.`")
+
+ self.electra = ElectraModel(config)
+ self.generator_predictions = ElectraGeneratorPredictions(config)
+ self.generator_lm_head = nn.Linear(config.embedding_size, config.vocab_size)
+
+ self.init_weights()
+
+ def get_output_embeddings(self):
+ return self.generator_lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.generator_lm_head = new_embeddings
+
+ @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ past_key_values: Optional[List[torch.Tensor]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ **kwargs,
+ ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
+ r"""
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+ the model is configured as a decoder.
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
+ `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
+ ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, ElectraForCausalLM, ElectraConfig
+ >>> import torch
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/electra-base-generator")
+ >>> config = ElectraConfig.from_pretrained("google/electra-base-generator")
+ >>> config.is_decoder = True
+ >>> model = ElectraForCausalLM.from_pretrained("google/electra-base-generator", config=config)
+
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
+ >>> outputs = model(**inputs)
+
+ >>> prediction_logits = outputs.logits
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ if labels is not None:
+ use_cache = False
+
+ outputs = self.electra(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+ prediction_scores = self.generator_lm_head(self.generator_predictions(sequence_output))
+
+ lm_loss = None
+ if labels is not None:
+ lm_loss = self.loss_function(
+ prediction_scores,
+ labels,
+ vocab_size=self.config.vocab_size,
+ **kwargs,
+ )
+
+ if not return_dict:
+ output = (prediction_scores,) + outputs[1:]
+ return ((lm_loss,) + output) if lm_loss is not None else output
+
+ return CausalLMOutputWithCrossAttentions(
+ loss=lm_loss,
+ logits=prediction_scores,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ cross_attentions=outputs.cross_attentions,
+ )
+
+ # Copied from transformers.models.roberta.modeling_roberta.RobertaForCausalLM._reorder_cache
+ def _reorder_cache(self, past_key_values, beam_idx):
+ reordered_past = ()
+ for layer_past in past_key_values:
+ reordered_past += (
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
+ )
+ return reordered_past
+
+
+__all__ = [
+ "ElectraForCausalLM",
+ "ElectraForMaskedLM",
+ "ElectraForMultipleChoice",
+ "ElectraForPreTraining",
+ "ElectraForQuestionAnswering",
+ "ElectraForSequenceClassification",
+ "ElectraForTokenClassification",
+ "ElectraModel",
+ "ElectraPreTrainedModel",
+ "load_tf_weights_in_electra",
+]
diff --git a/docs/transformers/build/lib/transformers/models/electra/modeling_flax_electra.py b/docs/transformers/build/lib/transformers/models/electra/modeling_flax_electra.py
new file mode 100644
index 0000000000000000000000000000000000000000..77a445e6ccaa7ce3294655894be874bc83300a38
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/electra/modeling_flax_electra.py
@@ -0,0 +1,1614 @@
+# coding=utf-8
+# Copyright 2021 The Google Flax Team Authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Callable, Optional, Tuple
+
+import flax
+import flax.linen as nn
+import jax
+import jax.numpy as jnp
+import numpy as np
+from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
+from flax.linen import combine_masks, make_causal_mask
+from flax.linen import partitioning as nn_partitioning
+from flax.linen.attention import dot_product_attention_weights
+from flax.traverse_util import flatten_dict, unflatten_dict
+from jax import lax
+
+from ...modeling_flax_outputs import (
+ FlaxBaseModelOutput,
+ FlaxBaseModelOutputWithPastAndCrossAttentions,
+ FlaxCausalLMOutputWithCrossAttentions,
+ FlaxMaskedLMOutput,
+ FlaxMultipleChoiceModelOutput,
+ FlaxQuestionAnsweringModelOutput,
+ FlaxSequenceClassifierOutput,
+ FlaxTokenClassifierOutput,
+)
+from ...modeling_flax_utils import (
+ ACT2FN,
+ FlaxPreTrainedModel,
+ append_call_sample_docstring,
+ append_replace_return_docstrings,
+ overwrite_call_docstring,
+)
+from ...utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging
+from .configuration_electra import ElectraConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "google/electra-small-discriminator"
+_CONFIG_FOR_DOC = "ElectraConfig"
+
+remat = nn_partitioning.remat
+
+
+@flax.struct.dataclass
+class FlaxElectraForPreTrainingOutput(ModelOutput):
+ """
+ Output type of [`ElectraForPreTraining`].
+
+ Args:
+ logits (`jnp.ndarray` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ logits: jnp.ndarray = None
+ hidden_states: Optional[Tuple[jnp.ndarray]] = None
+ attentions: Optional[Tuple[jnp.ndarray]] = None
+
+
+ELECTRA_START_DOCSTRING = r"""
+
+ This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading, saving and converting weights from PyTorch models)
+
+ This model is also a Flax Linen
+ [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
+ regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
+
+ Finally, this model supports inherent JAX features such as:
+
+ - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
+ - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
+ - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
+ - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
+
+ Parameters:
+ config ([`ElectraConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+ELECTRA_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`numpy.ndarray` of shape `({0})`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`numpy.ndarray` of shape `({0})`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ token_type_ids (`numpy.ndarray` of shape `({0})`, *optional*):
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
+ 1]`:
+
+ - 0 corresponds to a *sentence A* token,
+ - 1 corresponds to a *sentence B* token.
+
+ [What are token type IDs?](../glossary#token-type-ids)
+ position_ids (`numpy.ndarray` of shape `({0})`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+ head_mask (`numpy.ndarray` of shape `({0})`, `optional):
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+
+"""
+
+
+class FlaxElectraEmbeddings(nn.Module):
+ """Construct the embeddings from word, position and token_type embeddings."""
+
+ config: ElectraConfig
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.word_embeddings = nn.Embed(
+ self.config.vocab_size,
+ self.config.embedding_size,
+ embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
+ )
+ self.position_embeddings = nn.Embed(
+ self.config.max_position_embeddings,
+ self.config.embedding_size,
+ embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
+ )
+ self.token_type_embeddings = nn.Embed(
+ self.config.type_vocab_size,
+ self.config.embedding_size,
+ embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
+ )
+ self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
+ self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
+
+ # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEmbeddings.__call__
+ def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True):
+ # Embed
+ inputs_embeds = self.word_embeddings(input_ids.astype("i4"))
+ position_embeds = self.position_embeddings(position_ids.astype("i4"))
+ token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4"))
+
+ # Sum all embeddings
+ hidden_states = inputs_embeds + token_type_embeddings + position_embeds
+
+ # Layer Norm
+ hidden_states = self.LayerNorm(hidden_states)
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
+ return hidden_states
+
+
+# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfAttention with Bert->Electra
+class FlaxElectraSelfAttention(nn.Module):
+ config: ElectraConfig
+ causal: bool = False
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.head_dim = self.config.hidden_size // self.config.num_attention_heads
+ if self.config.hidden_size % self.config.num_attention_heads != 0:
+ raise ValueError(
+ "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads` "
+ " : {self.config.num_attention_heads}"
+ )
+
+ self.query = nn.Dense(
+ self.config.hidden_size,
+ dtype=self.dtype,
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
+ )
+ self.key = nn.Dense(
+ self.config.hidden_size,
+ dtype=self.dtype,
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
+ )
+ self.value = nn.Dense(
+ self.config.hidden_size,
+ dtype=self.dtype,
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
+ )
+
+ if self.causal:
+ self.causal_mask = make_causal_mask(
+ jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool"
+ )
+
+ def _split_heads(self, hidden_states):
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.config.num_attention_heads, self.head_dim))
+
+ def _merge_heads(self, hidden_states):
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.config.hidden_size,))
+
+ @nn.compact
+ # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention._concatenate_to_cache
+ def _concatenate_to_cache(self, key, value, query, attention_mask):
+ """
+ This function takes projected key, value states from a single input token and concatenates the states to cached
+ states from previous steps. This function is slightly adapted from the official Flax repository:
+ https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
+ """
+ # detect if we're initializing by absence of existing cache data.
+ is_initialized = self.has_variable("cache", "cached_key")
+ cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
+ cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
+ cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
+
+ if is_initialized:
+ *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
+ # update key, value caches with our new 1d spatial slices
+ cur_index = cache_index.value
+ indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
+ key = lax.dynamic_update_slice(cached_key.value, key, indices)
+ value = lax.dynamic_update_slice(cached_value.value, value, indices)
+ cached_key.value = key
+ cached_value.value = value
+ num_updated_cache_vectors = query.shape[1]
+ cache_index.value = cache_index.value + num_updated_cache_vectors
+ # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
+ pad_mask = jnp.broadcast_to(
+ jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
+ tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
+ )
+ attention_mask = combine_masks(pad_mask, attention_mask)
+ return key, value, attention_mask
+
+ def __call__(
+ self,
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ key_value_states: Optional[jnp.ndarray] = None,
+ init_cache: bool = False,
+ deterministic=True,
+ output_attentions: bool = False,
+ ):
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+ batch_size = hidden_states.shape[0]
+
+ # get query proj
+ query_states = self.query(hidden_states)
+ # get key, value proj
+ if is_cross_attention:
+ # cross_attentions
+ key_states = self.key(key_value_states)
+ value_states = self.value(key_value_states)
+ else:
+ # self_attention
+ key_states = self.key(hidden_states)
+ value_states = self.value(hidden_states)
+
+ query_states = self._split_heads(query_states)
+ key_states = self._split_heads(key_states)
+ value_states = self._split_heads(value_states)
+
+ # handle cache prepare causal attention mask
+ if self.causal:
+ query_length, key_length = query_states.shape[1], key_states.shape[1]
+ if self.has_variable("cache", "cached_key"):
+ mask_shift = self.variables["cache"]["cache_index"]
+ max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
+ causal_mask = lax.dynamic_slice(
+ self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
+ )
+ else:
+ causal_mask = self.causal_mask[:, :, :query_length, :key_length]
+ causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
+
+ # combine masks if needed
+ if attention_mask is not None and self.causal:
+ attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
+ attention_mask = combine_masks(attention_mask, causal_mask)
+ elif self.causal:
+ attention_mask = causal_mask
+ elif attention_mask is not None:
+ attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
+
+ # During fast autoregressive decoding, we feed one position at a time,
+ # and cache the keys and values step by step.
+ if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
+ key_states, value_states, attention_mask = self._concatenate_to_cache(
+ key_states, value_states, query_states, attention_mask
+ )
+
+ # Convert the boolean attention mask to an attention bias.
+ if attention_mask is not None:
+ # attention mask in the form of attention bias
+ attention_bias = lax.select(
+ attention_mask > 0,
+ jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
+ jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
+ )
+ else:
+ attention_bias = None
+
+ dropout_rng = None
+ if not deterministic and self.config.attention_probs_dropout_prob > 0.0:
+ dropout_rng = self.make_rng("dropout")
+
+ attn_weights = dot_product_attention_weights(
+ query_states,
+ key_states,
+ bias=attention_bias,
+ dropout_rng=dropout_rng,
+ dropout_rate=self.config.attention_probs_dropout_prob,
+ broadcast_dropout=True,
+ deterministic=deterministic,
+ dtype=self.dtype,
+ precision=None,
+ )
+
+ # Mask heads if we want to
+ if layer_head_mask is not None:
+ attn_weights = jnp.einsum("...hqk,h->...hqk", attn_weights, layer_head_mask)
+
+ attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
+ attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,))
+
+ outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
+ return outputs
+
+
+# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfOutput with Bert->Electra
+class FlaxElectraSelfOutput(nn.Module):
+ config: ElectraConfig
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.dense = nn.Dense(
+ self.config.hidden_size,
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
+ dtype=self.dtype,
+ )
+ self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
+ self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
+
+ def __call__(self, hidden_states, input_tensor, deterministic: bool = True):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertAttention with Bert->Electra
+class FlaxElectraAttention(nn.Module):
+ config: ElectraConfig
+ causal: bool = False
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ self.self = FlaxElectraSelfAttention(self.config, causal=self.causal, dtype=self.dtype)
+ self.output = FlaxElectraSelfOutput(self.config, dtype=self.dtype)
+
+ def __call__(
+ self,
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ key_value_states=None,
+ init_cache=False,
+ deterministic=True,
+ output_attentions: bool = False,
+ ):
+ # Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
+ # FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
+ # with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
+ attn_outputs = self.self(
+ hidden_states,
+ attention_mask,
+ layer_head_mask=layer_head_mask,
+ key_value_states=key_value_states,
+ init_cache=init_cache,
+ deterministic=deterministic,
+ output_attentions=output_attentions,
+ )
+ attn_output = attn_outputs[0]
+ hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic)
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (attn_outputs[1],)
+
+ return outputs
+
+
+# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertIntermediate with Bert->Electra
+class FlaxElectraIntermediate(nn.Module):
+ config: ElectraConfig
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.dense = nn.Dense(
+ self.config.intermediate_size,
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
+ dtype=self.dtype,
+ )
+ self.activation = ACT2FN[self.config.hidden_act]
+
+ def __call__(self, hidden_states):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.activation(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertOutput with Bert->Electra
+class FlaxElectraOutput(nn.Module):
+ config: ElectraConfig
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.dense = nn.Dense(
+ self.config.hidden_size,
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
+ dtype=self.dtype,
+ )
+ self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
+ self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
+
+ def __call__(self, hidden_states, attention_output, deterministic: bool = True):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
+ hidden_states = self.LayerNorm(hidden_states + attention_output)
+ return hidden_states
+
+
+# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayer with Bert->Electra
+class FlaxElectraLayer(nn.Module):
+ config: ElectraConfig
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+
+ def setup(self):
+ self.attention = FlaxElectraAttention(self.config, causal=self.config.is_decoder, dtype=self.dtype)
+ self.intermediate = FlaxElectraIntermediate(self.config, dtype=self.dtype)
+ self.output = FlaxElectraOutput(self.config, dtype=self.dtype)
+ if self.config.add_cross_attention:
+ self.crossattention = FlaxElectraAttention(self.config, causal=False, dtype=self.dtype)
+
+ def __call__(
+ self,
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
+ init_cache: bool = False,
+ deterministic: bool = True,
+ output_attentions: bool = False,
+ ):
+ # Self Attention
+ attention_outputs = self.attention(
+ hidden_states,
+ attention_mask,
+ layer_head_mask=layer_head_mask,
+ init_cache=init_cache,
+ deterministic=deterministic,
+ output_attentions=output_attentions,
+ )
+ attention_output = attention_outputs[0]
+
+ # Cross-Attention Block
+ if encoder_hidden_states is not None:
+ cross_attention_outputs = self.crossattention(
+ attention_output,
+ attention_mask=encoder_attention_mask,
+ layer_head_mask=layer_head_mask,
+ key_value_states=encoder_hidden_states,
+ deterministic=deterministic,
+ output_attentions=output_attentions,
+ )
+ attention_output = cross_attention_outputs[0]
+
+ hidden_states = self.intermediate(attention_output)
+ hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic)
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (attention_outputs[1],)
+ if encoder_hidden_states is not None:
+ outputs += (cross_attention_outputs[1],)
+ return outputs
+
+
+# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayerCollection with Bert->Electra
+class FlaxElectraLayerCollection(nn.Module):
+ config: ElectraConfig
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+ gradient_checkpointing: bool = False
+
+ def setup(self):
+ if self.gradient_checkpointing:
+ FlaxElectraCheckpointLayer = remat(FlaxElectraLayer, static_argnums=(5, 6, 7))
+ self.layers = [
+ FlaxElectraCheckpointLayer(self.config, name=str(i), dtype=self.dtype)
+ for i in range(self.config.num_hidden_layers)
+ ]
+ else:
+ self.layers = [
+ FlaxElectraLayer(self.config, name=str(i), dtype=self.dtype)
+ for i in range(self.config.num_hidden_layers)
+ ]
+
+ def __call__(
+ self,
+ hidden_states,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
+ init_cache: bool = False,
+ deterministic: bool = True,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ):
+ all_attentions = () if output_attentions else None
+ all_hidden_states = () if output_hidden_states else None
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
+
+ # Check if head_mask has a correct number of layers specified if desired
+ if head_mask is not None:
+ if head_mask.shape[0] != (len(self.layers)):
+ raise ValueError(
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for "
+ f" {head_mask.shape[0]}."
+ )
+
+ for i, layer in enumerate(self.layers):
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ layer_outputs = layer(
+ hidden_states,
+ attention_mask,
+ head_mask[i] if head_mask is not None else None,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ init_cache,
+ deterministic,
+ output_attentions,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_attentions += (layer_outputs[1],)
+
+ if encoder_hidden_states is not None:
+ all_cross_attentions += (layer_outputs[2],)
+
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ outputs = (hidden_states, all_hidden_states, all_attentions, all_cross_attentions)
+
+ if not return_dict:
+ return tuple(v for v in outputs if v is not None)
+
+ return FlaxBaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_attentions,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEncoder with Bert->Electra
+class FlaxElectraEncoder(nn.Module):
+ config: ElectraConfig
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+ gradient_checkpointing: bool = False
+
+ def setup(self):
+ self.layer = FlaxElectraLayerCollection(
+ self.config,
+ dtype=self.dtype,
+ gradient_checkpointing=self.gradient_checkpointing,
+ )
+
+ def __call__(
+ self,
+ hidden_states,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
+ init_cache: bool = False,
+ deterministic: bool = True,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ):
+ return self.layer(
+ hidden_states,
+ attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ init_cache=init_cache,
+ deterministic=deterministic,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+
+class FlaxElectraGeneratorPredictions(nn.Module):
+ config: ElectraConfig
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
+ self.dense = nn.Dense(self.config.embedding_size, dtype=self.dtype)
+
+ def __call__(self, hidden_states):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = ACT2FN[self.config.hidden_act](hidden_states)
+ hidden_states = self.LayerNorm(hidden_states)
+ return hidden_states
+
+
+class FlaxElectraDiscriminatorPredictions(nn.Module):
+ """Prediction module for the discriminator, made up of two dense layers."""
+
+ config: ElectraConfig
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ self.dense = nn.Dense(self.config.hidden_size, dtype=self.dtype)
+ self.dense_prediction = nn.Dense(1, dtype=self.dtype)
+
+ def __call__(self, hidden_states):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = ACT2FN[self.config.hidden_act](hidden_states)
+ hidden_states = self.dense_prediction(hidden_states).squeeze(-1)
+ return hidden_states
+
+
+class FlaxElectraPreTrainedModel(FlaxPreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = ElectraConfig
+ base_model_prefix = "electra"
+ module_class: nn.Module = None
+
+ def __init__(
+ self,
+ config: ElectraConfig,
+ input_shape: Tuple = (1, 1),
+ seed: int = 0,
+ dtype: jnp.dtype = jnp.float32,
+ _do_init: bool = True,
+ gradient_checkpointing: bool = False,
+ **kwargs,
+ ):
+ module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs)
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
+
+ # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.enable_gradient_checkpointing
+ def enable_gradient_checkpointing(self):
+ self._module = self.module_class(
+ config=self.config,
+ dtype=self.dtype,
+ gradient_checkpointing=True,
+ )
+
+ # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.init_weights
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
+ # init input tensors
+ input_ids = jnp.zeros(input_shape, dtype="i4")
+ token_type_ids = jnp.zeros_like(input_ids)
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
+ attention_mask = jnp.ones_like(input_ids)
+ head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads))
+
+ params_rng, dropout_rng = jax.random.split(rng)
+ rngs = {"params": params_rng, "dropout": dropout_rng}
+
+ if self.config.add_cross_attention:
+ encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,))
+ encoder_attention_mask = attention_mask
+ module_init_outputs = self.module.init(
+ rngs,
+ input_ids,
+ attention_mask,
+ token_type_ids,
+ position_ids,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ return_dict=False,
+ )
+ else:
+ module_init_outputs = self.module.init(
+ rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
+ )
+
+ random_params = module_init_outputs["params"]
+
+ if params is not None:
+ random_params = flatten_dict(unfreeze(random_params))
+ params = flatten_dict(unfreeze(params))
+ for missing_key in self._missing_keys:
+ params[missing_key] = random_params[missing_key]
+ self._missing_keys = set()
+ return freeze(unflatten_dict(params))
+ else:
+ return random_params
+
+ # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartDecoderPreTrainedModel.init_cache
+ def init_cache(self, batch_size, max_length):
+ r"""
+ Args:
+ batch_size (`int`):
+ batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
+ max_length (`int`):
+ maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
+ cache.
+ """
+ # init input variables to retrieve cache
+ input_ids = jnp.ones((batch_size, max_length), dtype="i4")
+ attention_mask = jnp.ones_like(input_ids, dtype="i4")
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
+
+ init_variables = self.module.init(
+ jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
+ )
+ return unfreeze(init_variables["cache"])
+
+ @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ def __call__(
+ self,
+ input_ids,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ params: dict = None,
+ dropout_rng: jax.random.PRNGKey = None,
+ train: bool = False,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ past_key_values: dict = None,
+ ):
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
+
+ # init input tensors if not passed
+ if token_type_ids is None:
+ token_type_ids = jnp.ones_like(input_ids)
+
+ if position_ids is None:
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
+
+ if attention_mask is None:
+ attention_mask = jnp.ones_like(input_ids)
+
+ if head_mask is None:
+ head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads))
+
+ # Handle any PRNG if needed
+ rngs = {}
+ if dropout_rng is not None:
+ rngs["dropout"] = dropout_rng
+
+ inputs = {"params": params or self.params}
+
+ if self.config.add_cross_attention:
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed
+ # down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be
+ # changed by FlaxElectraAttention module
+ if past_key_values:
+ inputs["cache"] = past_key_values
+ mutable = ["cache"]
+ else:
+ mutable = False
+
+ outputs = self.module.apply(
+ inputs,
+ jnp.array(input_ids, dtype="i4"),
+ jnp.array(attention_mask, dtype="i4"),
+ token_type_ids=jnp.array(token_type_ids, dtype="i4"),
+ position_ids=jnp.array(position_ids, dtype="i4"),
+ head_mask=jnp.array(head_mask, dtype="i4"),
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ deterministic=not train,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ rngs=rngs,
+ mutable=mutable,
+ )
+
+ # add updated cache to model output
+ if past_key_values is not None and return_dict:
+ outputs, past_key_values = outputs
+ outputs["past_key_values"] = unfreeze(past_key_values["cache"])
+ return outputs
+ elif past_key_values is not None and not return_dict:
+ outputs, past_key_values = outputs
+ outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
+
+ else:
+ outputs = self.module.apply(
+ inputs,
+ jnp.array(input_ids, dtype="i4"),
+ jnp.array(attention_mask, dtype="i4"),
+ token_type_ids=jnp.array(token_type_ids, dtype="i4"),
+ position_ids=jnp.array(position_ids, dtype="i4"),
+ head_mask=jnp.array(head_mask, dtype="i4"),
+ deterministic=not train,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ rngs=rngs,
+ )
+
+ return outputs
+
+
+class FlaxElectraModule(nn.Module):
+ config: ElectraConfig
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
+ gradient_checkpointing: bool = False
+
+ def setup(self):
+ self.embeddings = FlaxElectraEmbeddings(self.config, dtype=self.dtype)
+ if self.config.embedding_size != self.config.hidden_size:
+ self.embeddings_project = nn.Dense(self.config.hidden_size, dtype=self.dtype)
+ self.encoder = FlaxElectraEncoder(
+ self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
+ )
+
+ def __call__(
+ self,
+ input_ids,
+ attention_mask,
+ token_type_ids,
+ position_ids,
+ head_mask: Optional[np.ndarray] = None,
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
+ init_cache: bool = False,
+ deterministic: bool = True,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ):
+ embeddings = self.embeddings(
+ input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic
+ )
+ if hasattr(self, "embeddings_project"):
+ embeddings = self.embeddings_project(embeddings)
+
+ return self.encoder(
+ embeddings,
+ attention_mask,
+ head_mask=head_mask,
+ deterministic=deterministic,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ init_cache=init_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+
+@add_start_docstrings(
+ "The bare Electra Model transformer outputting raw hidden-states without any specific head on top.",
+ ELECTRA_START_DOCSTRING,
+)
+class FlaxElectraModel(FlaxElectraPreTrainedModel):
+ module_class = FlaxElectraModule
+
+
+append_call_sample_docstring(FlaxElectraModel, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutput, _CONFIG_FOR_DOC)
+
+
+class FlaxElectraTiedDense(nn.Module):
+ embedding_size: int
+ dtype: jnp.dtype = jnp.float32
+ precision = None
+ bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros
+
+ def setup(self):
+ self.bias = self.param("bias", self.bias_init, (self.embedding_size,))
+
+ def __call__(self, x, kernel):
+ x = jnp.asarray(x, self.dtype)
+ kernel = jnp.asarray(kernel, self.dtype)
+ y = lax.dot_general(
+ x,
+ kernel,
+ (((x.ndim - 1,), (0,)), ((), ())),
+ precision=self.precision,
+ )
+ bias = jnp.asarray(self.bias, self.dtype)
+ return y + bias
+
+
+class FlaxElectraForMaskedLMModule(nn.Module):
+ config: ElectraConfig
+ dtype: jnp.dtype = jnp.float32
+ gradient_checkpointing: bool = False
+
+ def setup(self):
+ self.electra = FlaxElectraModule(
+ config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
+ )
+ self.generator_predictions = FlaxElectraGeneratorPredictions(config=self.config, dtype=self.dtype)
+ if self.config.tie_word_embeddings:
+ self.generator_lm_head = FlaxElectraTiedDense(self.config.vocab_size, dtype=self.dtype)
+ else:
+ self.generator_lm_head = nn.Dense(self.config.vocab_size, dtype=self.dtype)
+
+ def __call__(
+ self,
+ input_ids,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ deterministic: bool = True,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ):
+ outputs = self.electra(
+ input_ids,
+ attention_mask,
+ token_type_ids,
+ position_ids,
+ head_mask,
+ deterministic=deterministic,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ hidden_states = outputs[0]
+ prediction_scores = self.generator_predictions(hidden_states)
+
+ if self.config.tie_word_embeddings:
+ shared_embedding = self.electra.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
+ prediction_scores = self.generator_lm_head(prediction_scores, shared_embedding.T)
+ else:
+ prediction_scores = self.generator_lm_head(prediction_scores)
+
+ if not return_dict:
+ return (prediction_scores,) + outputs[1:]
+
+ return FlaxMaskedLMOutput(
+ logits=prediction_scores,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings("""Electra Model with a `language modeling` head on top.""", ELECTRA_START_DOCSTRING)
+class FlaxElectraForMaskedLM(FlaxElectraPreTrainedModel):
+ module_class = FlaxElectraForMaskedLMModule
+
+
+append_call_sample_docstring(FlaxElectraForMaskedLM, _CHECKPOINT_FOR_DOC, FlaxMaskedLMOutput, _CONFIG_FOR_DOC)
+
+
+class FlaxElectraForPreTrainingModule(nn.Module):
+ config: ElectraConfig
+ dtype: jnp.dtype = jnp.float32
+ gradient_checkpointing: bool = False
+
+ def setup(self):
+ self.electra = FlaxElectraModule(
+ config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
+ )
+ self.discriminator_predictions = FlaxElectraDiscriminatorPredictions(config=self.config, dtype=self.dtype)
+
+ def __call__(
+ self,
+ input_ids,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ deterministic: bool = True,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ):
+ # Model
+ outputs = self.electra(
+ input_ids,
+ attention_mask,
+ token_type_ids,
+ position_ids,
+ head_mask,
+ deterministic=deterministic,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ hidden_states = outputs[0]
+
+ logits = self.discriminator_predictions(hidden_states)
+
+ if not return_dict:
+ return (logits,) + outputs[1:]
+
+ return FlaxElectraForPreTrainingOutput(
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ Electra model with a binary classification head on top as used during pretraining for identifying generated tokens.
+
+ It is recommended to load the discriminator checkpoint into that model.
+ """,
+ ELECTRA_START_DOCSTRING,
+)
+class FlaxElectraForPreTraining(FlaxElectraPreTrainedModel):
+ module_class = FlaxElectraForPreTrainingModule
+
+
+FLAX_ELECTRA_FOR_PRETRAINING_DOCSTRING = """
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, FlaxElectraForPreTraining
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/electra-small-discriminator")
+ >>> model = FlaxElectraForPreTraining.from_pretrained("google/electra-small-discriminator")
+
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="np")
+ >>> outputs = model(**inputs)
+
+ >>> prediction_logits = outputs.logits
+ ```
+"""
+
+overwrite_call_docstring(
+ FlaxElectraForPreTraining,
+ ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length") + FLAX_ELECTRA_FOR_PRETRAINING_DOCSTRING,
+)
+append_replace_return_docstrings(
+ FlaxElectraForPreTraining, output_type=FlaxElectraForPreTrainingOutput, config_class=_CONFIG_FOR_DOC
+)
+
+
+class FlaxElectraForTokenClassificationModule(nn.Module):
+ config: ElectraConfig
+ dtype: jnp.dtype = jnp.float32
+ gradient_checkpointing: bool = False
+
+ def setup(self):
+ self.electra = FlaxElectraModule(
+ config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
+ )
+ classifier_dropout = (
+ self.config.classifier_dropout
+ if self.config.classifier_dropout is not None
+ else self.config.hidden_dropout_prob
+ )
+ self.dropout = nn.Dropout(classifier_dropout)
+ self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype)
+
+ def __call__(
+ self,
+ input_ids,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ deterministic: bool = True,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ):
+ # Model
+ outputs = self.electra(
+ input_ids,
+ attention_mask,
+ token_type_ids,
+ position_ids,
+ head_mask,
+ deterministic=deterministic,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ hidden_states = outputs[0]
+
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
+ logits = self.classifier(hidden_states)
+
+ if not return_dict:
+ return (logits,) + outputs[1:]
+
+ return FlaxTokenClassifierOutput(
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ Electra model with a token classification head on top.
+
+ Both the discriminator and generator may be loaded into this model.
+ """,
+ ELECTRA_START_DOCSTRING,
+)
+class FlaxElectraForTokenClassification(FlaxElectraPreTrainedModel):
+ module_class = FlaxElectraForTokenClassificationModule
+
+
+append_call_sample_docstring(
+ FlaxElectraForTokenClassification,
+ _CHECKPOINT_FOR_DOC,
+ FlaxTokenClassifierOutput,
+ _CONFIG_FOR_DOC,
+)
+
+
+def identity(x, **kwargs):
+ return x
+
+
+class FlaxElectraSequenceSummary(nn.Module):
+ r"""
+ Compute a single vector summary of a sequence hidden states.
+
+ Args:
+ config ([`PretrainedConfig`]):
+ The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
+ config class of your model for the default values it uses):
+
+ - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction.
+ - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes
+ (otherwise to `config.hidden_size`).
+ - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output,
+ another string or `None` will add no activation.
+ - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation.
+ - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation.
+ """
+
+ config: ElectraConfig
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ self.summary = identity
+ if hasattr(self.config, "summary_use_proj") and self.config.summary_use_proj:
+ if (
+ hasattr(self.config, "summary_proj_to_labels")
+ and self.config.summary_proj_to_labels
+ and self.config.num_labels > 0
+ ):
+ num_classes = self.config.num_labels
+ else:
+ num_classes = self.config.hidden_size
+ self.summary = nn.Dense(num_classes, dtype=self.dtype)
+
+ activation_string = getattr(self.config, "summary_activation", None)
+ self.activation = ACT2FN[activation_string] if activation_string else lambda x: x # noqa F407
+
+ self.first_dropout = identity
+ if hasattr(self.config, "summary_first_dropout") and self.config.summary_first_dropout > 0:
+ self.first_dropout = nn.Dropout(self.config.summary_first_dropout)
+
+ self.last_dropout = identity
+ if hasattr(self.config, "summary_last_dropout") and self.config.summary_last_dropout > 0:
+ self.last_dropout = nn.Dropout(self.config.summary_last_dropout)
+
+ def __call__(self, hidden_states, cls_index=None, deterministic: bool = True):
+ """
+ Compute a single vector summary of a sequence hidden states.
+
+ Args:
+ hidden_states (`jnp.ndarray` of shape `[batch_size, seq_len, hidden_size]`):
+ The hidden states of the last layer.
+ cls_index (`jnp.ndarray` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*):
+ Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token.
+
+ Returns:
+ `jnp.ndarray`: The summary of the sequence hidden states.
+ """
+ # NOTE: this doest "first" type summary always
+ output = hidden_states[:, 0]
+ output = self.first_dropout(output, deterministic=deterministic)
+ output = self.summary(output)
+ output = self.activation(output)
+ output = self.last_dropout(output, deterministic=deterministic)
+ return output
+
+
+class FlaxElectraForMultipleChoiceModule(nn.Module):
+ config: ElectraConfig
+ dtype: jnp.dtype = jnp.float32
+ gradient_checkpointing: bool = False
+
+ def setup(self):
+ self.electra = FlaxElectraModule(
+ config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
+ )
+ self.sequence_summary = FlaxElectraSequenceSummary(config=self.config, dtype=self.dtype)
+ self.classifier = nn.Dense(1, dtype=self.dtype)
+
+ def __call__(
+ self,
+ input_ids,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ deterministic: bool = True,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ):
+ num_choices = input_ids.shape[1]
+ input_ids = input_ids.reshape(-1, input_ids.shape[-1]) if input_ids is not None else None
+ attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1]) if attention_mask is not None else None
+ token_type_ids = token_type_ids.reshape(-1, token_type_ids.shape[-1]) if token_type_ids is not None else None
+ position_ids = position_ids.reshape(-1, position_ids.shape[-1]) if position_ids is not None else None
+
+ # Model
+ outputs = self.electra(
+ input_ids,
+ attention_mask,
+ token_type_ids,
+ position_ids,
+ head_mask,
+ deterministic=deterministic,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ hidden_states = outputs[0]
+ pooled_output = self.sequence_summary(hidden_states, deterministic=deterministic)
+ logits = self.classifier(pooled_output)
+
+ reshaped_logits = logits.reshape(-1, num_choices)
+
+ if not return_dict:
+ return (reshaped_logits,) + outputs[1:]
+
+ return FlaxMultipleChoiceModelOutput(
+ logits=reshaped_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ ELECTRA Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
+ softmax) e.g. for RocStories/SWAG tasks.
+ """,
+ ELECTRA_START_DOCSTRING,
+)
+class FlaxElectraForMultipleChoice(FlaxElectraPreTrainedModel):
+ module_class = FlaxElectraForMultipleChoiceModule
+
+
+# adapt docstring slightly for FlaxElectraForMultipleChoice
+overwrite_call_docstring(
+ FlaxElectraForMultipleChoice, ELECTRA_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
+)
+append_call_sample_docstring(
+ FlaxElectraForMultipleChoice,
+ _CHECKPOINT_FOR_DOC,
+ FlaxMultipleChoiceModelOutput,
+ _CONFIG_FOR_DOC,
+)
+
+
+class FlaxElectraForQuestionAnsweringModule(nn.Module):
+ config: ElectraConfig
+ dtype: jnp.dtype = jnp.float32
+ gradient_checkpointing: bool = False
+
+ def setup(self):
+ self.electra = FlaxElectraModule(
+ config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
+ )
+ self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype)
+
+ def __call__(
+ self,
+ input_ids,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ deterministic: bool = True,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ):
+ # Model
+ outputs = self.electra(
+ input_ids,
+ attention_mask,
+ token_type_ids,
+ position_ids,
+ head_mask,
+ deterministic=deterministic,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ hidden_states = outputs[0]
+ logits = self.qa_outputs(hidden_states)
+ start_logits, end_logits = jnp.split(logits, self.config.num_labels, axis=-1)
+ start_logits = start_logits.squeeze(-1)
+ end_logits = end_logits.squeeze(-1)
+
+ if not return_dict:
+ return (start_logits, end_logits) + outputs[1:]
+
+ return FlaxQuestionAnsweringModelOutput(
+ start_logits=start_logits,
+ end_logits=end_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ ELECTRA Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
+ layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
+ """,
+ ELECTRA_START_DOCSTRING,
+)
+class FlaxElectraForQuestionAnswering(FlaxElectraPreTrainedModel):
+ module_class = FlaxElectraForQuestionAnsweringModule
+
+
+append_call_sample_docstring(
+ FlaxElectraForQuestionAnswering,
+ _CHECKPOINT_FOR_DOC,
+ FlaxQuestionAnsweringModelOutput,
+ _CONFIG_FOR_DOC,
+)
+
+
+class FlaxElectraClassificationHead(nn.Module):
+ """Head for sentence-level classification tasks."""
+
+ config: ElectraConfig
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ self.dense = nn.Dense(self.config.hidden_size, dtype=self.dtype)
+ classifier_dropout = (
+ self.config.classifier_dropout
+ if self.config.classifier_dropout is not None
+ else self.config.hidden_dropout_prob
+ )
+ self.dropout = nn.Dropout(classifier_dropout)
+ self.out_proj = nn.Dense(self.config.num_labels, dtype=self.dtype)
+
+ def __call__(self, hidden_states, deterministic: bool = True):
+ x = hidden_states[:, 0, :] # take token (equiv. to [CLS])
+ x = self.dropout(x, deterministic=deterministic)
+ x = self.dense(x)
+ x = ACT2FN["gelu"](x) # although BERT uses tanh here, it seems Electra authors used gelu
+ x = self.dropout(x, deterministic=deterministic)
+ x = self.out_proj(x)
+ return x
+
+
+class FlaxElectraForSequenceClassificationModule(nn.Module):
+ config: ElectraConfig
+ dtype: jnp.dtype = jnp.float32
+ gradient_checkpointing: bool = False
+
+ def setup(self):
+ self.electra = FlaxElectraModule(
+ config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
+ )
+ self.classifier = FlaxElectraClassificationHead(config=self.config, dtype=self.dtype)
+
+ def __call__(
+ self,
+ input_ids,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ deterministic: bool = True,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ):
+ # Model
+ outputs = self.electra(
+ input_ids,
+ attention_mask,
+ token_type_ids,
+ position_ids,
+ head_mask,
+ deterministic=deterministic,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ hidden_states = outputs[0]
+ logits = self.classifier(hidden_states, deterministic=deterministic)
+
+ if not return_dict:
+ return (logits,) + outputs[1:]
+
+ return FlaxSequenceClassifierOutput(
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ Electra Model transformer with a sequence classification/regression head on top (a linear layer on top of the
+ pooled output) e.g. for GLUE tasks.
+ """,
+ ELECTRA_START_DOCSTRING,
+)
+class FlaxElectraForSequenceClassification(FlaxElectraPreTrainedModel):
+ module_class = FlaxElectraForSequenceClassificationModule
+
+
+append_call_sample_docstring(
+ FlaxElectraForSequenceClassification,
+ _CHECKPOINT_FOR_DOC,
+ FlaxSequenceClassifierOutput,
+ _CONFIG_FOR_DOC,
+)
+
+
+class FlaxElectraForCausalLMModule(nn.Module):
+ config: ElectraConfig
+ dtype: jnp.dtype = jnp.float32
+ gradient_checkpointing: bool = False
+
+ def setup(self):
+ self.electra = FlaxElectraModule(
+ config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
+ )
+ self.generator_predictions = FlaxElectraGeneratorPredictions(config=self.config, dtype=self.dtype)
+ if self.config.tie_word_embeddings:
+ self.generator_lm_head = FlaxElectraTiedDense(self.config.vocab_size, dtype=self.dtype)
+ else:
+ self.generator_lm_head = nn.Dense(self.config.vocab_size, dtype=self.dtype)
+
+ def __call__(
+ self,
+ input_ids,
+ attention_mask: Optional[jnp.ndarray] = None,
+ token_type_ids: Optional[jnp.ndarray] = None,
+ position_ids: Optional[jnp.ndarray] = None,
+ head_mask: Optional[jnp.ndarray] = None,
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
+ init_cache: bool = False,
+ deterministic: bool = True,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ):
+ outputs = self.electra(
+ input_ids,
+ attention_mask,
+ token_type_ids,
+ position_ids,
+ head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ init_cache=init_cache,
+ deterministic=deterministic,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ hidden_states = outputs[0]
+ prediction_scores = self.generator_predictions(hidden_states)
+
+ if self.config.tie_word_embeddings:
+ shared_embedding = self.electra.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
+ prediction_scores = self.generator_lm_head(prediction_scores, shared_embedding.T)
+ else:
+ prediction_scores = self.generator_lm_head(prediction_scores)
+
+ if not return_dict:
+ return (prediction_scores,) + outputs[1:]
+
+ return FlaxCausalLMOutputWithCrossAttentions(
+ logits=prediction_scores,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ cross_attentions=outputs.cross_attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ Electra Model with a language modeling head on top (a linear layer on top of the hidden-states output) e.g for
+ autoregressive tasks.
+ """,
+ ELECTRA_START_DOCSTRING,
+)
+# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForCausalLM with Bert->Electra
+class FlaxElectraForCausalLM(FlaxElectraPreTrainedModel):
+ module_class = FlaxElectraForCausalLMModule
+
+ def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
+ # initializing the cache
+ batch_size, seq_length = input_ids.shape
+
+ past_key_values = self.init_cache(batch_size, max_length)
+ # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
+ # But since the decoder uses a causal mask, those positions are masked anyway.
+ # Thus, we can create a single static attention_mask here, which is more efficient for compilation
+ extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
+ if attention_mask is not None:
+ position_ids = attention_mask.cumsum(axis=-1) - 1
+ extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
+ else:
+ position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
+
+ return {
+ "past_key_values": past_key_values,
+ "attention_mask": extended_attention_mask,
+ "position_ids": position_ids,
+ }
+
+ def update_inputs_for_generation(self, model_outputs, model_kwargs):
+ model_kwargs["past_key_values"] = model_outputs.past_key_values
+ model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
+ return model_kwargs
+
+
+append_call_sample_docstring(
+ FlaxElectraForCausalLM,
+ _CHECKPOINT_FOR_DOC,
+ FlaxCausalLMOutputWithCrossAttentions,
+ _CONFIG_FOR_DOC,
+)
+
+
+__all__ = [
+ "FlaxElectraForCausalLM",
+ "FlaxElectraForMaskedLM",
+ "FlaxElectraForMultipleChoice",
+ "FlaxElectraForPreTraining",
+ "FlaxElectraForQuestionAnswering",
+ "FlaxElectraForSequenceClassification",
+ "FlaxElectraForTokenClassification",
+ "FlaxElectraModel",
+ "FlaxElectraPreTrainedModel",
+]
diff --git a/docs/transformers/build/lib/transformers/models/electra/modeling_tf_electra.py b/docs/transformers/build/lib/transformers/models/electra/modeling_tf_electra.py
new file mode 100644
index 0000000000000000000000000000000000000000..6dc3ac8ebf8cb32056aac75586f9f5eb15f58e7e
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/electra/modeling_tf_electra.py
@@ -0,0 +1,1776 @@
+# coding=utf-8
+# Copyright 2019 The Google AI Language Team Authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""TF Electra model."""
+
+from __future__ import annotations
+
+import math
+import warnings
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import tensorflow as tf
+
+from ...activations_tf import get_tf_activation
+from ...modeling_tf_outputs import (
+ TFBaseModelOutputWithPastAndCrossAttentions,
+ TFMaskedLMOutput,
+ TFMultipleChoiceModelOutput,
+ TFQuestionAnsweringModelOutput,
+ TFSequenceClassifierOutput,
+ TFTokenClassifierOutput,
+)
+from ...modeling_tf_utils import (
+ TFMaskedLanguageModelingLoss,
+ TFModelInputType,
+ TFMultipleChoiceLoss,
+ TFPreTrainedModel,
+ TFQuestionAnsweringLoss,
+ TFSequenceClassificationLoss,
+ TFSequenceSummary,
+ TFTokenClassificationLoss,
+ get_initializer,
+ keras,
+ keras_serializable,
+ unpack_inputs,
+)
+from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
+from ...utils import (
+ ModelOutput,
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from .configuration_electra import ElectraConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "google/electra-small-discriminator"
+_CONFIG_FOR_DOC = "ElectraConfig"
+
+
+# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfAttention with Bert->Electra
+class TFElectraSelfAttention(keras.layers.Layer):
+ def __init__(self, config: ElectraConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ if config.hidden_size % config.num_attention_heads != 0:
+ raise ValueError(
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number "
+ f"of attention heads ({config.num_attention_heads})"
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+ self.sqrt_att_head_size = math.sqrt(self.attention_head_size)
+
+ self.query = keras.layers.Dense(
+ units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
+ )
+ self.key = keras.layers.Dense(
+ units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
+ )
+ self.value = keras.layers.Dense(
+ units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
+ )
+ self.dropout = keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
+
+ self.is_decoder = config.is_decoder
+ self.config = config
+
+ def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
+ # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
+ tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
+
+ # Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size]
+ return tf.transpose(tensor, perm=[0, 2, 1, 3])
+
+ def call(
+ self,
+ hidden_states: tf.Tensor,
+ attention_mask: tf.Tensor,
+ head_mask: tf.Tensor,
+ encoder_hidden_states: tf.Tensor,
+ encoder_attention_mask: tf.Tensor,
+ past_key_value: Tuple[tf.Tensor],
+ output_attentions: bool,
+ training: bool = False,
+ ) -> Tuple[tf.Tensor]:
+ batch_size = shape_list(hidden_states)[0]
+ mixed_query_layer = self.query(inputs=hidden_states)
+
+ # If this is instantiated as a cross-attention module, the keys
+ # and values come from an encoder; the attention mask needs to be
+ # such that the encoder's padding tokens are not attended to.
+ is_cross_attention = encoder_hidden_states is not None
+
+ if is_cross_attention and past_key_value is not None:
+ # reuse k,v, cross_attentions
+ key_layer = past_key_value[0]
+ value_layer = past_key_value[1]
+ attention_mask = encoder_attention_mask
+ elif is_cross_attention:
+ key_layer = self.transpose_for_scores(self.key(inputs=encoder_hidden_states), batch_size)
+ value_layer = self.transpose_for_scores(self.value(inputs=encoder_hidden_states), batch_size)
+ attention_mask = encoder_attention_mask
+ elif past_key_value is not None:
+ key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size)
+ value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size)
+ key_layer = tf.concat([past_key_value[0], key_layer], axis=2)
+ value_layer = tf.concat([past_key_value[1], value_layer], axis=2)
+ else:
+ key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size)
+ value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size)
+
+ query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
+
+ if self.is_decoder:
+ # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_layer, value_layer)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ # (batch size, num_heads, seq_len_q, seq_len_k)
+ attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
+ dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype)
+ attention_scores = tf.divide(attention_scores, dk)
+
+ if attention_mask is not None:
+ # Apply the attention mask is (precomputed for all layers in TFElectraModel call() function)
+ attention_scores = tf.add(attention_scores, attention_mask)
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = stable_softmax(logits=attention_scores, axis=-1)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(inputs=attention_probs, training=training)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs = tf.multiply(attention_probs, head_mask)
+
+ attention_output = tf.matmul(attention_probs, value_layer)
+ attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3])
+
+ # (batch_size, seq_len_q, all_head_size)
+ attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size))
+ outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
+
+ if self.is_decoder:
+ outputs = outputs + (past_key_value,)
+ return outputs
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "query", None) is not None:
+ with tf.name_scope(self.query.name):
+ self.query.build([None, None, self.config.hidden_size])
+ if getattr(self, "key", None) is not None:
+ with tf.name_scope(self.key.name):
+ self.key.build([None, None, self.config.hidden_size])
+ if getattr(self, "value", None) is not None:
+ with tf.name_scope(self.value.name):
+ self.value.build([None, None, self.config.hidden_size])
+
+
+# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfOutput with Bert->Electra
+class TFElectraSelfOutput(keras.layers.Layer):
+ def __init__(self, config: ElectraConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.dense = keras.layers.Dense(
+ units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
+ )
+ self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
+ self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
+ self.config = config
+
+ def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
+ hidden_states = self.dense(inputs=hidden_states)
+ hidden_states = self.dropout(inputs=hidden_states, training=training)
+ hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor)
+
+ return hidden_states
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "dense", None) is not None:
+ with tf.name_scope(self.dense.name):
+ self.dense.build([None, None, self.config.hidden_size])
+ if getattr(self, "LayerNorm", None) is not None:
+ with tf.name_scope(self.LayerNorm.name):
+ self.LayerNorm.build([None, None, self.config.hidden_size])
+
+
+# Copied from transformers.models.bert.modeling_tf_bert.TFBertAttention with Bert->Electra
+class TFElectraAttention(keras.layers.Layer):
+ def __init__(self, config: ElectraConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.self_attention = TFElectraSelfAttention(config, name="self")
+ self.dense_output = TFElectraSelfOutput(config, name="output")
+
+ def prune_heads(self, heads):
+ raise NotImplementedError
+
+ def call(
+ self,
+ input_tensor: tf.Tensor,
+ attention_mask: tf.Tensor,
+ head_mask: tf.Tensor,
+ encoder_hidden_states: tf.Tensor,
+ encoder_attention_mask: tf.Tensor,
+ past_key_value: Tuple[tf.Tensor],
+ output_attentions: bool,
+ training: bool = False,
+ ) -> Tuple[tf.Tensor]:
+ self_outputs = self.self_attention(
+ hidden_states=input_tensor,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ training=training,
+ )
+ attention_output = self.dense_output(
+ hidden_states=self_outputs[0], input_tensor=input_tensor, training=training
+ )
+ # add attentions (possibly with past_key_value) if we output them
+ outputs = (attention_output,) + self_outputs[1:]
+
+ return outputs
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "self_attention", None) is not None:
+ with tf.name_scope(self.self_attention.name):
+ self.self_attention.build(None)
+ if getattr(self, "dense_output", None) is not None:
+ with tf.name_scope(self.dense_output.name):
+ self.dense_output.build(None)
+
+
+# Copied from transformers.models.bert.modeling_tf_bert.TFBertIntermediate with Bert->Electra
+class TFElectraIntermediate(keras.layers.Layer):
+ def __init__(self, config: ElectraConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.dense = keras.layers.Dense(
+ units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
+ )
+
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = get_tf_activation(config.hidden_act)
+ else:
+ self.intermediate_act_fn = config.hidden_act
+ self.config = config
+
+ def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
+ hidden_states = self.dense(inputs=hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+
+ return hidden_states
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "dense", None) is not None:
+ with tf.name_scope(self.dense.name):
+ self.dense.build([None, None, self.config.hidden_size])
+
+
+# Copied from transformers.models.bert.modeling_tf_bert.TFBertOutput with Bert->Electra
+class TFElectraOutput(keras.layers.Layer):
+ def __init__(self, config: ElectraConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.dense = keras.layers.Dense(
+ units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
+ )
+ self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
+ self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
+ self.config = config
+
+ def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
+ hidden_states = self.dense(inputs=hidden_states)
+ hidden_states = self.dropout(inputs=hidden_states, training=training)
+ hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor)
+
+ return hidden_states
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "dense", None) is not None:
+ with tf.name_scope(self.dense.name):
+ self.dense.build([None, None, self.config.intermediate_size])
+ if getattr(self, "LayerNorm", None) is not None:
+ with tf.name_scope(self.LayerNorm.name):
+ self.LayerNorm.build([None, None, self.config.hidden_size])
+
+
+# Copied from transformers.models.bert.modeling_tf_bert.TFBertLayer with Bert->Electra
+class TFElectraLayer(keras.layers.Layer):
+ def __init__(self, config: ElectraConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.attention = TFElectraAttention(config, name="attention")
+ self.is_decoder = config.is_decoder
+ self.add_cross_attention = config.add_cross_attention
+ if self.add_cross_attention:
+ if not self.is_decoder:
+ raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
+ self.crossattention = TFElectraAttention(config, name="crossattention")
+ self.intermediate = TFElectraIntermediate(config, name="intermediate")
+ self.bert_output = TFElectraOutput(config, name="output")
+
+ def call(
+ self,
+ hidden_states: tf.Tensor,
+ attention_mask: tf.Tensor,
+ head_mask: tf.Tensor,
+ encoder_hidden_states: tf.Tensor | None,
+ encoder_attention_mask: tf.Tensor | None,
+ past_key_value: Tuple[tf.Tensor] | None,
+ output_attentions: bool,
+ training: bool = False,
+ ) -> Tuple[tf.Tensor]:
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
+ self_attention_outputs = self.attention(
+ input_tensor=hidden_states,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=self_attn_past_key_value,
+ output_attentions=output_attentions,
+ training=training,
+ )
+ attention_output = self_attention_outputs[0]
+
+ # if decoder, the last output is tuple of self-attn cache
+ if self.is_decoder:
+ outputs = self_attention_outputs[1:-1]
+ present_key_value = self_attention_outputs[-1]
+ else:
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
+
+ cross_attn_present_key_value = None
+ if self.is_decoder and encoder_hidden_states is not None:
+ if not hasattr(self, "crossattention"):
+ raise ValueError(
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
+ " by setting `config.add_cross_attention=True`"
+ )
+
+ # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
+ cross_attention_outputs = self.crossattention(
+ input_tensor=attention_output,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ past_key_value=cross_attn_past_key_value,
+ output_attentions=output_attentions,
+ training=training,
+ )
+ attention_output = cross_attention_outputs[0]
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
+
+ # add cross-attn cache to positions 3,4 of present_key_value tuple
+ cross_attn_present_key_value = cross_attention_outputs[-1]
+ present_key_value = present_key_value + cross_attn_present_key_value
+
+ intermediate_output = self.intermediate(hidden_states=attention_output)
+ layer_output = self.bert_output(
+ hidden_states=intermediate_output, input_tensor=attention_output, training=training
+ )
+ outputs = (layer_output,) + outputs # add attentions if we output them
+
+ # if decoder, return the attn key/values as the last output
+ if self.is_decoder:
+ outputs = outputs + (present_key_value,)
+
+ return outputs
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "attention", None) is not None:
+ with tf.name_scope(self.attention.name):
+ self.attention.build(None)
+ if getattr(self, "intermediate", None) is not None:
+ with tf.name_scope(self.intermediate.name):
+ self.intermediate.build(None)
+ if getattr(self, "bert_output", None) is not None:
+ with tf.name_scope(self.bert_output.name):
+ self.bert_output.build(None)
+ if getattr(self, "crossattention", None) is not None:
+ with tf.name_scope(self.crossattention.name):
+ self.crossattention.build(None)
+
+
+# Copied from transformers.models.bert.modeling_tf_bert.TFBertEncoder with Bert->Electra
+class TFElectraEncoder(keras.layers.Layer):
+ def __init__(self, config: ElectraConfig, **kwargs):
+ super().__init__(**kwargs)
+ self.config = config
+ self.layer = [TFElectraLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)]
+
+ def call(
+ self,
+ hidden_states: tf.Tensor,
+ attention_mask: tf.Tensor,
+ head_mask: tf.Tensor,
+ encoder_hidden_states: tf.Tensor | None,
+ encoder_attention_mask: tf.Tensor | None,
+ past_key_values: Tuple[Tuple[tf.Tensor]] | None,
+ use_cache: Optional[bool],
+ output_attentions: bool,
+ output_hidden_states: bool,
+ return_dict: bool,
+ training: bool = False,
+ ) -> Union[TFBaseModelOutputWithPastAndCrossAttentions, Tuple[tf.Tensor]]:
+ all_hidden_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
+
+ next_decoder_cache = () if use_cache else None
+ for i, layer_module in enumerate(self.layer):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ past_key_value = past_key_values[i] if past_key_values is not None else None
+
+ layer_outputs = layer_module(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ head_mask=head_mask[i],
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ training=training,
+ )
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache += (layer_outputs[-1],)
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[1],)
+ if self.config.add_cross_attention and encoder_hidden_states is not None:
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
+
+ # Add last layer
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v for v in [hidden_states, all_hidden_states, all_attentions, all_cross_attentions] if v is not None
+ )
+
+ return TFBaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=next_decoder_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_attentions,
+ cross_attentions=all_cross_attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "layer", None) is not None:
+ for layer in self.layer:
+ with tf.name_scope(layer.name):
+ layer.build(None)
+
+
+# Copied from transformers.models.bert.modeling_tf_bert.TFBertPooler with Bert->Electra
+class TFElectraPooler(keras.layers.Layer):
+ def __init__(self, config: ElectraConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.dense = keras.layers.Dense(
+ units=config.hidden_size,
+ kernel_initializer=get_initializer(config.initializer_range),
+ activation="tanh",
+ name="dense",
+ )
+ self.config = config
+
+ def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
+ # We "pool" the model by simply taking the hidden state corresponding
+ # to the first token.
+ first_token_tensor = hidden_states[:, 0]
+ pooled_output = self.dense(inputs=first_token_tensor)
+
+ return pooled_output
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "dense", None) is not None:
+ with tf.name_scope(self.dense.name):
+ self.dense.build([None, None, self.config.hidden_size])
+
+
+# Copied from transformers.models.albert.modeling_tf_albert.TFAlbertEmbeddings with Albert->Electra
+class TFElectraEmbeddings(keras.layers.Layer):
+ """Construct the embeddings from word, position and token_type embeddings."""
+
+ def __init__(self, config: ElectraConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.config = config
+ self.embedding_size = config.embedding_size
+ self.max_position_embeddings = config.max_position_embeddings
+ self.initializer_range = config.initializer_range
+ self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
+ self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
+
+ def build(self, input_shape=None):
+ with tf.name_scope("word_embeddings"):
+ self.weight = self.add_weight(
+ name="weight",
+ shape=[self.config.vocab_size, self.embedding_size],
+ initializer=get_initializer(self.initializer_range),
+ )
+
+ with tf.name_scope("token_type_embeddings"):
+ self.token_type_embeddings = self.add_weight(
+ name="embeddings",
+ shape=[self.config.type_vocab_size, self.embedding_size],
+ initializer=get_initializer(self.initializer_range),
+ )
+
+ with tf.name_scope("position_embeddings"):
+ self.position_embeddings = self.add_weight(
+ name="embeddings",
+ shape=[self.max_position_embeddings, self.embedding_size],
+ initializer=get_initializer(self.initializer_range),
+ )
+
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "LayerNorm", None) is not None:
+ with tf.name_scope(self.LayerNorm.name):
+ self.LayerNorm.build([None, None, self.config.embedding_size])
+
+ # Copied from transformers.models.bert.modeling_tf_bert.TFBertEmbeddings.call
+ def call(
+ self,
+ input_ids: Optional[tf.Tensor] = None,
+ position_ids: Optional[tf.Tensor] = None,
+ token_type_ids: Optional[tf.Tensor] = None,
+ inputs_embeds: Optional[tf.Tensor] = None,
+ past_key_values_length=0,
+ training: bool = False,
+ ) -> tf.Tensor:
+ """
+ Applies embedding based on inputs tensor.
+
+ Returns:
+ final_embeddings (`tf.Tensor`): output embedding tensor.
+ """
+ if input_ids is None and inputs_embeds is None:
+ raise ValueError("Need to provide either `input_ids` or `input_embeds`.")
+
+ if input_ids is not None:
+ check_embeddings_within_bounds(input_ids, self.config.vocab_size)
+ inputs_embeds = tf.gather(params=self.weight, indices=input_ids)
+
+ input_shape = shape_list(inputs_embeds)[:-1]
+
+ if token_type_ids is None:
+ token_type_ids = tf.fill(dims=input_shape, value=0)
+
+ if position_ids is None:
+ position_ids = tf.expand_dims(
+ tf.range(start=past_key_values_length, limit=input_shape[1] + past_key_values_length), axis=0
+ )
+
+ position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
+ token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids)
+ final_embeddings = inputs_embeds + position_embeds + token_type_embeds
+ final_embeddings = self.LayerNorm(inputs=final_embeddings)
+ final_embeddings = self.dropout(inputs=final_embeddings, training=training)
+
+ return final_embeddings
+
+
+class TFElectraDiscriminatorPredictions(keras.layers.Layer):
+ def __init__(self, config, **kwargs):
+ super().__init__(**kwargs)
+
+ self.dense = keras.layers.Dense(config.hidden_size, name="dense")
+ self.dense_prediction = keras.layers.Dense(1, name="dense_prediction")
+ self.config = config
+
+ def call(self, discriminator_hidden_states, training=False):
+ hidden_states = self.dense(discriminator_hidden_states)
+ hidden_states = get_tf_activation(self.config.hidden_act)(hidden_states)
+ logits = tf.squeeze(self.dense_prediction(hidden_states), -1)
+
+ return logits
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "dense", None) is not None:
+ with tf.name_scope(self.dense.name):
+ self.dense.build([None, None, self.config.hidden_size])
+ if getattr(self, "dense_prediction", None) is not None:
+ with tf.name_scope(self.dense_prediction.name):
+ self.dense_prediction.build([None, None, self.config.hidden_size])
+
+
+class TFElectraGeneratorPredictions(keras.layers.Layer):
+ def __init__(self, config, **kwargs):
+ super().__init__(**kwargs)
+
+ self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
+ self.dense = keras.layers.Dense(config.embedding_size, name="dense")
+ self.config = config
+
+ def call(self, generator_hidden_states, training=False):
+ hidden_states = self.dense(generator_hidden_states)
+ hidden_states = get_tf_activation("gelu")(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states)
+
+ return hidden_states
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "LayerNorm", None) is not None:
+ with tf.name_scope(self.LayerNorm.name):
+ self.LayerNorm.build([None, None, self.config.embedding_size])
+ if getattr(self, "dense", None) is not None:
+ with tf.name_scope(self.dense.name):
+ self.dense.build([None, None, self.config.hidden_size])
+
+
+class TFElectraPreTrainedModel(TFPreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = ElectraConfig
+ base_model_prefix = "electra"
+ # When the model is loaded from a PT model
+ _keys_to_ignore_on_load_unexpected = [r"generator_lm_head.weight"]
+ _keys_to_ignore_on_load_missing = [r"dropout"]
+
+
+@keras_serializable
+class TFElectraMainLayer(keras.layers.Layer):
+ config_class = ElectraConfig
+
+ def __init__(self, config, **kwargs):
+ super().__init__(**kwargs)
+
+ self.config = config
+ self.is_decoder = config.is_decoder
+
+ self.embeddings = TFElectraEmbeddings(config, name="embeddings")
+
+ if config.embedding_size != config.hidden_size:
+ self.embeddings_project = keras.layers.Dense(config.hidden_size, name="embeddings_project")
+
+ self.encoder = TFElectraEncoder(config, name="encoder")
+
+ def get_input_embeddings(self):
+ return self.embeddings
+
+ def set_input_embeddings(self, value):
+ self.embeddings.weight = value
+ self.embeddings.vocab_size = shape_list(value)[0]
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ raise NotImplementedError
+
+ def get_extended_attention_mask(self, attention_mask, input_shape, dtype, past_key_values_length=0):
+ batch_size, seq_length = input_shape
+
+ if attention_mask is None:
+ attention_mask = tf.fill(dims=(batch_size, seq_length + past_key_values_length), value=1)
+
+ # We create a 3D attention mask from a 2D tensor mask.
+ # Sizes are [batch_size, 1, 1, to_seq_length]
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
+ # this attention mask is more simple than the triangular masking of causal attention
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
+ attention_mask_shape = shape_list(attention_mask)
+
+ mask_seq_length = seq_length + past_key_values_length
+ # Copied from `modeling_tf_t5.py`
+ # Provided a padding mask of dimensions [batch_size, mask_seq_length]
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
+ if self.is_decoder:
+ seq_ids = tf.range(mask_seq_length)
+ causal_mask = tf.less_equal(
+ tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)),
+ seq_ids[None, :, None],
+ )
+ causal_mask = tf.cast(causal_mask, dtype=attention_mask.dtype)
+ extended_attention_mask = causal_mask * attention_mask[:, None, :]
+ attention_mask_shape = shape_list(extended_attention_mask)
+ extended_attention_mask = tf.reshape(
+ extended_attention_mask, (attention_mask_shape[0], 1, attention_mask_shape[1], attention_mask_shape[2])
+ )
+ if past_key_values_length > 0:
+ extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :]
+ else:
+ extended_attention_mask = tf.reshape(
+ attention_mask, (attention_mask_shape[0], 1, 1, attention_mask_shape[1])
+ )
+
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
+ # masked positions, this operation will create a tensor which is 0.0 for
+ # positions we want to attend and -10000.0 for masked positions.
+ # Since we are adding it to the raw scores before the softmax, this is
+ # effectively the same as removing these entirely.
+ extended_attention_mask = tf.cast(extended_attention_mask, dtype=dtype)
+ one_cst = tf.constant(1.0, dtype=dtype)
+ ten_thousand_cst = tf.constant(-10000.0, dtype=dtype)
+ extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst)
+
+ return extended_attention_mask
+
+ def get_head_mask(self, head_mask):
+ if head_mask is not None:
+ raise NotImplementedError
+ else:
+ head_mask = [None] * self.config.num_hidden_layers
+
+ return head_mask
+
+ @unpack_inputs
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
+ position_ids: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ encoder_hidden_states: np.ndarray | tf.Tensor | None = None,
+ encoder_attention_mask: np.ndarray | tf.Tensor | None = None,
+ past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ training: Optional[bool] = False,
+ ) -> Union[TFBaseModelOutputWithPastAndCrossAttentions, Tuple[tf.Tensor]]:
+ if not self.config.is_decoder:
+ use_cache = False
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = shape_list(input_ids)
+ elif inputs_embeds is not None:
+ input_shape = shape_list(inputs_embeds)[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ batch_size, seq_length = input_shape
+
+ if past_key_values is None:
+ past_key_values_length = 0
+ past_key_values = [None] * len(self.encoder.layer)
+ else:
+ past_key_values_length = shape_list(past_key_values[0][0])[-2]
+
+ if attention_mask is None:
+ attention_mask = tf.fill(dims=(batch_size, seq_length + past_key_values_length), value=1)
+
+ if token_type_ids is None:
+ token_type_ids = tf.fill(dims=input_shape, value=0)
+
+ hidden_states = self.embeddings(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ token_type_ids=token_type_ids,
+ inputs_embeds=inputs_embeds,
+ past_key_values_length=past_key_values_length,
+ training=training,
+ )
+ extended_attention_mask = self.get_extended_attention_mask(
+ attention_mask, input_shape, hidden_states.dtype, past_key_values_length
+ )
+
+ # Copied from `modeling_tf_t5.py` with -1e9 -> -10000
+ if self.is_decoder and encoder_attention_mask is not None:
+ # If a 2D ou 3D attention mask is provided for the cross-attention
+ # we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ encoder_attention_mask = tf.cast(encoder_attention_mask, dtype=extended_attention_mask.dtype)
+ num_dims_encoder_attention_mask = len(shape_list(encoder_attention_mask))
+ if num_dims_encoder_attention_mask == 3:
+ encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
+ if num_dims_encoder_attention_mask == 2:
+ encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
+
+ # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
+ # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
+ # encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask,
+ # tf.transpose(encoder_extended_attention_mask, perm=(-1, -2)))
+
+ encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0
+ else:
+ encoder_extended_attention_mask = None
+
+ head_mask = self.get_head_mask(head_mask)
+
+ if hasattr(self, "embeddings_project"):
+ hidden_states = self.embeddings_project(hidden_states, training=training)
+
+ hidden_states = self.encoder(
+ hidden_states=hidden_states,
+ attention_mask=extended_attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_extended_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ return hidden_states
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "embeddings", None) is not None:
+ with tf.name_scope(self.embeddings.name):
+ self.embeddings.build(None)
+ if getattr(self, "encoder", None) is not None:
+ with tf.name_scope(self.encoder.name):
+ self.encoder.build(None)
+ if getattr(self, "embeddings_project", None) is not None:
+ with tf.name_scope(self.embeddings_project.name):
+ self.embeddings_project.build([None, None, self.config.embedding_size])
+
+
+@dataclass
+class TFElectraForPreTrainingOutput(ModelOutput):
+ """
+ Output type of [`TFElectraForPreTraining`].
+
+ Args:
+ loss (*optional*, returned when `labels` is provided, `tf.Tensor` of shape `(1,)`):
+ Total loss of the ELECTRA objective.
+ logits (`tf.Tensor` of shape `(batch_size, sequence_length)`):
+ Prediction scores of the head (scores for each token before SoftMax).
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+ `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ logits: Optional[tf.Tensor] = None
+ hidden_states: Tuple[tf.Tensor] | None = None
+ attentions: Tuple[tf.Tensor] | None = None
+
+
+ELECTRA_START_DOCSTRING = r"""
+
+ This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
+ as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
+ behavior.
+
+
+
+ TensorFlow models and layers in `transformers` accept two formats as input:
+
+ - having all inputs as keyword arguments (like PyTorch models), or
+ - having all inputs as a list, tuple or dict in the first positional argument.
+
+ The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
+ and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
+ pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
+ format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
+ the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
+ positional argument:
+
+ - a single Tensor with `input_ids` only and nothing else: `model(input_ids)`
+ - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
+ `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
+ - a dictionary with one or several input Tensors associated to the input names given in the docstring:
+ `model({"input_ids": input_ids, "token_type_ids": token_type_ids})`
+
+ Note that when creating models and layers with
+ [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
+ about any of this, as you can just pass inputs like you would to any other Python function!
+
+
+
+ Parameters:
+ config ([`ElectraConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+ELECTRA_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`Numpy array` or `tf.Tensor` of shape `({0})`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and
+ [`PreTrainedTokenizer.encode`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ position_ids (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ head_mask (`Numpy array` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ inputs_embeds (`tf.Tensor` of shape `({0}, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
+ config will be used instead.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
+ used instead.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
+ eager mode, in graph mode the value will always be set to True.
+ training (`bool`, *optional*, defaults to `False`):
+ Whether or not to use the model in training mode (some modules like dropout modules have different
+ behaviors between training and evaluation).
+"""
+
+
+@add_start_docstrings(
+ "The bare Electra Model transformer outputting raw hidden-states without any specific head on top. Identical to "
+ "the BERT model except that it uses an additional linear layer between the embedding layer and the encoder if the "
+ "hidden size and embedding size are different. "
+ ""
+ "Both the generator and discriminator checkpoints may be loaded into this model.",
+ ELECTRA_START_DOCSTRING,
+)
+class TFElectraModel(TFElectraPreTrainedModel):
+ def __init__(self, config, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+
+ self.electra = TFElectraMainLayer(config, name="electra")
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TFBaseModelOutputWithPastAndCrossAttentions,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
+ position_ids: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ encoder_hidden_states: np.ndarray | tf.Tensor | None = None,
+ encoder_attention_mask: np.ndarray | tf.Tensor | None = None,
+ past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ training: Optional[bool] = False,
+ ) -> Union[TFBaseModelOutputWithPastAndCrossAttentions, Tuple[tf.Tensor]]:
+ r"""
+ encoder_hidden_states (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+ the model is configured as a decoder.
+ encoder_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ past_key_values (`Tuple[Tuple[tf.Tensor]]` of length `config.n_layers`)
+ contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`). Set to `False` during training, `True` during generation
+ """
+ outputs = self.electra(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ return outputs
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "electra", None) is not None:
+ with tf.name_scope(self.electra.name):
+ self.electra.build(None)
+
+
+@add_start_docstrings(
+ """
+ Electra model with a binary classification head on top as used during pretraining for identifying generated tokens.
+
+ Even though both the discriminator and generator may be loaded into this model, the discriminator is the only model
+ of the two to have the correct classification head to be used for this model.
+ """,
+ ELECTRA_START_DOCSTRING,
+)
+class TFElectraForPreTraining(TFElectraPreTrainedModel):
+ def __init__(self, config, **kwargs):
+ super().__init__(config, **kwargs)
+
+ self.electra = TFElectraMainLayer(config, name="electra")
+ self.discriminator_predictions = TFElectraDiscriminatorPredictions(config, name="discriminator_predictions")
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @replace_return_docstrings(output_type=TFElectraForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
+ position_ids: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ training: Optional[bool] = False,
+ ) -> Union[TFElectraForPreTrainingOutput, Tuple[tf.Tensor]]:
+ r"""
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> import tensorflow as tf
+ >>> from transformers import AutoTokenizer, TFElectraForPreTraining
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/electra-small-discriminator")
+ >>> model = TFElectraForPreTraining.from_pretrained("google/electra-small-discriminator")
+ >>> input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
+ >>> outputs = model(input_ids)
+ >>> scores = outputs[0]
+ ```"""
+ discriminator_hidden_states = self.electra(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ discriminator_sequence_output = discriminator_hidden_states[0]
+ logits = self.discriminator_predictions(discriminator_sequence_output)
+
+ if not return_dict:
+ return (logits,) + discriminator_hidden_states[1:]
+
+ return TFElectraForPreTrainingOutput(
+ logits=logits,
+ hidden_states=discriminator_hidden_states.hidden_states,
+ attentions=discriminator_hidden_states.attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "electra", None) is not None:
+ with tf.name_scope(self.electra.name):
+ self.electra.build(None)
+ if getattr(self, "discriminator_predictions", None) is not None:
+ with tf.name_scope(self.discriminator_predictions.name):
+ self.discriminator_predictions.build(None)
+
+
+class TFElectraMaskedLMHead(keras.layers.Layer):
+ def __init__(self, config, input_embeddings, **kwargs):
+ super().__init__(**kwargs)
+
+ self.config = config
+ self.embedding_size = config.embedding_size
+ self.input_embeddings = input_embeddings
+
+ def build(self, input_shape):
+ self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias")
+
+ super().build(input_shape)
+
+ def get_output_embeddings(self):
+ return self.input_embeddings
+
+ def set_output_embeddings(self, value):
+ self.input_embeddings.weight = value
+ self.input_embeddings.vocab_size = shape_list(value)[0]
+
+ def get_bias(self):
+ return {"bias": self.bias}
+
+ def set_bias(self, value):
+ self.bias = value["bias"]
+ self.config.vocab_size = shape_list(value["bias"])[0]
+
+ def call(self, hidden_states):
+ seq_length = shape_list(tensor=hidden_states)[1]
+ hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.embedding_size])
+ hidden_states = tf.matmul(a=hidden_states, b=self.input_embeddings.weight, transpose_b=True)
+ hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.config.vocab_size])
+ hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.bias)
+
+ return hidden_states
+
+
+@add_start_docstrings(
+ """
+ Electra model with a language modeling head on top.
+
+ Even though both the discriminator and generator may be loaded into this model, the generator is the only model of
+ the two to have been trained for the masked language modeling task.
+ """,
+ ELECTRA_START_DOCSTRING,
+)
+class TFElectraForMaskedLM(TFElectraPreTrainedModel, TFMaskedLanguageModelingLoss):
+ def __init__(self, config, **kwargs):
+ super().__init__(config, **kwargs)
+
+ self.config = config
+ self.electra = TFElectraMainLayer(config, name="electra")
+ self.generator_predictions = TFElectraGeneratorPredictions(config, name="generator_predictions")
+
+ if isinstance(config.hidden_act, str):
+ self.activation = get_tf_activation(config.hidden_act)
+ else:
+ self.activation = config.hidden_act
+
+ self.generator_lm_head = TFElectraMaskedLMHead(config, self.electra.embeddings, name="generator_lm_head")
+
+ def get_lm_head(self):
+ return self.generator_lm_head
+
+ def get_prefix_bias_name(self):
+ warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
+ return self.name + "/" + self.generator_lm_head.name
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint="google/electra-small-generator",
+ output_type=TFMaskedLMOutput,
+ config_class=_CONFIG_FOR_DOC,
+ mask="[MASK]",
+ expected_output="'paris'",
+ expected_loss=1.22,
+ )
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
+ position_ids: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ labels: np.ndarray | tf.Tensor | None = None,
+ training: Optional[bool] = False,
+ ) -> Union[TFMaskedLMOutput, Tuple[tf.Tensor]]:
+ r"""
+ labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
+ """
+ generator_hidden_states = self.electra(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ generator_sequence_output = generator_hidden_states[0]
+ prediction_scores = self.generator_predictions(generator_sequence_output, training=training)
+ prediction_scores = self.generator_lm_head(prediction_scores, training=training)
+ loss = None if labels is None else self.hf_compute_loss(labels, prediction_scores)
+
+ if not return_dict:
+ output = (prediction_scores,) + generator_hidden_states[1:]
+
+ return ((loss,) + output) if loss is not None else output
+
+ return TFMaskedLMOutput(
+ loss=loss,
+ logits=prediction_scores,
+ hidden_states=generator_hidden_states.hidden_states,
+ attentions=generator_hidden_states.attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "electra", None) is not None:
+ with tf.name_scope(self.electra.name):
+ self.electra.build(None)
+ if getattr(self, "generator_predictions", None) is not None:
+ with tf.name_scope(self.generator_predictions.name):
+ self.generator_predictions.build(None)
+ if getattr(self, "generator_lm_head", None) is not None:
+ with tf.name_scope(self.generator_lm_head.name):
+ self.generator_lm_head.build(None)
+
+
+class TFElectraClassificationHead(keras.layers.Layer):
+ """Head for sentence-level classification tasks."""
+
+ def __init__(self, config, **kwargs):
+ super().__init__(**kwargs)
+
+ self.dense = keras.layers.Dense(
+ config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
+ )
+ classifier_dropout = (
+ config.classifhidden_dropout_probier_dropout
+ if config.classifier_dropout is not None
+ else config.hidden_dropout_prob
+ )
+ self.dropout = keras.layers.Dropout(classifier_dropout)
+ self.out_proj = keras.layers.Dense(
+ config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="out_proj"
+ )
+ self.config = config
+
+ def call(self, inputs, **kwargs):
+ x = inputs[:, 0, :] # take token (equiv. to [CLS])
+ x = self.dropout(x)
+ x = self.dense(x)
+ x = get_tf_activation("gelu")(x) # although BERT uses tanh here, it seems Electra authors used gelu here
+ x = self.dropout(x)
+ x = self.out_proj(x)
+
+ return x
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "dense", None) is not None:
+ with tf.name_scope(self.dense.name):
+ self.dense.build([None, None, self.config.hidden_size])
+ if getattr(self, "out_proj", None) is not None:
+ with tf.name_scope(self.out_proj.name):
+ self.out_proj.build([None, None, self.config.hidden_size])
+
+
+@add_start_docstrings(
+ """
+ ELECTRA Model transformer with a sequence classification/regression head on top (a linear layer on top of the
+ pooled output) e.g. for GLUE tasks.
+ """,
+ ELECTRA_START_DOCSTRING,
+)
+class TFElectraForSequenceClassification(TFElectraPreTrainedModel, TFSequenceClassificationLoss):
+ def __init__(self, config, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+ self.num_labels = config.num_labels
+ self.electra = TFElectraMainLayer(config, name="electra")
+ self.classifier = TFElectraClassificationHead(config, name="classifier")
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint="bhadresh-savani/electra-base-emotion",
+ output_type=TFSequenceClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output="'joy'",
+ expected_loss=0.06,
+ )
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
+ position_ids: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ labels: np.ndarray | tf.Tensor | None = None,
+ training: Optional[bool] = False,
+ ) -> Union[TFSequenceClassifierOutput, Tuple[tf.Tensor]]:
+ r"""
+ labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ outputs = self.electra(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ logits = self.classifier(outputs[0])
+ loss = None if labels is None else self.hf_compute_loss(labels, logits)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+
+ return ((loss,) + output) if loss is not None else output
+
+ return TFSequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "electra", None) is not None:
+ with tf.name_scope(self.electra.name):
+ self.electra.build(None)
+ if getattr(self, "classifier", None) is not None:
+ with tf.name_scope(self.classifier.name):
+ self.classifier.build(None)
+
+
+@add_start_docstrings(
+ """
+ ELECTRA Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
+ softmax) e.g. for RocStories/SWAG tasks.
+ """,
+ ELECTRA_START_DOCSTRING,
+)
+class TFElectraForMultipleChoice(TFElectraPreTrainedModel, TFMultipleChoiceLoss):
+ def __init__(self, config, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+
+ self.electra = TFElectraMainLayer(config, name="electra")
+ self.sequence_summary = TFSequenceSummary(
+ config, initializer_range=config.initializer_range, name="sequence_summary"
+ )
+ self.classifier = keras.layers.Dense(
+ 1, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
+ )
+ self.config = config
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TFMultipleChoiceModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
+ position_ids: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ labels: np.ndarray | tf.Tensor | None = None,
+ training: Optional[bool] = False,
+ ) -> Union[TFMultipleChoiceModelOutput, Tuple[tf.Tensor]]:
+ r"""
+ labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]`
+ where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above)
+ """
+
+ if input_ids is not None:
+ num_choices = shape_list(input_ids)[1]
+ seq_length = shape_list(input_ids)[2]
+ else:
+ num_choices = shape_list(inputs_embeds)[1]
+ seq_length = shape_list(inputs_embeds)[2]
+
+ flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
+ flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
+ flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
+ flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
+ flat_inputs_embeds = (
+ tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
+ if inputs_embeds is not None
+ else None
+ )
+ outputs = self.electra(
+ input_ids=flat_input_ids,
+ attention_mask=flat_attention_mask,
+ token_type_ids=flat_token_type_ids,
+ position_ids=flat_position_ids,
+ head_mask=head_mask,
+ inputs_embeds=flat_inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ logits = self.sequence_summary(outputs[0])
+ logits = self.classifier(logits)
+ reshaped_logits = tf.reshape(logits, (-1, num_choices))
+ loss = None if labels is None else self.hf_compute_loss(labels, reshaped_logits)
+
+ if not return_dict:
+ output = (reshaped_logits,) + outputs[1:]
+
+ return ((loss,) + output) if loss is not None else output
+
+ return TFMultipleChoiceModelOutput(
+ loss=loss,
+ logits=reshaped_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "electra", None) is not None:
+ with tf.name_scope(self.electra.name):
+ self.electra.build(None)
+ if getattr(self, "sequence_summary", None) is not None:
+ with tf.name_scope(self.sequence_summary.name):
+ self.sequence_summary.build(None)
+ if getattr(self, "classifier", None) is not None:
+ with tf.name_scope(self.classifier.name):
+ self.classifier.build([None, None, self.config.hidden_size])
+
+
+@add_start_docstrings(
+ """
+ Electra model with a token classification head on top.
+
+ Both the discriminator and generator may be loaded into this model.
+ """,
+ ELECTRA_START_DOCSTRING,
+)
+class TFElectraForTokenClassification(TFElectraPreTrainedModel, TFTokenClassificationLoss):
+ def __init__(self, config, **kwargs):
+ super().__init__(config, **kwargs)
+
+ self.electra = TFElectraMainLayer(config, name="electra")
+ classifier_dropout = (
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
+ )
+ self.dropout = keras.layers.Dropout(classifier_dropout)
+ self.classifier = keras.layers.Dense(
+ config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
+ )
+ self.config = config
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint="bhadresh-savani/electra-base-discriminator-finetuned-conll03-english",
+ output_type=TFTokenClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output="['B-LOC', 'B-ORG', 'O', 'O', 'O', 'O', 'O', 'B-LOC', 'O', 'B-LOC', 'I-LOC']",
+ expected_loss=0.11,
+ )
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
+ position_ids: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ labels: np.ndarray | tf.Tensor | None = None,
+ training: Optional[bool] = False,
+ ) -> Union[TFTokenClassifierOutput, Tuple[tf.Tensor]]:
+ r"""
+ labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
+ """
+ discriminator_hidden_states = self.electra(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ discriminator_sequence_output = discriminator_hidden_states[0]
+ discriminator_sequence_output = self.dropout(discriminator_sequence_output)
+ logits = self.classifier(discriminator_sequence_output)
+ loss = None if labels is None else self.hf_compute_loss(labels, logits)
+
+ if not return_dict:
+ output = (logits,) + discriminator_hidden_states[1:]
+
+ return ((loss,) + output) if loss is not None else output
+
+ return TFTokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=discriminator_hidden_states.hidden_states,
+ attentions=discriminator_hidden_states.attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "electra", None) is not None:
+ with tf.name_scope(self.electra.name):
+ self.electra.build(None)
+ if getattr(self, "classifier", None) is not None:
+ with tf.name_scope(self.classifier.name):
+ self.classifier.build([None, None, self.config.hidden_size])
+
+
+@add_start_docstrings(
+ """
+ Electra Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
+ layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
+ """,
+ ELECTRA_START_DOCSTRING,
+)
+class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnsweringLoss):
+ def __init__(self, config, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+
+ self.num_labels = config.num_labels
+ self.electra = TFElectraMainLayer(config, name="electra")
+ self.qa_outputs = keras.layers.Dense(
+ config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
+ )
+ self.config = config
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint="bhadresh-savani/electra-base-squad2",
+ output_type=TFQuestionAnsweringModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ qa_target_start_index=11,
+ qa_target_end_index=12,
+ expected_output="'a nice puppet'",
+ expected_loss=2.64,
+ )
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ token_type_ids: np.ndarray | tf.Tensor | None = None,
+ position_ids: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ start_positions: np.ndarray | tf.Tensor | None = None,
+ end_positions: np.ndarray | tf.Tensor | None = None,
+ training: Optional[bool] = False,
+ ) -> Union[TFQuestionAnsweringModelOutput, Tuple[tf.Tensor]]:
+ r"""
+ start_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
+ are not taken into account for computing the loss.
+ end_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
+ are not taken into account for computing the loss.
+ """
+ discriminator_hidden_states = self.electra(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ discriminator_sequence_output = discriminator_hidden_states[0]
+ logits = self.qa_outputs(discriminator_sequence_output)
+ start_logits, end_logits = tf.split(logits, 2, axis=-1)
+ start_logits = tf.squeeze(start_logits, axis=-1)
+ end_logits = tf.squeeze(end_logits, axis=-1)
+ loss = None
+
+ if start_positions is not None and end_positions is not None:
+ labels = {"start_position": start_positions}
+ labels["end_position"] = end_positions
+ loss = self.hf_compute_loss(labels, (start_logits, end_logits))
+
+ if not return_dict:
+ output = (
+ start_logits,
+ end_logits,
+ ) + discriminator_hidden_states[1:]
+
+ return ((loss,) + output) if loss is not None else output
+
+ return TFQuestionAnsweringModelOutput(
+ loss=loss,
+ start_logits=start_logits,
+ end_logits=end_logits,
+ hidden_states=discriminator_hidden_states.hidden_states,
+ attentions=discriminator_hidden_states.attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "electra", None) is not None:
+ with tf.name_scope(self.electra.name):
+ self.electra.build(None)
+ if getattr(self, "qa_outputs", None) is not None:
+ with tf.name_scope(self.qa_outputs.name):
+ self.qa_outputs.build([None, None, self.config.hidden_size])
+
+
+__all__ = [
+ "TFElectraForMaskedLM",
+ "TFElectraForMultipleChoice",
+ "TFElectraForPreTraining",
+ "TFElectraForQuestionAnswering",
+ "TFElectraForSequenceClassification",
+ "TFElectraForTokenClassification",
+ "TFElectraModel",
+ "TFElectraPreTrainedModel",
+]
diff --git a/docs/transformers/build/lib/transformers/models/electra/tokenization_electra.py b/docs/transformers/build/lib/transformers/models/electra/tokenization_electra.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b21527e6cdae25e1bdc40a81a01f3b2f014ffb5
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/electra/tokenization_electra.py
@@ -0,0 +1,511 @@
+# coding=utf-8
+# Copyright 2020 The Google AI Team, Stanford University and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import collections
+import os
+import unicodedata
+from typing import List, Optional, Tuple
+
+from ...tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
+
+
+# Copied from transformers.models.bert.tokenization_bert.load_vocab
+def load_vocab(vocab_file):
+ """Loads a vocabulary file into a dictionary."""
+ vocab = collections.OrderedDict()
+ with open(vocab_file, "r", encoding="utf-8") as reader:
+ tokens = reader.readlines()
+ for index, token in enumerate(tokens):
+ token = token.rstrip("\n")
+ vocab[token] = index
+ return vocab
+
+
+# Copied from transformers.models.bert.tokenization_bert.whitespace_tokenize
+def whitespace_tokenize(text):
+ """Runs basic whitespace cleaning and splitting on a piece of text."""
+ text = text.strip()
+ if not text:
+ return []
+ tokens = text.split()
+ return tokens
+
+
+# Copied from transformers.models.bert.tokenization_bert.BertTokenizer with Bert->Electra,BERT->Electra
+class ElectraTokenizer(PreTrainedTokenizer):
+ r"""
+ Construct a Electra tokenizer. Based on WordPiece.
+
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
+ this superclass for more information regarding those methods.
+
+ Args:
+ vocab_file (`str`):
+ File containing the vocabulary.
+ do_lower_case (`bool`, *optional*, defaults to `True`):
+ Whether or not to lowercase the input when tokenizing.
+ do_basic_tokenize (`bool`, *optional*, defaults to `True`):
+ Whether or not to do basic tokenization before WordPiece.
+ never_split (`Iterable`, *optional*):
+ Collection of tokens which will never be split during tokenization. Only has an effect when
+ `do_basic_tokenize=True`
+ unk_token (`str`, *optional*, defaults to `"[UNK]"`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ sep_token (`str`, *optional*, defaults to `"[SEP]"`):
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
+ sequence classification or for a text and a question for question answering. It is also used as the last
+ token of a sequence built with special tokens.
+ pad_token (`str`, *optional*, defaults to `"[PAD]"`):
+ The token used for padding, for example when batching sequences of different lengths.
+ cls_token (`str`, *optional*, defaults to `"[CLS]"`):
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
+ mask_token (`str`, *optional*, defaults to `"[MASK]"`):
+ The token used for masking values. This is the token used when training this model with masked language
+ modeling. This is the token which the model will try to predict.
+ tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
+ Whether or not to tokenize Chinese characters.
+
+ This should likely be deactivated for Japanese (see this
+ [issue](https://github.com/huggingface/transformers/issues/328)).
+ strip_accents (`bool`, *optional*):
+ Whether or not to strip all accents. If this option is not specified, then it will be determined by the
+ value for `lowercase` (as in the original Electra).
+ clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`):
+ Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like
+ extra spaces.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+
+ def __init__(
+ self,
+ vocab_file,
+ do_lower_case=True,
+ do_basic_tokenize=True,
+ never_split=None,
+ unk_token="[UNK]",
+ sep_token="[SEP]",
+ pad_token="[PAD]",
+ cls_token="[CLS]",
+ mask_token="[MASK]",
+ tokenize_chinese_chars=True,
+ strip_accents=None,
+ clean_up_tokenization_spaces=True,
+ **kwargs,
+ ):
+ if not os.path.isfile(vocab_file):
+ raise ValueError(
+ f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained"
+ " model use `tokenizer = ElectraTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
+ )
+ self.vocab = load_vocab(vocab_file)
+ self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
+ self.do_basic_tokenize = do_basic_tokenize
+ if do_basic_tokenize:
+ self.basic_tokenizer = BasicTokenizer(
+ do_lower_case=do_lower_case,
+ never_split=never_split,
+ tokenize_chinese_chars=tokenize_chinese_chars,
+ strip_accents=strip_accents,
+ )
+
+ self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=str(unk_token))
+
+ super().__init__(
+ do_lower_case=do_lower_case,
+ do_basic_tokenize=do_basic_tokenize,
+ never_split=never_split,
+ unk_token=unk_token,
+ sep_token=sep_token,
+ pad_token=pad_token,
+ cls_token=cls_token,
+ mask_token=mask_token,
+ tokenize_chinese_chars=tokenize_chinese_chars,
+ strip_accents=strip_accents,
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
+ **kwargs,
+ )
+
+ @property
+ def do_lower_case(self):
+ return self.basic_tokenizer.do_lower_case
+
+ @property
+ def vocab_size(self):
+ return len(self.vocab)
+
+ def get_vocab(self):
+ return dict(self.vocab, **self.added_tokens_encoder)
+
+ def _tokenize(self, text, split_special_tokens=False):
+ split_tokens = []
+ if self.do_basic_tokenize:
+ for token in self.basic_tokenizer.tokenize(
+ text, never_split=self.all_special_tokens if not split_special_tokens else None
+ ):
+ # If the token is part of the never_split set
+ if token in self.basic_tokenizer.never_split:
+ split_tokens.append(token)
+ else:
+ split_tokens += self.wordpiece_tokenizer.tokenize(token)
+ else:
+ split_tokens = self.wordpiece_tokenizer.tokenize(text)
+ return split_tokens
+
+ def _convert_token_to_id(self, token):
+ """Converts a token (str) in an id using the vocab."""
+ return self.vocab.get(token, self.vocab.get(self.unk_token))
+
+ def _convert_id_to_token(self, index):
+ """Converts an index (integer) in a token (str) using the vocab."""
+ return self.ids_to_tokens.get(index, self.unk_token)
+
+ def convert_tokens_to_string(self, tokens):
+ """Converts a sequence of tokens (string) in a single string."""
+ out_string = " ".join(tokens).replace(" ##", "").strip()
+ return out_string
+
+ def build_inputs_with_special_tokens(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
+ adding special tokens. A Electra sequence has the following format:
+
+ - single sequence: `[CLS] X [SEP]`
+ - pair of sequences: `[CLS] A [SEP] B [SEP]`
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs to which the special tokens will be added.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
+ """
+ if token_ids_1 is None:
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
+ cls = [self.cls_token_id]
+ sep = [self.sep_token_id]
+ return cls + token_ids_0 + sep + token_ids_1 + sep
+
+ def get_special_tokens_mask(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
+ ) -> List[int]:
+ """
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
+ special tokens using the tokenizer `prepare_for_model` method.
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
+ Whether or not the token list is already formatted with special tokens for the model.
+
+ Returns:
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
+ """
+
+ if already_has_special_tokens:
+ return super().get_special_tokens_mask(
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
+ )
+
+ if token_ids_1 is not None:
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
+ return [1] + ([0] * len(token_ids_0)) + [1]
+
+ def create_token_type_ids_from_sequences(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. A Electra sequence
+ pair mask has the following format:
+
+ ```
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
+ | first sequence | second sequence |
+ ```
+
+ If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s).
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
+ """
+ sep = [self.sep_token_id]
+ cls = [self.cls_token_id]
+ if token_ids_1 is None:
+ return len(cls + token_ids_0 + sep) * [0]
+ return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
+
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
+ index = 0
+ if os.path.isdir(save_directory):
+ vocab_file = os.path.join(
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
+ )
+ else:
+ vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory
+ with open(vocab_file, "w", encoding="utf-8") as writer:
+ for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
+ if index != token_index:
+ logger.warning(
+ f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
+ " Please check that the vocabulary is not corrupted!"
+ )
+ index = token_index
+ writer.write(token + "\n")
+ index += 1
+ return (vocab_file,)
+
+
+# Copied from transformers.models.bert.tokenization_bert.BasicTokenizer
+class BasicTokenizer:
+ """
+ Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.).
+
+ Args:
+ do_lower_case (`bool`, *optional*, defaults to `True`):
+ Whether or not to lowercase the input when tokenizing.
+ never_split (`Iterable`, *optional*):
+ Collection of tokens which will never be split during tokenization. Only has an effect when
+ `do_basic_tokenize=True`
+ tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
+ Whether or not to tokenize Chinese characters.
+
+ This should likely be deactivated for Japanese (see this
+ [issue](https://github.com/huggingface/transformers/issues/328)).
+ strip_accents (`bool`, *optional*):
+ Whether or not to strip all accents. If this option is not specified, then it will be determined by the
+ value for `lowercase` (as in the original BERT).
+ do_split_on_punc (`bool`, *optional*, defaults to `True`):
+ In some instances we want to skip the basic punctuation splitting so that later tokenization can capture
+ the full context of the words, such as contractions.
+ """
+
+ def __init__(
+ self,
+ do_lower_case=True,
+ never_split=None,
+ tokenize_chinese_chars=True,
+ strip_accents=None,
+ do_split_on_punc=True,
+ ):
+ if never_split is None:
+ never_split = []
+ self.do_lower_case = do_lower_case
+ self.never_split = set(never_split)
+ self.tokenize_chinese_chars = tokenize_chinese_chars
+ self.strip_accents = strip_accents
+ self.do_split_on_punc = do_split_on_punc
+
+ def tokenize(self, text, never_split=None):
+ """
+ Basic Tokenization of a piece of text. For sub-word tokenization, see WordPieceTokenizer.
+
+ Args:
+ never_split (`List[str]`, *optional*)
+ Kept for backward compatibility purposes. Now implemented directly at the base class level (see
+ [`PreTrainedTokenizer.tokenize`]) List of token not to split.
+ """
+ # union() returns a new set by concatenating the two sets.
+ never_split = self.never_split.union(set(never_split)) if never_split else self.never_split
+ text = self._clean_text(text)
+
+ # This was added on November 1st, 2018 for the multilingual and Chinese
+ # models. This is also applied to the English models now, but it doesn't
+ # matter since the English models were not trained on any Chinese data
+ # and generally don't have any Chinese data in them (there are Chinese
+ # characters in the vocabulary because Wikipedia does have some Chinese
+ # words in the English Wikipedia.).
+ if self.tokenize_chinese_chars:
+ text = self._tokenize_chinese_chars(text)
+ # prevents treating the same character with different unicode codepoints as different characters
+ unicode_normalized_text = unicodedata.normalize("NFC", text)
+ orig_tokens = whitespace_tokenize(unicode_normalized_text)
+ split_tokens = []
+ for token in orig_tokens:
+ if token not in never_split:
+ if self.do_lower_case:
+ token = token.lower()
+ if self.strip_accents is not False:
+ token = self._run_strip_accents(token)
+ elif self.strip_accents:
+ token = self._run_strip_accents(token)
+ split_tokens.extend(self._run_split_on_punc(token, never_split))
+
+ output_tokens = whitespace_tokenize(" ".join(split_tokens))
+ return output_tokens
+
+ def _run_strip_accents(self, text):
+ """Strips accents from a piece of text."""
+ text = unicodedata.normalize("NFD", text)
+ output = []
+ for char in text:
+ cat = unicodedata.category(char)
+ if cat == "Mn":
+ continue
+ output.append(char)
+ return "".join(output)
+
+ def _run_split_on_punc(self, text, never_split=None):
+ """Splits punctuation on a piece of text."""
+ if not self.do_split_on_punc or (never_split is not None and text in never_split):
+ return [text]
+ chars = list(text)
+ i = 0
+ start_new_word = True
+ output = []
+ while i < len(chars):
+ char = chars[i]
+ if _is_punctuation(char):
+ output.append([char])
+ start_new_word = True
+ else:
+ if start_new_word:
+ output.append([])
+ start_new_word = False
+ output[-1].append(char)
+ i += 1
+
+ return ["".join(x) for x in output]
+
+ def _tokenize_chinese_chars(self, text):
+ """Adds whitespace around any CJK character."""
+ output = []
+ for char in text:
+ cp = ord(char)
+ if self._is_chinese_char(cp):
+ output.append(" ")
+ output.append(char)
+ output.append(" ")
+ else:
+ output.append(char)
+ return "".join(output)
+
+ def _is_chinese_char(self, cp):
+ """Checks whether CP is the codepoint of a CJK character."""
+ # This defines a "chinese character" as anything in the CJK Unicode block:
+ # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
+ #
+ # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
+ # despite its name. The modern Korean Hangul alphabet is a different block,
+ # as is Japanese Hiragana and Katakana. Those alphabets are used to write
+ # space-separated words, so they are not treated specially and handled
+ # like the all of the other languages.
+ if (
+ (cp >= 0x4E00 and cp <= 0x9FFF)
+ or (cp >= 0x3400 and cp <= 0x4DBF) #
+ or (cp >= 0x20000 and cp <= 0x2A6DF) #
+ or (cp >= 0x2A700 and cp <= 0x2B73F) #
+ or (cp >= 0x2B740 and cp <= 0x2B81F) #
+ or (cp >= 0x2B820 and cp <= 0x2CEAF) #
+ or (cp >= 0xF900 and cp <= 0xFAFF)
+ or (cp >= 0x2F800 and cp <= 0x2FA1F) #
+ ): #
+ return True
+
+ return False
+
+ def _clean_text(self, text):
+ """Performs invalid character removal and whitespace cleanup on text."""
+ output = []
+ for char in text:
+ cp = ord(char)
+ if cp == 0 or cp == 0xFFFD or _is_control(char):
+ continue
+ if _is_whitespace(char):
+ output.append(" ")
+ else:
+ output.append(char)
+ return "".join(output)
+
+
+# Copied from transformers.models.bert.tokenization_bert.WordpieceTokenizer
+class WordpieceTokenizer:
+ """Runs WordPiece tokenization."""
+
+ def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
+ self.vocab = vocab
+ self.unk_token = unk_token
+ self.max_input_chars_per_word = max_input_chars_per_word
+
+ def tokenize(self, text):
+ """
+ Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform
+ tokenization using the given vocabulary.
+
+ For example, `input = "unaffable"` wil return as output `["un", "##aff", "##able"]`.
+
+ Args:
+ text: A single token or whitespace separated tokens. This should have
+ already been passed through *BasicTokenizer*.
+
+ Returns:
+ A list of wordpiece tokens.
+ """
+
+ output_tokens = []
+ for token in whitespace_tokenize(text):
+ chars = list(token)
+ if len(chars) > self.max_input_chars_per_word:
+ output_tokens.append(self.unk_token)
+ continue
+
+ is_bad = False
+ start = 0
+ sub_tokens = []
+ while start < len(chars):
+ end = len(chars)
+ cur_substr = None
+ while start < end:
+ substr = "".join(chars[start:end])
+ if start > 0:
+ substr = "##" + substr
+ if substr in self.vocab:
+ cur_substr = substr
+ break
+ end -= 1
+ if cur_substr is None:
+ is_bad = True
+ break
+ sub_tokens.append(cur_substr)
+ start = end
+
+ if is_bad:
+ output_tokens.append(self.unk_token)
+ else:
+ output_tokens.extend(sub_tokens)
+ return output_tokens
+
+
+__all__ = ["ElectraTokenizer"]
diff --git a/docs/transformers/build/lib/transformers/models/electra/tokenization_electra_fast.py b/docs/transformers/build/lib/transformers/models/electra/tokenization_electra_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..34ea4339b9382b28c6f9a5842f88af8807f9f928
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/electra/tokenization_electra_fast.py
@@ -0,0 +1,172 @@
+# coding=utf-8
+# Copyright 2020 The Google AI Team, Stanford University and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+from typing import List, Optional, Tuple
+
+from tokenizers import normalizers
+
+from ...tokenization_utils_fast import PreTrainedTokenizerFast
+from .tokenization_electra import ElectraTokenizer
+
+
+VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.json"}
+
+
+# Copied from transformers.models.bert.tokenization_bert_fast.BertTokenizerFast with Bert->Electra , BERT->ELECTRA
+class ElectraTokenizerFast(PreTrainedTokenizerFast):
+ r"""
+ Construct a "fast" ELECTRA tokenizer (backed by HuggingFace's *tokenizers* library). Based on WordPiece.
+
+ This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
+ refer to this superclass for more information regarding those methods.
+
+ Args:
+ vocab_file (`str`):
+ File containing the vocabulary.
+ do_lower_case (`bool`, *optional*, defaults to `True`):
+ Whether or not to lowercase the input when tokenizing.
+ unk_token (`str`, *optional*, defaults to `"[UNK]"`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ sep_token (`str`, *optional*, defaults to `"[SEP]"`):
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
+ sequence classification or for a text and a question for question answering. It is also used as the last
+ token of a sequence built with special tokens.
+ pad_token (`str`, *optional*, defaults to `"[PAD]"`):
+ The token used for padding, for example when batching sequences of different lengths.
+ cls_token (`str`, *optional*, defaults to `"[CLS]"`):
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
+ mask_token (`str`, *optional*, defaults to `"[MASK]"`):
+ The token used for masking values. This is the token used when training this model with masked language
+ modeling. This is the token which the model will try to predict.
+ clean_text (`bool`, *optional*, defaults to `True`):
+ Whether or not to clean the text before tokenization by removing any control characters and replacing all
+ whitespaces by the classic one.
+ tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
+ Whether or not to tokenize Chinese characters. This should likely be deactivated for Japanese (see [this
+ issue](https://github.com/huggingface/transformers/issues/328)).
+ strip_accents (`bool`, *optional*):
+ Whether or not to strip all accents. If this option is not specified, then it will be determined by the
+ value for `lowercase` (as in the original ELECTRA).
+ wordpieces_prefix (`str`, *optional*, defaults to `"##"`):
+ The prefix for subwords.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ slow_tokenizer_class = ElectraTokenizer
+
+ def __init__(
+ self,
+ vocab_file=None,
+ tokenizer_file=None,
+ do_lower_case=True,
+ unk_token="[UNK]",
+ sep_token="[SEP]",
+ pad_token="[PAD]",
+ cls_token="[CLS]",
+ mask_token="[MASK]",
+ tokenize_chinese_chars=True,
+ strip_accents=None,
+ **kwargs,
+ ):
+ super().__init__(
+ vocab_file,
+ tokenizer_file=tokenizer_file,
+ do_lower_case=do_lower_case,
+ unk_token=unk_token,
+ sep_token=sep_token,
+ pad_token=pad_token,
+ cls_token=cls_token,
+ mask_token=mask_token,
+ tokenize_chinese_chars=tokenize_chinese_chars,
+ strip_accents=strip_accents,
+ **kwargs,
+ )
+
+ normalizer_state = json.loads(self.backend_tokenizer.normalizer.__getstate__())
+ if (
+ normalizer_state.get("lowercase", do_lower_case) != do_lower_case
+ or normalizer_state.get("strip_accents", strip_accents) != strip_accents
+ or normalizer_state.get("handle_chinese_chars", tokenize_chinese_chars) != tokenize_chinese_chars
+ ):
+ normalizer_class = getattr(normalizers, normalizer_state.pop("type"))
+ normalizer_state["lowercase"] = do_lower_case
+ normalizer_state["strip_accents"] = strip_accents
+ normalizer_state["handle_chinese_chars"] = tokenize_chinese_chars
+ self.backend_tokenizer.normalizer = normalizer_class(**normalizer_state)
+
+ self.do_lower_case = do_lower_case
+
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
+ """
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
+ adding special tokens. A ELECTRA sequence has the following format:
+
+ - single sequence: `[CLS] X [SEP]`
+ - pair of sequences: `[CLS] A [SEP] B [SEP]`
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs to which the special tokens will be added.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
+ """
+ output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
+
+ if token_ids_1 is not None:
+ output += token_ids_1 + [self.sep_token_id]
+
+ return output
+
+ def create_token_type_ids_from_sequences(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. A ELECTRA sequence
+ pair mask has the following format:
+
+ ```
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
+ | first sequence | second sequence |
+ ```
+
+ If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s).
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
+ """
+ sep = [self.sep_token_id]
+ cls = [self.cls_token_id]
+ if token_ids_1 is None:
+ return len(cls + token_ids_0 + sep) * [0]
+ return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
+
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
+ files = self._tokenizer.model.save(save_directory, name=filename_prefix)
+ return tuple(files)
+
+
+__all__ = ["ElectraTokenizerFast"]
diff --git a/docs/transformers/build/lib/transformers/models/emu3/__init__.py b/docs/transformers/build/lib/transformers/models/emu3/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d8555f58d1866451c38abb5559ef5bef9545f0b0
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/emu3/__init__.py
@@ -0,0 +1,29 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_emu3 import *
+ from .image_processing_emu3 import *
+ from .modeling_emu3 import *
+ from .processing_emu3 import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/docs/transformers/build/lib/transformers/models/emu3/configuration_emu3.py b/docs/transformers/build/lib/transformers/models/emu3/configuration_emu3.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b5abedf4016d5959c8eeea9a3d955470c8b1f13
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/emu3/configuration_emu3.py
@@ -0,0 +1,327 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Dict, List, Optional, Union
+
+from ...configuration_utils import PretrainedConfig
+from ...modeling_rope_utils import rope_config_validation
+
+
+class Emu3VQVAEConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Emu3VQVAE`]. It is used to instantiate an VQ-VAE
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a configuration to the VQ model presented in Emu3 paper.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+ Args:
+ codebook_size (`int`, *optional*, defaults to 32768):
+ Codebook size of the VQ model.
+ embed_dim (`int`, *optional*, defaults to 4):
+ Dimension of the quantized vector in codebook.
+ latent_channels (`int`, *optional*, defaults to 4):
+ Dimension of the output channel of encoder and the input channel of decoder
+ double_latent (`bool`, *optional*, defaults to `False`):
+ Whether double the output dim of the encoder.
+ in_channels (`int`, *optional*, defaults to 3):
+ Input channel of encoder.
+ out_channels (`int`, *optional*, defaults to 3):
+ Output channel of decoder.
+ temporal_downsample_factor (`int`, *optional*, defaults to 4):
+ Temporal downsample factor.
+ base_channels (`int`, *optional*, defaults to 256):
+ Basic channel number of the intermediate blocks.
+ channel_multiplier (`List[int]`, *optional*, defaults to `[1, 2, 2, 4]`):
+ Channel scaling factor of the intermediate blocks.
+ num_res_blocks (`int`, *optional*, defaults to 2):
+ Residual block number in each stage.
+ attn_resolutions (`List[int]`, *optional*, defaults to `[3]`):
+ Stage indices to apply attention.
+ hidden_size (`int`, *optional*, defaults to 1024):
+ Dimension of the hidden representations in the attention layer.
+ num_attention_heads (`int`, *optional*, defaults to 1):
+ Number of attention heads for each attention layer.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+
+ ```python
+ >>> from transformers import Emu3VQVAE, Emu3VQVAEConfig
+
+ >>> # Initializing a video VQ model of Emu3 configuration
+ >>> configuration = Emu3VQVAEConfig()
+
+ >>> # Initializing a model from the Emu3 VQ model style configuration
+ >>> model = Emu3VQVAE(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "emu3_vqgan"
+ base_config_key = "vq_config"
+
+ def __init__(
+ self,
+ codebook_size: int = 32768,
+ embed_dim: int = 4,
+ latent_channels: int = 4,
+ double_latent: bool = False,
+ in_channels: int = 3,
+ out_channels: int = 3,
+ temporal_downsample_factor: int = 4,
+ base_channels: int = 256,
+ channel_multiplier: List[int] = [1, 2, 2, 4],
+ num_res_blocks: int = 2,
+ attn_resolutions: List[int] = [3],
+ hidden_size: int = 1024,
+ num_attention_heads: int = 1,
+ attention_dropout: float = 0.0,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.codebook_size = codebook_size
+ self.embed_dim = embed_dim
+ self.latent_channels = latent_channels
+ self.double_latent = double_latent
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.temporal_downsample_factor = temporal_downsample_factor
+ self.base_channels = base_channels
+ self.channel_multiplier = channel_multiplier
+ self.num_res_blocks = num_res_blocks
+ self.attn_resolutions = attn_resolutions
+ self.hidden_size = hidden_size
+ self.num_attention_heads = num_attention_heads
+ self.attention_dropout = attention_dropout
+
+
+class Emu3TextConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Emu3TextModel`]. It is used to instantiate a
+ emu3 model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the
+ [Emu3-community/Emu3-Chat-hf](https://huggingface.co/Emu3-community/Emu3-Chat-hf).
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 184622):
+ Vocabulary size of the Emu3 model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`Emu3Model`]
+ hidden_size (`int`, *optional*, defaults to 4096):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 14336):
+ Dimension of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 32):
+ Number of hidden layers in the Transformer decoder.
+ num_attention_heads (`int`, *optional*, defaults to 32):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ num_key_value_heads (`int`, *optional*, defaults to 8):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details checkout [this
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
+ `num_attention_heads`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 9216):
+ The maximum sequence length that this model might ever be used with. Emu supports up to 9216 tokens,
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ pad_token_id (`int`, *optional*, defaults to 151643):
+ Padding token id.
+ bos_token_id (`int`, *optional*, defaults to 151849):
+ Beginning of stream token id.
+ eos_token_id (`int`, *optional*, defaults to 151850):
+ End of stream token id.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether to tie weight embeddings
+ rope_theta (`float`, *optional*, defaults to 1000000.0):
+ The base period of the RoPE embeddings.
+ rope_scaling (`Dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
+ accordingly.
+ Expected contents:
+ `rope_type` (`str`):
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
+ 'llama3'], with 'default' being the original RoPE implementation.
+ `factor` (`float`, *optional*):
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
+ original maximum pre-trained length.
+ `original_max_position_embeddings` (`int`, *optional*):
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
+ pretraining.
+ `attention_factor` (`float`, *optional*):
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
+ `factor` field to infer the suggested value.
+ `beta_fast` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
+ ramp function. If unspecified, it defaults to 32.
+ `beta_slow` (`float`, *optional*):
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
+ ramp function. If unspecified, it defaults to 1.
+ `short_factor` (`List[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `long_factor` (`List[float]`, *optional*):
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
+ size divided by the number of attention heads divided by 2
+ `low_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
+ `high_freq_factor` (`float`, *optional*):
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
+ mlp_bias (`bool`, *optional*, defaults to `False`):
+ Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
+ attention_bias (`bool`, *optional*, defaults to `False`):
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
+ attention_dropout (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for the attention probabilities.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+
+
+ ```python
+ >>> from transformers import Emu3Model, Emu3Config
+
+ >>> # Initializing a Emu3-community/Emu3-Chat-hf style configuration
+ >>> configuration = Emu3Config()
+
+ >>> # Initializing a model from the Emu3-community/Emu3-Chat-hf style configuration
+ >>> model = Emu3Model(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "emu3_text_model"
+ base_config_key = "text_config"
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ vocab_size: int = 184622,
+ hidden_size: int = 4096,
+ intermediate_size: int = 14336,
+ num_hidden_layers: int = 32,
+ num_attention_heads: int = 32,
+ num_key_value_heads: Optional[int] = 8,
+ hidden_act: str = "silu",
+ max_position_embeddings: int = 9216,
+ rms_norm_eps: float = 1e-5,
+ use_cache: bool = True,
+ pad_token_id: int = 151643,
+ bos_token_id: int = 151849,
+ eos_token_id: int = 151850,
+ tie_word_embeddings: bool = False,
+ rope_theta: float = 1000000.0,
+ rope_scaling: Optional = None,
+ mlp_bias=False,
+ attention_bias=False,
+ attention_dropout: float = 0.1,
+ initializer_range: float = 0.02,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ self.mlp_bias = mlp_bias
+ self.attention_bias = attention_bias
+ self.initializer_range = initializer_range
+ rope_config_validation(self)
+
+ self.attention_dropout = attention_dropout
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+
+
+class Emu3Config(PretrainedConfig):
+ """
+ This is the configuration class to store the configuration of a [`Emu3Model`]. It is used to instantiate a
+ emu3 model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the
+ [Emu3-community/Emu3-Chat-hf](https://huggingface.co/Emu3-community/Emu3-Chat-hf).
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vq_config (`Union[Dict, Emu3VQVAEConfig]`, *optional*):
+ Emu3VQVAEConfig instance containing the configuration for the VQ-VAE model.
+ text_config (`Union[Dict, Emu3TextConfig]``, *optional*):
+ Emu3TextConfig instance containing the configuration for the language model.
+ vocabulary_map (`dict`, *optional*):
+ A dictionary containing the vocabulary map from the tokenizer. Used to obtain tokens from the image inputs.
+ """
+
+ model_type = "emu3"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ sub_configs = {"text_config": Emu3TextConfig, "vq_config": Emu3VQVAEConfig}
+
+ def __init__(
+ self,
+ vq_config: Union[Dict, Emu3VQVAEConfig] = None,
+ text_config: Union[Dict, Emu3TextConfig] = None,
+ vocabulary_map: Dict[int, int] = None,
+ **kwargs,
+ ):
+ if vq_config is None:
+ vq_config = Emu3VQVAEConfig()
+ elif isinstance(vq_config, dict):
+ vq_config = Emu3VQVAEConfig(**vq_config)
+
+ if text_config is None:
+ text_config = Emu3TextConfig()
+ elif isinstance(text_config, dict):
+ text_config = Emu3TextConfig(**text_config)
+
+ self.vq_config = vq_config
+ self.text_config = text_config
+ self.vocabulary_map = vocabulary_map
+
+ super().__init__(**kwargs)
+
+
+__all__ = ["Emu3Config", "Emu3TextConfig", "Emu3VQVAEConfig"]
diff --git a/docs/transformers/build/lib/transformers/models/emu3/convert_emu3_weights_to_hf.py b/docs/transformers/build/lib/transformers/models/emu3/convert_emu3_weights_to_hf.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ac8db7e429031ee5157532bb8f4fb044844d281
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/emu3/convert_emu3_weights_to_hf.py
@@ -0,0 +1,448 @@
+# Copyright 2024 The Emu team, BAAI and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import argparse
+import json
+import os
+import re
+from typing import Dict, Optional
+
+import requests
+import torch
+from accelerate import init_empty_weights
+from PIL import Image
+
+from transformers import (
+ AutoModel,
+ AutoModelForCausalLM,
+ AutoTokenizer,
+ Emu3Config,
+ Emu3ForConditionalGeneration,
+ Emu3ImageProcessor,
+ Emu3Processor,
+ Emu3TextConfig,
+ GenerationConfig,
+)
+from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode
+
+
+"""
+Sample usage:
+
+```
+python src/transformers/models/emu3/convert_emu3_weights_to_hf.py \
+ --vq_model_id BAAI/Emu3-VisionTokenizer --llm_model_id BAAI/Emu3-Chat --output_dir /output/path
+```
+
+Thereafter, models can be loaded via:
+
+```py
+from transformers import Emu3ForConditionalGeneration, Emu3Processor
+
+model = Emu3ForConditionalGeneration.from_pretrained("/output/path")
+processor = Emu3Processor.from_pretrained("/output/path")
+```
+
+"""
+
+
+byte_encoder = bytes_to_unicode()
+CHAT_TEMPLATE = "{% for message in messages %}{% if message['role'] != 'system' %}{{ message['role'].upper() + ': '}}{% endif %}{# Render all images first #}{% for content in message['content'] | selectattr('type', 'equalto', 'image') %}{{ '' }}{% endfor %}{# Render all text next #}{% if message['role'] != 'assistant' %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{{ content['text'] + ' '}}{% endfor %}{% else %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{% generation %}{{ content['text'] + ' '}}{% endgeneration %}{% endfor %}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'ASSISTANT:' }}{% endif %}"
+
+
+# Tiktoken to HF conversion, thanks for Xenova
+def token_bytes_to_string(b):
+ return "".join([byte_encoder[ord(char)] for char in b.decode("latin-1")])
+
+
+# Adapted from https://github.com/openai/tiktoken/issues/60#issuecomment-1499977960
+def bpe(mergeable_ranks: Dict[bytes, int], token: bytes, max_rank: Optional[int] = None):
+ parts = [bytes([b]) for b in token]
+ while True:
+ min_idx = None
+ min_rank = None
+ for i, pair in enumerate(zip(parts[:-1], parts[1:])):
+ rank = mergeable_ranks.get(pair[0] + pair[1])
+ if rank is not None and (min_rank is None or rank < min_rank):
+ min_idx = i
+ min_rank = rank
+ if min_rank is None or (max_rank is not None and min_rank >= max_rank):
+ break
+ assert min_idx is not None
+ parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2 :]
+ return parts
+
+
+def generate_vocab_and_merges(encoder):
+ mergeable_ranks = encoder._mergeable_ranks
+
+ merges = []
+ vocab = {}
+ for token, rank in mergeable_ranks.items():
+ vocab[token_bytes_to_string(token)] = rank
+
+ if len(token) == 1:
+ continue
+ merged = tuple(bpe(mergeable_ranks, token, max_rank=rank))
+ assert len(merged) == 2
+ merges.append(" ".join(map(token_bytes_to_string, merged)))
+
+ # Also add special tokens
+ vocab.update(encoder._special_tokens)
+ return vocab, merges
+
+
+def convert_tiktoken(tokenizer, output_dir):
+ encoder = tokenizer.tokenizer
+ vocab, merges = generate_vocab_and_merges(encoder)
+ added_tokens = [
+ {
+ "id": id,
+ "content": content,
+ "single_word": False,
+ "lstrip": False,
+ "rstrip": False,
+ "normalized": False,
+ "special": True,
+ }
+ for content, id in encoder._special_tokens.items()
+ if content != "<|extra_0|>"
+ ]
+
+ # https://huggingface.co/Xenova/gpt2/raw/main/tokenizer_config.json
+ tokenizer_config_template = {
+ "add_prefix_space": False,
+ "bos_token": "<|extra_203|>",
+ "clean_up_tokenization_spaces": False,
+ "eos_token": "<|extra_204|>",
+ "pad_token": "<|endoftext|>",
+ }
+ tokenizer_config_template.update({"tokenizer_class": "GPT2Tokenizer"})
+ tokenizer_config_template = dict(sorted(tokenizer_config_template.items(), key=lambda x: x[0]))
+
+ # add placeholder image token by taking one of the reserved tokens
+ reserved_token_id = vocab["<|extra_0|>"]
+ vocab[""] = reserved_token_id
+ del vocab["<|extra_0|>"]
+ added_tokens.append(
+ {
+ "id": reserved_token_id,
+ "content": "",
+ "single_word": False,
+ "lstrip": False,
+ "rstrip": False,
+ "normalized": False,
+ "special": True,
+ }
+ )
+
+ os.makedirs(output_dir, exist_ok=True)
+
+ pre_tokenizer = {
+ "type": "ByteLevel",
+ "add_prefix_space": False,
+ "trim_offsets": True,
+ "use_regex": True,
+ }
+
+ # https://huggingface.co/Xenova/gpt2/raw/main/tokenizer.json
+ tokenizer_template = {
+ "version": "1.0",
+ "truncation": None,
+ "padding": None,
+ "added_tokens": added_tokens,
+ "normalizer": None,
+ "pre_tokenizer": pre_tokenizer,
+ "post_processor": None,
+ "decoder": {
+ "type": "ByteLevel",
+ "add_prefix_space": True,
+ "trim_offsets": True,
+ "use_regex": True,
+ },
+ "model": {
+ "type": "BPE",
+ "dropout": None,
+ "unk_token": None,
+ "continuing_subword_prefix": "",
+ "end_of_word_suffix": "",
+ "fuse_unk": False,
+ "byte_fallback": False,
+ "vocab": vocab,
+ "merges": merges,
+ },
+ }
+
+ # Save to files
+ with open(os.path.join(output_dir, "vocab.json"), "w", encoding="utf-8") as fp:
+ json.dump(vocab, fp, indent=2, ensure_ascii=False)
+
+ with open(os.path.join(output_dir, "tokenizer.json"), "w", encoding="utf-8") as fp:
+ json.dump(tokenizer_template, fp, indent=2, ensure_ascii=False)
+
+ with open(os.path.join(output_dir, "tokenizer_config.json"), "w", encoding="utf-8") as fp:
+ json.dump(tokenizer_config_template, fp, indent=2, ensure_ascii=False)
+
+ with open(os.path.join(output_dir, "special_tokens_map.json"), "w", encoding="utf-8") as fp:
+ json.dump(
+ {
+ "bos_token": "<|extra_203|>",
+ "eos_token": "<|extra_204|>",
+ "pad_token": "<|endoftext|>",
+ },
+ fp,
+ indent=2,
+ ensure_ascii=False,
+ )
+
+ with open(os.path.join(output_dir, "merges.txt"), "w", encoding="utf-8") as fp:
+ fp.write("#version: 0.2\n")
+ fp.write("\n".join(merges))
+
+
+KEYS_TO_MODIFY_MAPPING = {
+ "^encoder": "model.vqmodel.encoder",
+ "^decoder": "model.vqmodel.decoder",
+ "^post_quant_conv": "model.vqmodel.post_quant_conv",
+ "^quant_conv": "model.vqmodel.quant_conv",
+ "^quantize": "model.vqmodel.quantize",
+ "^model": "text_model.model",
+ r"lm_head\.weight": "text_model.lm_head.weight",
+ r"^text_model\.model\.vqmodel": "vqmodel",
+ # rename QKV proj for the VQ-VAE model because we use SiglipAttention
+ r"\.q\.": ".q_proj.",
+ r"\.k\.": ".k_proj.",
+ r"\.v\.": ".v_proj.",
+ r"\.proj_out\.": ".out_proj.",
+ # move the attention norms outside of attention modules
+ r"mid\.attn_1\.norm\.": "mid.attn_norm.",
+ r"attn\.0\.norm\.": "attn_norms.0.",
+ r"attn\.1\.norm\.": "attn_norms.1.",
+ r"attn\.2\.norm\.": "attn_norms.2.",
+ r"attn\.3\.norm\.": "attn_norms.3.",
+ # isolate down/mid/up into separate classes for readability
+ r"\.down\.": ".down_block.down.",
+ r"\.up\.": ".up_block.up.",
+ r"\.mid\.": ".middle_block.",
+}
+
+
+def convert_state_dict_to_hf(old_state_dict, new_state_dict):
+ for key, value in old_state_dict.items():
+ # convert conv layers in attn to linear
+ if (
+ any(key.endswith(name) for name in ["q.weight", "k.weight", "v.weight", "proj_out.weight"])
+ and value.ndim == 4
+ ):
+ value = value.squeeze()
+
+ for old_pattern, new_pattern in KEYS_TO_MODIFY_MAPPING.items():
+ key = re.sub(old_pattern, new_pattern, key)
+
+ new_state_dict[key] = value
+ return new_state_dict
+
+
+def convert_model(vq_model_id, llm_model_id, output_dir, hub_model_id=None, test_inference=False):
+ os.makedirs(output_dir, exist_ok=True)
+
+ # Convert and save processor
+ tokenizer_tiktoken = AutoTokenizer.from_pretrained(llm_model_id, trust_remote_code=True)
+ convert_tiktoken(tokenizer_tiktoken, output_dir)
+ extra_special_tokens = extra_special_tokens = {
+ "image_token": "",
+ "boi_token": "<|image start|>",
+ "eoi_token": "<|image end|>",
+ "image_wrapper_token": "<|image token|>",
+ "eof_token": "<|extra_201|>",
+ }
+ tokenizer_converted = AutoTokenizer.from_pretrained(output_dir, extra_special_tokens=extra_special_tokens)
+ tokenizer_converted.padding_side = "left"
+
+ image_processor = Emu3ImageProcessor.from_pretrained(vq_model_id)
+ processor = Emu3Processor(image_processor, tokenizer_converted, chat_template=CHAT_TEMPLATE)
+ processor.save_pretrained(output_dir)
+
+ # load models
+ model_llm = AutoModelForCausalLM.from_pretrained(
+ llm_model_id,
+ trust_remote_code=True,
+ )
+ model_vqgan = AutoModel.from_pretrained(vq_model_id, trust_remote_code=True)
+ with open(f"{output_dir}/tokenizer.json", "r") as file:
+ tokenizer_config = json.load(file)
+ vocabulary_map = tokenizer_config["model"]["vocab"]
+
+ text_config = Emu3TextConfig(
+ max_position_embeddings=model_llm.config.max_position_embeddings,
+ rope_scaling={"rope_type": "default"},
+ )
+ config = Emu3Config(text_config=text_config, vocabulary_map=vocabulary_map)
+
+ with init_empty_weights():
+ model = Emu3ForConditionalGeneration(config=config)
+ model.generation_config = GenerationConfig(
+ do_sample=True,
+ top_k=2048,
+ max_new_tokens=50_000,
+ pad_token_id=processor.tokenizer.pad_token_id,
+ eos_token_id=processor.tokenizer.eos_token_id,
+ )
+
+ state_dict = {}
+ state_dict = convert_state_dict_to_hf(model_llm.state_dict(), state_dict)
+ state_dict = convert_state_dict_to_hf(model_vqgan.state_dict(), state_dict)
+
+ model.load_state_dict(state_dict, assign=True, strict=True)
+ model.save_pretrained(output_dir, safe_serialization=True)
+
+ if hub_model_id is not None:
+ model.push_to_hub(hub_model_id)
+ processor.push_to_hub(hub_model_id)
+
+ if test_inference and llm_model_id.endswith("Chat"):
+ # Short inference on a few examples to check if generation makes sense
+ print("Loading the checkpoint in a Emu3 model...")
+ print("*" * 100)
+ model = Emu3ForConditionalGeneration.from_pretrained(output_dir, torch_dtype=torch.bfloat16, device_map="auto")
+ processor = Emu3Processor.from_pretrained(output_dir)
+
+ conversation = [
+ {
+ "role": "system",
+ "content": [
+ {"type": "text", "text": "You are a helpful assistant."},
+ ],
+ },
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": "Please tell me about this art work and its artist."},
+ {"type": "image"},
+ ],
+ },
+ ]
+ prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
+
+ image = Image.open(
+ requests.get(
+ "https://uploads4.wikiart.org/images/paul-klee/death-for-the-idea-1915.jpg!Large.jpg", stream=True
+ ).raw
+ )
+ inputs = processor(images=image, text=prompt, return_tensors="pt").to(model.device, torch.bfloat16)
+ length = inputs.input_ids.shape[1]
+
+ out = model.generate(**inputs, max_new_tokens=40, do_sample=False)
+ generated_text = processor.batch_decode(out[:, length:], skip_special_tokens=True)[0]
+
+ print(f"Generation for single-image: {generated_text}")
+ print("*" * 100)
+ elif test_inference and llm_model_id.endswith("Gen"):
+ processor = Emu3Processor.from_pretrained(output_dir)
+ model = Emu3ForConditionalGeneration.from_pretrained(output_dir, torch_dtype=torch.bfloat16, device_map="auto")
+
+ inputs = processor(
+ text=[
+ "a portrait of young girl. masterpiece, film grained, best quality.",
+ "a dog running under the rain",
+ ],
+ padding=True,
+ return_tensors="pt",
+ return_for_image_generation=True,
+ )
+ inputs = inputs.to(device="cuda:0", dtype=torch.bfloat16)
+
+ neg_prompt = "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry."
+ neg_inputs = processor(text=[neg_prompt] * 2, return_tensors="pt").to(device="cuda:0")
+
+ image_sizes = inputs.pop("image_sizes")
+ HEIGHT, WIDTH = image_sizes[0]
+ VISUAL_TOKENS = model.vocabulary_mapping.image_tokens
+
+ def prefix_allowed_tokens_fn(batch_id, input_ids):
+ height, width = HEIGHT, WIDTH
+ visual_tokens = VISUAL_TOKENS
+ image_token_id = processor.tokenizer.encode("<|image token|>", return_tensors="pt")[0].to(model.device)
+ eoi_token_id = processor.tokenizer.encode("<|image end|>", return_tensors="pt")[0]
+ eos_token_id = processor.tokenizer.encode("<|extra_204|>", return_tensors="pt")[0]
+ pad_token_id = processor.tokenizer.encode("<|endoftext|>", return_tensors="pt")[0]
+ eol_token_id = processor.tokenizer.encode("<|extra_200|>", return_tensors="pt")[0]
+ eof_token_id = processor.tokenizer.encode("<|extra_201|>", return_tensors="pt")[0]
+
+ position = torch.nonzero(input_ids == image_token_id, as_tuple=True)[0][0]
+ offset = input_ids.shape[0] - position
+ if offset % (width + 1) == 0:
+ return (eol_token_id,)
+ elif offset == (width + 1) * height + 1:
+ return (eof_token_id,)
+ elif offset == (width + 1) * height + 2:
+ return (eoi_token_id,)
+ elif offset == (width + 1) * height + 3:
+ return (eos_token_id,)
+ elif offset > (width + 1) * height + 3:
+ return (pad_token_id,)
+ else:
+ return visual_tokens
+
+ out = model.generate(
+ **inputs,
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
+ negative_prompt_ids=neg_inputs.input_ids,
+ negative_prompt_attention_mask=neg_inputs.attention_mask,
+ )
+
+ image = model.decode_image_tokens(out[:, inputs.input_ids.shape[1] :], height=HEIGHT, width=WIDTH)
+ images = processor.postprocess(
+ list(image.float()), return_tensors="PIL.Image.Image"
+ ) # internally we convert to np but it's not supported in bf16 precision
+ for i, image in enumerate(images["pixel_values"]):
+ image.save(f"result_{i}.png")
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--vq_model_id",
+ help="Model ID of Emu3 VQ-VAE on the hub",
+ default="BAAI/Emu3-VisionTokenizer",
+ )
+ parser.add_argument(
+ "--llm_model_id",
+ help="Model ID of Emu3 bacbone LLM on the hub",
+ default="BAAI/Emu3-Chat",
+ )
+ parser.add_argument(
+ "--output_dir",
+ help="Location to write HF model",
+ )
+ parser.add_argument(
+ "--hub_model_id",
+ help="Model ID in the hub where to push the model.",
+ )
+ parser.add_argument(
+ "--test_inference",
+ action="store_true",
+ help="Whether to load the model for generation to test it's converted correctly.",
+ )
+ args = parser.parse_args()
+ convert_model(
+ vq_model_id=args.vq_model_id,
+ llm_model_id=args.llm_model_id,
+ output_dir=args.output_dir,
+ hub_model_id=args.hub_model_id,
+ test_inference=args.test_inference,
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/docs/transformers/build/lib/transformers/models/emu3/image_processing_emu3.py b/docs/transformers/build/lib/transformers/models/emu3/image_processing_emu3.py
new file mode 100644
index 0000000000000000000000000000000000000000..a63269c99ef12e0303fe2c7a0bcf93f4918eb459
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/emu3/image_processing_emu3.py
@@ -0,0 +1,552 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import Dict, Iterable, List, Optional, Union
+
+import numpy as np
+
+from ...image_processing_utils import BaseImageProcessor, BatchFeature
+from ...image_transforms import convert_to_rgb, pad, resize, to_channel_dimension_format
+from ...image_utils import (
+ OPENAI_CLIP_MEAN,
+ OPENAI_CLIP_STD,
+ ChannelDimension,
+ ImageInput,
+ PILImageResampling,
+ VideoInput,
+ get_image_size,
+ infer_channel_dimension_format,
+ is_scaled_image,
+ is_valid_image,
+ make_list_of_images,
+ to_numpy_array,
+ valid_images,
+ validate_preprocess_arguments,
+)
+from ...utils import TensorType, is_vision_available, logging
+
+
+if is_vision_available():
+ from PIL import Image
+
+logger = logging.get_logger(__name__)
+
+
+def make_batched_images(images) -> List[List[ImageInput]]:
+ """
+ Accepts images in list or nested list format, and makes a list of images for preprocessing.
+
+ Args:
+ images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`):
+ The input image.
+
+ Returns:
+ list: A list of images.
+ """
+ if isinstance(images, (list, tuple)) and isinstance(images[0], (list, tuple)) and is_valid_image(images[0][0]):
+ return [img for img_list in images for img in img_list]
+
+ elif isinstance(images, (list, tuple)) and is_valid_image(images[0]):
+ return images
+
+ elif is_valid_image(images):
+ return [images]
+
+ raise ValueError(f"Could not make batched images from {images}")
+
+
+def smart_resize(
+ height: int, width: int, factor: int = 28, min_pixels: int = 56 * 56, max_pixels: int = 14 * 14 * 4 * 1280
+):
+ """Rescales the image so that the following conditions are met:
+
+ 1. Both dimensions (height and width) are divisible by 'factor'.
+
+ 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
+
+ 3. The aspect ratio of the image is maintained as closely as possible.
+
+ """
+ if height < factor or width < factor:
+ raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}")
+ elif max(height, width) / min(height, width) > 200:
+ raise ValueError(
+ f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
+ )
+ h_bar = round(height / factor) * factor
+ w_bar = round(width / factor) * factor
+ if h_bar * w_bar > max_pixels:
+ beta = math.sqrt((height * width) / max_pixels)
+ h_bar = math.floor(height / beta / factor) * factor
+ w_bar = math.floor(width / beta / factor) * factor
+ elif h_bar * w_bar < min_pixels:
+ beta = math.sqrt(min_pixels / (height * width))
+ h_bar = math.ceil(height * beta / factor) * factor
+ w_bar = math.ceil(width * beta / factor) * factor
+ return h_bar, w_bar
+
+
+class Emu3ImageProcessor(BaseImageProcessor):
+ r"""
+ Constructs a Emu3 image processor that dynamically resizes images based on the original images.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to resize the image's (height, width) dimensions.
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
+ Resampling filter to use when resizing the image.
+ do_rescale (`bool`, *optional*, defaults to `True`):
+ Whether to rescale the image by the specified scale `rescale_factor`.
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
+ Scale factor to use if rescaling the image.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the image.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`):
+ Mean to use if normalizing the image. This is a float or list of floats for each channel in the image.
+ image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`):
+ Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image.
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
+ Whether to convert the image to RGB.
+ do_pad (`bool`, *optional*, defaults to `True`):
+ Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
+ number of patches in the batch. Padding will be applied to the bottom and right with zeros.
+ min_pixels (`int`, *optional*, defaults to `512 * 512`):
+ The min pixels of the image to resize the image.
+ max_pixels (`int`, *optional*, defaults to `1024 * 1024`):
+ The max pixels of the image to resize the image.
+ spatial_factor (`int`, *optional*, defaults to 8):
+ The spatial downsample factor the image will be downsampled in feature extracting phase
+ """
+
+ model_input_names = ["pixel_values", "image_sizes"]
+
+ def __init__(
+ self,
+ do_resize: bool = True,
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
+ do_rescale: bool = True,
+ rescale_factor: Union[int, float] = 1 / 255,
+ do_normalize: bool = True,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ do_convert_rgb: bool = True,
+ do_pad: bool = True,
+ min_pixels: int = 512 * 512,
+ max_pixels: int = 1024 * 1024,
+ spatial_factor: int = 8,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.do_resize = do_resize
+ self.resample = resample
+ self.do_rescale = do_rescale
+ self.rescale_factor = rescale_factor
+ self.do_normalize = do_normalize
+ self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
+ self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
+ self.min_pixels = min_pixels
+ self.max_pixels = max_pixels
+ self.spatial_factor = spatial_factor
+ self.size = {"min_pixels": min_pixels, "max_pixels": max_pixels}
+ self.do_convert_rgb = do_convert_rgb
+
+ def _preprocess(
+ self,
+ images: Union[ImageInput, VideoInput],
+ do_resize: Optional[bool] = None,
+ resample: PILImageResampling = None,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ do_convert_rgb: Optional[bool] = None,
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ):
+ """
+ Preprocess an image or batch of images.
+
+ Args:
+ images (`ImageInput`):
+ Image or batch of images to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`.
+ vision_info (`List[Dict]`, *optional*):
+ Optional list of dictionaries containing additional information about vision inputs.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
+ Resampling filter to use if resizing the image. This can be one of the `PILImageResampling` enums.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image.
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Scale factor to use if rescaling the image.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
+ Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
+ Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
+ Whether to convert the image to RGB.
+ data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - Unset: Use the channel dimension format of the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ """
+ images = make_list_of_images(images)
+
+ if do_convert_rgb:
+ images = [convert_to_rgb(image) for image in images]
+
+ # All transformations expect numpy arrays.
+ images = [to_numpy_array(image) for image in images]
+
+ if is_scaled_image(images[0]) and do_rescale:
+ logger.warning_once(
+ "It looks like you are trying to rescale already rescaled images. If the input"
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
+ )
+ if input_data_format is None:
+ # We assume that all images have the same channel dimension format.
+ input_data_format = infer_channel_dimension_format(images[0])
+
+ height, width = get_image_size(images[0], channel_dim=input_data_format)
+ resized_height, resized_width = height, width
+ processed_images = []
+ for image in images:
+ if do_resize:
+ resized_height, resized_width = smart_resize(
+ height,
+ width,
+ factor=self.spatial_factor,
+ min_pixels=self.min_pixels,
+ max_pixels=self.max_pixels,
+ )
+ image = resize(
+ image, size=(resized_height, resized_width), resample=resample, input_data_format=input_data_format
+ )
+
+ if do_rescale:
+ image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format)
+
+ if do_normalize:
+ image = self.normalize(
+ image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
+ )
+
+ image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
+ processed_images.append(image)
+
+ images = np.array(processed_images)
+ return images
+
+ def _pad_for_batching(
+ self,
+ pixel_values: List[np.ndarray],
+ image_sizes: List[List[int]],
+ data_format: Optional[Union[str, ChannelDimension]] = None,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ):
+ """
+ Pads images on the `num_of_patches` dimension with zeros to form a batch of same number of patches.
+
+ Args:
+ pixel_values (`List[np.ndarray]`):
+ An array of pixel values of each images of shape (`batch_size`, `num_patches`, `image_in_3D`)
+ image_sizes (`List[List[int]]`):
+ A list of sizes for each image in `pixel_values` in (height, width) format.
+ data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ If unset, will use same as the input image.
+ input_data_format (`str` or `ChannelDimension`, *optional*):
+ The channel dimension format for the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ If unset, will use the inferred format of the input image.
+
+ Returns:
+ List[`np.ndarray`]: The padded images.
+ """
+
+ max_shape = (
+ max([size[0] for size in image_sizes]),
+ max([size[1] for size in image_sizes]),
+ )
+ pixel_values = [
+ pad(
+ image,
+ padding=((0, max_shape[0] - size[0]), (0, max_shape[1] - size[1])),
+ data_format=data_format,
+ input_data_format=input_data_format,
+ )
+ for image, size in zip(pixel_values, image_sizes)
+ ]
+ return pixel_values
+
+ def preprocess(
+ self,
+ images: ImageInput,
+ do_resize: Optional[bool] = None,
+ size: Dict[str, int] = None,
+ resample: PILImageResampling = None,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ do_convert_rgb: Optional[bool] = None,
+ do_pad: bool = True,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ):
+ """
+ Args:
+ images (`ImageInput`):
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+ Whether to resize the image.
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
+ Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
+ the longest edge resized to keep the input aspect ratio.
+ resample (`int`, *optional*, defaults to `self.resample`):
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
+ has an effect if `do_resize` is set to `True`.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image.
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
+ `True`.
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
+ Whether to convert the image to RGB.
+ do_pad (`bool`, *optional*, defaults to `True`):
+ Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
+ number of patches in the batch. Padding will be applied to the bottom and right with zeros.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+ The channel dimension format for the output image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - Unset: Use the channel dimension format of the input image.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+
+ """
+ do_resize = do_resize if do_resize is not None else self.do_resize
+ size = size if size is not None else self.size
+ resample = resample if resample is not None else self.resample
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
+ do_pad = do_pad if do_pad is not None else self.do_pad
+
+ if images is not None:
+ images = make_batched_images(images)
+
+ if images is not None and not valid_images(images):
+ raise ValueError(
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+ "torch.Tensor, tf.Tensor or jax.ndarray."
+ )
+
+ validate_preprocess_arguments(
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ do_resize=do_resize,
+ size=size,
+ resample=resample,
+ )
+
+ pixel_values = []
+ for image in images:
+ image = self._preprocess(
+ image,
+ do_resize=do_resize,
+ resample=resample,
+ do_rescale=do_rescale,
+ rescale_factor=rescale_factor,
+ do_normalize=do_normalize,
+ image_mean=image_mean,
+ image_std=image_std,
+ data_format=data_format,
+ do_convert_rgb=do_convert_rgb,
+ input_data_format=input_data_format,
+ )
+ pixel_values.extend(image)
+
+ image_sizes = [image.shape[-2:] for image in pixel_values]
+ if do_pad:
+ pixel_values = self._pad_for_batching(pixel_values, image_sizes)
+ pixel_values = np.array(pixel_values)
+
+ return BatchFeature(
+ data={"pixel_values": pixel_values, "image_sizes": image_sizes}, tensor_type=return_tensors
+ )
+
+ def postprocess(
+ self,
+ images: ImageInput,
+ do_rescale: Optional[bool] = None,
+ rescale_factor: Optional[float] = None,
+ do_normalize: Optional[bool] = None,
+ image_mean: Optional[Union[float, List[float]]] = None,
+ image_std: Optional[Union[float, List[float]]] = None,
+ return_tensors: Union[str, TensorType] = "PIL.Image.Image",
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ):
+ """
+ Postprocess an image or batch of images tensor. Postprocess is the reverse process of preprocess.
+ The parameters should be same as in preprocess.
+ Args:
+ images (`ImageInput`):
+ Image to postprocess. Expects a single or batch of images with pixel values ranging from -1 to 1.
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+ Whether to rescale the image.
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+ Whether to normalize the image.
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to `True`.
+ return_tensors (`str` or `TensorType`, *optional*):
+ The type of tensors to return. Can be one of:
+ - Unset: Return a list of `np.ndarray`.
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ """
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
+ rescale_factor = 1.0 / self.rescale_factor if rescale_factor is None else rescale_factor
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
+ image_mean = image_mean if image_mean is not None else self.image_mean
+ image_std = image_std if image_std is not None else self.image_std
+
+ images = make_list_of_images(images)
+ if isinstance(images[0], Image.Image):
+ return images if len(images) > 1 else images[0]
+
+ if input_data_format is None:
+ # We assume that all images have the same channel dimension format.
+ input_data_format = infer_channel_dimension_format(images[0])
+
+ pixel_values = []
+ for image in images:
+ image = to_numpy_array(image)
+ if do_normalize:
+ image = self.unnormalize(
+ image=image, image_mean=image_mean, image_std=image_std, input_data_format=input_data_format
+ )
+
+ if do_rescale:
+ image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format)
+ image = image.clip(0, 255).astype(np.uint8)
+
+ if do_normalize and do_rescale and return_tensors == "PIL.Image.Image":
+ image = to_channel_dimension_format(image, ChannelDimension.LAST, input_channel_dim=input_data_format)
+ pixel_values.append(Image.fromarray(image))
+ else:
+ pixel_values.extend(image)
+
+ data = {"pixel_values": pixel_values}
+ return_tensors = return_tensors if return_tensors != "PIL.Image.Image" else None
+
+ return BatchFeature(data=data, tensor_type=return_tensors)
+
+ def unnormalize(
+ self,
+ image: np.array,
+ image_mean: Union[float, Iterable[float]],
+ image_std: Union[float, Iterable[float]],
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
+ ) -> np.array:
+ """
+ Unnormalizes `image` using the mean and standard deviation specified by `mean` and `std`.
+ image = (image * image_std) + image_mean
+ Args:
+ image (`torch.Tensor` of shape `(batch_size, num_channels, image_size, image_size)` or `(num_channels, image_size, image_size)`):
+ Batch of pixel values to postprocess.
+ image_mean (`float` or `Iterable[float]`):
+ The mean to use for unnormalization.
+ image_std (`float` or `Iterable[float]`):
+ The standard deviation to use for unnormalization.
+ input_data_format (`ChannelDimension` or `str`, *optional*):
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
+ from the input image. Can be one of:
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
+ """
+ num_channels = 3
+
+ if isinstance(image_mean, Iterable):
+ if len(image_mean) != num_channels:
+ raise ValueError(f"mean must have {num_channels} elements if it is an iterable, got {len(image_mean)}")
+ else:
+ image_mean = [image_mean] * num_channels
+
+ if isinstance(image_std, Iterable):
+ if len(image_std) != num_channels:
+ raise ValueError(f"std must have {num_channels} elements if it is an iterable, got {len(image_std)}")
+ else:
+ image_std = [image_std] * num_channels
+
+ rev_image_mean = tuple(-mean / std for mean, std in zip(image_mean, image_std))
+ rev_image_std = tuple(1 / std for std in image_std)
+ image = self.normalize(
+ image=image, mean=rev_image_mean, std=rev_image_std, input_data_format=input_data_format
+ )
+ return image
+
+
+__all__ = ["Emu3ImageProcessor"]
diff --git a/docs/transformers/build/lib/transformers/models/emu3/modeling_emu3.py b/docs/transformers/build/lib/transformers/models/emu3/modeling_emu3.py
new file mode 100644
index 0000000000000000000000000000000000000000..375b1beb230ce89ef388f0f7b521bb1d28acec00
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/emu3/modeling_emu3.py
@@ -0,0 +1,1976 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/emu3/modular_emu3.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_emu3.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from functools import cached_property
+from typing import Callable, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache, StaticCache
+from ...generation import GenerationMixin
+from ...integrations import use_kernel_forward_from_hub
+from ...modeling_attn_mask_utils import AttentionMaskConverter
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_layers import GradientCheckpointingLayer
+from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import (
+ LossKwargs,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ can_return_tuple,
+ is_torch_flex_attn_available,
+ logging,
+ replace_return_docstrings,
+)
+from .configuration_emu3 import Emu3Config, Emu3TextConfig, Emu3VQVAEConfig
+
+
+if is_torch_flex_attn_available():
+ from torch.nn.attention.flex_attention import BlockMask
+
+ from ...integrations.flex_attention import make_flex_block_causal_mask
+
+
+logger = logging.get_logger(__name__)
+
+
+_CONFIG_FOR_DOC = "Emu3Config"
+
+
+@use_kernel_forward_from_hub("RMSNorm")
+class Emu3RMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ Emu3RMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+class Emu3MLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, x):
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+ return down_proj
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs,
+):
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+class Emu3Attention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: Emu3Config, layer_idx: int):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
+ self.scaling = self.head_dim**-0.5
+ self.attention_dropout = config.attention_dropout
+ self.is_causal = True
+
+ self.q_proj = nn.Linear(
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.k_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.v_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.o_proj = nn.Linear(
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor],
+ past_key_value: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ attention_interface: Callable = eager_attention_forward
+
+ if self.config._attn_implementation != "eager":
+ if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
+ logger.warning_once(
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ else:
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+
+class Emu3DecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: Emu3Config, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+
+ self.self_attn = Emu3Attention(config=config, layer_idx=layer_idx)
+
+ self.mlp = Emu3MLP(config)
+ self.input_layernorm = Emu3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = Emu3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.dropout = nn.Dropout(config.attention_dropout)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ **kwargs,
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*):
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
+ query_sequence_length, key_sequence_length)` if default attention is used.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence
+ kwargs (`dict`, *optional*):
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
+ into the model
+ """
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = residual + self.dropout(hidden_states)
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + self.dropout(hidden_states)
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ return outputs
+
+
+class Emu3VQVAEVectorQuantizer(nn.Module):
+ """
+ A module for vector quantization using learned embedding vectors.
+
+ This module implements the quantization process similar to te one described in
+ the VQ-VAE (Vector Quantized Variational AutoEncoder) paper. It quantizes continuous
+ input vectors into discrete codebook vectors, which are learned during training.
+ Current implementation improves over previous ones by avoiding costly matrix multiplications
+ and allowing for post-hoc remapping of indices.
+ """
+
+ def __init__(self, config: Emu3VQVAEConfig):
+ super().__init__()
+ self.embedding = nn.Embedding(config.codebook_size, config.embed_dim)
+ self.embedding.weight.data.uniform_(-1.0 / config.codebook_size, 1.0 / config.codebook_size)
+
+ def forward(self, hidden_state: torch.Tensor):
+ batch_size, temporal, channels, height, width = hidden_state.shape
+ hidden_state = hidden_state.permute(0, 1, 3, 4, 2).contiguous()
+ hidden_state_flattened = hidden_state.view(-1, channels)
+
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+ hidden_state_sum = torch.sum(hidden_state_flattened**2, dim=1, keepdim=True)
+ embedding_sum = torch.sum(self.embedding.weight**2, dim=1)
+
+ # "bd,dn->bn",
+ distances = 2 * torch.matmul(hidden_state_flattened, self.embedding.weight.transpose(0, 1))
+ distances = hidden_state_sum + embedding_sum - distances
+
+ min_encoding_indices = torch.argmin(distances, dim=1)
+ min_encoding_indices = min_encoding_indices.view(batch_size, temporal, height, width)
+ return min_encoding_indices
+
+
+class Emu3VQVAEEncoderConvDownsample(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
+
+ def forward(self, hidden_states):
+ # no asymmetric padding in torch conv, must do it ourselves
+ hidden_states = F.pad(hidden_states, pad=(0, 1, 0, 1), mode="constant", value=0)
+ hidden_states = self.conv(hidden_states)
+ return hidden_states
+
+
+class Emu3VQVAEEncoderConvUpsample(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
+
+ def forward(self, hidden_states):
+ hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
+ hidden_states = self.conv(hidden_states)
+ return hidden_states
+
+
+class Emu3VQVAEConv3d(nn.Module):
+ def __init__(
+ self,
+ in_channel: int,
+ out_channel: int,
+ kernel_size: Tuple[int],
+ stride: Tuple[int],
+ ):
+ super().__init__()
+
+ padding_sizes = [one_kernel - one_stride for one_kernel, one_stride in zip(kernel_size[1:], stride[1:])]
+ self.padding = ()
+ for pad_size in padding_sizes[::-1]:
+ self.padding += (pad_size // 2 + pad_size % 2, pad_size // 2)
+ self.padding += (2, 0)
+
+ self.conv = nn.Conv3d(
+ in_channel,
+ out_channel,
+ kernel_size,
+ stride=stride,
+ )
+
+ def forward(self, hidden_states: torch.Tensor):
+ hidden_states = F.pad(hidden_states, self.padding)
+ hidden_states = self.conv(hidden_states)
+ return hidden_states
+
+
+class Emu3VQVAESpatialNorm(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ ):
+ super().__init__()
+ self.norm_layer = nn.GroupNorm(
+ num_channels=out_channels,
+ num_groups=32,
+ eps=1e-6,
+ affine=True,
+ )
+
+ self.conv_y = nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ )
+ self.conv_b = nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ )
+
+ def forward(self, hidden_states: torch.Tensor, quant_states: torch.Tensor):
+ quant_states = F.interpolate(quant_states, size=hidden_states.shape[-2:], mode="nearest")
+ hidden_states = self.norm_layer(hidden_states)
+ hidden_states = hidden_states * self.conv_y(quant_states) + self.conv_b(quant_states)
+ return hidden_states
+
+
+class Emu3VQVAETemporalUpsample(nn.Module):
+ def __init__(
+ self,
+ in_channel: int,
+ out_channel: int,
+ ):
+ super().__init__()
+ self.conv = Emu3VQVAEConv3d(
+ in_channel,
+ out_channel,
+ kernel_size=(3, 3, 3),
+ stride=(1, 1, 1),
+ )
+
+ def forward(self, hidden_states: torch.Tensor):
+ batch_size, channels, temporal, height, width = hidden_states.shape
+ hidden_states = hidden_states.permute(0, 1, 3, 4, 2).contiguous().view(batch_size, -1, temporal)
+ hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
+ hidden_states = hidden_states.view(batch_size, channels, height, width, -1).permute(0, 1, 4, 2, 3).contiguous()
+ hidden_states = self.conv(hidden_states)
+ return hidden_states
+
+
+class Emu3VQVAETemporalDownsample(nn.Module):
+ def __init__(
+ self,
+ in_channel: int,
+ out_channel: int,
+ ):
+ super().__init__()
+ self.conv = Emu3VQVAEConv3d(
+ in_channel,
+ out_channel,
+ kernel_size=(4, 3, 3),
+ stride=(2, 1, 1),
+ )
+
+ def forward(self, hidden_states: torch.Tensor):
+ hidden_states = self.conv(hidden_states)
+ return hidden_states
+
+
+class Emu3VQVAETemporalResnetBlock(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels=None,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = in_channels if out_channels is None else out_channels
+
+ self.norm1 = nn.BatchNorm3d(in_channels)
+ self.conv1 = Emu3VQVAEConv3d(
+ in_channels,
+ out_channels,
+ kernel_size=(3, 3, 3),
+ stride=(1, 1, 1),
+ )
+ self.norm2 = nn.BatchNorm3d(out_channels)
+ self.conv2 = Emu3VQVAEConv3d(
+ out_channels,
+ out_channels,
+ kernel_size=(3, 3, 3),
+ stride=(1, 1, 1),
+ )
+ if self.in_channels != self.out_channels:
+ self.nin_shortcut = nn.Conv3d(
+ in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ )
+
+ def forward(self, hidden_states):
+ residual = hidden_states
+ hidden_states = self.norm1(hidden_states)
+ hidden_states *= torch.sigmoid(hidden_states)
+ hidden_states = self.conv1(hidden_states)
+
+ hidden_states = self.norm2(hidden_states)
+ hidden_states *= torch.sigmoid(hidden_states)
+ hidden_states = self.conv2(hidden_states)
+
+ if self.in_channels != self.out_channels:
+ residual = self.nin_shortcut(residual)
+
+ return residual + hidden_states
+
+
+class Emu3VQVAEResnetBlock(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: Optional[int] = None,
+ quant_channels: Optional[int] = None,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.quant_channels = quant_channels
+
+ if quant_channels is None:
+ self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=32, eps=1e-6, affine=True)
+ self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=32, eps=1e-6, affine=True)
+ else:
+ self.norm1 = Emu3VQVAESpatialNorm(quant_channels, in_channels)
+ self.norm2 = Emu3VQVAESpatialNorm(quant_channels, out_channels)
+
+ self.conv1 = nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ )
+
+ self.conv2 = nn.Conv2d(
+ out_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ )
+
+ if self.in_channels != self.out_channels:
+ self.nin_shortcut = nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ )
+
+ def forward(self, hidden_states: torch.Tensor, quant_channels: Optional[torch.Tensor] = None):
+ norm_args = () if self.quant_channels is None else (quant_channels,)
+
+ residual = hidden_states
+ hidden_states = self.norm1(hidden_states, *norm_args)
+ hidden_states *= torch.sigmoid(hidden_states)
+ hidden_states = self.conv1(hidden_states)
+
+ hidden_states = self.norm2(hidden_states, *norm_args)
+ hidden_states *= torch.sigmoid(hidden_states)
+ hidden_states = self.conv2(hidden_states)
+
+ if self.in_channels != self.out_channels:
+ residual = self.nin_shortcut(residual)
+
+ return residual + hidden_states
+
+
+class Emu3VQVAEAttentionBlock(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: Emu3VQVAEConfig):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.embed_dim // self.num_heads
+ if self.head_dim * self.num_heads != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {self.num_heads})."
+ )
+ self.scale = self.head_dim**-0.5
+ self.dropout = config.attention_dropout
+ self.is_causal = False
+
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
+
+ # for compatibility with the attention interface
+ self.num_key_value_groups = 1
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ """Input shape: Batch x Time x Channel"""
+
+ batch_size, seq_length, embed_dim = hidden_states.shape
+
+ queries = self.q_proj(hidden_states)
+ keys = self.k_proj(hidden_states)
+ values = self.v_proj(hidden_states)
+
+ queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
+ keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
+ values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ if self.config._attn_implementation == "sdpa" and output_attentions:
+ logger.warning_once(
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ else:
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ queries,
+ keys,
+ values,
+ attention_mask,
+ is_causal=self.is_causal,
+ scaling=self.scale,
+ dropout=0.0 if not self.training else self.dropout,
+ )
+
+ attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
+ attn_output = self.out_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights
+
+
+class Emu3VQVAEGroupNorm(nn.GroupNorm):
+ """
+ Same as the torch GroupNorm with the only difference that this ones accepts
+ an optional kwarg `quant_states` which is not used. This class makes it easier to
+ use SpatialNorm or GroupNorm without conditionals
+ """
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+
+ def forward(self, input, quant_states=None):
+ return F.group_norm(input, self.num_groups, self.weight, self.bias, self.eps)
+
+
+class Emu3VQVAEMiddleBlock(nn.Module):
+ def __init__(self, config, in_channels, quant_channels=None):
+ super().__init__()
+
+ self.block_1 = Emu3VQVAEResnetBlock(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ quant_channels=quant_channels,
+ )
+ self.attn_1 = Emu3VQVAEAttentionBlock(config)
+ if quant_channels is None:
+ self.attn_norm = Emu3VQVAEGroupNorm(num_channels=in_channels, num_groups=32, eps=1e-6, affine=True)
+ else:
+ self.attn_norm = Emu3VQVAESpatialNorm(quant_channels, in_channels)
+
+ self.block_2 = Emu3VQVAEResnetBlock(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ quant_channels=quant_channels,
+ )
+
+ def forward(self, hidden_states: torch.FloatTensor, quant_states: Optional[torch.FloatTensor] = None):
+ hidden_states = self.block_1(hidden_states, quant_states)
+ residual = hidden_states
+ hidden_states = self.attn_norm(hidden_states, quant_states)
+ batch_size, channels, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channels, height * width).transpose(1, 2)
+ hidden_states = self.attn_1(hidden_states)[0]
+ hidden_states = hidden_states.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2)
+ hidden_states = residual + hidden_states
+ hidden_states = self.block_2(hidden_states, quant_states)
+ return hidden_states
+
+
+class Emu3VQVAEDownBlock(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+
+ self.num_resolutions = len(config.channel_multiplier)
+ self.num_res_blocks = config.num_res_blocks
+ base_channels = config.base_channels
+ channel_multiplier = config.channel_multiplier
+
+ in_channel_multiplier = (1,) + tuple(channel_multiplier)
+ self.in_channel_multiplier = in_channel_multiplier
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ attn_norms = nn.ModuleList()
+ block_in = base_channels * in_channel_multiplier[i_level]
+ block_out = base_channels * channel_multiplier[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(
+ Emu3VQVAEResnetBlock(
+ in_channels=block_in,
+ out_channels=block_out,
+ )
+ )
+ block_in = block_out
+ if config.attn_resolutions is not None and i_level in config.attn_resolutions:
+ attn.append(Emu3VQVAEAttentionBlock(config))
+ attn_norms.append(nn.GroupNorm(num_channels=block_in, num_groups=32, eps=1e-6, affine=True))
+
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ down.attn_norms = attn_norms
+ if i_level != self.num_resolutions - 1:
+ down.downsample = Emu3VQVAEEncoderConvDownsample(block_in)
+ self.down.append(down)
+
+ def forward(self, hidden_states: torch.FloatTensor):
+ for i_level, blocks in enumerate(self.down):
+ for i_block in range(self.num_res_blocks):
+ hidden_states = blocks.block[i_block](hidden_states)
+ if len(blocks.attn) > 0:
+ residual = hidden_states
+ hidden_states = blocks.attn_norms[i_block](hidden_states)
+
+ batch_size, channels, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channels, height * width).transpose(1, 2)
+ hidden_states = blocks.attn[i_block](hidden_states)[0]
+
+ hidden_states = hidden_states.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2)
+ hidden_states = residual + hidden_states
+
+ if i_level != self.num_resolutions - 1:
+ hidden_states = blocks.downsample(hidden_states)
+
+ return hidden_states
+
+
+class Emu3VQVAEUpBlock(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+
+ self.num_resolutions = len(config.channel_multiplier)
+ self.num_res_blocks = config.num_res_blocks
+
+ quant_channels = config.embed_dim
+ block_in = config.base_channels * config.channel_multiplier[-1]
+
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ attn_norms = nn.ModuleList()
+ block_out = config.base_channels * config.channel_multiplier[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ block.append(
+ Emu3VQVAEResnetBlock(
+ in_channels=block_in,
+ out_channels=block_out,
+ quant_channels=quant_channels,
+ )
+ )
+ block_in = block_out
+ if i_level in config.attn_resolutions:
+ attn.append(Emu3VQVAEAttentionBlock(config))
+ attn_norms.append(Emu3VQVAESpatialNorm(quant_channels, block_in))
+
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ up.attn_norms = attn_norms
+ if i_level != 0:
+ up.upsample = Emu3VQVAEEncoderConvUpsample(block_in)
+
+ self.up.insert(0, up)
+
+ def forward(self, hidden_states: torch.FloatTensor, quant_states: torch.FloatTensor):
+ for i_level, blocks in enumerate(self.up[::-1]):
+ for i_block in range(self.num_res_blocks + 1):
+ hidden_states = blocks.block[i_block](hidden_states, quant_states)
+ if len(blocks.attn) > 0:
+ residual = hidden_states
+ hidden_states = blocks.attn_norms[i_block](hidden_states, quant_states)
+
+ batch_size, channels, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channels, height * width).transpose(1, 2)
+ hidden_states = blocks.attn[i_block](hidden_states)[0]
+
+ hidden_states = hidden_states.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2)
+ hidden_states = residual + hidden_states
+ if i_level != len(self.up) - 1:
+ hidden_states = blocks.upsample(hidden_states)
+
+ return hidden_states
+
+
+class Emu3VQVAEEncoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+
+ base_channels = config.base_channels
+ in_channels = config.in_channels
+ double_latent = config.double_latent
+ latent_channels = config.latent_channels
+ channel_multiplier = config.channel_multiplier
+ out_channels = 2 * latent_channels if double_latent else latent_channels
+ block_in = base_channels * channel_multiplier[-1]
+
+ self.conv_in = torch.nn.Conv2d(in_channels, base_channels, kernel_size=3, stride=1, padding=1)
+ self.down_block = Emu3VQVAEDownBlock(config)
+ self.middle_block = Emu3VQVAEMiddleBlock(config, block_in)
+
+ self.norm_out = torch.nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
+ self.conv_out = torch.nn.Conv2d(
+ block_in,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ )
+
+ temporal_down_blocks = int(math.log2(config.temporal_downsample_factor))
+ self.time_conv = nn.ModuleList()
+ self.time_res_stack = nn.ModuleList()
+
+ for i in range(temporal_down_blocks):
+ conv = Emu3VQVAETemporalDownsample(out_channels, out_channels)
+ self.time_conv.append(conv)
+
+ for _ in range(config.num_res_blocks):
+ time_res_conv = Emu3VQVAETemporalResnetBlock(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ )
+ self.time_res_stack.append(time_res_conv)
+
+ def forward(self, pixel_values: torch.LongTensor):
+ temporal_dim = pixel_values.shape[1]
+ pixel_values = pixel_values.reshape(-1, *pixel_values.shape[2:])
+
+ # downsampling & middle
+ hidden_states = self.conv_in(pixel_values)
+ hidden_states = self.down_block(hidden_states)
+ hidden_states = self.middle_block(hidden_states)
+
+ # end
+ hidden_states = self.norm_out(hidden_states)
+ hidden_states *= torch.sigmoid(hidden_states)
+ hidden_states = self.conv_out(hidden_states)
+
+ hidden_states = hidden_states.reshape(-1, temporal_dim, *hidden_states.shape[1:])
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
+
+ # temporal convs
+ for conv in self.time_conv:
+ hidden_states = conv(hidden_states)
+ hidden_states *= torch.sigmoid(hidden_states)
+
+ for layer in self.time_res_stack:
+ hidden_states = layer(hidden_states)
+
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
+
+ return hidden_states
+
+
+class Emu3VQVAEDecoder(nn.Module):
+ def __init__(self, config: Emu3VQVAEConfig):
+ super().__init__()
+
+ quant_channels = config.embed_dim
+ block_in = config.base_channels * config.channel_multiplier[-1]
+ self.time_res_stack = nn.ModuleList()
+ for _ in range(config.num_res_blocks):
+ time_res_conv = Emu3VQVAETemporalResnetBlock(
+ in_channels=config.latent_channels, out_channels=config.latent_channels
+ )
+ self.time_res_stack.append(time_res_conv)
+
+ temp_upsample_block_num = int(math.log2(config.temporal_downsample_factor))
+ self.time_conv = nn.ModuleList()
+ for i in range(temp_upsample_block_num):
+ conv = Emu3VQVAETemporalUpsample(config.latent_channels, config.latent_channels)
+ self.time_conv.append(conv)
+
+ self.conv_in = nn.Conv2d(
+ config.latent_channels,
+ block_in,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ )
+
+ self.middle_block = Emu3VQVAEMiddleBlock(config, block_in, quant_channels=quant_channels)
+ self.up_block = Emu3VQVAEUpBlock(config)
+
+ block_in = config.base_channels * config.channel_multiplier[0]
+ self.norm_out = Emu3VQVAESpatialNorm(quant_channels, block_in)
+ self.conv_out = nn.Conv2d(
+ block_in,
+ config.out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ )
+
+ def forward(self, hidden_states: torch.Tensor, quant_states: torch.Tensor):
+ hidden_quant_states = torch.cat((hidden_states, quant_states), dim=0)
+ hidden_quant_states = hidden_quant_states.permute(0, 2, 1, 3, 4)
+
+ # temporal convs
+ for layer in self.time_res_stack:
+ hidden_quant_states = layer(hidden_quant_states)
+
+ for layer in self.time_conv:
+ hidden_quant_states = layer(hidden_quant_states)
+ hidden_quant_states *= torch.sigmoid(hidden_quant_states)
+
+ hidden_quant_states = hidden_quant_states.permute(0, 2, 1, 3, 4)
+ hidden_states, quant_states = torch.chunk(hidden_quant_states, 2, dim=0)
+ hidden_states = hidden_states.reshape(-1, *hidden_states.shape[2:])
+ quant_states = quant_states.reshape(-1, *quant_states.shape[2:])
+
+ hidden_states = self.conv_in(hidden_states)
+
+ # middle & upsampling
+ hidden_states = self.middle_block(hidden_states, quant_states)
+ hidden_states = self.up_block(hidden_states, quant_states)
+
+ hidden_states = self.norm_out(hidden_states, quant_states)
+ hidden_states *= torch.sigmoid(hidden_states)
+ hidden_states = self.conv_out(hidden_states)
+
+ return hidden_states
+
+
+EMU3_VQ_START_DOCSTRING = r"""
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`Emu3VQVAEConfig`]):
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
+ load the weights associated with the model, only the configuration. Check out the
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+@add_start_docstrings(
+ """The VQ-VAE model used in Emu3 for encoding/decoding images into discrete tokens.
+ This model follows the "Make-a-scene: Scene-based text-to-image generation with human priors" paper from
+ [ Oran Gafni, Adam Polyak, Oron Ashual, Shelly Sheynin, Devi Parikh, and Yaniv Taigman](https://arxiv.org/abs/2203.13131).
+ """,
+ EMU3_VQ_START_DOCSTRING,
+)
+class Emu3VQVAE(PreTrainedModel):
+ config_class = Emu3VQVAEConfig
+ base_model_prefix = "emuvideovq"
+ main_input_name = "pixel_values"
+ _supports_sdpa = True
+ _supports_flash_attn_2 = True
+ _supports_flex_attn = True
+ _no_split_modules = [
+ "Emu3VQVAETemporalResnetBlock",
+ "Emu3VQVAEAttentionBlock",
+ "Emu3VQVAEResnetBlock",
+ "Emu3VQVAEVectorQuantizer",
+ ]
+
+ def _init_weights(self, module):
+ if isinstance(module, (nn.Conv2d, nn.Conv3d)):
+ nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
+ if module.bias is not None:
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
+ bound = 1 / math.sqrt(fan_in)
+ nn.init.uniform_(module.bias, -bound, bound)
+ elif isinstance(module, nn.Linear):
+ nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5))
+ if module.bias is not None:
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
+ bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
+ nn.init.uniform_(module.bias, -bound, bound)
+ elif isinstance(module, (nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm)):
+ nn.init.constant_(module.weight, 1.0)
+ nn.init.constant_(module.bias, 0.0)
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_()
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+ def __init__(self, config: Emu3VQVAEConfig):
+ super().__init__(config)
+
+ self.config = config
+
+ self.encoder = Emu3VQVAEEncoder(config)
+ self.decoder = Emu3VQVAEDecoder(config)
+ self.quantize = Emu3VQVAEVectorQuantizer(config)
+ self.vision_spatial_factor = 2 ** (len(config.channel_multiplier) - 1)
+
+ self.quant_conv = Emu3VQVAEConv3d(
+ config.latent_channels, config.embed_dim, kernel_size=(3, 1, 1), stride=(1, 1, 1)
+ )
+ self.post_quant_conv = Emu3VQVAEConv3d(
+ config.embed_dim, config.latent_channels, kernel_size=(3, 1, 1), stride=(1, 1, 1)
+ )
+ self.spatial_scale_factor = 2 ** (len(config.channel_multiplier) - 1)
+ self.eval() # Emu3's VQ model is frozen
+
+ self.post_init()
+
+ def encode(self, pixel_values: torch.Tensor, image_sizes: torch.Tensor):
+ is_image = pixel_values.ndim == 4
+ if is_image:
+ temporal = self.config.temporal_downsample_factor
+ batch_size, channels, height, width = pixel_values.shape
+ pixel_values = pixel_values.unsqueeze(1).repeat(1, temporal, 1, 1, 1)
+ else:
+ batch_size, temporal, channels, height, width = pixel_values.shape
+
+ hidden_states = self.encoder(pixel_values)
+
+ # b t c h w -> b c t h w
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
+ hidden_states = self.quant_conv(hidden_states)
+
+ # b c t h w -> b t c h w
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
+ codes = self.quantize(hidden_states)
+
+ image_tokens = codes.squeeze(1) if is_image else codes
+
+ image_tokens = [
+ single_image[: int(size[0] / self.vision_spatial_factor), : int(size[1] / self.vision_spatial_factor)]
+ for single_image, size in zip(image_tokens, image_sizes)
+ ]
+
+ return image_tokens
+
+ def decode(self, hidden_states: torch.Tensor):
+ is_image = hidden_states.ndim == 3
+ if is_image:
+ hidden_states = hidden_states.unsqueeze(1)
+
+ batch_size, temporal, height, width = hidden_states.shape
+ quant = self.quantize.embedding(hidden_states.flatten())
+
+ channels = quant.shape[-1]
+ quant = quant.view(batch_size, temporal, height, width, channels).permute(0, 4, 1, 2, 3).contiguous()
+ post_quant = self.post_quant_conv(quant)
+
+ quant = quant.permute(0, 2, 1, 3, 4)
+ post_quant = post_quant.permute(0, 2, 1, 3, 4)
+
+ video = self.decoder(post_quant, quant)
+ video = video.reshape(
+ batch_size,
+ temporal * self.config.temporal_downsample_factor,
+ self.config.out_channels,
+ height * self.spatial_scale_factor,
+ width * self.spatial_scale_factor,
+ )
+ return video[:, 0] if is_image else video
+
+
+class Emu3ImageVocabularyMapping:
+ """
+ A class for mapping discrete image tokens from VQGAN to BPE tokens.
+ """
+
+ def __init__(self, vocab_map):
+ self.vocab_map = vocab_map
+ self.eol_token_id = vocab_map.get("<|extra_200|>")
+ self.image_token_id = vocab_map.get("")
+
+ @cached_property
+ def image_tokens(self):
+ return sorted([val for name, val in self.vocab_map.items() if name.startswith("<|visual token")])
+
+ @cached_property
+ def image_tokens_str(self):
+ return sorted([name for name, val in self.vocab_map.items() if name.startswith("<|visual token")])
+
+ @cached_property
+ def img2bpe(self):
+ return {int(token[-8:-2]): self.vocab_map[token] for token in self.image_tokens_str}
+
+ @cached_property
+ def bpe2img(self):
+ return {v: k for k, v in self.img2bpe.items()}
+
+ @cached_property
+ def bpe2img_mapping_tensor(self):
+ mapping = torch.zeros(max(self.bpe2img.keys()) + 1, dtype=torch.int)
+ for k, v in self.bpe2img.items():
+ mapping[k] = v
+ return mapping
+
+ @cached_property
+ def img2bpe_mapping_tensor(self):
+ mapping = torch.zeros(max(self.img2bpe.keys()) + 1, dtype=torch.int)
+ for k, v in self.img2bpe.items():
+ mapping[k] = v
+ return mapping
+
+ def convert_img2bpe(self, img_batch: List[torch.Tensor]) -> torch.Tensor:
+ device = img_batch.device
+ eol_row = torch.ones((img_batch.shape[0], 1), dtype=torch.int) * self.eol_token_id
+ img_tokens = self.img2bpe_mapping_tensor[img_batch.to("cpu")]
+ img_tokens = torch.cat([img_tokens, eol_row], dim=-1)
+ return img_tokens.to(device)
+
+ def convert_bpe2img(self, img_batch: torch.Tensor) -> torch.Tensor:
+ device = img_batch.device
+ img_batch = img_batch[..., :-1] # remove last row of EOL tokens
+ img_tokens = self.bpe2img_mapping_tensor[img_batch.to("cpu")]
+ return img_tokens.to(device)
+
+
+EMU3_START_DOCSTRING = r"""
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`Emu3Config`]):
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
+ load the weights associated with the model, only the configuration. Check out the
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+@add_start_docstrings(
+ "The bare emu3 Model outputting raw hidden-states without any specific head on top.",
+ EMU3_START_DOCSTRING,
+)
+class Emu3PreTrainedModel(PreTrainedModel):
+ config_class = Emu3Config
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = [
+ "Emu3DecoderLayer",
+ ]
+ _skip_keys_device_placement = ["past_key_values", "causal_mask"]
+ _supports_flash_attn_2 = True
+ _supports_sdpa = True
+ _supports_quantized_cache = True
+ _supports_cache_class = True
+ _supports_static_cache = True
+ _supports_param_buffer_assignment = False
+ _supports_flex_attn = True
+
+ def _init_weights(self, module):
+ std = self.config.get_text_config().initializer_range
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, Emu3RMSNorm): # noqa: F821
+ module.weight.data.fill_(1.0)
+
+
+class Emu3RotaryEmbedding(nn.Module):
+ def __init__(self, config: Emu3Config, device=None):
+ super().__init__()
+ # BC: "rope_type" was originally "type"
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+EMU3_TEXT_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
+ `past_key_values`).
+
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
+ information on the default strategy.
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.n_positions - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ past_key_values (`Cache`, *optional*):
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
+
+ Has to be an instance of [`~cache_utils.Cache`] instance, see our
+ [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ The model will output the same cache type that is fed as input. If no `past_key_values` are passed, the
+ legacy cache format will be returned.
+
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
+ of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
+ the complete sequence length.
+"""
+
+
+@add_start_docstrings(
+ "The bare Emu3Text Model outputting raw hidden-states without any specific head on top.",
+ EMU3_START_DOCSTRING,
+)
+class Emu3TextModel(Emu3PreTrainedModel):
+ """
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Emu3TextDecoderLayer`]
+
+ Args:
+ config: Emu3TextConfig
+ """
+
+ def __init__(self, config: Emu3Config):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList(
+ [Emu3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = Emu3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.rotary_emb = Emu3RotaryEmbedding(config=config)
+ self.gradient_checkpointing = False
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ @can_return_tuple
+ @add_start_docstrings_to_model_forward(EMU3_TEXT_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
+ ) -> BaseModelOutputWithPast:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if self.gradient_checkpointing and self.training and use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
+ )
+ use_cache = False
+
+ # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
+ if not isinstance(past_key_values, (type(None), Cache)):
+ raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if use_cache and past_key_values is None:
+ past_key_values = DynamicCache()
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = self._update_causal_mask(
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
+ )
+
+ hidden_states = inputs_embeds
+
+ # create position embeddings to be shared across the decoder layers
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **flash_attn_kwargs,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values if use_cache else None,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+ def _update_causal_mask(
+ self,
+ attention_mask: Union[torch.Tensor, "BlockMask"],
+ input_tensor: torch.Tensor,
+ cache_position: torch.Tensor,
+ past_key_values: Cache,
+ output_attentions: bool = False,
+ ):
+ if self.config._attn_implementation == "flash_attention_2":
+ if attention_mask is not None and (attention_mask == 0.0).any():
+ return attention_mask
+ return None
+ if self.config._attn_implementation == "flex_attention":
+ if isinstance(attention_mask, torch.Tensor):
+ attention_mask = make_flex_block_causal_mask(attention_mask)
+ return attention_mask
+
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
+ # to infer the attention mask.
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ using_static_cache = isinstance(past_key_values, StaticCache)
+
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
+ if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
+ attention_mask,
+ inputs_embeds=input_tensor,
+ past_key_values_length=past_seen_tokens,
+ is_training=self.training,
+ ):
+ return None
+
+ dtype, device = input_tensor.dtype, input_tensor.device
+ sequence_length = input_tensor.shape[1]
+ if using_static_cache:
+ target_length = past_key_values.get_max_cache_shape()
+ else:
+ target_length = (
+ attention_mask.shape[-1]
+ if isinstance(attention_mask, torch.Tensor)
+ else past_seen_tokens + sequence_length + 1
+ )
+
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask,
+ sequence_length=sequence_length,
+ target_length=target_length,
+ dtype=dtype,
+ device=device,
+ cache_position=cache_position,
+ batch_size=input_tensor.shape[0],
+ )
+
+ if (
+ self.config._attn_implementation == "sdpa"
+ and attention_mask is not None
+ and attention_mask.device.type in ["cuda", "xpu", "npu"]
+ and not output_attentions
+ ):
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
+ # Details: https://github.com/pytorch/pytorch/issues/110213
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
+
+ return causal_mask
+
+ @staticmethod
+ def _prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask: torch.Tensor,
+ sequence_length: int,
+ target_length: int,
+ dtype: torch.dtype,
+ device: torch.device,
+ cache_position: torch.Tensor,
+ batch_size: int,
+ **kwargs,
+ ):
+ """
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
+
+ Args:
+ attention_mask (`torch.Tensor`):
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
+ `(batch_size, 1, query_length, key_value_length)`.
+ sequence_length (`int`):
+ The sequence length being processed.
+ target_length (`int`):
+ The target length: when generating with static cache, the mask should be as long as the static cache,
+ to account for the 0 padding, the part of the cache that is not filled yet.
+ dtype (`torch.dtype`):
+ The dtype to use for the 4D attention mask.
+ device (`torch.device`):
+ The device to place the 4D attention mask on.
+ cache_position (`torch.Tensor`):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ batch_size (`torch.Tensor`):
+ Batch size.
+ """
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+ causal_mask = attention_mask
+ else:
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = torch.full(
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
+ )
+ if sequence_length != 1:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
+ causal_mask.device
+ )
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+ padding_mask, min_dtype
+ )
+
+ return causal_mask
+
+
+class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
+
+
+class Emu3ForCausalLM(Emu3PreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["lm_head.weight"]
+ _tp_plan = {"lm_head": "colwise_rep"}
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
+ config_class = Emu3TextConfig
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = Emu3TextModel(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model = decoder
+
+ def get_decoder(self):
+ return self.model
+
+ @can_return_tuple
+ @add_start_docstrings_to_model_forward(EMU3_TEXT_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="Emu3TextConfig")
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs: Unpack[KwargsForCausalLM],
+ ) -> CausalLMOutputWithPast:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import Emu3Processor, Emu3ForConditionalGeneration
+ >>> import torch
+ >>> import requests
+ >>> from PIL import Image
+
+ >>> model = Emu3ForCausalLM.from_pretrained("BAAI/Emu3-Chat-hf", torch_dtype=torch.bfloat16)
+ >>> processor = Emu3Processor.from_pretrained("BAAI/Emu3-Chat-hf")
+
+ >>> inputs = processor(text=["Can you write me a poem about winter."], return_tensors="pt").to(model.device)
+
+ >>> generated_ids = model.generate(**inputs, max_new_tokens=100, do_sample=False)
+ >>> processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs: BaseModelOutputWithPast = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs.last_hidden_state
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+EMU3_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, max_num_images, max_num_tiles, channels, image_size, image_size)):
+ The tensors corresponding to the input images. Pixel values can be obtained using
+ [`AutoImageProcessor`]. See [`Emu3ImageProcessor.__call__`] for details ([]`Emu3Processor`] uses
+ [`Emu3ImageProcessor`] for processing images).
+ image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`):
+ The sizes of the images in the batch, being (height, width) for each image. Image sizes can be obtained using
+ [`AutoImageProcessor`]. See [`Emu3ImageProcessor.__call__`] for details ([]`Emu3Processor`] uses
+ [`Emu3ImageProcessor`] for processing images).
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
+ `past_key_values`).
+
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
+ information on the default strategy.
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.n_positions - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
+
+ Has to be an instance of [`~cache_utils.Cache`] instance, see our
+ [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
+ legacy cache format will be returned.
+
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
+ of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
+ the complete sequence length.
+"""
+
+
+class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["text_model.lm_head.weight"]
+ _supports_static_cache = False # `get_image_tokens()`, called when `pixel_values` is passed, is not compileable
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.text_model = Emu3ForCausalLM._from_config(config.text_config)
+ self.vqmodel = Emu3VQVAE(config.vq_config)
+ self.vocabulary_mapping = Emu3ImageVocabularyMapping(config.vocabulary_map)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.text_model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.text_model.set_input_embeddings(value)
+
+ def get_image_tokens(self, pixel_values: torch.FloatTensor, image_sizes: torch.LongTensor):
+ """
+ Tokenizes images into discrete tokens with VQGAN module. Converts
+ obtained image tokens into BPE tokens and wraps with "boi" and "eoi"
+ special tokens.
+
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
+ The tensors corresponding to the input images.
+ image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`):
+ The sizes of the images in the batch, being (height, width) for each image.
+ """
+ image_tokens_list = self.vqmodel.encode(pixel_values, image_sizes)
+ bpe_tokens_list = [self.vocabulary_mapping.convert_img2bpe(tokens).flatten() for tokens in image_tokens_list]
+ bpe_tokens = torch.cat(bpe_tokens_list)
+ return bpe_tokens
+
+ @torch.no_grad
+ def decode_image_tokens(self, image_tokens: torch.LongTensor, height: int, width: int):
+ """
+ Decodes generated image tokens from language model to continuous pixel values
+ with VQGAN module via upsampling.
+
+ Args:
+ image_tokens (`torch.LongTensor` of shape `(batch_size, num_of_tokens)`):
+ The tensors corresponding to the input images.
+ height (`int`):
+ Height of the generated image before upsampling.
+ width (`int`):
+ Width of the generated image before upsampling.
+ """
+ sequences = image_tokens[:, :-3].view(-1, height, width + 1)
+ image_tokens = self.vocabulary_mapping.convert_bpe2img(sequences)
+ image = self.vqmodel.decode(image_tokens)
+ return image
+
+ @can_return_tuple
+ @add_start_docstrings_to_model_forward(EMU3_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ image_sizes: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ ) -> CausalLMOutputWithPast:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import Emu3Processor, Emu3ForConditionalGeneration
+ >>> import torch
+ >>> import requests
+ >>> from PIL import Image
+
+ >>> model = Emu3ForConditionalGeneration.from_pretrained("BAAI/Emu3-Chat-hf", torch_dtype=torch.bfloat16)
+ >>> processor = Emu3Processor.from_pretrained("BAAI/Emu3-Chat-hf")
+
+ >>> conversation = [
+ ... {
+ ... "role": "system",
+ ... "content": [
+ ... {"type": "text", "text": "You are a helpful assistant."},
+ ... ],
+ ... },
+ ... {
+ ... "role": "user",
+ ... "content": [
+ ... {"type": "image"},
+ ... {"type": "text", "text": "Please describe the image."},
+ ... ],
+ ... },
+ ... ]
+
+ >>> prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
+ >>> image = Image.open(requests.get("https://www.ilankelman.org/stopsigns/australia.jpg", stream=True).raw)
+
+ >>> inputs = processor(images=[image], text=[prompt], return_tensors="pt").to(model.device, torch.bfloat16)
+
+ >>> generated_ids = model.generate(**inputs, max_new_tokens=100, do_sample=False)
+ >>> processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError(
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
+ )
+
+ if pixel_values is not None and inputs_embeds is not None:
+ raise ValueError(
+ "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
+ )
+
+ if pixel_values is not None:
+ image_tokens = self.get_image_tokens(pixel_values, image_sizes)
+ special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
+ image_tokens = image_tokens.to(input_ids.device, input_ids.dtype)
+ input_ids = input_ids.masked_scatter(special_image_mask, image_tokens)
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ return self.text_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ cache_position=cache_position,
+ logits_to_keep=logits_to_keep,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ attention_mask=None,
+ inputs_embeds=None,
+ cache_position=None,
+ position_ids=None,
+ use_cache=True,
+ pixel_values=None,
+ **kwargs,
+ ):
+ # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
+
+ model_inputs = super().prepare_inputs_for_generation(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ cache_position=cache_position,
+ position_ids=position_ids,
+ pixel_values=pixel_values,
+ use_cache=use_cache,
+ **kwargs,
+ )
+
+ if cache_position[0] != 0:
+ model_inputs["pixel_values"] = None
+
+ return model_inputs
+
+
+__all__ = ["Emu3ForConditionalGeneration", "Emu3ForCausalLM", "Emu3TextModel", "Emu3PreTrainedModel", "Emu3VQVAE"]
diff --git a/docs/transformers/build/lib/transformers/models/emu3/modular_emu3.py b/docs/transformers/build/lib/transformers/models/emu3/modular_emu3.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4e35e71d21baacd46310d029b9caf7d5925457a
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/emu3/modular_emu3.py
@@ -0,0 +1,1321 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from functools import cached_property
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint
+
+from ...cache_utils import Cache
+from ...generation import GenerationMixin
+from ...modeling_outputs import (
+ CausalLMOutputWithPast,
+)
+from ...modeling_utils import PreTrainedModel
+from ...utils import (
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ can_return_tuple,
+ logging,
+ replace_return_docstrings,
+)
+from ..chameleon.modeling_chameleon import (
+ ChameleonPreTrainedModel,
+ ChameleonVQVAEEncoderConvDownsample,
+)
+from ..llama.modeling_llama import (
+ LlamaDecoderLayer,
+ LlamaForCausalLM,
+ LlamaModel,
+)
+from ..siglip.modeling_siglip import SiglipAttention
+from .configuration_emu3 import Emu3Config, Emu3TextConfig, Emu3VQVAEConfig
+
+
+_CONFIG_FOR_DOC = "Emu3Config"
+_CHECKPOINT_FOR_DOC = "BAAI/Emu3-Chat-hf"
+
+logger = logging.get_logger(__name__)
+
+
+# Has extra dropout which no other model in the library has
+class Emu3DecoderLayer(LlamaDecoderLayer):
+ def __init__(self, config: Emu3Config, layer_idx: int):
+ super().__init__(config, layer_idx)
+ self.dropout = nn.Dropout(config.attention_dropout)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ **kwargs,
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*):
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
+ query_sequence_length, key_sequence_length)` if default attention is used.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence
+ kwargs (`dict`, *optional*):
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
+ into the model
+ """
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+ hidden_states = residual + self.dropout(hidden_states)
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + self.dropout(hidden_states)
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ return outputs
+
+
+class Emu3VQVAEVectorQuantizer(nn.Module):
+ """
+ A module for vector quantization using learned embedding vectors.
+
+ This module implements the quantization process similar to te one described in
+ the VQ-VAE (Vector Quantized Variational AutoEncoder) paper. It quantizes continuous
+ input vectors into discrete codebook vectors, which are learned during training.
+ Current implementation improves over previous ones by avoiding costly matrix multiplications
+ and allowing for post-hoc remapping of indices.
+ """
+
+ def __init__(self, config: Emu3VQVAEConfig):
+ super().__init__()
+ self.embedding = nn.Embedding(config.codebook_size, config.embed_dim)
+ self.embedding.weight.data.uniform_(-1.0 / config.codebook_size, 1.0 / config.codebook_size)
+
+ def forward(self, hidden_state: torch.Tensor):
+ batch_size, temporal, channels, height, width = hidden_state.shape
+ hidden_state = hidden_state.permute(0, 1, 3, 4, 2).contiguous()
+ hidden_state_flattened = hidden_state.view(-1, channels)
+
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+ hidden_state_sum = torch.sum(hidden_state_flattened**2, dim=1, keepdim=True)
+ embedding_sum = torch.sum(self.embedding.weight**2, dim=1)
+
+ # "bd,dn->bn",
+ distances = 2 * torch.matmul(hidden_state_flattened, self.embedding.weight.transpose(0, 1))
+ distances = hidden_state_sum + embedding_sum - distances
+
+ min_encoding_indices = torch.argmin(distances, dim=1)
+ min_encoding_indices = min_encoding_indices.view(batch_size, temporal, height, width)
+ return min_encoding_indices
+
+
+class Emu3VQVAEEncoderConvDownsample(ChameleonVQVAEEncoderConvDownsample):
+ pass
+
+
+class Emu3VQVAEEncoderConvUpsample(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
+
+ def forward(self, hidden_states):
+ hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
+ hidden_states = self.conv(hidden_states)
+ return hidden_states
+
+
+class Emu3VQVAEConv3d(nn.Module):
+ def __init__(
+ self,
+ in_channel: int,
+ out_channel: int,
+ kernel_size: Tuple[int],
+ stride: Tuple[int],
+ ):
+ super().__init__()
+
+ padding_sizes = [one_kernel - one_stride for one_kernel, one_stride in zip(kernel_size[1:], stride[1:])]
+ self.padding = ()
+ for pad_size in padding_sizes[::-1]:
+ self.padding += (pad_size // 2 + pad_size % 2, pad_size // 2)
+ self.padding += (2, 0)
+
+ self.conv = nn.Conv3d(
+ in_channel,
+ out_channel,
+ kernel_size,
+ stride=stride,
+ )
+
+ def forward(self, hidden_states: torch.Tensor):
+ hidden_states = F.pad(hidden_states, self.padding)
+ hidden_states = self.conv(hidden_states)
+ return hidden_states
+
+
+class Emu3VQVAESpatialNorm(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ ):
+ super().__init__()
+ self.norm_layer = nn.GroupNorm(
+ num_channels=out_channels,
+ num_groups=32,
+ eps=1e-6,
+ affine=True,
+ )
+
+ self.conv_y = nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ )
+ self.conv_b = nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ )
+
+ def forward(self, hidden_states: torch.Tensor, quant_states: torch.Tensor):
+ quant_states = F.interpolate(quant_states, size=hidden_states.shape[-2:], mode="nearest")
+ hidden_states = self.norm_layer(hidden_states)
+ hidden_states = hidden_states * self.conv_y(quant_states) + self.conv_b(quant_states)
+ return hidden_states
+
+
+class Emu3VQVAETemporalUpsample(nn.Module):
+ def __init__(
+ self,
+ in_channel: int,
+ out_channel: int,
+ ):
+ super().__init__()
+ self.conv = Emu3VQVAEConv3d(
+ in_channel,
+ out_channel,
+ kernel_size=(3, 3, 3),
+ stride=(1, 1, 1),
+ )
+
+ def forward(self, hidden_states: torch.Tensor):
+ batch_size, channels, temporal, height, width = hidden_states.shape
+ hidden_states = hidden_states.permute(0, 1, 3, 4, 2).contiguous().view(batch_size, -1, temporal)
+ hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
+ hidden_states = hidden_states.view(batch_size, channels, height, width, -1).permute(0, 1, 4, 2, 3).contiguous()
+ hidden_states = self.conv(hidden_states)
+ return hidden_states
+
+
+class Emu3VQVAETemporalDownsample(nn.Module):
+ def __init__(
+ self,
+ in_channel: int,
+ out_channel: int,
+ ):
+ super().__init__()
+ self.conv = Emu3VQVAEConv3d(
+ in_channel,
+ out_channel,
+ kernel_size=(4, 3, 3),
+ stride=(2, 1, 1),
+ )
+
+ def forward(self, hidden_states: torch.Tensor):
+ hidden_states = self.conv(hidden_states)
+ return hidden_states
+
+
+class Emu3VQVAETemporalResnetBlock(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels=None,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = in_channels if out_channels is None else out_channels
+
+ self.norm1 = nn.BatchNorm3d(in_channels)
+ self.conv1 = Emu3VQVAEConv3d(
+ in_channels,
+ out_channels,
+ kernel_size=(3, 3, 3),
+ stride=(1, 1, 1),
+ )
+ self.norm2 = nn.BatchNorm3d(out_channels)
+ self.conv2 = Emu3VQVAEConv3d(
+ out_channels,
+ out_channels,
+ kernel_size=(3, 3, 3),
+ stride=(1, 1, 1),
+ )
+ if self.in_channels != self.out_channels:
+ self.nin_shortcut = nn.Conv3d(
+ in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ )
+
+ def forward(self, hidden_states):
+ residual = hidden_states
+ hidden_states = self.norm1(hidden_states)
+ hidden_states *= torch.sigmoid(hidden_states)
+ hidden_states = self.conv1(hidden_states)
+
+ hidden_states = self.norm2(hidden_states)
+ hidden_states *= torch.sigmoid(hidden_states)
+ hidden_states = self.conv2(hidden_states)
+
+ if self.in_channels != self.out_channels:
+ residual = self.nin_shortcut(residual)
+
+ return residual + hidden_states
+
+
+class Emu3VQVAEResnetBlock(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: Optional[int] = None,
+ quant_channels: Optional[int] = None,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.quant_channels = quant_channels
+
+ if quant_channels is None:
+ self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=32, eps=1e-6, affine=True)
+ self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=32, eps=1e-6, affine=True)
+ else:
+ self.norm1 = Emu3VQVAESpatialNorm(quant_channels, in_channels)
+ self.norm2 = Emu3VQVAESpatialNorm(quant_channels, out_channels)
+
+ self.conv1 = nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ )
+
+ self.conv2 = nn.Conv2d(
+ out_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ )
+
+ if self.in_channels != self.out_channels:
+ self.nin_shortcut = nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ )
+
+ def forward(self, hidden_states: torch.Tensor, quant_channels: Optional[torch.Tensor] = None):
+ norm_args = () if self.quant_channels is None else (quant_channels,)
+
+ residual = hidden_states
+ hidden_states = self.norm1(hidden_states, *norm_args)
+ hidden_states *= torch.sigmoid(hidden_states)
+ hidden_states = self.conv1(hidden_states)
+
+ hidden_states = self.norm2(hidden_states, *norm_args)
+ hidden_states *= torch.sigmoid(hidden_states)
+ hidden_states = self.conv2(hidden_states)
+
+ if self.in_channels != self.out_channels:
+ residual = self.nin_shortcut(residual)
+
+ return residual + hidden_states
+
+
+class Emu3VQVAEAttentionBlock(SiglipAttention):
+ def __init__(self, config: Emu3VQVAEConfig):
+ super().__init__(config)
+
+ # for compatibility with the attention interface
+ self.num_key_value_groups = 1
+
+
+class Emu3VQVAEGroupNorm(nn.GroupNorm):
+ """
+ Same as the torch GroupNorm with the only difference that this ones accepts
+ an optional kwarg `quant_states` which is not used. This class makes it easier to
+ use SpatialNorm or GroupNorm without conditionals
+ """
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+
+ def forward(self, input, quant_states=None):
+ return F.group_norm(input, self.num_groups, self.weight, self.bias, self.eps)
+
+
+class Emu3VQVAEMiddleBlock(nn.Module):
+ def __init__(self, config, in_channels, quant_channels=None):
+ super().__init__()
+
+ self.block_1 = Emu3VQVAEResnetBlock(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ quant_channels=quant_channels,
+ )
+ self.attn_1 = Emu3VQVAEAttentionBlock(config)
+ if quant_channels is None:
+ self.attn_norm = Emu3VQVAEGroupNorm(num_channels=in_channels, num_groups=32, eps=1e-6, affine=True)
+ else:
+ self.attn_norm = Emu3VQVAESpatialNorm(quant_channels, in_channels)
+
+ self.block_2 = Emu3VQVAEResnetBlock(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ quant_channels=quant_channels,
+ )
+
+ def forward(self, hidden_states: torch.FloatTensor, quant_states: Optional[torch.FloatTensor] = None):
+ hidden_states = self.block_1(hidden_states, quant_states)
+ residual = hidden_states
+ hidden_states = self.attn_norm(hidden_states, quant_states)
+ batch_size, channels, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channels, height * width).transpose(1, 2)
+ hidden_states = self.attn_1(hidden_states)[0]
+ hidden_states = hidden_states.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2)
+ hidden_states = residual + hidden_states
+ hidden_states = self.block_2(hidden_states, quant_states)
+ return hidden_states
+
+
+class Emu3VQVAEDownBlock(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+
+ self.num_resolutions = len(config.channel_multiplier)
+ self.num_res_blocks = config.num_res_blocks
+ base_channels = config.base_channels
+ channel_multiplier = config.channel_multiplier
+
+ in_channel_multiplier = (1,) + tuple(channel_multiplier)
+ self.in_channel_multiplier = in_channel_multiplier
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ attn_norms = nn.ModuleList()
+ block_in = base_channels * in_channel_multiplier[i_level]
+ block_out = base_channels * channel_multiplier[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(
+ Emu3VQVAEResnetBlock(
+ in_channels=block_in,
+ out_channels=block_out,
+ )
+ )
+ block_in = block_out
+ if config.attn_resolutions is not None and i_level in config.attn_resolutions:
+ attn.append(Emu3VQVAEAttentionBlock(config))
+ attn_norms.append(nn.GroupNorm(num_channels=block_in, num_groups=32, eps=1e-6, affine=True))
+
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ down.attn_norms = attn_norms
+ if i_level != self.num_resolutions - 1:
+ down.downsample = Emu3VQVAEEncoderConvDownsample(block_in)
+ self.down.append(down)
+
+ def forward(self, hidden_states: torch.FloatTensor):
+ for i_level, blocks in enumerate(self.down):
+ for i_block in range(self.num_res_blocks):
+ hidden_states = blocks.block[i_block](hidden_states)
+ if len(blocks.attn) > 0:
+ residual = hidden_states
+ hidden_states = blocks.attn_norms[i_block](hidden_states)
+
+ batch_size, channels, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channels, height * width).transpose(1, 2)
+ hidden_states = blocks.attn[i_block](hidden_states)[0]
+
+ hidden_states = hidden_states.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2)
+ hidden_states = residual + hidden_states
+
+ if i_level != self.num_resolutions - 1:
+ hidden_states = blocks.downsample(hidden_states)
+
+ return hidden_states
+
+
+class Emu3VQVAEUpBlock(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+
+ self.num_resolutions = len(config.channel_multiplier)
+ self.num_res_blocks = config.num_res_blocks
+
+ quant_channels = config.embed_dim
+ block_in = config.base_channels * config.channel_multiplier[-1]
+
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ attn_norms = nn.ModuleList()
+ block_out = config.base_channels * config.channel_multiplier[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ block.append(
+ Emu3VQVAEResnetBlock(
+ in_channels=block_in,
+ out_channels=block_out,
+ quant_channels=quant_channels,
+ )
+ )
+ block_in = block_out
+ if i_level in config.attn_resolutions:
+ attn.append(Emu3VQVAEAttentionBlock(config))
+ attn_norms.append(Emu3VQVAESpatialNorm(quant_channels, block_in))
+
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ up.attn_norms = attn_norms
+ if i_level != 0:
+ up.upsample = Emu3VQVAEEncoderConvUpsample(block_in)
+
+ self.up.insert(0, up)
+
+ def forward(self, hidden_states: torch.FloatTensor, quant_states: torch.FloatTensor):
+ for i_level, blocks in enumerate(self.up[::-1]):
+ for i_block in range(self.num_res_blocks + 1):
+ hidden_states = blocks.block[i_block](hidden_states, quant_states)
+ if len(blocks.attn) > 0:
+ residual = hidden_states
+ hidden_states = blocks.attn_norms[i_block](hidden_states, quant_states)
+
+ batch_size, channels, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channels, height * width).transpose(1, 2)
+ hidden_states = blocks.attn[i_block](hidden_states)[0]
+
+ hidden_states = hidden_states.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2)
+ hidden_states = residual + hidden_states
+ if i_level != len(self.up) - 1:
+ hidden_states = blocks.upsample(hidden_states)
+
+ return hidden_states
+
+
+class Emu3VQVAEEncoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+
+ base_channels = config.base_channels
+ in_channels = config.in_channels
+ double_latent = config.double_latent
+ latent_channels = config.latent_channels
+ channel_multiplier = config.channel_multiplier
+ out_channels = 2 * latent_channels if double_latent else latent_channels
+ block_in = base_channels * channel_multiplier[-1]
+
+ self.conv_in = torch.nn.Conv2d(in_channels, base_channels, kernel_size=3, stride=1, padding=1)
+ self.down_block = Emu3VQVAEDownBlock(config)
+ self.middle_block = Emu3VQVAEMiddleBlock(config, block_in)
+
+ self.norm_out = torch.nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
+ self.conv_out = torch.nn.Conv2d(
+ block_in,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ )
+
+ temporal_down_blocks = int(math.log2(config.temporal_downsample_factor))
+ self.time_conv = nn.ModuleList()
+ self.time_res_stack = nn.ModuleList()
+
+ for i in range(temporal_down_blocks):
+ conv = Emu3VQVAETemporalDownsample(out_channels, out_channels)
+ self.time_conv.append(conv)
+
+ for _ in range(config.num_res_blocks):
+ time_res_conv = Emu3VQVAETemporalResnetBlock(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ )
+ self.time_res_stack.append(time_res_conv)
+
+ def forward(self, pixel_values: torch.LongTensor):
+ temporal_dim = pixel_values.shape[1]
+ pixel_values = pixel_values.reshape(-1, *pixel_values.shape[2:])
+
+ # downsampling & middle
+ hidden_states = self.conv_in(pixel_values)
+ hidden_states = self.down_block(hidden_states)
+ hidden_states = self.middle_block(hidden_states)
+
+ # end
+ hidden_states = self.norm_out(hidden_states)
+ hidden_states *= torch.sigmoid(hidden_states)
+ hidden_states = self.conv_out(hidden_states)
+
+ hidden_states = hidden_states.reshape(-1, temporal_dim, *hidden_states.shape[1:])
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
+
+ # temporal convs
+ for conv in self.time_conv:
+ hidden_states = conv(hidden_states)
+ hidden_states *= torch.sigmoid(hidden_states)
+
+ for layer in self.time_res_stack:
+ hidden_states = layer(hidden_states)
+
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
+
+ return hidden_states
+
+
+class Emu3VQVAEDecoder(nn.Module):
+ def __init__(self, config: Emu3VQVAEConfig):
+ super().__init__()
+
+ quant_channels = config.embed_dim
+ block_in = config.base_channels * config.channel_multiplier[-1]
+ self.time_res_stack = nn.ModuleList()
+ for _ in range(config.num_res_blocks):
+ time_res_conv = Emu3VQVAETemporalResnetBlock(
+ in_channels=config.latent_channels, out_channels=config.latent_channels
+ )
+ self.time_res_stack.append(time_res_conv)
+
+ temp_upsample_block_num = int(math.log2(config.temporal_downsample_factor))
+ self.time_conv = nn.ModuleList()
+ for i in range(temp_upsample_block_num):
+ conv = Emu3VQVAETemporalUpsample(config.latent_channels, config.latent_channels)
+ self.time_conv.append(conv)
+
+ self.conv_in = nn.Conv2d(
+ config.latent_channels,
+ block_in,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ )
+
+ self.middle_block = Emu3VQVAEMiddleBlock(config, block_in, quant_channels=quant_channels)
+ self.up_block = Emu3VQVAEUpBlock(config)
+
+ block_in = config.base_channels * config.channel_multiplier[0]
+ self.norm_out = Emu3VQVAESpatialNorm(quant_channels, block_in)
+ self.conv_out = nn.Conv2d(
+ block_in,
+ config.out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ )
+
+ def forward(self, hidden_states: torch.Tensor, quant_states: torch.Tensor):
+ hidden_quant_states = torch.cat((hidden_states, quant_states), dim=0)
+ hidden_quant_states = hidden_quant_states.permute(0, 2, 1, 3, 4)
+
+ # temporal convs
+ for layer in self.time_res_stack:
+ hidden_quant_states = layer(hidden_quant_states)
+
+ for layer in self.time_conv:
+ hidden_quant_states = layer(hidden_quant_states)
+ hidden_quant_states *= torch.sigmoid(hidden_quant_states)
+
+ hidden_quant_states = hidden_quant_states.permute(0, 2, 1, 3, 4)
+ hidden_states, quant_states = torch.chunk(hidden_quant_states, 2, dim=0)
+ hidden_states = hidden_states.reshape(-1, *hidden_states.shape[2:])
+ quant_states = quant_states.reshape(-1, *quant_states.shape[2:])
+
+ hidden_states = self.conv_in(hidden_states)
+
+ # middle & upsampling
+ hidden_states = self.middle_block(hidden_states, quant_states)
+ hidden_states = self.up_block(hidden_states, quant_states)
+
+ hidden_states = self.norm_out(hidden_states, quant_states)
+ hidden_states *= torch.sigmoid(hidden_states)
+ hidden_states = self.conv_out(hidden_states)
+
+ return hidden_states
+
+
+EMU3_VQ_START_DOCSTRING = r"""
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`Emu3VQVAEConfig`]):
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
+ load the weights associated with the model, only the configuration. Check out the
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+@add_start_docstrings(
+ """The VQ-VAE model used in Emu3 for encoding/decoding images into discrete tokens.
+ This model follows the "Make-a-scene: Scene-based text-to-image generation with human priors" paper from
+ [ Oran Gafni, Adam Polyak, Oron Ashual, Shelly Sheynin, Devi Parikh, and Yaniv Taigman](https://arxiv.org/abs/2203.13131).
+ """,
+ EMU3_VQ_START_DOCSTRING,
+)
+class Emu3VQVAE(PreTrainedModel):
+ config_class = Emu3VQVAEConfig
+ base_model_prefix = "emuvideovq"
+ main_input_name = "pixel_values"
+ _supports_sdpa = True
+ _supports_flash_attn_2 = True
+ _supports_flex_attn = True
+ _no_split_modules = [
+ "Emu3VQVAETemporalResnetBlock",
+ "Emu3VQVAEAttentionBlock",
+ "Emu3VQVAEResnetBlock",
+ "Emu3VQVAEVectorQuantizer",
+ ]
+
+ def _init_weights(self, module):
+ if isinstance(module, (nn.Conv2d, nn.Conv3d)):
+ nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
+ if module.bias is not None:
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
+ bound = 1 / math.sqrt(fan_in)
+ nn.init.uniform_(module.bias, -bound, bound)
+ elif isinstance(module, nn.Linear):
+ nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5))
+ if module.bias is not None:
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
+ bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
+ nn.init.uniform_(module.bias, -bound, bound)
+ elif isinstance(module, (nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm)):
+ nn.init.constant_(module.weight, 1.0)
+ nn.init.constant_(module.bias, 0.0)
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_()
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+ def __init__(self, config: Emu3VQVAEConfig):
+ super().__init__(config)
+
+ self.config = config
+
+ self.encoder = Emu3VQVAEEncoder(config)
+ self.decoder = Emu3VQVAEDecoder(config)
+ self.quantize = Emu3VQVAEVectorQuantizer(config)
+ self.vision_spatial_factor = 2 ** (len(config.channel_multiplier) - 1)
+
+ self.quant_conv = Emu3VQVAEConv3d(
+ config.latent_channels, config.embed_dim, kernel_size=(3, 1, 1), stride=(1, 1, 1)
+ )
+ self.post_quant_conv = Emu3VQVAEConv3d(
+ config.embed_dim, config.latent_channels, kernel_size=(3, 1, 1), stride=(1, 1, 1)
+ )
+ self.spatial_scale_factor = 2 ** (len(config.channel_multiplier) - 1)
+ self.eval() # Emu3's VQ model is frozen
+
+ self.post_init()
+
+ def encode(self, pixel_values: torch.Tensor, image_sizes: torch.Tensor):
+ is_image = pixel_values.ndim == 4
+ if is_image:
+ temporal = self.config.temporal_downsample_factor
+ batch_size, channels, height, width = pixel_values.shape
+ pixel_values = pixel_values.unsqueeze(1).repeat(1, temporal, 1, 1, 1)
+ else:
+ batch_size, temporal, channels, height, width = pixel_values.shape
+
+ hidden_states = self.encoder(pixel_values)
+
+ # b t c h w -> b c t h w
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
+ hidden_states = self.quant_conv(hidden_states)
+
+ # b c t h w -> b t c h w
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
+ codes = self.quantize(hidden_states)
+
+ image_tokens = codes.squeeze(1) if is_image else codes
+
+ image_tokens = [
+ single_image[: int(size[0] / self.vision_spatial_factor), : int(size[1] / self.vision_spatial_factor)]
+ for single_image, size in zip(image_tokens, image_sizes)
+ ]
+
+ return image_tokens
+
+ def decode(self, hidden_states: torch.Tensor):
+ is_image = hidden_states.ndim == 3
+ if is_image:
+ hidden_states = hidden_states.unsqueeze(1)
+
+ batch_size, temporal, height, width = hidden_states.shape
+ quant = self.quantize.embedding(hidden_states.flatten())
+
+ channels = quant.shape[-1]
+ quant = quant.view(batch_size, temporal, height, width, channels).permute(0, 4, 1, 2, 3).contiguous()
+ post_quant = self.post_quant_conv(quant)
+
+ quant = quant.permute(0, 2, 1, 3, 4)
+ post_quant = post_quant.permute(0, 2, 1, 3, 4)
+
+ video = self.decoder(post_quant, quant)
+ video = video.reshape(
+ batch_size,
+ temporal * self.config.temporal_downsample_factor,
+ self.config.out_channels,
+ height * self.spatial_scale_factor,
+ width * self.spatial_scale_factor,
+ )
+ return video[:, 0] if is_image else video
+
+
+class Emu3ImageVocabularyMapping:
+ """
+ A class for mapping discrete image tokens from VQGAN to BPE tokens.
+ """
+
+ def __init__(self, vocab_map):
+ self.vocab_map = vocab_map
+ self.eol_token_id = vocab_map.get("<|extra_200|>")
+ self.image_token_id = vocab_map.get("")
+
+ @cached_property
+ def image_tokens(self):
+ return sorted([val for name, val in self.vocab_map.items() if name.startswith("<|visual token")])
+
+ @cached_property
+ def image_tokens_str(self):
+ return sorted([name for name, val in self.vocab_map.items() if name.startswith("<|visual token")])
+
+ @cached_property
+ def img2bpe(self):
+ return {int(token[-8:-2]): self.vocab_map[token] for token in self.image_tokens_str}
+
+ @cached_property
+ def bpe2img(self):
+ return {v: k for k, v in self.img2bpe.items()}
+
+ @cached_property
+ def bpe2img_mapping_tensor(self):
+ mapping = torch.zeros(max(self.bpe2img.keys()) + 1, dtype=torch.int)
+ for k, v in self.bpe2img.items():
+ mapping[k] = v
+ return mapping
+
+ @cached_property
+ def img2bpe_mapping_tensor(self):
+ mapping = torch.zeros(max(self.img2bpe.keys()) + 1, dtype=torch.int)
+ for k, v in self.img2bpe.items():
+ mapping[k] = v
+ return mapping
+
+ def convert_img2bpe(self, img_batch: List[torch.Tensor]) -> torch.Tensor:
+ device = img_batch.device
+ eol_row = torch.ones((img_batch.shape[0], 1), dtype=torch.int) * self.eol_token_id
+ img_tokens = self.img2bpe_mapping_tensor[img_batch.to("cpu")]
+ img_tokens = torch.cat([img_tokens, eol_row], dim=-1)
+ return img_tokens.to(device)
+
+ def convert_bpe2img(self, img_batch: torch.Tensor) -> torch.Tensor:
+ device = img_batch.device
+ img_batch = img_batch[..., :-1] # remove last row of EOL tokens
+ img_tokens = self.bpe2img_mapping_tensor[img_batch.to("cpu")]
+ return img_tokens.to(device)
+
+
+class Emu3PreTrainedModel(ChameleonPreTrainedModel, Emu3VQVAE):
+ _no_split_modules = [
+ "Emu3DecoderLayer",
+ ]
+ _supports_flex_attn = True
+
+ def _init_weights(self, module):
+ std = self.config.get_text_config().initializer_range
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, Emu3RMSNorm): # noqa: F821
+ module.weight.data.fill_(1.0)
+
+
+EMU3_TEXT_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
+ `past_key_values`).
+
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
+ information on the default strategy.
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.n_positions - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ past_key_values (`Cache`, *optional*):
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
+
+ Has to be an instance of [`~cache_utils.Cache`] instance, see our
+ [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ The model will output the same cache type that is fed as input. If no `past_key_values` are passed, the
+ legacy cache format will be returned.
+
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
+ of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
+ the complete sequence length.
+"""
+
+
+EMU3_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, max_num_images, max_num_tiles, channels, image_size, image_size)):
+ The tensors corresponding to the input images. Pixel values can be obtained using
+ [`AutoImageProcessor`]. See [`Emu3ImageProcessor.__call__`] for details ([]`Emu3Processor`] uses
+ [`Emu3ImageProcessor`] for processing images).
+ image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`):
+ The sizes of the images in the batch, being (height, width) for each image. Image sizes can be obtained using
+ [`AutoImageProcessor`]. See [`Emu3ImageProcessor.__call__`] for details ([]`Emu3Processor`] uses
+ [`Emu3ImageProcessor`] for processing images).
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
+ `past_key_values`).
+
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
+ information on the default strategy.
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.n_positions - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
+
+ Has to be an instance of [`~cache_utils.Cache`] instance, see our
+ [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
+ legacy cache format will be returned.
+
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
+ of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
+ the complete sequence length.
+"""
+
+
+class Emu3TextModel(LlamaModel, Emu3PreTrainedModel):
+ def __init__(self, config: Emu3Config):
+ super().__init__(config)
+ self.layers = nn.ModuleList(
+ [Emu3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+
+ @can_return_tuple
+ @add_start_docstrings_to_model_forward(EMU3_TEXT_INPUTS_DOCSTRING)
+ def forward(self, **super_kwargs):
+ super().forward(**super_kwargs)
+
+
+class Emu3ForCausalLM(LlamaForCausalLM, Emu3PreTrainedModel, GenerationMixin):
+ config_class = Emu3TextConfig
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = Emu3TextModel(config)
+
+ @can_return_tuple
+ @add_start_docstrings_to_model_forward(EMU3_TEXT_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="Emu3TextConfig")
+ def forward(**super_kwargs):
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import Emu3Processor, Emu3ForConditionalGeneration
+ >>> import torch
+ >>> import requests
+ >>> from PIL import Image
+
+ >>> model = Emu3ForCausalLM.from_pretrained("BAAI/Emu3-Chat-hf", torch_dtype=torch.bfloat16)
+ >>> processor = Emu3Processor.from_pretrained("BAAI/Emu3-Chat-hf")
+
+ >>> inputs = processor(text=["Can you write me a poem about winter."], return_tensors="pt").to(model.device)
+
+ >>> generated_ids = model.generate(**inputs, max_new_tokens=100, do_sample=False)
+ >>> processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
+ ```"""
+ super().forward()
+
+
+class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["text_model.lm_head.weight"]
+ _supports_static_cache = False # `get_image_tokens()`, called when `pixel_values` is passed, is not compileable
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.text_model = Emu3ForCausalLM._from_config(config.text_config)
+ self.vqmodel = Emu3VQVAE(config.vq_config)
+ self.vocabulary_mapping = Emu3ImageVocabularyMapping(config.vocabulary_map)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.text_model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.text_model.set_input_embeddings(value)
+
+ def get_image_tokens(self, pixel_values: torch.FloatTensor, image_sizes: torch.LongTensor):
+ """
+ Tokenizes images into discrete tokens with VQGAN module. Converts
+ obtained image tokens into BPE tokens and wraps with "boi" and "eoi"
+ special tokens.
+
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
+ The tensors corresponding to the input images.
+ image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`):
+ The sizes of the images in the batch, being (height, width) for each image.
+ """
+ image_tokens_list = self.vqmodel.encode(pixel_values, image_sizes)
+ bpe_tokens_list = [self.vocabulary_mapping.convert_img2bpe(tokens).flatten() for tokens in image_tokens_list]
+ bpe_tokens = torch.cat(bpe_tokens_list)
+ return bpe_tokens
+
+ @torch.no_grad
+ def decode_image_tokens(self, image_tokens: torch.LongTensor, height: int, width: int):
+ """
+ Decodes generated image tokens from language model to continuous pixel values
+ with VQGAN module via upsampling.
+
+ Args:
+ image_tokens (`torch.LongTensor` of shape `(batch_size, num_of_tokens)`):
+ The tensors corresponding to the input images.
+ height (`int`):
+ Height of the generated image before upsampling.
+ width (`int`):
+ Width of the generated image before upsampling.
+ """
+ sequences = image_tokens[:, :-3].view(-1, height, width + 1)
+ image_tokens = self.vocabulary_mapping.convert_bpe2img(sequences)
+ image = self.vqmodel.decode(image_tokens)
+ return image
+
+ @can_return_tuple
+ @add_start_docstrings_to_model_forward(EMU3_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ image_sizes: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ ) -> CausalLMOutputWithPast:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import Emu3Processor, Emu3ForConditionalGeneration
+ >>> import torch
+ >>> import requests
+ >>> from PIL import Image
+
+ >>> model = Emu3ForConditionalGeneration.from_pretrained("BAAI/Emu3-Chat-hf", torch_dtype=torch.bfloat16)
+ >>> processor = Emu3Processor.from_pretrained("BAAI/Emu3-Chat-hf")
+
+ >>> conversation = [
+ ... {
+ ... "role": "system",
+ ... "content": [
+ ... {"type": "text", "text": "You are a helpful assistant."},
+ ... ],
+ ... },
+ ... {
+ ... "role": "user",
+ ... "content": [
+ ... {"type": "image"},
+ ... {"type": "text", "text": "Please describe the image."},
+ ... ],
+ ... },
+ ... ]
+
+ >>> prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
+ >>> image = Image.open(requests.get("https://www.ilankelman.org/stopsigns/australia.jpg", stream=True).raw)
+
+ >>> inputs = processor(images=[image], text=[prompt], return_tensors="pt").to(model.device, torch.bfloat16)
+
+ >>> generated_ids = model.generate(**inputs, max_new_tokens=100, do_sample=False)
+ >>> processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError(
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
+ )
+
+ if pixel_values is not None and inputs_embeds is not None:
+ raise ValueError(
+ "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
+ )
+
+ if pixel_values is not None:
+ image_tokens = self.get_image_tokens(pixel_values, image_sizes)
+ special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
+ image_tokens = image_tokens.to(input_ids.device, input_ids.dtype)
+ input_ids = input_ids.masked_scatter(special_image_mask, image_tokens)
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ return self.text_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ cache_position=cache_position,
+ logits_to_keep=logits_to_keep,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ attention_mask=None,
+ inputs_embeds=None,
+ cache_position=None,
+ position_ids=None,
+ use_cache=True,
+ pixel_values=None,
+ **kwargs,
+ ):
+ # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
+
+ model_inputs = super().prepare_inputs_for_generation(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ cache_position=cache_position,
+ position_ids=position_ids,
+ pixel_values=pixel_values,
+ use_cache=use_cache,
+ **kwargs,
+ )
+
+ if cache_position[0] != 0:
+ model_inputs["pixel_values"] = None
+
+ return model_inputs
+
+
+__all__ = [
+ "Emu3ForConditionalGeneration",
+ "Emu3ForCausalLM",
+ "Emu3TextModel",
+ "Emu3PreTrainedModel",
+ "Emu3VQVAE",
+]
diff --git a/docs/transformers/build/lib/transformers/models/emu3/processing_emu3.py b/docs/transformers/build/lib/transformers/models/emu3/processing_emu3.py
new file mode 100644
index 0000000000000000000000000000000000000000..a94dc08cd97da2cd0669d263bd423446479d69cf
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/emu3/processing_emu3.py
@@ -0,0 +1,222 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Optional, Union
+
+from ...image_processing_utils import BatchFeature
+from ...image_utils import ImageInput
+from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, TextKwargs, Unpack
+from ...tokenization_utils_base import PreTokenizedInput, TextInput
+
+
+class Emu3TextKwargs(TextKwargs, total=False):
+ return_for_image_generation: bool
+
+
+class Emu3ImagesKwargs(ImagesKwargs, total=False):
+ ratio: str
+ image_area: int
+
+
+class Emu3ProcessorKwargs(ProcessingKwargs, total=False):
+ text_kwargs: Emu3TextKwargs
+ images_kwargs: Emu3ImagesKwargs
+ _defaults = {
+ "text_kwargs": {
+ "return_for_image_generation": False,
+ },
+ "images_kwargs": {
+ "ratio": "1:1",
+ "image_area": 518400,
+ },
+ }
+
+
+class Emu3Processor(ProcessorMixin):
+ r"""
+ Constructs a Emu3 processor which wraps a Emu3 image processor and a GPT2 tokenizer into a single
+ processor.
+
+ [`Emu3Processor`] offers all the functionalities of [`Emu3ImageProcessor`] and [`GPT2TokenizerFast`].
+ See the [`~Emu3Processor.__call__`] and [`~Emu3Processor.decode`] for more information.
+
+ Args:
+ image_processor ([`Emu3ImageProcessor`]):
+ The image processor is a required input.
+ tokenizer ([`Emu3TokenizerFast`]):
+ The tokenizer is a required input.
+ chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
+ in a chat into a tokenizable string.
+ """
+
+ attributes = ["image_processor", "tokenizer"]
+ valid_kwargs = ["chat_template"]
+ tokenizer_class = ("GPT2Tokenizer", "GPT2TokenizerFast")
+ image_processor_class = "Emu3ImageProcessor"
+
+ def __init__(
+ self,
+ image_processor,
+ tokenizer,
+ chat_template=None,
+ **kwargs,
+ ):
+ self.image_token = tokenizer.image_token # image_token as placeholder to be replaced by vq-vae tokens
+ self.image_token_id = tokenizer.image_token_id
+ self.image_start_token = tokenizer.boi_token # "<|image start|>" fixed tokens for start and end of image
+ self.image_end_token = tokenizer.eoi_token # "<|image end|>"
+ self.fake_token_around_image = tokenizer.image_wrapper_token # "<|image token|>" every image starts with it
+ self.eof_token = tokenizer.eof_token # "<|extra_201|>"
+ self.bos_token = tokenizer.bos_token
+ self.downsample_ratio = 8
+ super().__init__(image_processor, tokenizer, chat_template=chat_template)
+
+ def __call__(
+ self,
+ images: Optional[ImageInput] = None,
+ text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
+ audio=None,
+ videos=None,
+ **kwargs: Unpack[Emu3ProcessorKwargs],
+ ) -> BatchFeature:
+ """
+ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
+ and `kwargs` arguments to Emu3TokenizerFast's [`~Emu3TokenizerFast.__call__`] if `text` is not `None` to encode
+ the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
+ CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the docstring
+ of the above two methods for more information.
+
+ Args:
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
+ tensor. Both channels-first and channels-last formats are supported.
+ text (`str`, `List[str]`, `List[List[str]]`):
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
+ If set, will return tensors of a particular framework. Acceptable values are:
+
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return NumPy `np.ndarray` objects.
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
+
+ Returns:
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
+
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
+ `None`).
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
+ """
+ # check if images and text inputs are reversed for BC
+
+ if isinstance(text, str):
+ text = [text]
+ elif not isinstance(text, list) and not isinstance(text[0], str):
+ raise TypeError("Invalid input text. Please provide a string, or a list of strings")
+
+ output_kwargs = self._merge_kwargs(
+ Emu3ProcessorKwargs,
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
+ **kwargs,
+ )
+ return_for_image_generation = output_kwargs["text_kwargs"].pop("return_for_image_generation", False)
+ ratio = output_kwargs["images_kwargs"].pop("ratio", None)
+ image_area = output_kwargs["images_kwargs"].pop("image_area", None)
+
+ if return_for_image_generation and images is not None:
+ raise ValueError("You should not provide `images` when `return_for_image_generation=True`")
+
+ if not return_for_image_generation and text is None and images is None:
+ raise ValueError("You must provide either text or images when `return_for_image_generation=False`")
+
+ image_features = {}
+ image_start_tokens = f"{self.image_start_token}"
+ image_end_tokens = f"{self.eof_token}{self.image_end_token}"
+
+ # generate text from image + text input, so we add placeholders for image tokens
+ if not return_for_image_generation and images is not None:
+ image_features = self.image_processor(images, **output_kwargs["images_kwargs"])
+ image_sizes = iter(image_features.image_sizes)
+
+ prompt_strings = []
+ for sample in text:
+ while self.image_token in sample:
+ image_size = next(image_sizes)
+ height, width = image_size
+ height = height // self.downsample_ratio
+ width = width // self.downsample_ratio
+ image_seq_length = height * (width + 1) # +1 for extra row when converting to BPE in modeling code
+
+ image_placeholder = f"{image_start_tokens}{height}*{width}{self.fake_token_around_image}{'' * image_seq_length}{image_end_tokens}"
+ sample = sample.replace(self.image_token, image_placeholder, 1)
+ sample = f"{self.bos_token}{sample}" # add BOS because PT tokenizer doesn't add it
+ prompt_strings.append(sample)
+ text = [sample.replace("", self.image_token) for sample in prompt_strings]
+
+ # generate image from text input, so we add begin-of-image tokens from where image generation starts
+ elif return_for_image_generation:
+ height, width = self.calculate_generate_size(ratio, image_area, self.downsample_ratio)
+ image_prompt = f"{image_start_tokens}{height}*{width}{self.fake_token_around_image}"
+ text = [f"{self.bos_token}{sample}{image_prompt}" for sample in text]
+ image_features["image_sizes"] = [[height, width]] * len(text)
+
+ # else just generate from text-only input, and we do no special treatment for text
+ return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
+ data = self.tokenizer(text, **output_kwargs["text_kwargs"])
+ self._check_special_mm_tokens(text, data, modalities=["image"])
+
+ data.update(**image_features)
+
+ return BatchFeature(data=data, tensor_type=return_tensors)
+
+ def calculate_generate_size(self, ratio, image_area, spatial_factor):
+ width, height = map(int, ratio.split(":"))
+ current_area = width * height
+ target_ratio = (image_area / current_area) ** 0.5
+
+ token_height = int(round(height * target_ratio / spatial_factor))
+ token_width = int(round(width * target_ratio / spatial_factor))
+ return token_height, token_width
+
+ def postprocess(self, images: ImageInput, **kwargs):
+ return self.image_processor.postprocess(images, **kwargs)
+
+ def batch_decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to Emu3TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
+ refer to the docstring of this method for more information.
+ """
+ return self.tokenizer.batch_decode(*args, **kwargs)
+
+ def decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to Emu3TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
+ the docstring of this method for more information.
+ """
+ return self.tokenizer.decode(*args, **kwargs)
+
+ @property
+ def model_input_names(self):
+ tokenizer_input_names = self.tokenizer.model_input_names
+ image_processor_input_names = self.image_processor.model_input_names
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
+
+
+__all__ = ["Emu3Processor"]
diff --git a/docs/transformers/build/lib/transformers/models/encodec/__init__.py b/docs/transformers/build/lib/transformers/models/encodec/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3adeea056604d1d31f946a5cd0bf53ea590ea3aa
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/encodec/__init__.py
@@ -0,0 +1,28 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_encodec import *
+ from .feature_extraction_encodec import *
+ from .modeling_encodec import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/docs/transformers/build/lib/transformers/models/encodec/configuration_encodec.py b/docs/transformers/build/lib/transformers/models/encodec/configuration_encodec.py
new file mode 100644
index 0000000000000000000000000000000000000000..77fd67727dc390e0b6a402f4b3617e6aae9e9791
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/encodec/configuration_encodec.py
@@ -0,0 +1,192 @@
+# coding=utf-8
+# Copyright 2023 Meta Platforms, Inc. and affiliates, and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""EnCodec model configuration"""
+
+import math
+from typing import Optional
+
+import numpy as np
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class EncodecConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of an [`EncodecModel`]. It is used to instantiate a
+ Encodec model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the
+ [facebook/encodec_24khz](https://huggingface.co/facebook/encodec_24khz) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ target_bandwidths (`List[float]`, *optional*, defaults to `[1.5, 3.0, 6.0, 12.0, 24.0]`):
+ The range of diffent bandwiths the model can encode audio with.
+ sampling_rate (`int`, *optional*, defaults to 24000):
+ The sampling rate at which the audio waveform should be digitalized expressed in hertz (Hz).
+ audio_channels (`int`, *optional*, defaults to 1):
+ Number of channels in the audio data. Either 1 for mono or 2 for stereo.
+ normalize (`bool`, *optional*, defaults to `False`):
+ Whether the audio shall be normalized when passed.
+ chunk_length_s (`float`, *optional*):
+ If defined the audio is pre-processed into chunks of lengths `chunk_length_s` and then encoded.
+ overlap (`float`, *optional*):
+ Defines the overlap between each chunk. It is used to compute the `chunk_stride` using the following
+ formulae : `int((1.0 - self.overlap) * self.chunk_length)`.
+ hidden_size (`int`, *optional*, defaults to 128):
+ Intermediate representation dimension.
+ num_filters (`int`, *optional*, defaults to 32):
+ Number of convolution kernels of first `EncodecConv1d` down sampling layer.
+ num_residual_layers (`int`, *optional*, defaults to 1):
+ Number of residual layers.
+ upsampling_ratios (`Sequence[int]` , *optional*, defaults to `[8, 5, 4, 2]`):
+ Kernel size and stride ratios. The encoder uses downsampling ratios instead of upsampling ratios, hence it
+ will use the ratios in the reverse order to the ones specified here that must match the decoder order.
+ norm_type (`str`, *optional*, defaults to `"weight_norm"`):
+ Normalization method. Should be in `["weight_norm", "time_group_norm"]`
+ kernel_size (`int`, *optional*, defaults to 7):
+ Kernel size for the initial convolution.
+ last_kernel_size (`int`, *optional*, defaults to 7):
+ Kernel size for the last convolution layer.
+ residual_kernel_size (`int`, *optional*, defaults to 3):
+ Kernel size for the residual layers.
+ dilation_growth_rate (`int`, *optional*, defaults to 2):
+ How much to increase the dilation with each layer.
+ use_causal_conv (`bool`, *optional*, defaults to `True`):
+ Whether to use fully causal convolution.
+ pad_mode (`str`, *optional*, defaults to `"reflect"`):
+ Padding mode for the convolutions.
+ compress (`int`, *optional*, defaults to 2):
+ Reduced dimensionality in residual branches (from Demucs v3).
+ num_lstm_layers (`int`, *optional*, defaults to 2):
+ Number of LSTM layers at the end of the encoder.
+ trim_right_ratio (`float`, *optional*, defaults to 1.0):
+ Ratio for trimming at the right of the transposed convolution under the `use_causal_conv = True` setup. If
+ equal to 1.0, it means that all the trimming is done at the right.
+ codebook_size (`int`, *optional*, defaults to 1024):
+ Number of discret codes that make up VQVAE.
+ codebook_dim (`int`, *optional*):
+ Dimension of the codebook vectors. If not defined, uses `hidden_size`.
+ use_conv_shortcut (`bool`, *optional*, defaults to `True`):
+ Whether to use a convolutional layer as the 'skip' connection in the `EncodecResnetBlock` block. If False,
+ an identity function will be used, giving a generic residual connection.
+
+ Example:
+
+ ```python
+ >>> from transformers import EncodecModel, EncodecConfig
+
+ >>> # Initializing a "facebook/encodec_24khz" style configuration
+ >>> configuration = EncodecConfig()
+
+ >>> # Initializing a model (with random weights) from the "facebook/encodec_24khz" style configuration
+ >>> model = EncodecModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "encodec"
+
+ def __init__(
+ self,
+ target_bandwidths=[1.5, 3.0, 6.0, 12.0, 24.0],
+ sampling_rate=24_000,
+ audio_channels=1,
+ normalize=False,
+ chunk_length_s=None,
+ overlap=None,
+ hidden_size=128,
+ num_filters=32,
+ num_residual_layers=1,
+ upsampling_ratios=[8, 5, 4, 2],
+ norm_type="weight_norm",
+ kernel_size=7,
+ last_kernel_size=7,
+ residual_kernel_size=3,
+ dilation_growth_rate=2,
+ use_causal_conv=True,
+ pad_mode="reflect",
+ compress=2,
+ num_lstm_layers=2,
+ trim_right_ratio=1.0,
+ codebook_size=1024,
+ codebook_dim=None,
+ use_conv_shortcut=True,
+ **kwargs,
+ ):
+ self.target_bandwidths = target_bandwidths
+ self.sampling_rate = sampling_rate
+ self.audio_channels = audio_channels
+ self.normalize = normalize
+ self.chunk_length_s = chunk_length_s
+ self.overlap = overlap
+ self.hidden_size = hidden_size
+ self.num_filters = num_filters
+ self.num_residual_layers = num_residual_layers
+ self.upsampling_ratios = upsampling_ratios
+ self.norm_type = norm_type
+ self.kernel_size = kernel_size
+ self.last_kernel_size = last_kernel_size
+ self.residual_kernel_size = residual_kernel_size
+ self.dilation_growth_rate = dilation_growth_rate
+ self.use_causal_conv = use_causal_conv
+ self.pad_mode = pad_mode
+ self.compress = compress
+ self.num_lstm_layers = num_lstm_layers
+ self.trim_right_ratio = trim_right_ratio
+ self.codebook_size = codebook_size
+ self.codebook_dim = codebook_dim if codebook_dim is not None else hidden_size
+ self.use_conv_shortcut = use_conv_shortcut
+
+ if self.norm_type not in ["weight_norm", "time_group_norm"]:
+ raise ValueError(
+ f'self.norm_type must be one of `"weight_norm"`, `"time_group_norm"`), got {self.norm_type}'
+ )
+
+ super().__init__(**kwargs)
+
+ # This is a property because you might want to change the chunk_length_s on the fly
+ @property
+ def chunk_length(self) -> Optional[int]:
+ if self.chunk_length_s is None:
+ return None
+ else:
+ return int(self.chunk_length_s * self.sampling_rate)
+
+ # This is a property because you might want to change the chunk_length_s on the fly
+ @property
+ def chunk_stride(self) -> Optional[int]:
+ if self.chunk_length_s is None or self.overlap is None:
+ return None
+ else:
+ return max(1, int((1.0 - self.overlap) * self.chunk_length))
+
+ @property
+ def frame_rate(self) -> int:
+ hop_length = np.prod(self.upsampling_ratios)
+ return math.ceil(self.sampling_rate / hop_length)
+
+ @property
+ def num_quantizers(self) -> int:
+ return int(1000 * self.target_bandwidths[-1] // (self.frame_rate * 10))
+
+
+__all__ = ["EncodecConfig"]
diff --git a/docs/transformers/build/lib/transformers/models/encodec/convert_encodec_checkpoint_to_pytorch.py b/docs/transformers/build/lib/transformers/models/encodec/convert_encodec_checkpoint_to_pytorch.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1fb0168705f42b943cb5d9bd40aa904c294cfb6
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/encodec/convert_encodec_checkpoint_to_pytorch.py
@@ -0,0 +1,365 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Convert EnCodec checkpoints."""
+
+import argparse
+
+import torch
+
+from transformers import (
+ EncodecConfig,
+ EncodecFeatureExtractor,
+ EncodecModel,
+ logging,
+)
+
+
+# checkpoints downloaded from:
+# https://dl.fbaipublicfiles.com/encodec/v0/encodec_24khz-d7cc33bc.th
+# https://huggingface.co/facebook/musicgen-small/resolve/main/compression_state_dict.bin
+# https://dl.fbaipublicfiles.com/encodec/v0/encodec_48khz-7e698e3e.th
+
+
+logging.set_verbosity_info()
+logger = logging.get_logger("transformers.models.encodec")
+
+MAPPING_QUANTIZER = {
+ "quantizer.vq.layers.*._codebook.inited": "quantizer.layers.*.codebook.inited",
+ "quantizer.vq.layers.*._codebook.cluster_size": "quantizer.layers.*.codebook.cluster_size",
+ "quantizer.vq.layers.*._codebook.embed": "quantizer.layers.*.codebook.embed",
+ "quantizer.vq.layers.*._codebook.embed_avg": "quantizer.layers.*.codebook.embed_avg",
+}
+MAPPING_ENCODER = {
+ "encoder.model.0.conv.conv": "encoder.layers.0.conv",
+ "encoder.model.1.block.1.conv.conv": "encoder.layers.1.block.1.conv",
+ "encoder.model.1.block.3.conv.conv": "encoder.layers.1.block.3.conv",
+ "encoder.model.1.shortcut.conv.conv": "encoder.layers.1.shortcut.conv",
+ "encoder.model.3.conv.conv": "encoder.layers.3.conv",
+ "encoder.model.4.block.1.conv.conv": "encoder.layers.4.block.1.conv",
+ "encoder.model.4.block.3.conv.conv": "encoder.layers.4.block.3.conv",
+ "encoder.model.4.shortcut.conv.conv": "encoder.layers.4.shortcut.conv",
+ "encoder.model.6.conv.conv": "encoder.layers.6.conv",
+ "encoder.model.7.block.1.conv.conv": "encoder.layers.7.block.1.conv",
+ "encoder.model.7.block.3.conv.conv": "encoder.layers.7.block.3.conv",
+ "encoder.model.7.shortcut.conv.conv": "encoder.layers.7.shortcut.conv",
+ "encoder.model.9.conv.conv": "encoder.layers.9.conv",
+ "encoder.model.10.block.1.conv.conv": "encoder.layers.10.block.1.conv",
+ "encoder.model.10.block.3.conv.conv": "encoder.layers.10.block.3.conv",
+ "encoder.model.10.shortcut.conv.conv": "encoder.layers.10.shortcut.conv",
+ "encoder.model.12.conv.conv": "encoder.layers.12.conv",
+ "encoder.model.13.lstm": "encoder.layers.13.lstm",
+ "encoder.model.15.conv.conv": "encoder.layers.15.conv",
+}
+MAPPING_ENCODER_48K = {
+ "encoder.model.0.conv.norm": "encoder.layers.0.norm",
+ "encoder.model.1.block.1.conv.norm": "encoder.layers.1.block.1.norm",
+ "encoder.model.1.block.3.conv.norm": "encoder.layers.1.block.3.norm",
+ "encoder.model.1.shortcut.conv.norm": "encoder.layers.1.shortcut.norm",
+ "encoder.model.3.conv.norm": "encoder.layers.3.norm",
+ "encoder.model.4.block.1.conv.norm": "encoder.layers.4.block.1.norm",
+ "encoder.model.4.block.3.conv.norm": "encoder.layers.4.block.3.norm",
+ "encoder.model.4.shortcut.conv.norm": "encoder.layers.4.shortcut.norm",
+ "encoder.model.6.conv.norm": "encoder.layers.6.norm",
+ "encoder.model.7.block.1.conv.norm": "encoder.layers.7.block.1.norm",
+ "encoder.model.7.block.3.conv.norm": "encoder.layers.7.block.3.norm",
+ "encoder.model.7.shortcut.conv.norm": "encoder.layers.7.shortcut.norm",
+ "encoder.model.9.conv.norm": "encoder.layers.9.norm",
+ "encoder.model.10.block.1.conv.norm": "encoder.layers.10.block.1.norm",
+ "encoder.model.10.block.3.conv.norm": "encoder.layers.10.block.3.norm",
+ "encoder.model.10.shortcut.conv.norm": "encoder.layers.10.shortcut.norm",
+ "encoder.model.12.conv.norm": "encoder.layers.12.norm",
+ "encoder.model.15.conv.norm": "encoder.layers.15.norm",
+}
+MAPPING_DECODER = {
+ "decoder.model.0.conv.conv": "decoder.layers.0.conv",
+ "decoder.model.1.lstm": "decoder.layers.1.lstm",
+ "decoder.model.3.convtr.convtr": "decoder.layers.3.conv",
+ "decoder.model.4.block.1.conv.conv": "decoder.layers.4.block.1.conv",
+ "decoder.model.4.block.3.conv.conv": "decoder.layers.4.block.3.conv",
+ "decoder.model.4.shortcut.conv.conv": "decoder.layers.4.shortcut.conv",
+ "decoder.model.6.convtr.convtr": "decoder.layers.6.conv",
+ "decoder.model.7.block.1.conv.conv": "decoder.layers.7.block.1.conv",
+ "decoder.model.7.block.3.conv.conv": "decoder.layers.7.block.3.conv",
+ "decoder.model.7.shortcut.conv.conv": "decoder.layers.7.shortcut.conv",
+ "decoder.model.9.convtr.convtr": "decoder.layers.9.conv",
+ "decoder.model.10.block.1.conv.conv": "decoder.layers.10.block.1.conv",
+ "decoder.model.10.block.3.conv.conv": "decoder.layers.10.block.3.conv",
+ "decoder.model.10.shortcut.conv.conv": "decoder.layers.10.shortcut.conv",
+ "decoder.model.12.convtr.convtr": "decoder.layers.12.conv",
+ "decoder.model.13.block.1.conv.conv": "decoder.layers.13.block.1.conv",
+ "decoder.model.13.block.3.conv.conv": "decoder.layers.13.block.3.conv",
+ "decoder.model.13.shortcut.conv.conv": "decoder.layers.13.shortcut.conv",
+ "decoder.model.15.conv.conv": "decoder.layers.15.conv",
+}
+MAPPING_DECODER_48K = {
+ "decoder.model.0.conv.norm": "decoder.layers.0.norm",
+ "decoder.model.3.convtr.norm": "decoder.layers.3.norm",
+ "decoder.model.4.block.1.conv.norm": "decoder.layers.4.block.1.norm",
+ "decoder.model.4.block.3.conv.norm": "decoder.layers.4.block.3.norm",
+ "decoder.model.4.shortcut.conv.norm": "decoder.layers.4.shortcut.norm",
+ "decoder.model.6.convtr.norm": "decoder.layers.6.norm",
+ "decoder.model.7.block.1.conv.norm": "decoder.layers.7.block.1.norm",
+ "decoder.model.7.block.3.conv.norm": "decoder.layers.7.block.3.norm",
+ "decoder.model.7.shortcut.conv.norm": "decoder.layers.7.shortcut.norm",
+ "decoder.model.9.convtr.norm": "decoder.layers.9.norm",
+ "decoder.model.10.block.1.conv.norm": "decoder.layers.10.block.1.norm",
+ "decoder.model.10.block.3.conv.norm": "decoder.layers.10.block.3.norm",
+ "decoder.model.10.shortcut.conv.norm": "decoder.layers.10.shortcut.norm",
+ "decoder.model.12.convtr.norm": "decoder.layers.12.norm",
+ "decoder.model.13.block.1.conv.norm": "decoder.layers.13.block.1.norm",
+ "decoder.model.13.block.3.conv.norm": "decoder.layers.13.block.3.norm",
+ "decoder.model.13.shortcut.conv.norm": "decoder.layers.13.shortcut.norm",
+ "decoder.model.15.conv.norm": "decoder.layers.15.norm",
+}
+MAPPING_24K = {
+ **MAPPING_QUANTIZER,
+ **MAPPING_ENCODER,
+ **MAPPING_DECODER,
+}
+MAPPING_48K = {
+ **MAPPING_QUANTIZER,
+ **MAPPING_ENCODER,
+ **MAPPING_ENCODER_48K,
+ **MAPPING_DECODER,
+ **MAPPING_DECODER_48K,
+}
+TOP_LEVEL_KEYS = []
+IGNORE_KEYS = []
+
+
+def set_recursively(hf_pointer, key, value, full_name, weight_type):
+ for attribute in key.split("."):
+ hf_pointer = getattr(hf_pointer, attribute)
+
+ if weight_type is not None:
+ hf_shape = getattr(hf_pointer, weight_type).shape
+ else:
+ hf_shape = hf_pointer.shape
+
+ if hf_shape != value.shape:
+ raise ValueError(
+ f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be"
+ f" {value.shape} for {full_name}"
+ )
+
+ if weight_type == "weight":
+ hf_pointer.weight.data = value
+ elif weight_type == "weight_g":
+ hf_pointer.weight_g.data = value
+ elif weight_type == "weight_v":
+ hf_pointer.weight_v.data = value
+ elif weight_type == "bias":
+ hf_pointer.bias.data = value
+ elif weight_type == "running_mean":
+ hf_pointer.running_mean.data = value
+ elif weight_type == "running_var":
+ hf_pointer.running_var.data = value
+ elif weight_type == "num_batches_tracked":
+ hf_pointer.num_batches_tracked.data = value
+ elif weight_type == "weight_ih_l0":
+ hf_pointer.weight_ih_l0.data = value
+ elif weight_type == "weight_hh_l0":
+ hf_pointer.weight_hh_l0.data = value
+ elif weight_type == "bias_ih_l0":
+ hf_pointer.bias_ih_l0.data = value
+ elif weight_type == "bias_hh_l0":
+ hf_pointer.bias_hh_l0.data = value
+ elif weight_type == "weight_ih_l1":
+ hf_pointer.weight_ih_l1.data = value
+ elif weight_type == "weight_hh_l1":
+ hf_pointer.weight_hh_l1.data = value
+ elif weight_type == "bias_ih_l1":
+ hf_pointer.bias_ih_l1.data = value
+ elif weight_type == "bias_hh_l1":
+ hf_pointer.bias_hh_l1.data = value
+ else:
+ hf_pointer.data = value
+
+ logger.info(f"{key + ('.' + weight_type if weight_type is not None else '')} was initialized from {full_name}.")
+
+
+def should_ignore(name, ignore_keys):
+ for key in ignore_keys:
+ if key.endswith(".*"):
+ if name.startswith(key[:-1]):
+ return True
+ elif ".*." in key:
+ prefix, suffix = key.split(".*.")
+ if prefix in name and suffix in name:
+ return True
+ elif key in name:
+ return True
+ return False
+
+
+def recursively_load_weights(orig_dict, hf_model, model_name):
+ unused_weights = []
+
+ if model_name in ["encodec_24khz", "encodec_32khz"]:
+ MAPPING = MAPPING_24K
+ elif model_name == "encodec_48khz":
+ MAPPING = MAPPING_48K
+ else:
+ raise ValueError(f"Unsupported model: {model_name}")
+
+ for name, value in orig_dict.items():
+ if should_ignore(name, IGNORE_KEYS):
+ logger.info(f"{name} was ignored")
+ continue
+
+ is_used = False
+ for key, mapped_key in MAPPING.items():
+ if "*" in key:
+ prefix, suffix = key.split(".*.")
+ if prefix in name and suffix in name:
+ key = suffix
+
+ if key in name:
+ # HACK otherwise .embed gets initialized with .embed_avg too
+ if key.endswith("embed") and name.endswith("embed_avg"):
+ continue
+
+ is_used = True
+ if "*" in mapped_key:
+ layer_index = name.split(key)[0].split(".")[-2]
+ mapped_key = mapped_key.replace("*", layer_index)
+ if "weight_g" in name:
+ weight_type = "weight_g"
+ elif "weight_v" in name:
+ weight_type = "weight_v"
+ elif "weight_ih_l0" in name:
+ weight_type = "weight_ih_l0"
+ elif "weight_hh_l0" in name:
+ weight_type = "weight_hh_l0"
+ elif "bias_ih_l0" in name:
+ weight_type = "bias_ih_l0"
+ elif "bias_hh_l0" in name:
+ weight_type = "bias_hh_l0"
+ elif "weight_ih_l1" in name:
+ weight_type = "weight_ih_l1"
+ elif "weight_hh_l1" in name:
+ weight_type = "weight_hh_l1"
+ elif "bias_ih_l1" in name:
+ weight_type = "bias_ih_l1"
+ elif "bias_hh_l1" in name:
+ weight_type = "bias_hh_l1"
+ elif "bias" in name:
+ weight_type = "bias"
+ elif "weight" in name:
+ weight_type = "weight"
+ elif "running_mean" in name:
+ weight_type = "running_mean"
+ elif "running_var" in name:
+ weight_type = "running_var"
+ elif "num_batches_tracked" in name:
+ weight_type = "num_batches_tracked"
+ else:
+ weight_type = None
+ set_recursively(hf_model, mapped_key, value, name, weight_type)
+ continue
+ if not is_used:
+ unused_weights.append(name)
+
+ logger.warning(f"Unused weights: {unused_weights}")
+
+
+@torch.no_grad()
+def convert_checkpoint(
+ model_name,
+ checkpoint_path,
+ pytorch_dump_folder_path,
+ config_path=None,
+ repo_id=None,
+):
+ """
+ Copy/paste/tweak model's weights to transformers design.
+ """
+ if config_path is not None:
+ config = EncodecConfig.from_pretrained(config_path)
+ else:
+ config = EncodecConfig()
+
+ if model_name == "encodec_24khz":
+ pass # config is already correct
+ elif model_name == "encodec_32khz":
+ config.upsampling_ratios = [8, 5, 4, 4]
+ config.target_bandwidths = [2.2]
+ config.num_filters = 64
+ config.sampling_rate = 32_000
+ config.codebook_size = 2048
+ config.use_causal_conv = False
+ config.normalize = False
+ config.use_conv_shortcut = False
+ elif model_name == "encodec_48khz":
+ config.upsampling_ratios = [8, 5, 4, 2]
+ config.target_bandwidths = [3.0, 6.0, 12.0, 24.0]
+ config.sampling_rate = 48_000
+ config.audio_channels = 2
+ config.use_causal_conv = False
+ config.norm_type = "time_group_norm"
+ config.normalize = True
+ config.chunk_length_s = 1.0
+ config.overlap = 0.01
+ else:
+ raise ValueError(f"Unknown model name: {model_name}")
+
+ model = EncodecModel(config)
+
+ feature_extractor = EncodecFeatureExtractor(
+ feature_size=config.audio_channels,
+ sampling_rate=config.sampling_rate,
+ chunk_length_s=config.chunk_length_s,
+ overlap=config.overlap,
+ )
+ feature_extractor.save_pretrained(pytorch_dump_folder_path)
+
+ original_checkpoint = torch.load(checkpoint_path, weights_only=True)
+ if "best_state" in original_checkpoint:
+ # we might have a training state saved, in which case discard the yaml results and just retain the weights
+ original_checkpoint = original_checkpoint["best_state"]
+ recursively_load_weights(original_checkpoint, model, model_name)
+ model.save_pretrained(pytorch_dump_folder_path)
+
+ if repo_id:
+ print("Pushing to the hub...")
+ feature_extractor.push_to_hub(repo_id)
+ model.push_to_hub(repo_id)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--model",
+ default="encodec_24khz",
+ type=str,
+ help="The model to convert. Should be one of 'encodec_24khz', 'encodec_32khz', 'encodec_48khz'.",
+ )
+ parser.add_argument("--checkpoint_path", required=True, default=None, type=str, help="Path to original checkpoint")
+ parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
+ parser.add_argument(
+ "--pytorch_dump_folder_path", required=True, default=None, type=str, help="Path to the output PyTorch model."
+ )
+ parser.add_argument(
+ "--push_to_hub", default=None, type=str, help="Where to upload the converted model on the 🤗 hub."
+ )
+
+ args = parser.parse_args()
+ convert_checkpoint(
+ args.model,
+ args.checkpoint_path,
+ args.pytorch_dump_folder_path,
+ args.config_path,
+ args.push_to_hub,
+ )
diff --git a/docs/transformers/build/lib/transformers/models/encodec/feature_extraction_encodec.py b/docs/transformers/build/lib/transformers/models/encodec/feature_extraction_encodec.py
new file mode 100644
index 0000000000000000000000000000000000000000..f33191862e4891375059b882f8c50f5953dabc1b
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/encodec/feature_extraction_encodec.py
@@ -0,0 +1,209 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Feature extractor class for EnCodec."""
+
+from typing import List, Optional, Union
+
+import numpy as np
+
+from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
+from ...feature_extraction_utils import BatchFeature
+from ...utils import PaddingStrategy, TensorType, logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class EncodecFeatureExtractor(SequenceFeatureExtractor):
+ r"""
+ Constructs an EnCodec feature extractor.
+
+ This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains
+ most of the main methods. Users should refer to this superclass for more information regarding those methods.
+
+ Instantiating a feature extractor with the defaults will yield a similar configuration to that of the
+ [facebook/encodec_24khz](https://huggingface.co/facebook/encodec_24khz) architecture.
+
+ Args:
+ feature_size (`int`, *optional*, defaults to 1):
+ The feature dimension of the extracted features. Use 1 for mono, 2 for stereo.
+ sampling_rate (`int`, *optional*, defaults to 24000):
+ The sampling rate at which the audio waveform should be digitalized expressed in hertz (Hz).
+ padding_value (`float`, *optional*, defaults to 0.0):
+ The value that is used to fill the padding values.
+ chunk_length_s (`float`, *optional*):
+ If defined the audio is pre-processed into chunks of lengths `chunk_length_s` and then encoded.
+ overlap (`float`, *optional*):
+ Defines the overlap between each chunk. It is used to compute the `chunk_stride` using the following
+ formulae : `int((1.0 - self.overlap) * self.chunk_length)`.
+ """
+
+ model_input_names = ["input_values", "padding_mask"]
+
+ def __init__(
+ self,
+ feature_size: int = 1,
+ sampling_rate: int = 24000,
+ padding_value: float = 0.0,
+ chunk_length_s: Optional[float] = None,
+ overlap: Optional[float] = None,
+ **kwargs,
+ ):
+ super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
+ self.chunk_length_s = chunk_length_s
+ self.overlap = overlap
+
+ # This is a property because you might want to change the chunk_length_s on the fly
+ @property
+ def chunk_length(self) -> Optional[int]:
+ if self.chunk_length_s is None:
+ return None
+ else:
+ return int(self.chunk_length_s * self.sampling_rate)
+
+ # This is a property because you might want to change the chunk_length_s on the fly
+ @property
+ def chunk_stride(self) -> Optional[int]:
+ if self.chunk_length_s is None or self.overlap is None:
+ return None
+ else:
+ return max(1, int((1.0 - self.overlap) * self.chunk_length))
+
+ def __call__(
+ self,
+ raw_audio: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]],
+ padding: Optional[Union[bool, str, PaddingStrategy]] = None,
+ truncation: Optional[bool] = False,
+ max_length: Optional[int] = None,
+ return_tensors: Optional[Union[str, TensorType]] = None,
+ sampling_rate: Optional[int] = None,
+ ) -> BatchFeature:
+ """
+ Main method to featurize and prepare for the model one or several sequence(s).
+
+ Args:
+ raw_audio (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`):
+ The sequence or batch of sequences to be processed. Each sequence can be a numpy array, a list of float
+ values, a list of numpy arrays or a list of list of float values. The numpy array must be of shape
+ `(num_samples,)` for mono audio (`feature_size = 1`), or `(2, num_samples)` for stereo audio
+ (`feature_size = 2`).
+ padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding
+ index) among:
+
+ - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
+ sequence if provided).
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
+ acceptable input length for the model if that argument is not provided.
+ - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
+ lengths).
+ truncation (`bool`, *optional*, defaults to `False`):
+ Activates truncation to cut input sequences longer than `max_length` to `max_length`.
+ max_length (`int`, *optional*):
+ Maximum length of the returned list and optionally padding length (see above).
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
+ If set, will return tensors instead of list of python integers. Acceptable values are:
+
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return Numpy `np.ndarray` objects.
+ sampling_rate (`int`, *optional*):
+ The sampling rate at which the `audio` input was sampled. It is strongly recommended to pass
+ `sampling_rate` at the forward call to prevent silent errors.
+ """
+ if sampling_rate is not None:
+ if sampling_rate != self.sampling_rate:
+ raise ValueError(
+ f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of"
+ f" {self.sampling_rate}. Please make sure that the provided audio input was sampled with"
+ f" {self.sampling_rate} and not {sampling_rate}."
+ )
+ else:
+ logger.warning(
+ f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. "
+ "Failing to do so can result in silent errors that might be hard to debug."
+ )
+
+ if padding and truncation:
+ raise ValueError("Both padding and truncation were set. Make sure you only set one.")
+ elif padding is None:
+ # by default let's pad the inputs
+ padding = True
+
+ is_batched = bool(
+ isinstance(raw_audio, (list, tuple)) and (isinstance(raw_audio[0], (np.ndarray, tuple, list)))
+ )
+
+ if is_batched:
+ raw_audio = [np.asarray(audio, dtype=np.float32).T for audio in raw_audio]
+ elif not is_batched and not isinstance(raw_audio, np.ndarray):
+ raw_audio = np.asarray(raw_audio, dtype=np.float32)
+ elif isinstance(raw_audio, np.ndarray) and raw_audio.dtype is np.dtype(np.float64):
+ raw_audio = raw_audio.astype(np.float32)
+
+ # always return batch
+ if not is_batched:
+ raw_audio = [np.asarray(raw_audio).T]
+
+ # verify inputs are valid
+ for idx, example in enumerate(raw_audio):
+ if example.ndim > 2:
+ raise ValueError(f"Expected input shape (channels, length) but got shape {example.shape}")
+ if self.feature_size == 1 and example.ndim != 1:
+ raise ValueError(f"Expected mono audio but example has {example.shape[-1]} channels")
+ if self.feature_size == 2 and example.shape[-1] != 2:
+ raise ValueError(f"Expected stereo audio but example has {example.shape[-1]} channels")
+
+ padded_inputs = None
+ input_values = BatchFeature({"input_values": raw_audio})
+ if self.chunk_stride is not None and self.chunk_length is not None and max_length is None:
+ if truncation:
+ max_length = min(array.shape[0] for array in raw_audio)
+ nb_step = int(np.floor(max_length / self.chunk_stride))
+ max_length = (nb_step - 1) * self.chunk_stride + self.chunk_length
+ elif padding:
+ max_length = max(array.shape[0] for array in raw_audio)
+ nb_step = int(np.ceil(max_length / self.chunk_stride))
+ max_length = (nb_step - 1) * self.chunk_stride + self.chunk_length
+ padding = "max_length"
+ else:
+ padded_inputs = input_values
+
+ # normal padding on batch
+ if padded_inputs is None:
+ padded_inputs = self.pad(
+ input_values,
+ max_length=max_length,
+ truncation=truncation,
+ padding=padding,
+ return_attention_mask=padding,
+ )
+ if padding:
+ padded_inputs["padding_mask"] = padded_inputs.pop("attention_mask")
+
+ input_values = []
+ for example in padded_inputs.pop("input_values"):
+ if self.feature_size == 1:
+ example = example[..., None]
+ input_values.append(example.T)
+
+ padded_inputs["input_values"] = input_values
+ if return_tensors is not None:
+ padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
+
+ return padded_inputs
+
+
+__all__ = ["EncodecFeatureExtractor"]
diff --git a/docs/transformers/build/lib/transformers/models/encodec/modeling_encodec.py b/docs/transformers/build/lib/transformers/models/encodec/modeling_encodec.py
new file mode 100644
index 0000000000000000000000000000000000000000..670ac99e03e03a88c07e0478e5190622403eb77a
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/encodec/modeling_encodec.py
@@ -0,0 +1,818 @@
+# coding=utf-8
+# Copyright 2023 Meta Platforms, Inc. and affiliates, and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch EnCodec model."""
+
+import math
+from dataclasses import dataclass
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+
+from ...modeling_utils import PreTrainedModel
+from ...utils import (
+ ModelOutput,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from .configuration_encodec import EncodecConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+# General docstring
+_CONFIG_FOR_DOC = "EncodecConfig"
+
+
+@dataclass
+class EncodecOutput(ModelOutput):
+ """
+ Args:
+ audio_codes (`torch.LongTensor` of shape `(batch_size, nb_chunks, chunk_length)`, *optional*):
+ Discret code embeddings computed using `model.encode`.
+ audio_values (`torch.FlaotTensor` of shape `(batch_size, sequence_length)`, *optional*)
+ Decoded audio values, obtained using the decoder part of Encodec.
+ """
+
+ audio_codes: Optional[torch.LongTensor] = None
+ audio_values: Optional[torch.FloatTensor] = None
+
+
+@dataclass
+class EncodecEncoderOutput(ModelOutput):
+ """
+ Args:
+ audio_codes (`torch.LongTensor` of shape `(batch_size, nb_chunks, chunk_length)`, *optional*):
+ Discret code embeddings computed using `model.encode`.
+ audio_scales (`torch.Tensor` of shape `(batch_size, nb_chunks)`, *optional*):
+ Scaling factor for each `audio_codes` input. This is used to unscale each chunk of audio when decoding.
+ """
+
+ audio_codes: Optional[torch.LongTensor] = None
+ audio_scales: Optional[torch.FloatTensor] = None
+
+
+@dataclass
+class EncodecDecoderOutput(ModelOutput):
+ """
+ Args:
+ audio_values (`torch.FloatTensor` of shape `(batch_size, segment_length)`, *optional*):
+ Decoded audio values, obtained using the decoder part of Encodec.
+ """
+
+ audio_values: Optional[torch.FloatTensor] = None
+
+
+class EncodecConv1d(nn.Module):
+ """Conv1d with asymmetric or causal padding and normalization."""
+
+ def __init__(
+ self, config, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, dilation: int = 1
+ ):
+ super().__init__()
+ self.causal = config.use_causal_conv
+ self.pad_mode = config.pad_mode
+ self.norm_type = config.norm_type
+
+ if self.norm_type not in ["weight_norm", "time_group_norm"]:
+ raise ValueError(
+ f'self.norm_type must be one of `"weight_norm"`, `"time_group_norm"`), got {self.norm_type}'
+ )
+
+ # warn user on unusual setup between dilation and stride
+ if stride > 1 and dilation > 1:
+ logger.warning(
+ "EncodecConv1d has been initialized with stride > 1 and dilation > 1"
+ f" (kernel_size={kernel_size} stride={stride}, dilation={dilation})."
+ )
+
+ self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, dilation=dilation)
+ weight_norm = nn.utils.weight_norm
+ if hasattr(nn.utils.parametrizations, "weight_norm"):
+ weight_norm = nn.utils.parametrizations.weight_norm
+
+ if self.norm_type == "weight_norm":
+ self.conv = weight_norm(self.conv)
+ elif self.norm_type == "time_group_norm":
+ self.norm = nn.GroupNorm(1, out_channels)
+
+ kernel_size = self.conv.kernel_size[0]
+ stride = torch.tensor(self.conv.stride[0], dtype=torch.int64)
+ dilation = self.conv.dilation[0]
+
+ # Effective kernel size with dilations.
+ kernel_size = torch.tensor((kernel_size - 1) * dilation + 1, dtype=torch.int64)
+
+ self.register_buffer("stride", stride, persistent=False)
+ self.register_buffer("kernel_size", kernel_size, persistent=False)
+ self.register_buffer("padding_total", kernel_size - stride, persistent=False)
+
+ def _get_extra_padding_for_conv1d(
+ self,
+ hidden_states: torch.Tensor,
+ ) -> torch.Tensor:
+ """See `pad_for_conv1d`."""
+ length = hidden_states.shape[-1]
+ n_frames = (length - self.kernel_size + self.padding_total) / self.stride + 1
+ n_frames = torch.ceil(n_frames).to(torch.int64) - 1
+ ideal_length = n_frames * self.stride + self.kernel_size - self.padding_total
+
+ return ideal_length - length
+
+ @staticmethod
+ def _pad1d(hidden_states: torch.Tensor, paddings: Tuple[int, int], mode: str = "zero", value: float = 0.0):
+ """Tiny wrapper around torch.nn.functional.pad, just to allow for reflect padding on small input.
+ If this is the case, we insert extra 0 padding to the right before the reflection happens.
+ """
+ length = hidden_states.shape[-1]
+ padding_left, padding_right = paddings
+ if not mode == "reflect":
+ return nn.functional.pad(hidden_states, paddings, mode, value)
+
+ max_pad = max(padding_left, padding_right)
+ extra_pad = 0
+ if length <= max_pad:
+ extra_pad = max_pad - length + 1
+ hidden_states = nn.functional.pad(hidden_states, (0, extra_pad))
+ padded = nn.functional.pad(hidden_states, paddings, mode, value)
+ end = padded.shape[-1] - extra_pad
+ return padded[..., :end]
+
+ def forward(self, hidden_states):
+ extra_padding = self._get_extra_padding_for_conv1d(hidden_states)
+
+ if self.causal:
+ # Left padding for causal
+ hidden_states = self._pad1d(hidden_states, (self.padding_total, extra_padding), mode=self.pad_mode)
+ else:
+ # Asymmetric padding required for odd strides
+ padding_right = self.padding_total // 2
+ padding_left = self.padding_total - padding_right
+ hidden_states = self._pad1d(
+ hidden_states, (padding_left, padding_right + extra_padding), mode=self.pad_mode
+ )
+
+ hidden_states = self.conv(hidden_states)
+
+ if self.norm_type == "time_group_norm":
+ hidden_states = self.norm(hidden_states)
+
+ return hidden_states
+
+
+class EncodecConvTranspose1d(nn.Module):
+ """ConvTranspose1d with asymmetric or causal padding and normalization."""
+
+ def __init__(self, config, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1):
+ super().__init__()
+ self.causal = config.use_causal_conv
+ self.trim_right_ratio = config.trim_right_ratio
+ self.norm_type = config.norm_type
+ if self.norm_type not in ["weight_norm", "time_group_norm"]:
+ raise ValueError(
+ f'self.norm_type must be one of `"weight_norm"`, `"time_group_norm"`), got {self.norm_type}'
+ )
+
+ self.conv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride)
+
+ weight_norm = nn.utils.weight_norm
+ if hasattr(nn.utils.parametrizations, "weight_norm"):
+ weight_norm = nn.utils.parametrizations.weight_norm
+
+ if config.norm_type == "weight_norm":
+ self.conv = weight_norm(self.conv)
+ elif config.norm_type == "time_group_norm":
+ self.norm = nn.GroupNorm(1, out_channels)
+
+ if not (self.causal or self.trim_right_ratio == 1.0):
+ raise ValueError("`trim_right_ratio` != 1.0 only makes sense for causal convolutions")
+
+ def forward(self, hidden_states):
+ kernel_size = self.conv.kernel_size[0]
+ stride = self.conv.stride[0]
+ padding_total = kernel_size - stride
+
+ hidden_states = self.conv(hidden_states)
+
+ if self.norm_type == "time_group_norm":
+ hidden_states = self.norm(hidden_states)
+
+ # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
+ # removed at the very end, when keeping only the right length for the output,
+ # as removing it here would require also passing the length at the matching layer
+ # in the encoder.
+ if self.causal:
+ # Trim the padding on the right according to the specified ratio
+ # if trim_right_ratio = 1.0, trim everything from right
+ padding_right = math.ceil(padding_total * self.trim_right_ratio)
+ else:
+ # Asymmetric padding required for odd strides
+ padding_right = padding_total // 2
+
+ padding_left = padding_total - padding_right
+
+ # unpad
+ end = hidden_states.shape[-1] - padding_right
+ hidden_states = hidden_states[..., padding_left:end]
+ return hidden_states
+
+
+class EncodecLSTM(nn.Module):
+ """
+ LSTM without worrying about the hidden state, nor the layout of the data. Expects input as convolutional layout.
+ """
+
+ def __init__(self, config, dimension):
+ super().__init__()
+ self.lstm = nn.LSTM(dimension, dimension, config.num_lstm_layers)
+
+ def forward(self, hidden_states):
+ hidden_states = hidden_states.permute(2, 0, 1)
+ hidden_states = self.lstm(hidden_states)[0] + hidden_states
+ hidden_states = hidden_states.permute(1, 2, 0)
+ return hidden_states
+
+
+class EncodecResnetBlock(nn.Module):
+ """
+ Residual block from SEANet model as used by EnCodec.
+ """
+
+ def __init__(self, config: EncodecConfig, dim: int, dilations: List[int]):
+ super().__init__()
+ kernel_sizes = (config.residual_kernel_size, 1)
+ if len(kernel_sizes) != len(dilations):
+ raise ValueError("Number of kernel sizes should match number of dilations")
+
+ hidden = dim // config.compress
+ block = []
+ for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
+ in_chs = dim if i == 0 else hidden
+ out_chs = dim if i == len(kernel_sizes) - 1 else hidden
+ block += [nn.ELU()]
+ block += [EncodecConv1d(config, in_chs, out_chs, kernel_size, dilation=dilation)]
+ self.block = nn.ModuleList(block)
+
+ if config.use_conv_shortcut:
+ self.shortcut = EncodecConv1d(config, dim, dim, kernel_size=1)
+ else:
+ self.shortcut = nn.Identity()
+
+ def forward(self, hidden_states):
+ residual = hidden_states
+ for layer in self.block:
+ hidden_states = layer(hidden_states)
+
+ return self.shortcut(residual) + hidden_states
+
+
+class EncodecEncoder(nn.Module):
+ """SEANet encoder as used by EnCodec."""
+
+ def __init__(self, config: EncodecConfig):
+ super().__init__()
+ model = [EncodecConv1d(config, config.audio_channels, config.num_filters, config.kernel_size)]
+ scaling = 1
+
+ # Downsample to raw audio scale
+ for ratio in reversed(config.upsampling_ratios):
+ current_scale = scaling * config.num_filters
+ # Add residual layers
+ for j in range(config.num_residual_layers):
+ model += [EncodecResnetBlock(config, current_scale, [config.dilation_growth_rate**j, 1])]
+ # Add downsampling layers
+ model += [nn.ELU()]
+ model += [EncodecConv1d(config, current_scale, current_scale * 2, kernel_size=ratio * 2, stride=ratio)]
+ scaling *= 2
+
+ model += [EncodecLSTM(config, scaling * config.num_filters)]
+ model += [nn.ELU()]
+ model += [EncodecConv1d(config, scaling * config.num_filters, config.hidden_size, config.last_kernel_size)]
+
+ self.layers = nn.ModuleList(model)
+
+ def forward(self, hidden_states):
+ for layer in self.layers:
+ hidden_states = layer(hidden_states)
+ return hidden_states
+
+
+class EncodecDecoder(nn.Module):
+ """SEANet decoder as used by EnCodec."""
+
+ def __init__(self, config: EncodecConfig):
+ super().__init__()
+ scaling = int(2 ** len(config.upsampling_ratios))
+ model = [EncodecConv1d(config, config.hidden_size, scaling * config.num_filters, config.kernel_size)]
+
+ model += [EncodecLSTM(config, scaling * config.num_filters)]
+
+ # Upsample to raw audio scale
+ for ratio in config.upsampling_ratios:
+ current_scale = scaling * config.num_filters
+ # Add upsampling layers
+ model += [nn.ELU()]
+ model += [
+ EncodecConvTranspose1d(config, current_scale, current_scale // 2, kernel_size=ratio * 2, stride=ratio)
+ ]
+ # Add residual layers
+ for j in range(config.num_residual_layers):
+ model += [EncodecResnetBlock(config, current_scale // 2, (config.dilation_growth_rate**j, 1))]
+ scaling //= 2
+
+ # Add final layers
+ model += [nn.ELU()]
+ model += [EncodecConv1d(config, config.num_filters, config.audio_channels, config.last_kernel_size)]
+ self.layers = nn.ModuleList(model)
+
+ def forward(self, hidden_states):
+ for layer in self.layers:
+ hidden_states = layer(hidden_states)
+ return hidden_states
+
+
+class EncodecEuclideanCodebook(nn.Module):
+ """Codebook with Euclidean distance."""
+
+ def __init__(self, config: EncodecConfig):
+ super().__init__()
+ embed = torch.zeros(config.codebook_size, config.codebook_dim)
+
+ self.codebook_size = config.codebook_size
+
+ self.register_buffer("inited", torch.Tensor([True]))
+ self.register_buffer("cluster_size", torch.zeros(config.codebook_size))
+ self.register_buffer("embed", embed)
+ self.register_buffer("embed_avg", embed.clone())
+
+ def quantize(self, hidden_states):
+ embed = self.embed.t()
+ scaled_states = hidden_states.pow(2).sum(1, keepdim=True)
+ dist = -(scaled_states - 2 * hidden_states @ embed + embed.pow(2).sum(0, keepdim=True))
+ embed_ind = dist.max(dim=-1).indices
+ return embed_ind
+
+ def encode(self, hidden_states):
+ shape = hidden_states.shape
+ # pre-process
+ hidden_states = hidden_states.reshape((-1, shape[-1]))
+ # quantize
+ embed_ind = self.quantize(hidden_states)
+ # post-process
+ embed_ind = embed_ind.view(*shape[:-1])
+ return embed_ind
+
+ def decode(self, embed_ind):
+ quantize = nn.functional.embedding(embed_ind, self.embed)
+ return quantize
+
+
+class EncodecVectorQuantization(nn.Module):
+ """
+ Vector quantization implementation. Currently supports only euclidean distance.
+ """
+
+ def __init__(self, config: EncodecConfig):
+ super().__init__()
+ self.codebook = EncodecEuclideanCodebook(config)
+
+ def encode(self, hidden_states):
+ hidden_states = hidden_states.permute(0, 2, 1)
+ embed_in = self.codebook.encode(hidden_states)
+ return embed_in
+
+ def decode(self, embed_ind):
+ quantize = self.codebook.decode(embed_ind)
+ quantize = quantize.permute(0, 2, 1)
+ return quantize
+
+
+class EncodecResidualVectorQuantizer(nn.Module):
+ """Residual Vector Quantizer."""
+
+ def __init__(self, config: EncodecConfig):
+ super().__init__()
+ self.codebook_size = config.codebook_size
+ self.frame_rate = config.frame_rate
+ self.num_quantizers = config.num_quantizers
+ self.layers = nn.ModuleList([EncodecVectorQuantization(config) for _ in range(config.num_quantizers)])
+
+ def get_num_quantizers_for_bandwidth(self, bandwidth: Optional[float] = None) -> int:
+ """Return num_quantizers based on specified target bandwidth."""
+ bw_per_q = math.log2(self.codebook_size) * self.frame_rate
+ num_quantizers = self.num_quantizers
+ if bandwidth is not None and bandwidth > 0.0:
+ num_quantizers = int(max(1, math.floor(bandwidth * 1000 / bw_per_q)))
+ return num_quantizers
+
+ def encode(self, embeddings: torch.Tensor, bandwidth: Optional[float] = None) -> torch.Tensor:
+ """
+ Encode a given input tensor with the specified frame rate at the given bandwidth. The RVQ encode method sets
+ the appropriate number of quantizers to use and returns indices for each quantizer.
+ """
+ num_quantizers = self.get_num_quantizers_for_bandwidth(bandwidth)
+ residual = embeddings
+ all_indices = []
+ for layer in self.layers[:num_quantizers]:
+ indices = layer.encode(residual)
+ quantized = layer.decode(indices)
+ residual = residual - quantized
+ all_indices.append(indices)
+ out_indices = torch.stack(all_indices)
+ return out_indices
+
+ def decode(self, codes: torch.Tensor) -> torch.Tensor:
+ """Decode the given codes to the quantized representation."""
+ quantized_out = torch.tensor(0.0, device=codes.device)
+ for i, indices in enumerate(codes):
+ layer = self.layers[i]
+ quantized = layer.decode(indices)
+ quantized_out = quantized_out + quantized
+ return quantized_out
+
+
+class EncodecPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = EncodecConfig
+ base_model_prefix = "encodec"
+ main_input_name = "input_values"
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, nn.Conv1d):
+ nn.init.kaiming_normal_(module.weight)
+ if module.bias is not None:
+ k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
+ nn.init.uniform_(module.bias, a=-k, b=k)
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, nn.LSTM):
+ for name, param in module.named_parameters():
+ if "weight" in name:
+ nn.init.xavier_uniform_(param)
+ elif "bias" in name:
+ nn.init.constant_(param, 0.0)
+
+
+ENCODEC_START_DOCSTRING = r"""
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`EncodecConfig`]):
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
+ load the weights associated with the model, only the configuration. Check out the
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+ENCODEC_INPUTS_DOCSTRING = r"""
+ Args:
+ input_values (`torch.FloatTensor` of shape `(batch_size, channels, sequence_length)`, *optional*):
+ Raw audio input converted to Float and padded to the approriate length in order to be encoded using chunks
+ of length self.chunk_length and a stride of `config.chunk_stride`.
+ padding_mask (`torch.BoolTensor` of shape `(batch_size, channels, sequence_length)`, *optional*):
+ Mask to avoid computing scaling factors on padding token indices (can we avoid computing conv on these+).
+ Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+
+
+ `padding_mask` should always be passed, unless the input was truncated or not padded. This is because in
+ order to process tensors effectively, the input audio should be padded so that `input_length % stride =
+ step` with `step = chunk_length-stride`. This ensures that all chunks are of the same shape
+
+
+
+ bandwidth (`float`, *optional*):
+ The target bandwidth. Must be one of `config.target_bandwidths`. If `None`, uses the smallest possible
+ bandwidth. bandwidth is represented as a thousandth of what it is, e.g. 6kbps bandwidth is represented as
+ `bandwidth == 6.0`
+ audio_codes (`torch.LongTensor` of shape `(batch_size, nb_chunks, chunk_length)`, *optional*):
+ Discret code embeddings computed using `model.encode`.
+ audio_scales (`torch.Tensor` of shape `(batch_size, nb_chunks)`, *optional*):
+ Scaling factor for each `audio_codes` input.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The EnCodec neural audio codec model.",
+ ENCODEC_START_DOCSTRING,
+)
+class EncodecModel(EncodecPreTrainedModel):
+ def __init__(self, config: EncodecConfig):
+ super().__init__(config)
+ self.config = config
+
+ self.encoder = EncodecEncoder(config)
+ self.decoder = EncodecDecoder(config)
+
+ self.quantizer = EncodecResidualVectorQuantizer(config)
+
+ self.bits_per_codebook = int(math.log2(self.config.codebook_size))
+ if 2**self.bits_per_codebook != self.config.codebook_size:
+ raise ValueError("The codebook_size must be a power of 2.")
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_encoder(self):
+ return self.encoder
+
+ def get_decoder(self):
+ return self.decoder
+
+ def _encode_frame(
+ self, input_values: torch.Tensor, bandwidth: float, padding_mask: int
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ """
+ Encodes the given input using the underlying VQVAE. If `config.normalize` is set to `True` the input is first
+ normalized. The padding mask is required to compute the correct scale.
+ """
+ length = input_values.shape[-1]
+ duration = length / self.config.sampling_rate
+
+ if self.config.chunk_length_s is not None and duration > 1e-5 + self.config.chunk_length_s:
+ raise RuntimeError(f"Duration of frame ({duration}) is longer than chunk {self.config.chunk_length_s}")
+
+ scale = None
+ if self.config.normalize:
+ # if the padding is non zero
+ input_values = input_values * padding_mask.unsqueeze(1)
+ mono = torch.sum(input_values, 1, keepdim=True) / input_values.shape[1]
+ scale = mono.pow(2).mean(dim=-1, keepdim=True).sqrt() + 1e-8
+ input_values = input_values / scale
+
+ embeddings = self.encoder(input_values)
+ codes = self.quantizer.encode(embeddings, bandwidth)
+ codes = codes.transpose(0, 1)
+ return codes, scale
+
+ def encode(
+ self,
+ input_values: torch.Tensor,
+ padding_mask: Optional[torch.Tensor] = None,
+ bandwidth: Optional[float] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], EncodecEncoderOutput]:
+ """
+ Encodes the input audio waveform into discrete codes.
+
+ Args:
+ input_values (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
+ Float values of the input audio waveform.
+ padding_mask (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
+ Padding mask used to pad the `input_values`.
+ bandwidth (`float`, *optional*):
+ The target bandwidth. Must be one of `config.target_bandwidths`. If `None`, uses the smallest possible
+ bandwidth. bandwidth is represented as a thousandth of what it is, e.g. 6kbps bandwidth is represented
+ as bandwidth == 6.0
+
+ Returns:
+ A list of frames containing the discrete encoded codes for the input audio waveform, along with rescaling
+ factors for each chunk when `normalize` is True. Each frames is a tuple `(codebook, scale)`, with
+ `codebook` of shape `[batch_size, num_codebooks, frames]`.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
+
+ if bandwidth is None:
+ bandwidth = self.config.target_bandwidths[0]
+ if bandwidth not in self.config.target_bandwidths:
+ raise ValueError(
+ f"This model doesn't support the bandwidth {bandwidth}. Select one of {self.config.target_bandwidths}."
+ )
+
+ _, channels, input_length = input_values.shape
+
+ if channels < 1 or channels > 2:
+ raise ValueError(f"Number of audio channels must be 1 or 2, but got {channels}")
+
+ chunk_length = self.config.chunk_length
+ if chunk_length is None:
+ chunk_length = input_length
+ stride = input_length
+ else:
+ stride = self.config.chunk_stride
+
+ if padding_mask is None:
+ padding_mask = torch.ones_like(input_values).bool()
+
+ encoded_frames = []
+ scales = []
+
+ step = chunk_length - stride
+ if (input_length % stride) - step != 0:
+ raise ValueError(
+ "The input length is not properly padded for batched chunked decoding. Make sure to pad the input correctly."
+ )
+
+ for offset in range(0, input_length - step, stride):
+ mask = padding_mask[..., offset : offset + chunk_length].bool()
+ frame = input_values[:, :, offset : offset + chunk_length]
+ encoded_frame, scale = self._encode_frame(frame, bandwidth, mask)
+ encoded_frames.append(encoded_frame)
+ scales.append(scale)
+
+ encoded_frames = torch.stack(encoded_frames)
+
+ if not return_dict:
+ return (encoded_frames, scales)
+
+ return EncodecEncoderOutput(encoded_frames, scales)
+
+ @staticmethod
+ def _linear_overlap_add(frames: List[torch.Tensor], stride: int):
+ # Generic overlap add, with linear fade-in/fade-out, supporting complex scenario
+ # e.g., more than 2 frames per position.
+ # The core idea is to use a weight function that is a triangle,
+ # with a maximum value at the middle of the chunk.
+ # We use this weighting when summing the frames, and divide by the sum of weights
+ # for each positions at the end. Thus:
+ # - if a frame is the only one to cover a position, the weighting is a no-op.
+ # - if 2 frames cover a position:
+ # ... ...
+ # / \/ \
+ # / /\ \
+ # S T , i.e. S offset of second frame starts, T end of first frame.
+ # Then the weight function for each one is: (t - S), (T - t), with `t` a given offset.
+ # After the final normalization, the weight of the second frame at position `t` is
+ # (t - S) / (t - S + (T - t)) = (t - S) / (T - S), which is exactly what we want.
+ #
+ # - if more than 2 frames overlap at a given point, we hope that by induction
+ # something sensible happens.
+ if len(frames) == 0:
+ raise ValueError("`frames` cannot be an empty list.")
+
+ device = frames[0].device
+ dtype = frames[0].dtype
+ shape = frames[0].shape[:-1]
+ total_size = stride * (len(frames) - 1) + frames[-1].shape[-1]
+
+ frame_length = frames[0].shape[-1]
+ time_vec = torch.linspace(0, 1, frame_length + 2, device=device, dtype=dtype)[1:-1]
+ weight = 0.5 - (time_vec - 0.5).abs()
+
+ sum_weight = torch.zeros(total_size, device=device, dtype=dtype)
+ out = torch.zeros(*shape, total_size, device=device, dtype=dtype)
+ offset: int = 0
+
+ for frame in frames:
+ frame_length = frame.shape[-1]
+ out[..., offset : offset + frame_length] += weight[:frame_length] * frame
+ sum_weight[offset : offset + frame_length] += weight[:frame_length]
+ offset += stride
+
+ if sum_weight.min() == 0:
+ raise ValueError(f"`sum_weight` minimum element must be bigger than zero: {sum_weight}`")
+
+ return out / sum_weight
+
+ def _decode_frame(self, codes: torch.Tensor, scale: Optional[torch.Tensor] = None) -> torch.Tensor:
+ codes = codes.transpose(0, 1)
+ embeddings = self.quantizer.decode(codes)
+ outputs = self.decoder(embeddings)
+ if scale is not None:
+ outputs = outputs * scale.view(-1, 1, 1)
+ return outputs
+
+ def decode(
+ self,
+ audio_codes: torch.Tensor,
+ audio_scales: torch.Tensor,
+ padding_mask: Optional[torch.Tensor] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], EncodecDecoderOutput]:
+ """
+ Decodes the given frames into an output audio waveform.
+
+ Note that the output might be a bit bigger than the input. In that case, any extra steps at the end can be
+ trimmed.
+
+ Args:
+ audio_codes (`torch.LongTensor` of shape `(batch_size, nb_chunks, chunk_length)`, *optional*):
+ Discret code embeddings computed using `model.encode`.
+ audio_scales (`torch.Tensor` of shape `(batch_size, nb_chunks)`, *optional*):
+ Scaling factor for each `audio_codes` input.
+ padding_mask (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
+ Padding mask used to pad the `input_values`.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+
+ """
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
+
+ chunk_length = self.config.chunk_length
+ if chunk_length is None:
+ if len(audio_codes) != 1:
+ raise ValueError(f"Expected one frame, got {len(audio_codes)}")
+ audio_values = self._decode_frame(audio_codes[0], audio_scales[0])
+ else:
+ decoded_frames = []
+
+ for frame, scale in zip(audio_codes, audio_scales):
+ frames = self._decode_frame(frame, scale)
+ decoded_frames.append(frames)
+
+ audio_values = self._linear_overlap_add(decoded_frames, self.config.chunk_stride or 1)
+
+ # truncate based on padding mask
+ if padding_mask is not None and padding_mask.shape[-1] < audio_values.shape[-1]:
+ audio_values = audio_values[..., : padding_mask.shape[-1]]
+
+ if not return_dict:
+ return (audio_values,)
+ return EncodecDecoderOutput(audio_values)
+
+ @add_start_docstrings_to_model_forward(ENCODEC_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=EncodecOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_values: torch.Tensor,
+ padding_mask: Optional[torch.Tensor] = None,
+ bandwidth: Optional[float] = None,
+ audio_codes: Optional[torch.Tensor] = None,
+ audio_scales: Optional[torch.Tensor] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], EncodecOutput]:
+ r"""
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from datasets import load_dataset
+ >>> from transformers import AutoProcessor, EncodecModel
+
+ >>> dataset = load_dataset("hf-internal-testing/ashraq-esc50-1-dog-example")
+ >>> audio_sample = dataset["train"]["audio"][0]["array"]
+
+ >>> model_id = "facebook/encodec_24khz"
+ >>> model = EncodecModel.from_pretrained(model_id)
+ >>> processor = AutoProcessor.from_pretrained(model_id)
+
+ >>> inputs = processor(raw_audio=audio_sample, return_tensors="pt")
+
+ >>> outputs = model(**inputs)
+ >>> audio_codes = outputs.audio_codes
+ >>> audio_values = outputs.audio_values
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
+
+ if padding_mask is None:
+ padding_mask = torch.ones_like(input_values).bool()
+
+ if audio_codes is not None and audio_scales is None:
+ raise ValueError("You specified `audio_codes` but did not specify the `audio_scales`")
+
+ if audio_scales is not None and audio_codes is None:
+ raise ValueError("You specified `audio_scales` but did not specify the `audio_codes`")
+
+ if audio_scales is None and audio_codes is None:
+ audio_codes, audio_scales = self.encode(input_values, padding_mask, bandwidth, False)
+
+ audio_values = self.decode(audio_codes, audio_scales, padding_mask, return_dict=return_dict)[0]
+ if not return_dict:
+ return (audio_codes, audio_values)
+
+ return EncodecOutput(audio_codes=audio_codes, audio_values=audio_values)
+
+
+__all__ = ["EncodecModel", "EncodecPreTrainedModel"]
diff --git a/docs/transformers/build/lib/transformers/models/encoder_decoder/__init__.py b/docs/transformers/build/lib/transformers/models/encoder_decoder/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c786feb9213fdd31640c0fdeaead5164026ad37a
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/encoder_decoder/__init__.py
@@ -0,0 +1,29 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_encoder_decoder import *
+ from .modeling_encoder_decoder import *
+ from .modeling_flax_encoder_decoder import *
+ from .modeling_tf_encoder_decoder import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/docs/transformers/build/lib/transformers/models/encoder_decoder/configuration_encoder_decoder.py b/docs/transformers/build/lib/transformers/models/encoder_decoder/configuration_encoder_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..a5eff83e55824786d70c821ecfa210e07c27da2e
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/encoder_decoder/configuration_encoder_decoder.py
@@ -0,0 +1,111 @@
+# coding=utf-8
+# Copyright 2020 The HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+from ..auto import AutoConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+class EncoderDecoderConfig(PretrainedConfig):
+ r"""
+ [`EncoderDecoderConfig`] is the configuration class to store the configuration of a [`EncoderDecoderModel`]. It is
+ used to instantiate an Encoder Decoder model according to the specified arguments, defining the encoder and decoder
+ configs.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ kwargs (*optional*):
+ Dictionary of keyword arguments. Notably:
+
+ - **encoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that defines
+ the encoder config.
+ - **decoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that defines
+ the decoder config.
+
+ Examples:
+
+ ```python
+ >>> from transformers import BertConfig, EncoderDecoderConfig, EncoderDecoderModel
+
+ >>> # Initializing a BERT google-bert/bert-base-uncased style configuration
+ >>> config_encoder = BertConfig()
+ >>> config_decoder = BertConfig()
+
+ >>> config = EncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder)
+
+ >>> # Initializing a Bert2Bert model (with random weights) from the google-bert/bert-base-uncased style configurations
+ >>> model = EncoderDecoderModel(config=config)
+
+ >>> # Accessing the model configuration
+ >>> config_encoder = model.config.encoder
+ >>> config_decoder = model.config.decoder
+ >>> # set decoder config to causal lm
+ >>> config_decoder.is_decoder = True
+ >>> config_decoder.add_cross_attention = True
+
+ >>> # Saving the model, including its configuration
+ >>> model.save_pretrained("my-model")
+
+ >>> # loading model and config from pretrained folder
+ >>> encoder_decoder_config = EncoderDecoderConfig.from_pretrained("my-model")
+ >>> model = EncoderDecoderModel.from_pretrained("my-model", config=encoder_decoder_config)
+ ```"""
+
+ model_type = "encoder-decoder"
+ sub_configs = {"encoder": AutoConfig, "decoder": AutoConfig}
+ has_no_defaults_at_init = True
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ if "encoder" not in kwargs or "decoder" not in kwargs:
+ raise ValueError(
+ f"A configuraton of type {self.model_type} cannot be instantiated because "
+ f"both `encoder` and `decoder` sub-configurations were not passed, only {kwargs}"
+ )
+ encoder_config = kwargs.pop("encoder")
+ encoder_model_type = encoder_config.pop("model_type")
+ decoder_config = kwargs.pop("decoder")
+ decoder_model_type = decoder_config.pop("model_type")
+
+ self.encoder = AutoConfig.for_model(encoder_model_type, **encoder_config)
+ self.decoder = AutoConfig.for_model(decoder_model_type, **decoder_config)
+ self.is_encoder_decoder = True
+
+ @classmethod
+ def from_encoder_decoder_configs(
+ cls, encoder_config: PretrainedConfig, decoder_config: PretrainedConfig, **kwargs
+ ) -> PretrainedConfig:
+ r"""
+ Instantiate a [`EncoderDecoderConfig`] (or a derived class) from a pre-trained encoder model configuration and
+ decoder model configuration.
+
+ Returns:
+ [`EncoderDecoderConfig`]: An instance of a configuration object
+ """
+ logger.info("Set `config.is_decoder=True` and `config.add_cross_attention=True` for decoder_config")
+ decoder_config.is_decoder = True
+ decoder_config.add_cross_attention = True
+
+ return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict(), **kwargs)
+
+
+__all__ = ["EncoderDecoderConfig"]
diff --git a/docs/transformers/build/lib/transformers/models/encoder_decoder/modeling_encoder_decoder.py b/docs/transformers/build/lib/transformers/models/encoder_decoder/modeling_encoder_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..415fd058e45d8fe6e44ca30ddfe178ffc717cf19
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/encoder_decoder/modeling_encoder_decoder.py
@@ -0,0 +1,689 @@
+# coding=utf-8
+# Copyright 2018 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Classes to support Encoder-Decoder architectures"""
+
+import gc
+import inspect
+import os
+import tempfile
+import warnings
+from typing import Optional, Tuple, Union
+
+import torch
+from torch import nn
+from torch.nn import CrossEntropyLoss
+
+from ...configuration_utils import PretrainedConfig
+from ...generation import GenerationMixin
+from ...modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
+from ...modeling_utils import PreTrainedModel
+from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
+from ..auto.configuration_auto import AutoConfig
+from ..auto.modeling_auto import AutoModel, AutoModelForCausalLM
+from .configuration_encoder_decoder import EncoderDecoderConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "EncoderDecoderConfig"
+
+DEPRECATION_WARNING = (
+ "Version v4.12.0 introduces a better way to train encoder-decoder models by computing the loss inside the"
+ " encoder-decoder framework rather than in the decoder itself. You may observe training discrepancies if"
+ " fine-tuning a model trained with versions anterior to 4.12.0. The decoder_input_ids are now created based on the"
+ " labels, no need to pass them yourself anymore."
+)
+
+ENCODER_DECODER_START_DOCSTRING = r"""
+ This class can be used to initialize a sequence-to-sequence model with any pretrained autoencoding model as the
+ encoder and any pretrained autoregressive model as the decoder. The encoder is loaded via
+ [`~AutoModel.from_pretrained`] function and the decoder is loaded via [`~AutoModelForCausalLM.from_pretrained`]
+ function. Cross-attention layers are automatically added to the decoder and should be fine-tuned on a downstream
+ generative task, like summarization.
+
+ The effectiveness of initializing sequence-to-sequence models with pretrained checkpoints for sequence generation
+ tasks was shown in [Leveraging Pre-trained Checkpoints for Sequence Generation
+ Tasks](https://arxiv.org/abs/1907.12461) by Sascha Rothe, Shashi Narayan, Aliaksei Severyn. Michael Matena, Yanqi
+ Zhou, Wei Li, Peter J. Liu.
+
+ After such an Encoder Decoder model has been trained/fine-tuned, it can be saved/loaded just like any other models
+ (see the examples for more information).
+
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`EncoderDecoderConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+ENCODER_DECODER_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Indices of decoder input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
+ `past_key_values`).
+
+ For training, `decoder_input_ids` are automatically created by the model by shifting the `labels` to the
+ right, replacing -100 by the `pad_token_id` and prepending them with the `decoder_start_token_id`.
+ decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
+ be used by default.
+ encoder_outputs (`tuple(torch.FloatTensor)`, *optional*):
+ This tuple must consist of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
+ `last_hidden_state` (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`) is a tensor
+ of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the
+ decoder.
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
+ representation. This is useful if you want more control over how to convert `decoder_input_ids` indices
+ into associated vectors than the model's internal embedding lookup matrix.
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss for the decoder. Indices should be in `[-100, 0,
+ ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ If set to `True`, the model will return a [`~utils.Seq2SeqLMOutput`] instead of a plain tuple.
+ kwargs (*optional*): Remaining dictionary of keyword arguments. Keyword arguments come in two flavors:
+
+ - Without a prefix which will be input as `**encoder_kwargs` for the encoder forward function.
+ - With a *decoder_* prefix which will be input as `**decoder_kwargs` for the decoder forward function.
+"""
+
+
+def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
+ """
+ Shift input ids one token to the right.
+ """
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
+ shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
+ if decoder_start_token_id is None:
+ raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.")
+ shifted_input_ids[:, 0] = decoder_start_token_id
+
+ if pad_token_id is None:
+ raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.")
+ # replace possible -100 values in labels by `pad_token_id`
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
+
+ return shifted_input_ids
+
+
+@add_start_docstrings(ENCODER_DECODER_START_DOCSTRING)
+class EncoderDecoderModel(PreTrainedModel, GenerationMixin):
+ r"""
+ [`EncoderDecoderModel`] is a generic model class that will be instantiated as a transformer architecture with one
+ of the base model classes of the library as encoder and another one as decoder when created with the
+ :meth*~transformers.AutoModel.from_pretrained* class method for the encoder and
+ :meth*~transformers.AutoModelForCausalLM.from_pretrained* class method for the decoder.
+ """
+
+ config_class = EncoderDecoderConfig
+ base_model_prefix = "encoder_decoder"
+ main_input_name = "input_ids"
+ supports_gradient_checkpointing = True
+ _supports_param_buffer_assignment = False
+ _supports_flash_attn_2 = True
+ _supports_sdpa = True
+
+ def __init__(
+ self,
+ config: Optional[PretrainedConfig] = None,
+ encoder: Optional[PreTrainedModel] = None,
+ decoder: Optional[PreTrainedModel] = None,
+ ):
+ if config is None and (encoder is None or decoder is None):
+ raise ValueError("Either a configuration or an encoder and a decoder has to be provided.")
+ if config is None:
+ config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config)
+ else:
+ if not isinstance(config, self.config_class):
+ raise ValueError(f"Config: {config} has to be of type {self.config_class}")
+
+ if config.decoder.cross_attention_hidden_size is not None:
+ if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
+ raise ValueError(
+ "If `cross_attention_hidden_size` is specified in the decoder's configuration, it has to be equal"
+ f" to the encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for"
+ f" `config.decoder.cross_attention_hidden_size` and {config.encoder.hidden_size} for"
+ " `config.encoder.hidden_size`."
+ )
+
+ # initialize with config
+ super().__init__(config)
+
+ if encoder is None:
+ from ..auto.modeling_auto import AutoModel
+
+ encoder = AutoModel.from_config(config.encoder)
+
+ if decoder is None:
+ from ..auto.modeling_auto import AutoModelForCausalLM
+
+ decoder = AutoModelForCausalLM.from_config(config.decoder)
+
+ self.encoder = encoder
+ self.decoder = decoder
+
+ if self.encoder.config.to_dict() != self.config.encoder.to_dict():
+ logger.warning(
+ f"Config of the encoder: {self.encoder.__class__} is overwritten by shared encoder config:"
+ f" {self.config.encoder}"
+ )
+ if self.decoder.config.to_dict() != self.config.decoder.to_dict():
+ logger.warning(
+ f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config:"
+ f" {self.config.decoder}"
+ )
+
+ # make sure that the individual model's config refers to the shared config
+ # so that the updates to the config will be synced
+ # update `_attn_implementation` because the attn is set in a deepcopied config within PreTrainedModel
+ self.config.encoder._attn_implementation = self.encoder.config._attn_implementation
+ self.config.decoder._attn_implementation = self.decoder.config._attn_implementation
+ self.encoder.config = self.config.encoder
+ self.decoder.config = self.config.decoder
+
+ # encoder outputs might need to be projected to different dimension for decoder
+ if (
+ self.encoder.config.hidden_size != self.decoder.config.hidden_size
+ and self.decoder.config.cross_attention_hidden_size is None
+ ):
+ self.enc_to_dec_proj = nn.Linear(self.encoder.config.hidden_size, self.decoder.config.hidden_size)
+
+ if self.encoder.get_output_embeddings() is not None:
+ raise ValueError(
+ f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head"
+ )
+
+ decoder_signature = set(inspect.signature(self.decoder.forward).parameters.keys())
+ if "encoder_hidden_states" not in decoder_signature:
+ raise ValueError(
+ "The selected decoder is not prepared for the encoder hidden states to be passed. Please see the "
+ "following discussion on GitHub: https://github.com/huggingface/transformers/issues/23350"
+ )
+
+ # tie encoder, decoder weights if config set accordingly
+ self.tie_weights()
+
+ def tie_weights(self):
+ self.encoder.tie_weights()
+ self.decoder.tie_weights()
+ # tie encoder & decoder if needed
+ if self.config.tie_encoder_decoder:
+ # tie encoder and decoder base model
+ decoder_base_model_prefix = self.decoder.base_model_prefix
+ tied_weights = self._tie_encoder_decoder_weights(
+ self.encoder,
+ self.decoder._modules[decoder_base_model_prefix],
+ self.decoder.base_model_prefix,
+ "encoder",
+ )
+ # Setting a dynamic variable instead of `_tied_weights_keys` because it's a class
+ # attributed not an instance member, therefore modifying it will modify the entire class
+ # Leading to issues on subsequent calls by different tests or subsequent calls.
+ self._dynamic_tied_weights_keys = tied_weights
+
+ def _init_weights(self, module):
+ if module in self.encoder.modules():
+ self.encoder._init_weights(module)
+ elif module in self.decoder.modules():
+ self.decoder._init_weights(module)
+
+ def get_encoder(self):
+ return self.encoder
+
+ def get_decoder(self):
+ return self.decoder
+
+ def get_input_embeddings(self):
+ return self.encoder.get_input_embeddings()
+
+ def get_output_embeddings(self):
+ return self.decoder.get_output_embeddings()
+
+ def set_output_embeddings(self, new_embeddings):
+ return self.decoder.set_output_embeddings(new_embeddings)
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
+ r"""
+ Example:
+
+ ```python
+ >>> from transformers import EncoderDecoderModel
+
+ >>> model = EncoderDecoderModel.from_pretrained("patrickvonplaten/bert2bert-cnn_dailymail-fp16")
+ ```"""
+
+ from_tf = kwargs.pop("from_tf", False)
+ if from_tf:
+ from transformers import TFEncoderDecoderModel
+
+ # a workaround to load from tensorflow checkpoint
+ # Using `_tf_model` won't work, because the weight names in the encoder/decoder of `_tf_model` get
+ # extended before saving those components. For example, The name of `_tf_model.encoder.vit` is
+ # `[top model name]/encoder/vit`, but the name of `tf_model.encoder.vit` is `[top model name]/vit`. The
+ # [top model name] is handled (stripped) by the conversion method, and the former case gets extra `encoder`,
+ # which should not occur when we want to save the components alone.
+ # There was a (very) ugly potential fix, which wasn't integrated to `transformers`: see
+ # https://github.com/huggingface/transformers/pull/13222/commits/dbb3c9de76eee235791d2064094654637c99f36d#r697304245
+ # (the change in `src/transformers/modeling_tf_utils.py`)
+ _tf_model = TFEncoderDecoderModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
+ config = _tf_model.config
+
+ # Using `tf_model` instead
+ encoder = _tf_model.encoder.__class__(_tf_model.config.encoder)
+ decoder = _tf_model.decoder.__class__(_tf_model.config.decoder)
+ # Make sure models are built
+ encoder(encoder.dummy_inputs)
+ decoder(decoder.dummy_inputs)
+
+ # Get the variable correspondence between `_tf_model` and `encoder` and `decoder`
+ encoder_variables = {}
+ for v in encoder.trainable_variables + encoder.non_trainable_variables:
+ encoder_variables["/".join(v.name.split("/")[1:])] = v
+ decoder_variables = {}
+ for v in decoder.trainable_variables + decoder.non_trainable_variables:
+ decoder_variables["/".join(v.name.split("/")[1:])] = v
+
+ _encoder_variables = {}
+ for v in _tf_model.encoder.trainable_variables + _tf_model.encoder.non_trainable_variables:
+ _encoder_variables["/".join(v.name.split("/")[2:])] = v
+ _decoder_variables = {}
+ for v in _tf_model.decoder.trainable_variables + _tf_model.decoder.non_trainable_variables:
+ _decoder_variables["/".join(v.name.split("/")[2:])] = v
+
+ # assign weight values to `encoder` and `decoder` from `_tf_model`
+ for name, v in encoder_variables.items():
+ v.assign(_encoder_variables[name])
+ for name, v in decoder_variables.items():
+ v.assign(_decoder_variables[name])
+
+ tf_model = TFEncoderDecoderModel(encoder=encoder, decoder=decoder)
+
+ # Deal with `enc_to_dec_proj`
+ if hasattr(_tf_model, "enc_to_dec_proj"):
+ tf_model(tf_model.dummy_inputs)
+ tf_model.enc_to_dec_proj.kernel.assign(_tf_model.enc_to_dec_proj.kernel)
+ tf_model.enc_to_dec_proj.bias.assign(_tf_model.enc_to_dec_proj.bias)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ encoder_dir = os.path.join(tmpdirname, "encoder")
+ decoder_dir = os.path.join(tmpdirname, "decoder")
+ tf_model.encoder.save_pretrained(encoder_dir)
+ tf_model.decoder.save_pretrained(decoder_dir)
+
+ if hasattr(tf_model, "enc_to_dec_proj"):
+ enc_to_dec_proj_weight = torch.transpose(
+ torch.from_numpy(tf_model.enc_to_dec_proj.kernel.numpy()), 1, 0
+ )
+ enc_to_dec_proj_bias = torch.from_numpy(tf_model.enc_to_dec_proj.bias.numpy())
+
+ del _tf_model
+ del tf_model
+ gc.collect()
+
+ model = EncoderDecoderModel.from_encoder_decoder_pretrained(
+ encoder_dir, decoder_dir, encoder_from_tf=True, decoder_from_tf=True
+ )
+ # This is only for copying some specific attributes of this particular model.
+ model.config = config
+
+ if hasattr(model, "enc_to_dec_proj"):
+ model.enc_to_dec_proj.weight.data = enc_to_dec_proj_weight.contiguous()
+ model.enc_to_dec_proj.bias.data = enc_to_dec_proj_bias.contiguous()
+
+ return model
+
+ return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
+
+ @classmethod
+ def from_encoder_decoder_pretrained(
+ cls,
+ encoder_pretrained_model_name_or_path: Optional[str] = None,
+ decoder_pretrained_model_name_or_path: Optional[str] = None,
+ *model_args,
+ **kwargs,
+ ) -> PreTrainedModel:
+ r"""
+ Instantiate an encoder and a decoder from one or two base classes of the library from pretrained model
+ checkpoints.
+
+
+ The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
+ the model, you need to first set it back in training mode with `model.train()`.
+
+ Params:
+ encoder_pretrained_model_name_or_path (`str`, *optional*):
+ Information necessary to initiate the encoder. Can be either:
+
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
+ - A path to a *directory* containing model weights saved using
+ [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
+ - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In
+ this case, `from_tf` should be set to `True` and a configuration object should be provided as
+ `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a
+ PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
+
+ decoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `None`):
+ Information necessary to initiate the decoder. Can be either:
+
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
+ - A path to a *directory* containing model weights saved using
+ [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
+ - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In
+ this case, `from_tf` should be set to `True` and a configuration object should be provided as
+ `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a
+ PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
+
+ model_args (remaining positional arguments, *optional*):
+ All remaining positional arguments will be passed to the underlying model's `__init__` method.
+
+ kwargs (remaining dictionary of keyword arguments, *optional*):
+ Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
+ `output_attentions=True`).
+
+ - To update the encoder configuration, use the prefix *encoder_* for each configuration parameter.
+ - To update the decoder configuration, use the prefix *decoder_* for each configuration parameter.
+ - To update the parent model configuration, do not use a prefix for each configuration parameter.
+
+ Behaves differently depending on whether a `config` is provided or automatically loaded.
+
+ Example:
+
+ ```python
+ >>> from transformers import EncoderDecoderModel
+
+ >>> # initialize a bert2bert from two pretrained BERT models. Note that the cross-attention layers will be randomly initialized
+ >>> model = EncoderDecoderModel.from_encoder_decoder_pretrained("google-bert/bert-base-uncased", "google-bert/bert-base-uncased")
+ >>> # saving model after fine-tuning
+ >>> model.save_pretrained("./bert2bert")
+ >>> # load fine-tuned model
+ >>> model = EncoderDecoderModel.from_pretrained("./bert2bert")
+ ```"""
+
+ kwargs_encoder = {
+ argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_")
+ }
+
+ kwargs_decoder = {
+ argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
+ }
+
+ # remove encoder, decoder kwargs from kwargs
+ for key in kwargs_encoder.keys():
+ del kwargs["encoder_" + key]
+ for key in kwargs_decoder.keys():
+ del kwargs["decoder_" + key]
+
+ # Load and initialize the encoder and decoder
+ # The distinction between encoder and decoder at the model level is made
+ # by the value of the flag `is_decoder` that we need to set correctly.
+ encoder = kwargs_encoder.pop("model", None)
+ if encoder is None:
+ if encoder_pretrained_model_name_or_path is None:
+ raise ValueError(
+ "If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has "
+ "to be defined."
+ )
+
+ if "config" not in kwargs_encoder:
+ encoder_config, kwargs_encoder = AutoConfig.from_pretrained(
+ encoder_pretrained_model_name_or_path, **kwargs_encoder, return_unused_kwargs=True
+ )
+
+ if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
+ logger.info(
+ f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
+ "from a decoder model. Cross-attention and casual mask are disabled."
+ )
+ encoder_config.is_decoder = False
+ encoder_config.add_cross_attention = False
+
+ kwargs_encoder["config"] = encoder_config
+
+ encoder = AutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder)
+
+ decoder = kwargs_decoder.pop("model", None)
+ if decoder is None:
+ if decoder_pretrained_model_name_or_path is None:
+ raise ValueError(
+ "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has "
+ "to be defined."
+ )
+
+ if "config" not in kwargs_decoder:
+ decoder_config, kwargs_decoder = AutoConfig.from_pretrained(
+ decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True
+ )
+
+ if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
+ logger.info(
+ f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention"
+ f" layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if"
+ f" {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
+ )
+ decoder_config.is_decoder = True
+ decoder_config.add_cross_attention = True
+
+ kwargs_decoder["config"] = decoder_config
+
+ if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False:
+ logger.warning(
+ f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. "
+ f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, "
+ "make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` "
+ "passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a "
+ "`decoder_config` to `.from_encoder_decoder_pretrained(...)`"
+ )
+
+ decoder = AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
+
+ # instantiate config with corresponding kwargs
+ config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs)
+ return cls(encoder=encoder, decoder=decoder, config=config)
+
+ @add_start_docstrings_to_model_forward(ENCODER_DECODER_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ decoder_input_ids: Optional[torch.LongTensor] = None,
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
+ encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
+ past_key_values: Tuple[Tuple[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ **kwargs,
+ ) -> Union[Tuple, Seq2SeqLMOutput]:
+ r"""
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import EncoderDecoderModel, BertTokenizer
+ >>> import torch
+
+ >>> tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased")
+ >>> model = EncoderDecoderModel.from_encoder_decoder_pretrained(
+ ... "google-bert/bert-base-uncased", "google-bert/bert-base-uncased"
+ ... ) # initialize Bert2Bert from pre-trained checkpoints
+
+ >>> # training
+ >>> model.config.decoder_start_token_id = tokenizer.cls_token_id
+ >>> model.config.pad_token_id = tokenizer.pad_token_id
+ >>> model.config.vocab_size = model.config.decoder.vocab_size
+
+ >>> input_ids = tokenizer("This is a really long text", return_tensors="pt").input_ids
+ >>> labels = tokenizer("This is the corresponding summary", return_tensors="pt").input_ids
+ >>> outputs = model(input_ids=input_ids, labels=labels)
+ >>> loss, logits = outputs.loss, outputs.logits
+
+ >>> # save and load from pretrained
+ >>> model.save_pretrained("bert2bert")
+ >>> model = EncoderDecoderModel.from_pretrained("bert2bert")
+
+ >>> # generation
+ >>> generated = model.generate(input_ids)
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")}
+
+ kwargs_decoder = {
+ argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
+ }
+ if "num_items_in_batch" in kwargs_encoder:
+ kwargs_decoder["num_items_in_batch"] = kwargs_encoder.pop("num_items_in_batch", None)
+
+ if encoder_outputs is None:
+ encoder_outputs = self.encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ **kwargs_encoder,
+ )
+ elif isinstance(encoder_outputs, tuple):
+ encoder_outputs = BaseModelOutput(*encoder_outputs)
+
+ encoder_hidden_states = encoder_outputs[0]
+
+ # optionally project encoder_hidden_states
+ if (
+ self.encoder.config.hidden_size != self.decoder.config.hidden_size
+ and self.decoder.config.cross_attention_hidden_size is None
+ ):
+ encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
+
+ if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
+ decoder_input_ids = shift_tokens_right(
+ labels, self.config.pad_token_id, self.config.decoder_start_token_id
+ )
+ if decoder_attention_mask is None:
+ decoder_attention_mask = decoder_input_ids.new_tensor(decoder_input_ids != self.config.pad_token_id)
+
+ # Decode
+ decoder_outputs = self.decoder(
+ input_ids=decoder_input_ids,
+ attention_mask=decoder_attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=attention_mask,
+ inputs_embeds=decoder_inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ use_cache=use_cache,
+ past_key_values=past_key_values,
+ return_dict=return_dict,
+ **kwargs_decoder,
+ )
+
+ # Compute loss independent from decoder (as some shift the logits inside them)
+ loss = None
+ if labels is not None:
+ warnings.warn(DEPRECATION_WARNING, FutureWarning)
+ logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.view(-1))
+
+ if not return_dict:
+ if loss is not None:
+ return (loss,) + decoder_outputs + encoder_outputs
+ else:
+ return decoder_outputs + encoder_outputs
+
+ return Seq2SeqLMOutput(
+ loss=loss,
+ logits=decoder_outputs.logits,
+ past_key_values=decoder_outputs.past_key_values,
+ decoder_hidden_states=decoder_outputs.hidden_states,
+ decoder_attentions=decoder_outputs.attentions,
+ cross_attentions=decoder_outputs.cross_attentions,
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
+ encoder_hidden_states=encoder_outputs.hidden_states,
+ encoder_attentions=encoder_outputs.attentions,
+ )
+
+ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
+ return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
+
+ def resize_token_embeddings(self, *args, **kwargs):
+ raise NotImplementedError(
+ "Resizing the embedding layers via the EncoderDecoderModel directly is not supported. Please use the"
+ " respective methods of the wrapped objects (model.encoder.resize_token_embeddings(...) or"
+ " model.decoder.resize_token_embeddings(...))"
+ )
+
+ def _reorder_cache(self, past_key_values, beam_idx):
+ # apply decoder cache reordering here
+ return self.decoder._reorder_cache(past_key_values, beam_idx)
+
+
+__all__ = ["EncoderDecoderModel"]
diff --git a/docs/transformers/build/lib/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py b/docs/transformers/build/lib/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..bdc589484cda1083f42ce96e67b12771f84b6af9
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py
@@ -0,0 +1,901 @@
+# coding=utf-8
+# Copyright 2021 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Classes to support Flax Encoder-Decoder architectures"""
+
+import os
+from typing import Optional, Tuple, Union
+
+import flax.linen as nn
+import jax
+import jax.numpy as jnp
+from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
+from flax.traverse_util import flatten_dict, unflatten_dict
+from jax import lax
+from jax.random import PRNGKey
+
+from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutputWithCrossAttentions, FlaxSeq2SeqLMOutput
+from ...modeling_flax_utils import FlaxPreTrainedModel
+from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
+from ..auto.configuration_auto import AutoConfig
+from ..auto.modeling_flax_auto import FlaxAutoModel, FlaxAutoModelForCausalLM
+from .configuration_encoder_decoder import EncoderDecoderConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "EncoderDecoderConfig"
+
+ENCODER_DECODER_START_DOCSTRING = r"""
+ This class can be used to initialize a sequence-to-sequence model with any pretrained autoencoding model as the
+ encoder and any pretrained autoregressive model as the decoder. The encoder is loaded via
+ [`~AutoModel.from_pretrained`] function and the decoder is loaded via [`~AutoModelForCausalLM.from_pretrained`]
+ function. Cross-attention layers are automatically added to the decoder and should be fine-tuned on a downstream
+ generative task, like summarization.
+
+ The effectiveness of initializing sequence-to-sequence models with pretrained checkpoints for sequence generation
+ tasks was shown in [Leveraging Pre-trained Checkpoints for Sequence Generation
+ Tasks](https://arxiv.org/abs/1907.12461) by Sascha Rothe, Shashi Narayan, Aliaksei Severyn. Michael Matena, Yanqi
+ Zhou, Wei Li, Peter J. Liu.
+
+ After such an Encoder Decoder model has been trained/fine-tuned, it can be saved/loaded just like any other models
+ (see the examples for more information).
+
+ This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a Flax Linen
+ [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
+ regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
+
+ Parameters:
+ config ([`EncoderDecoderConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
+ dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
+ The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
+ `jax.numpy.bfloat16` (on TPUs).
+
+ This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
+ specified all the computation will be performed with the given `dtype`.
+
+ **Note that this only specifies the dtype of the computation and does not influence the dtype of model
+ parameters.**
+
+ If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
+ [`~FlaxPreTrainedModel.to_bf16`].
+"""
+
+ENCODER_DECODER_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Indices of decoder input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
+
+ For sequence to sequence training, `decoder_input_ids` should be provided. `decoder_input_ids` should be
+ created outside of the model by shifting the `labels` to the right, replacing -100 by the `pad_token_id`
+ and prepending them with the `decoder_start_token_id`.
+ decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
+ be used by default.
+ position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.encoder.max_position_embeddings - 1]`.
+ decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
+ range `[0, config.decoder.max_position_embeddings - 1]`.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ If set to `True`, the model will return a [`~utils.FlaxSeq2SeqLMOutput`] instead of a plain tuple.
+"""
+
+ENCODER_DECODER_ENCODE_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.encoder.max_position_embeddings - 1]`.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ If set to `True`, the model will return a [`~utils.FlaxBaseModelOutput`] instead of a plain tuple.
+"""
+
+ENCODER_DECODER_DECODE_INPUTS_DOCSTRING = r"""
+ Args:
+ decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Indices of decoder input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
+
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
+ `past_key_values`).
+
+ For sequence to sequence training, `decoder_input_ids` should be provided. `decoder_input_ids` should be
+ created outside of the model by shifting the `labels` to the right, replacing -100 by the `pad_token_id`
+ and prepending them with the `decoder_start_token_id`.
+ encoder_outputs (`tuple(tuple(jnp.ndarray)`):
+ Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
+ `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
+ hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
+ encoder_attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
+ be used by default.
+ decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
+ range `[0, config.decoder.max_position_embeddings - 1]`.
+ past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`):
+ Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
+ auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ If set to `True`, the model will return a [`~utils.FlaxCausalLMOutputWithCrossAttentions`] instead of a
+ plain tuple.
+"""
+
+
+class FlaxEncoderDecoderModule(nn.Module):
+ config: EncoderDecoderConfig
+ dtype: jnp.dtype = jnp.float32
+
+ def setup(self):
+ encoder_config = self.config.encoder
+ decoder_config = self.config.decoder
+
+ # Copied from `modeling_hybrid_clip.py` with modifications.
+ from ...models.auto.modeling_flax_auto import FLAX_MODEL_FOR_CAUSAL_LM_MAPPING, FLAX_MODEL_MAPPING
+
+ encoder_module = FLAX_MODEL_MAPPING[encoder_config.__class__].module_class
+ decoder_module = FLAX_MODEL_FOR_CAUSAL_LM_MAPPING[decoder_config.__class__].module_class
+
+ self.encoder = encoder_module(encoder_config, dtype=self.dtype)
+ self.decoder = decoder_module(decoder_config, dtype=self.dtype)
+
+ # encoder outputs might need to be projected to different dimension for decoder
+ if (
+ self.encoder.config.hidden_size != self.decoder.config.hidden_size
+ and self.decoder.config.cross_attention_hidden_size is None
+ ):
+ self.enc_to_dec_proj = nn.Dense(
+ self.decoder.config.hidden_size,
+ kernel_init=jax.nn.initializers.normal(self.decoder.config.initializer_range),
+ dtype=self.dtype,
+ )
+ else:
+ self.enc_to_dec_proj = None
+
+ def _get_encoder_module(self):
+ return self.encoder
+
+ def _get_projection_module(self):
+ return self.enc_to_dec_proj
+
+ def _get_decoder_module(self):
+ return self.decoder
+
+ def __call__(
+ self,
+ input_ids,
+ attention_mask,
+ decoder_input_ids,
+ decoder_attention_mask,
+ position_ids,
+ decoder_position_ids,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ deterministic: bool = True,
+ ):
+ encoder_outputs = self.encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ deterministic=deterministic,
+ )
+
+ encoder_hidden_states = encoder_outputs[0]
+
+ # optionally project encoder_hidden_states
+ if self.enc_to_dec_proj is not None:
+ encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
+
+ decoder_outputs = self.decoder(
+ input_ids=decoder_input_ids,
+ attention_mask=decoder_attention_mask,
+ position_ids=decoder_position_ids,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ deterministic=deterministic,
+ )
+
+ if not return_dict:
+ return decoder_outputs + encoder_outputs
+
+ return FlaxSeq2SeqLMOutput(
+ logits=decoder_outputs.logits,
+ decoder_hidden_states=decoder_outputs.hidden_states,
+ decoder_attentions=decoder_outputs.attentions,
+ cross_attentions=decoder_outputs.cross_attentions,
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
+ encoder_hidden_states=encoder_outputs.hidden_states,
+ encoder_attentions=encoder_outputs.attentions,
+ )
+
+
+@add_start_docstrings(ENCODER_DECODER_START_DOCSTRING)
+class FlaxEncoderDecoderModel(FlaxPreTrainedModel):
+ r"""
+ [`FlaxEncoderDecoderModel`] is a generic model class that will be instantiated as a transformer architecture with
+ the module (flax.nn.Module) of one of the base model classes of the library as encoder module and another one as
+ decoder module when created with the :meth*~transformers.FlaxAutoModel.from_pretrained* class method for the
+ encoder and :meth*~transformers.FlaxAutoModelForCausalLM.from_pretrained* class method for the decoder.
+ """
+
+ config_class = EncoderDecoderConfig
+ base_model_prefix = "encoder_decoder"
+ module_class = FlaxEncoderDecoderModule
+
+ def __init__(
+ self,
+ config: EncoderDecoderConfig,
+ input_shape: Optional[Tuple] = None,
+ seed: int = 0,
+ dtype: jnp.dtype = jnp.float32,
+ _do_init: bool = True,
+ **kwargs,
+ ):
+ if input_shape is None:
+ input_shape = ((1, 1), (1, 1))
+
+ if not _do_init:
+ raise ValueError(
+ "`FlaxEncoderDecoderModel` cannot be created without initializing, `_do_init` must be `True`."
+ )
+
+ if config.decoder.cross_attention_hidden_size is not None:
+ if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
+ raise ValueError(
+ "If `cross_attention_hidden_size` is specified in the decoder's configuration, it has to be equal"
+ f" to the encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for"
+ f" `config.decoder.cross_attention_hidden_size` and {config.encoder.hidden_size} for"
+ " `config.encoder.hidden_size`."
+ )
+
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
+
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
+ encoder_input_shape, decoder_input_shape = input_shape
+
+ # init input tensors
+ input_ids = jnp.zeros(encoder_input_shape, dtype="i4")
+ attention_mask = jnp.ones_like(input_ids)
+ decoder_input_ids = jnp.zeros(decoder_input_shape, dtype="i4")
+ decoder_attention_mask = jnp.ones_like(decoder_input_ids)
+
+ batch_size, sequence_length = input_ids.shape
+ position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
+
+ decoder_batch_size, decoder_sequence_length = decoder_input_ids.shape
+ if not decoder_batch_size == batch_size:
+ raise ValueError(
+ f"The inputs of encoder and decoder should have the same batch size, but got {batch_size} for encoder"
+ f" and {decoder_batch_size} for decoder."
+ )
+ decoder_position_ids = jnp.broadcast_to(
+ jnp.arange(decoder_sequence_length)[None, :], (decoder_batch_size, decoder_sequence_length)
+ )
+
+ params_rng, dropout_rng = jax.random.split(rng)
+ rngs = {"params": params_rng, "dropout": dropout_rng}
+
+ random_params = self.module.init(
+ rngs,
+ input_ids,
+ attention_mask,
+ decoder_input_ids,
+ decoder_attention_mask,
+ position_ids,
+ decoder_position_ids,
+ )["params"]
+
+ if params is not None:
+ random_params = flatten_dict(unfreeze(random_params))
+ params = flatten_dict(unfreeze(params))
+ for missing_key in self._missing_keys:
+ params[missing_key] = random_params[missing_key]
+ self._missing_keys = set()
+ return freeze(unflatten_dict(params))
+ else:
+ return random_params
+
+ def init_cache(self, batch_size, max_length, encoder_outputs):
+ r"""
+ Args:
+ batch_size (`int`):
+ batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
+ max_length (`int`):
+ maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
+ cache.
+ encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`):
+ `encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*:
+ `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*)
+ is a sequence of hidden-states at the output of the last layer of the encoder. Used in the
+ cross-attention of the decoder.
+ """
+ # init input variables to retrieve cache
+ decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4")
+ decoder_attention_mask = jnp.ones_like(decoder_input_ids)
+ decoder_position_ids = jnp.broadcast_to(
+ jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape
+ )
+
+ def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
+ decoder_module = module._get_decoder_module()
+ return decoder_module(
+ input_ids=decoder_input_ids,
+ attention_mask=decoder_attention_mask,
+ position_ids=decoder_position_ids,
+ **kwargs,
+ )
+
+ init_variables = self.module.init(
+ jax.random.PRNGKey(0),
+ decoder_input_ids=decoder_input_ids,
+ decoder_attention_mask=decoder_attention_mask,
+ decoder_position_ids=decoder_position_ids,
+ encoder_hidden_states=encoder_outputs[0],
+ init_cache=True,
+ method=_decoder_forward, # we only need to call the decoder to init the cache
+ )
+ return unfreeze(init_variables["cache"])
+
+ @add_start_docstrings(ENCODER_DECODER_ENCODE_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=_CONFIG_FOR_DOC)
+ def encode(
+ self,
+ input_ids: jnp.ndarray,
+ attention_mask: Optional[jnp.ndarray] = None,
+ position_ids: Optional[jnp.ndarray] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ train: bool = False,
+ params: dict = None,
+ dropout_rng: PRNGKey = None,
+ ):
+ r"""
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import FlaxEncoderDecoderModel, BertTokenizer
+
+ >>> # initialize a bert2gpt2 from pretrained BERT and GPT2 models. Note that the cross-attention layers will be randomly initialized
+ >>> model = FlaxEncoderDecoderModel.from_encoder_decoder_pretrained("google-bert/bert-base-cased", "openai-community/gpt2")
+
+ >>> tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-cased")
+
+ >>> text = "My friends are cool but they eat too many carbs."
+ >>> input_ids = tokenizer.encode(text, return_tensors="np")
+ >>> encoder_outputs = model.encode(input_ids)
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
+
+ if attention_mask is None:
+ attention_mask = jnp.ones_like(input_ids)
+ if position_ids is None:
+ batch_size, sequence_length = input_ids.shape
+ position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
+
+ # Handle any PRNG if needed
+ rngs = {}
+ if dropout_rng is not None:
+ rngs["dropout"] = dropout_rng
+
+ def _encoder_forward(module, input_ids, attention_mask, position_ids, **kwargs):
+ encode_module = module._get_encoder_module()
+ return encode_module(input_ids, attention_mask, position_ids, **kwargs)
+
+ outputs = self.module.apply(
+ {"params": params or self.params},
+ input_ids=jnp.array(input_ids, dtype="i4"),
+ attention_mask=jnp.array(attention_mask, dtype="i4"),
+ position_ids=jnp.array(position_ids, dtype="i4"),
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ deterministic=not train,
+ rngs=rngs,
+ method=_encoder_forward,
+ )
+
+ if return_dict:
+ outputs = FlaxBaseModelOutput(
+ last_hidden_state=outputs.last_hidden_state,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ return outputs
+
+ @add_start_docstrings(ENCODER_DECODER_DECODE_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
+ def decode(
+ self,
+ decoder_input_ids,
+ encoder_outputs,
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
+ decoder_attention_mask: Optional[jnp.ndarray] = None,
+ decoder_position_ids: Optional[jnp.ndarray] = None,
+ past_key_values: dict = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ train: bool = False,
+ params: dict = None,
+ dropout_rng: PRNGKey = None,
+ ):
+ r"""
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import FlaxEncoderDecoderModel, BertTokenizer
+ >>> import jax.numpy as jnp
+
+ >>> # initialize a bert2gpt2 from pretrained BERT and GPT2 models. Note that the cross-attention layers will be randomly initialized
+ >>> model = FlaxEncoderDecoderModel.from_encoder_decoder_pretrained("google-bert/bert-base-cased", "openai-community/gpt2")
+
+ >>> tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-cased")
+
+ >>> text = "My friends are cool but they eat too many carbs."
+ >>> input_ids = tokenizer.encode(text, max_length=1024, return_tensors="np")
+ >>> encoder_outputs = model.encode(input_ids)
+
+ >>> decoder_start_token_id = model.config.decoder.bos_token_id
+ >>> decoder_input_ids = jnp.ones((input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
+
+ >>> outputs = model.decode(decoder_input_ids, encoder_outputs)
+ >>> logits = outputs.logits
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
+
+ encoder_hidden_states = encoder_outputs[0]
+ if encoder_attention_mask is None:
+ batch_size, sequence_length = encoder_hidden_states.shape[:2]
+ encoder_attention_mask = jnp.ones((batch_size, sequence_length))
+
+ batch_size, sequence_length = decoder_input_ids.shape
+ if decoder_attention_mask is None:
+ decoder_attention_mask = jnp.ones((batch_size, sequence_length))
+
+ if decoder_position_ids is None:
+ if past_key_values is not None:
+ raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.")
+
+ decoder_position_ids = jnp.broadcast_to(
+ jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
+ )
+
+ # Handle any PRNG if needed
+ rngs = {}
+ if dropout_rng is not None:
+ rngs["dropout"] = dropout_rng
+
+ inputs = {"params": params or self.params}
+
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be
+ # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
+ # it can be changed by FlaxBartAttention module
+ if past_key_values:
+ inputs["cache"] = past_key_values
+ mutable = ["cache"]
+ else:
+ mutable = False
+
+ def _decoder_forward(
+ module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, encoder_hidden_states, **kwargs
+ ):
+ projection_module = module._get_projection_module()
+ decoder_module = module._get_decoder_module()
+
+ # optionally project encoder_hidden_states
+ if projection_module is not None:
+ encoder_hidden_states = projection_module(encoder_hidden_states)
+
+ return decoder_module(
+ decoder_input_ids,
+ decoder_attention_mask,
+ decoder_position_ids,
+ encoder_hidden_states=encoder_hidden_states,
+ **kwargs,
+ )
+
+ outputs = self.module.apply(
+ inputs,
+ decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
+ decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
+ decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ deterministic=not train,
+ rngs=rngs,
+ mutable=mutable,
+ method=_decoder_forward,
+ )
+
+ # add updated cache to model output
+ if past_key_values is not None and return_dict:
+ outputs, past = outputs
+ outputs["past_key_values"] = unfreeze(past["cache"])
+ return outputs
+ elif past_key_values is not None and not return_dict:
+ outputs, past = outputs
+ outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
+
+ return outputs
+
+ @add_start_docstrings_to_model_forward(ENCODER_DECODER_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
+ def __call__(
+ self,
+ input_ids: jnp.ndarray,
+ attention_mask: Optional[jnp.ndarray] = None,
+ decoder_input_ids: Optional[jnp.ndarray] = None,
+ decoder_attention_mask: Optional[jnp.ndarray] = None,
+ position_ids: Optional[jnp.ndarray] = None,
+ decoder_position_ids: Optional[jnp.ndarray] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ train: bool = False,
+ params: dict = None,
+ dropout_rng: PRNGKey = None,
+ ):
+ r"""
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import FlaxEncoderDecoderModel, BertTokenizer, GPT2Tokenizer
+
+ >>> # load a fine-tuned bert2gpt2 model
+ >>> model = FlaxEncoderDecoderModel.from_pretrained("patrickvonplaten/bert2gpt2-cnn_dailymail-fp16")
+ >>> # load input & output tokenizer
+ >>> tokenizer_input = BertTokenizer.from_pretrained("google-bert/bert-base-cased")
+ >>> tokenizer_output = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
+
+ >>> article = '''Sigma Alpha Epsilon is under fire for a video showing party-bound fraternity members
+ >>> singing a racist chant. SAE's national chapter suspended the students,
+ >>> but University of Oklahoma President David Boren took it a step further,
+ >>> saying the university's affiliation with the fraternity is permanently done.'''
+
+ >>> input_ids = tokenizer_input(article, add_special_tokens=True, return_tensors="np").input_ids
+
+ >>> # use GPT2's eos_token as the pad as well as eos token
+ >>> model.config.eos_token_id = model.config.decoder.eos_token_id
+ >>> model.config.pad_token_id = model.config.eos_token_id
+
+ >>> sequences = model.generate(input_ids, num_beams=4, max_length=12).sequences
+
+ >>> summary = tokenizer_output.batch_decode(sequences, skip_special_tokens=True)[0]
+ >>> assert summary == "SAS Alpha Epsilon suspended Sigma Alpha Epsilon members"
+ ```
+ """
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
+
+ # prepare encoder inputs
+ if attention_mask is None:
+ attention_mask = jnp.ones_like(input_ids)
+ if position_ids is None:
+ batch_size, sequence_length = input_ids.shape
+ position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
+
+ # prepare decoder inputs
+ if decoder_input_ids is None:
+ raise ValueError(
+ "`decoder_input_ids` cannot be `None`. For sequence to sequence training, `decoder_position_ids` must"
+ " be specified as an input argument."
+ )
+ if decoder_attention_mask is None:
+ decoder_attention_mask = jnp.ones_like(decoder_input_ids)
+ if decoder_position_ids is None:
+ batch_size, sequence_length = decoder_input_ids.shape
+ decoder_position_ids = jnp.broadcast_to(
+ jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
+ )
+
+ # Handle any PRNG if needed
+ rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
+
+ return self.module.apply(
+ {"params": params or self.params},
+ input_ids=jnp.array(input_ids, dtype="i4"),
+ attention_mask=jnp.array(attention_mask, dtype="i4"),
+ decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
+ decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
+ position_ids=jnp.array(position_ids, dtype="i4"),
+ decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ deterministic=not train,
+ rngs=rngs,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ decoder_input_ids,
+ max_length,
+ attention_mask: Optional[jax.Array] = None,
+ decoder_attention_mask: Optional[jax.Array] = None,
+ encoder_outputs=None,
+ **kwargs,
+ ):
+ # initializing the cache
+ batch_size, seq_length = decoder_input_ids.shape
+
+ past_key_values = self.init_cache(batch_size, max_length, encoder_outputs)
+ # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
+ # But since the decoder uses a causal mask, those positions are masked anyways.
+ # Thus we can create a single static attention_mask here, which is more efficient for compilation
+ extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
+ if decoder_attention_mask is not None:
+ decoder_position_ids = decoder_attention_mask.cumsum(axis=-1) - 1
+ extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0))
+ else:
+ decoder_position_ids = jnp.broadcast_to(
+ jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)
+ )
+
+ return {
+ "past_key_values": past_key_values,
+ "encoder_outputs": encoder_outputs,
+ "encoder_attention_mask": attention_mask,
+ "decoder_attention_mask": extended_attention_mask,
+ "decoder_position_ids": decoder_position_ids,
+ }
+
+ def update_inputs_for_generation(self, model_outputs, model_kwargs):
+ model_kwargs["past_key_values"] = model_outputs.past_key_values
+ model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1
+ return model_kwargs
+
+ @classmethod
+ def from_encoder_decoder_pretrained(
+ cls,
+ encoder_pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
+ decoder_pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
+ *model_args,
+ **kwargs,
+ ) -> FlaxPreTrainedModel:
+ r"""
+ Instantiate an encoder and a decoder from one or two base classes of the library from pretrained model
+ checkpoints.
+
+ Params:
+ encoder_pretrained_model_name_or_path (`Union[str, os.PathLike]`, *optional*):
+ Information necessary to initiate the encoder. Can be either:
+
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
+ - A path to a *directory* containing model weights saved using
+ [`~FlaxPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
+
+ decoder_pretrained_model_name_or_path (`Union[str, os.PathLike]`, *optional*, defaults to `None`):
+ Information necessary to initiate the decoder. Can be either:
+
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
+ - A path to a *directory* containing model weights saved using
+ [`~FlaxPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
+
+ model_args (remaining positional arguments, *optional*):
+ All remaning positional arguments will be passed to the underlying model's `__init__` method.
+
+ kwargs (remaining dictionary of keyword arguments, *optional*):
+ Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
+ `output_attentions=True`).
+
+ - To update the encoder configuration, use the prefix *encoder_* for each configuration parameter.
+ - To update the decoder configuration, use the prefix *decoder_* for each configuration parameter.
+ - To update the parent model configuration, do not use a prefix for each configuration parameter.
+
+ Behaves differently depending on whether a `config` is provided or automatically loaded.
+
+ Example:
+
+ ```python
+ >>> from transformers import FlaxEncoderDecoderModel
+
+ >>> # initialize a bert2gpt2 from pretrained BERT and GPT2 models. Note that the cross-attention layers will be randomly initialized
+ >>> model = FlaxEncoderDecoderModel.from_encoder_decoder_pretrained("google-bert/bert-base-cased", "openai-community/gpt2")
+ >>> # saving model after fine-tuning
+ >>> model.save_pretrained("./bert2gpt2")
+ >>> # load fine-tuned model
+ >>> model = FlaxEncoderDecoderModel.from_pretrained("./bert2gpt2")
+ ```"""
+
+ kwargs_encoder = {
+ argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_")
+ }
+
+ kwargs_decoder = {
+ argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
+ }
+
+ # remove encoder, decoder kwargs from kwargs
+ for key in kwargs_encoder.keys():
+ del kwargs["encoder_" + key]
+ for key in kwargs_decoder.keys():
+ del kwargs["decoder_" + key]
+
+ # Load and initialize the encoder and decoder
+ # The distinction between encoder and decoder at the model level is made
+ # by the value of the flag `is_decoder` that we need to set correctly.
+ encoder = kwargs_encoder.pop("model", None)
+ if encoder is None:
+ if encoder_pretrained_model_name_or_path is None:
+ raise ValueError(
+ "If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has "
+ "to be defined."
+ )
+
+ if "config" not in kwargs_encoder:
+ encoder_config, kwargs_encoder = AutoConfig.from_pretrained(
+ encoder_pretrained_model_name_or_path, **kwargs_encoder, return_unused_kwargs=True
+ )
+ if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
+ logger.info(
+ f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
+ "from a decoder model. Cross-attention and casual mask are disabled."
+ )
+ encoder_config.is_decoder = False
+ encoder_config.add_cross_attention = False
+
+ kwargs_encoder["config"] = encoder_config
+
+ encoder = FlaxAutoModel.from_pretrained(
+ encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder
+ )
+
+ decoder = kwargs_decoder.pop("model", None)
+ if decoder is None:
+ if decoder_pretrained_model_name_or_path is None:
+ raise ValueError(
+ "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has "
+ "to be defined."
+ )
+
+ if "config" not in kwargs_decoder:
+ decoder_config, kwargs_decoder = AutoConfig.from_pretrained(
+ decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True
+ )
+ if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
+ logger.info(
+ f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention"
+ f" layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if"
+ f" {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
+ )
+ decoder_config.is_decoder = True
+ decoder_config.add_cross_attention = True
+
+ kwargs_decoder["config"] = decoder_config
+
+ if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False:
+ logger.warning(
+ f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. "
+ f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, "
+ "make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` "
+ "passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a "
+ "`decoder_config` to `.from_encoder_decoder_pretrained(...)`"
+ )
+
+ decoder = FlaxAutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
+
+ # instantiate config with corresponding kwargs
+ dtype = kwargs.pop("dtype", jnp.float32)
+ config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs)
+
+ # init model
+ model = cls(config, dtype=dtype)
+ model.params["encoder"] = encoder.params
+ model.params["decoder"] = decoder.params
+
+ return model
+
+
+__all__ = ["FlaxEncoderDecoderModel"]
diff --git a/docs/transformers/build/lib/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py b/docs/transformers/build/lib/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..a5abafc361b6b4e71e365d7b7f3caa3368f82d92
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py
@@ -0,0 +1,665 @@
+# coding=utf-8
+# Copyright 2021 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Classes to support TF Encoder-Decoder architectures"""
+
+from __future__ import annotations
+
+import inspect
+import re
+import warnings
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import tensorflow as tf
+
+from ...configuration_utils import PretrainedConfig
+from ...modeling_tf_outputs import TFBaseModelOutput, TFSeq2SeqLMOutput
+from ...modeling_tf_utils import (
+ TFCausalLanguageModelingLoss,
+ TFModelInputType,
+ TFPreTrainedModel,
+ get_initializer,
+ keras,
+ unpack_inputs,
+)
+from ...tf_utils import shape_list
+from ...utils import (
+ ModelOutput,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from ..auto.configuration_auto import AutoConfig
+from ..auto.modeling_tf_auto import TFAutoModel, TFAutoModelForCausalLM
+from .configuration_encoder_decoder import EncoderDecoderConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "EncoderDecoderConfig"
+
+DEPRECATION_WARNING = (
+ "Version v4.17.0 introduces a better way to train encoder-decoder models by computing the loss inside the"
+ " encoder-decoder framework rather than in the decoder itself. You may observe training discrepancies if"
+ " fine-tuning a model trained with versions anterior to 4.17.0. The decoder_input_ids are now created based on the"
+ " labels, no need to pass them yourself anymore."
+)
+
+ENCODER_DECODER_START_DOCSTRING = r"""
+ This class can be used to initialize a sequence-to-sequence model with any pretrained autoencoding model as the
+ encoder and any pretrained autoregressive model as the decoder. The encoder is loaded via
+ [`~TFAutoModel.from_pretrained`] function and the decoder is loaded via [`~TFAutoModelForCausalLM.from_pretrained`]
+ function. Cross-attention layers are automatically added to the decoder and should be fine-tuned on a downstream
+ generative task, like summarization.
+
+ The effectiveness of initializing sequence-to-sequence models with pretrained checkpoints for sequence generation
+ tasks was shown in [Leveraging Pre-trained Checkpoints for Sequence Generation
+ Tasks](https://arxiv.org/abs/1907.12461) by Sascha Rothe, Shashi Narayan, Aliaksei Severyn. Michael Matena, Yanqi
+ Zhou, Wei Li, Peter J. Liu.
+
+ After such an Encoder Decoder model has been trained/fine-tuned, it can be saved/loaded just like any other models
+ (see the examples for more information).
+
+ This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
+ as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
+ behavior.
+
+ Parameters:
+ config ([`EncoderDecoderConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+ENCODER_DECODER_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`np.ndarray`, `tf.Tensor`, `List[tf.Tensor]` ``Dict[str, tf.Tensor]` or `Dict[str, np.ndarray]` and each example must have the shape `({0})`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ decoder_input_ids (`np.ndarray` or `tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Indices of decoder input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
+ `past_key_values`).
+
+ Provide for sequence to sequence training to the decoder. Indices can be obtained using
+ [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for
+ details.
+ decoder_attention_mask (`np.ndarray` or `tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
+ be used by default.
+ encoder_outputs (`tuple(tuple(tf.Tensor)`, *optional*):
+ This tuple must consist of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
+ `last_hidden_state` (`tf.Tensor` of shape `({0}, hidden_size)`) is a tensor of hidden-states at the output
+ of the last layer of the encoder. Used in the cross-attention of the decoder.
+ past_key_values (`tuple(tuple(tf.Tensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `({0})`.
+ inputs_embeds (`np.ndarray` or `tf.Tensor` of shape `({0}, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ decoder_inputs_embeds (`np.ndarray` or `tf.Tensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
+ representation. This is useful if you want more control over how to convert `decoder_input_ids` indices
+ into associated vectors than the model's internal embedding lookup matrix.
+ labels (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*):
+ Labels for computing the masked language modeling loss for the decoder. Indices should be in `[-100, 0,
+ ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ If set to `True`, the model will return a [`~utils.Seq2SeqLMOutput`] instead of a plain tuple.
+ training (`bool`, *optional*, defaults to `False`):
+ Whether or not to use the model in training mode (some modules like dropout modules have different
+ behaviors between training and evaluation).
+ kwargs (*optional*): Remaining dictionary of keyword arguments. Keyword arguments come in two flavors:
+
+ - Without a prefix which will be input as `**encoder_kwargs` for the encoder forward function.
+ - With a *decoder_* prefix which will be input as `**decoder_kwargs`` for the decoder forward function.
+"""
+
+
+def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int):
+ if pad_token_id is None:
+ raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.")
+ pad_token_id = tf.cast(pad_token_id, input_ids.dtype)
+
+ if decoder_start_token_id is None:
+ raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.")
+ decoder_start_token_id = tf.cast(decoder_start_token_id, input_ids.dtype)
+
+ start_tokens = tf.fill((shape_list(input_ids)[0], 1), decoder_start_token_id)
+ shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1)
+ # replace possible -100 values in labels by `pad_token_id`
+ shifted_input_ids = tf.where(
+ shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
+ )
+
+ # "Verify that `labels` has only positive values and -100"
+ assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
+
+ # Make sure the assertion op is called by wrapping the result in an identity no-op
+ with tf.control_dependencies([assert_gte0]):
+ shifted_input_ids = tf.identity(shifted_input_ids)
+
+ return shifted_input_ids
+
+
+@add_start_docstrings(ENCODER_DECODER_START_DOCSTRING)
+class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
+ r"""
+ [`TFEncoderDecoderModel`] is a generic model class that will be instantiated as a transformer architecture with one
+ of the base model classes of the library as encoder and another one as decoder when created with the
+ [`~TFAutoModel.from_pretrained`] class method for the encoder and [`~TFAutoModelForCausalLM.from_pretrained`] class
+ method for the decoder.
+ """
+
+ config_class = EncoderDecoderConfig
+ base_model_prefix = "encoder_decoder"
+ load_weight_prefix = "tf_encoder_decoder_model"
+
+ def __init__(
+ self,
+ config: Optional[PretrainedConfig] = None,
+ encoder: Optional[TFPreTrainedModel] = None,
+ decoder: Optional[TFPreTrainedModel] = None,
+ ):
+ if config is None and (encoder is None or decoder is None):
+ raise ValueError("Either a configuration or an encoder and a decoder has to be provided.")
+ if config is None:
+ config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config)
+ else:
+ if not isinstance(config, self.config_class):
+ raise ValueError(f"config: {config} has to be of type {self.config_class}")
+
+ if config.decoder.cross_attention_hidden_size is not None:
+ if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
+ raise ValueError(
+ "If `cross_attention_hidden_size` is specified in the decoder's configuration, it has to be equal"
+ f" to the encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for"
+ f" `config.decoder.cross_attention_hidden_size` and {config.encoder.hidden_size} for"
+ " `config.encoder.hidden_size`."
+ )
+
+ # initialize with config
+ super().__init__(config)
+
+ if encoder is None:
+ encoder = TFAutoModel.from_config(config.encoder, name="encoder")
+
+ if decoder is None:
+ decoder = TFAutoModelForCausalLM.from_config(config.decoder, name="decoder")
+
+ self.encoder = encoder
+ self.decoder = decoder
+
+ if self.encoder.config.to_dict() != self.config.encoder.to_dict():
+ logger.warning(
+ f"Config of the encoder: {self.encoder.__class__} is overwritten by shared encoder config:"
+ f" {self.config.encoder}"
+ )
+ if self.decoder.config.to_dict() != self.config.decoder.to_dict():
+ logger.warning(
+ f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config:"
+ f" {self.config.decoder}"
+ )
+
+ # make sure that the individual model's config refers to the shared config
+ # so that the updates to the config will be synced
+ self.encoder.config = self.config.encoder
+ self.decoder.config = self.config.decoder
+
+ # encoder outputs might need to be projected to different dimension for decoder
+ if (
+ self.encoder.config.hidden_size != self.decoder.config.hidden_size
+ and self.decoder.config.cross_attention_hidden_size is None
+ ):
+ self.enc_to_dec_proj = keras.layers.Dense(
+ units=self.decoder.config.hidden_size,
+ kernel_initializer=get_initializer(config.encoder.initializer_range),
+ name="enc_to_dec_proj",
+ )
+
+ if self.encoder.get_output_embeddings() is not None:
+ raise ValueError(
+ f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head"
+ )
+
+ decoder_signature = set(inspect.signature(self.decoder.call).parameters.keys())
+ if "encoder_hidden_states" not in decoder_signature:
+ raise ValueError(
+ "The selected decoder is not prepared for the encoder hidden states to be passed. Please see the "
+ "following discussion on GitHub: https://github.com/huggingface/transformers/issues/23350"
+ )
+
+ def get_encoder(self):
+ return self.encoder
+
+ def get_decoder(self):
+ return self.decoder
+
+ def get_input_embeddings(self):
+ return self.encoder.get_input_embeddings()
+
+ def get_output_embeddings(self):
+ return self.decoder.get_output_embeddings()
+
+ def set_output_embeddings(self, new_embeddings):
+ return self.decoder.set_output_embeddings(new_embeddings)
+
+ def tf_to_pt_weight_rename(self, tf_weight):
+ # Matt: The TF and PT weights don't align because our TF base classes have an extra layer compared to PT models
+ # (the main model stem is in the MainLayer class). If we remove that layer, then weight names sync up as normal.
+ # However, the name of that extra layer is the name of the MainLayer in the base model. We make the assumption
+ # here that the config model_type is the same as the name of the MainLayer. I don't know of anywhere that's
+ # not the case, and I wasn't sure how else to go from the config to the correct MainLayer name!
+
+ # This override is only needed in the case where we're crossloading weights from PT. However, since weights are
+ # often safetensors now, we don't know if we're going to be crossloading until we sniff the weights file.
+ # Therefore, we specify tf_to_pt_weight_rename anyway, and let the super method figure out if it needs it
+ # or not.
+ encoder_model_type = self.config.encoder.model_type
+ if "encoder" in tf_weight and "decoder" not in tf_weight:
+ return (re.sub(rf"encoder\.{encoder_model_type}\.", "encoder.", tf_weight),)
+ else:
+ return (tf_weight,)
+
+ @classmethod
+ def from_encoder_decoder_pretrained(
+ cls,
+ encoder_pretrained_model_name_or_path: Optional[str] = None,
+ decoder_pretrained_model_name_or_path: Optional[str] = None,
+ *model_args,
+ **kwargs,
+ ) -> TFPreTrainedModel:
+ r"""
+ Instantiate an encoder and a decoder from one or two base classes of the library from pretrained model
+ checkpoints.
+
+
+ Params:
+ encoder_pretrained_model_name_or_path (`str`, *optional*):
+ Information necessary to initiate the encoder. Can be either:
+
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
+ - A path to a *directory* containing model weights saved using
+ [`~TFPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
+ - A path or url to a *pytorch index checkpoint file* (e.g, `./pt_model/`). In this case,
+ `encoder_from_pt` should be set to `True`.
+
+ decoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `None`):
+ Information necessary to initiate the decoder. Can be either:
+
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
+ - A path to a *directory* containing model weights saved using
+ [`~TFPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
+ - A path or url to a *pytorch checkpoint file* (e.g, `./pt_model/`). In this case,
+ `decoder_from_pt` should be set to `True`.
+
+ model_args (remaining positional arguments, *optional*):
+ All remaning positional arguments will be passed to the underlying model's `__init__` method.
+
+ kwargs (remaining dictionary of keyword arguments, *optional*):
+ Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
+ `output_attentions=True`).
+
+ - To update the encoder configuration, use the prefix *encoder_* for each configuration parameter.
+ - To update the decoder configuration, use the prefix *decoder_* for each configuration parameter.
+ - To update the parent model configuration, do not use a prefix for each configuration parameter.
+
+ Behaves differently depending on whether a `config` is provided or automatically loaded.
+
+ Example:
+
+ ```python
+ >>> from transformers import TFEncoderDecoderModel
+
+ >>> # initialize a bert2gpt2 from two pretrained BERT models. Note that the cross-attention layers will be randomly initialized
+ >>> model = TFEncoderDecoderModel.from_encoder_decoder_pretrained("google-bert/bert-base-uncased", "openai-community/gpt2")
+ >>> # saving model after fine-tuning
+ >>> model.save_pretrained("./bert2gpt2")
+ >>> # load fine-tuned model
+ >>> model = TFEncoderDecoderModel.from_pretrained("./bert2gpt2")
+ ```"""
+
+ kwargs_encoder = {
+ argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_")
+ }
+
+ kwargs_decoder = {
+ argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
+ }
+
+ # remove encoder, decoder kwargs from kwargs
+ for key in kwargs_encoder.keys():
+ del kwargs["encoder_" + key]
+ for key in kwargs_decoder.keys():
+ del kwargs["decoder_" + key]
+
+ # Load and initialize the encoder and decoder
+ # The distinction between encoder and decoder at the model level is made
+ # by the value of the flag `is_decoder` that we need to set correctly.
+ encoder = kwargs_encoder.pop("model", None)
+ if encoder is None:
+ if encoder_pretrained_model_name_or_path is None:
+ raise ValueError(
+ "If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has "
+ "to be defined."
+ )
+
+ if "config" not in kwargs_encoder:
+ encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path)
+ if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
+ logger.info(
+ f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
+ "from a decoder model. Cross-attention and casual mask are disabled."
+ )
+ encoder_config.is_decoder = False
+ encoder_config.add_cross_attention = False
+
+ kwargs_encoder["config"] = encoder_config
+
+ kwargs_encoder["name"] = "encoder"
+ kwargs_encoder["load_weight_prefix"] = cls.load_weight_prefix
+ encoder = TFAutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder)
+
+ decoder = kwargs_decoder.pop("model", None)
+ if decoder is None:
+ if decoder_pretrained_model_name_or_path is None:
+ raise ValueError(
+ "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has "
+ "to be defined."
+ )
+
+ if "config" not in kwargs_decoder:
+ decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path)
+ if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
+ logger.info(
+ f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention"
+ f" layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if"
+ f" {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
+ )
+ decoder_config.is_decoder = True
+ decoder_config.add_cross_attention = True
+
+ kwargs_decoder["config"] = decoder_config
+
+ if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False:
+ logger.warning(
+ f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. "
+ f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, "
+ "make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` "
+ "passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a "
+ "`decoder_config` to `.from_encoder_decoder_pretrained(...)`"
+ )
+
+ kwargs_decoder["name"] = "decoder"
+ kwargs_decoder["load_weight_prefix"] = cls.load_weight_prefix
+ decoder = TFAutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
+
+ # Make sure these 2 `keras.Model` have fixed names so `from_pretrained` could load model weights correctly.
+ if encoder.name != "encoder":
+ raise ValueError("encoder model must be created with the name `encoder`.")
+ if decoder.name != "decoder":
+ raise ValueError("decoder model must be created with the name `decoder`.")
+
+ # instantiate config with corresponding kwargs
+ config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs)
+ return cls(encoder=encoder, decoder=decoder, config=config)
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(ENCODER_DECODER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ decoder_input_ids: np.ndarray | tf.Tensor | None = None,
+ decoder_attention_mask: np.ndarray | tf.Tensor | None = None,
+ encoder_outputs: np.ndarray | tf.Tensor | None = None,
+ past_key_values: Tuple[Tuple[tf.Tensor]] | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ decoder_inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ labels: np.ndarray | tf.Tensor | None = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ training: bool = False,
+ **kwargs,
+ ) -> Union[TFSeq2SeqLMOutput, Tuple[tf.Tensor]]:
+ r"""
+ Returns:
+
+ Examples:
+
+ ```python
+ >>> from transformers import TFEncoderDecoderModel, BertTokenizer
+
+ >>> # initialize a bert2gpt2 from a pretrained BERT and GPT2 models. Note that the cross-attention layers will be randomly initialized
+ >>> model = TFEncoderDecoderModel.from_encoder_decoder_pretrained("google-bert/bert-base-cased", "openai-community/gpt2")
+
+ >>> tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-cased")
+
+ >>> # forward
+ >>> input_ids = tokenizer.encode(
+ ... "Hello, my dog is cute", add_special_tokens=True, return_tensors="tf"
+ ... ) # Batch size 1
+ >>> outputs = model(input_ids=input_ids, decoder_input_ids=input_ids)
+
+ >>> # training
+ >>> outputs = model(input_ids=input_ids, decoder_input_ids=input_ids, labels=input_ids)
+ >>> loss, logits = outputs.loss, outputs.logits
+
+ >>> # save and load from pretrained
+ >>> model.save_pretrained("bert2gpt2")
+ >>> model = TFEncoderDecoderModel.from_pretrained("bert2gpt2")
+
+ >>> # generation
+ >>> generated = model.generate(input_ids, decoder_start_token_id=model.config.decoder.bos_token_id)
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")}
+
+ kwargs_decoder = {
+ argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
+ }
+
+ # Let the user be responsible for the expected format.
+ if encoder_outputs is not None:
+ if return_dict and not isinstance(encoder_outputs, ModelOutput):
+ raise ValueError(
+ "If `return_dict=True` and `encoder_outputs` is provided, it should be an instance of "
+ f"`ModelOutput`. Got an instance {type(encoder_outputs)} for `encoder_outputs`."
+ )
+
+ if encoder_outputs is None:
+ encoder_inputs = {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ "inputs_embeds": inputs_embeds,
+ "output_attentions": output_attentions,
+ "output_hidden_states": output_hidden_states,
+ "return_dict": return_dict,
+ "training": training,
+ }
+
+ # Add arguments to encoder from `kwargs_encoder`
+ encoder_inputs.update(kwargs_encoder)
+
+ # Handle the case where the inputs are passed as a single dict which contains `labels`.
+ # The `labels` shouldn't be passed to `self.encoder` below, because it is a based model without this
+ # parameter (otherwise, an error occurs when `input_processing` is called inside `self.encoder.call()`).
+ if "labels" in encoder_inputs:
+ labels = encoder_inputs.pop("labels")
+
+ # handle the init case where `dummy_inputs` returns a dict containing `decoder_input_ids`.
+ if "decoder_input_ids" in encoder_inputs:
+ decoder_input_ids = encoder_inputs.pop("decoder_input_ids")
+ # handle the init case where `dummy_inputs` returns a dict containing `decoder_input_ids`.
+ if "decoder_attention_mask" in encoder_inputs:
+ decoder_attention_mask = encoder_inputs.pop("decoder_attention_mask")
+
+ encoder_outputs = self.encoder(**encoder_inputs)
+
+ encoder_hidden_states = encoder_outputs[0]
+
+ # optionally project encoder_hidden_states
+ if (
+ self.encoder.config.hidden_size != self.decoder.config.hidden_size
+ and self.decoder.config.cross_attention_hidden_size is None
+ ):
+ encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
+
+ if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
+ decoder_input_ids = shift_tokens_right(
+ labels, self.config.pad_token_id, self.config.decoder_start_token_id
+ )
+
+ decoder_inputs = {
+ "input_ids": decoder_input_ids,
+ "attention_mask": decoder_attention_mask,
+ "encoder_hidden_states": encoder_hidden_states,
+ "encoder_attention_mask": attention_mask,
+ "inputs_embeds": decoder_inputs_embeds,
+ "output_attentions": output_attentions,
+ "output_hidden_states": output_hidden_states,
+ "use_cache": use_cache,
+ "past_key_values": past_key_values,
+ "return_dict": return_dict,
+ "training": training,
+ }
+
+ # Add arguments to decoder from `kwargs_decoder`
+ decoder_inputs.update(kwargs_decoder)
+
+ decoder_outputs = self.decoder(**decoder_inputs)
+
+ logits = decoder_outputs[0]
+
+ # Compute loss independent from decoder (as some shift the logits inside them)
+ loss = None
+ if labels is not None:
+ warnings.warn(DEPRECATION_WARNING, FutureWarning)
+ loss = self.hf_compute_loss(labels, logits)
+
+ if not return_dict:
+ past_key_values = None
+ if use_cache:
+ past_key_values = decoder_outputs[1]
+ # The starting index of the remaining elements in `decoder_outputs`
+ start_index = sum([1 if x is not None else 0 for x in (loss, logits, past_key_values)])
+
+ if not isinstance(encoder_outputs, tuple):
+ encoder_outputs = encoder_outputs.to_tuple()
+ output = (loss, logits, past_key_values) + decoder_outputs[start_index:] + encoder_outputs
+ output = tuple([x for x in output if x is not None])
+ return output
+
+ return TFSeq2SeqLMOutput(
+ loss=loss,
+ logits=decoder_outputs.logits,
+ past_key_values=decoder_outputs.past_key_values,
+ decoder_hidden_states=decoder_outputs.hidden_states,
+ decoder_attentions=decoder_outputs.attentions,
+ cross_attentions=decoder_outputs.cross_attentions,
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
+ encoder_hidden_states=encoder_outputs.hidden_states,
+ encoder_attentions=encoder_outputs.attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
+ ):
+ decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past_key_values=past_key_values)
+ decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
+ past_key_values = decoder_inputs.get("past_key_values")
+ if past_key_values is None:
+ past_key_values = decoder_inputs.get("past") # e.g. on TF GPT2
+ input_dict = {
+ "input_ids": None, # needs to be passed to make Keras.layer.__call__ happy
+ "attention_mask": attention_mask,
+ "decoder_attention_mask": decoder_attention_mask,
+ "decoder_input_ids": decoder_inputs["input_ids"],
+ # TODO (joao): the `TFBaseModelOutput` wrapper should not be needed after the generate refactor is complete
+ "encoder_outputs": TFBaseModelOutput(last_hidden_state=encoder_outputs[0]),
+ "past_key_values": past_key_values,
+ "use_cache": use_cache,
+ }
+ return input_dict
+
+ def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
+ return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
+
+ def resize_token_embeddings(self, *args, **kwargs):
+ raise NotImplementedError(
+ "Resizing the embedding layers via the TFEncoderDecoderModel directly is not supported.Please use the"
+ " respective methods of the wrapped objects (model.encoder.resize_token_embeddings(...) or"
+ " model.decoder.resize_token_embeddings(...))"
+ )
+
+ def _reorder_cache(self, past, beam_idx):
+ # apply decoder cache reordering here
+ return self.decoder._reorder_cache(past, beam_idx)
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "enc_to_dec_proj", None) is not None:
+ with tf.name_scope(self.enc_to_dec_proj.name):
+ self.enc_to_dec_proj.build([None, None, self.encoder.config.hidden_size])
+ if getattr(self, "encoder", None) is not None:
+ with tf.name_scope(self.encoder.name):
+ self.encoder.build(None)
+ if getattr(self, "decoder", None) is not None:
+ with tf.name_scope(self.decoder.name):
+ self.decoder.build(None)
+
+
+__all__ = ["TFEncoderDecoderModel"]
diff --git a/docs/transformers/build/lib/transformers/models/ernie/__init__.py b/docs/transformers/build/lib/transformers/models/ernie/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9bb8983063ddb0117e8b0d7cd6603aa6ac3056b6
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/ernie/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_ernie import *
+ from .modeling_ernie import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/docs/transformers/build/lib/transformers/models/ernie/configuration_ernie.py b/docs/transformers/build/lib/transformers/models/ernie/configuration_ernie.py
new file mode 100644
index 0000000000000000000000000000000000000000..655e40e163b59dac4f2cab5fe96265b2173478c1
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/ernie/configuration_ernie.py
@@ -0,0 +1,163 @@
+# coding=utf-8
+# Copyright 2022 The Google AI Language Team Authors and The HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""ERNIE model configuration"""
+
+from collections import OrderedDict
+from typing import Mapping
+
+from ...configuration_utils import PretrainedConfig
+from ...onnx import OnnxConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class ErnieConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`ErnieModel`] or a [`TFErnieModel`]. It is used to
+ instantiate a ERNIE model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the ERNIE
+ [nghuyong/ernie-3.0-base-zh](https://huggingface.co/nghuyong/ernie-3.0-base-zh) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 30522):
+ Vocabulary size of the ERNIE model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`ErnieModel`] or [`TFErnieModel`].
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ intermediate_size (`int`, *optional*, defaults to 3072):
+ Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
+ hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for the attention probabilities.
+ max_position_embeddings (`int`, *optional*, defaults to 512):
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
+ just in case (e.g., 512 or 1024 or 2048).
+ type_vocab_size (`int`, *optional*, defaults to 2):
+ The vocabulary size of the `token_type_ids` passed when calling [`ErnieModel`] or [`TFErnieModel`].
+ task_type_vocab_size (`int`, *optional*, defaults to 3):
+ The vocabulary size of the `task_type_ids` for ERNIE2.0/ERNIE3.0 model
+ use_task_id (`bool`, *optional*, defaults to `False`):
+ Whether or not the model support `task_type_ids`
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
+ The epsilon used by the layer normalization layers.
+ pad_token_id (`int`, *optional*, defaults to 0):
+ Padding token id.
+ position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
+ Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
+ positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
+ [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155).
+ For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
+ with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658).
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ classifier_dropout (`float`, *optional*):
+ The dropout ratio for the classification head.
+
+ Examples:
+
+ ```python
+ >>> from transformers import ErnieConfig, ErnieModel
+
+ >>> # Initializing a ERNIE nghuyong/ernie-3.0-base-zh style configuration
+ >>> configuration = ErnieConfig()
+
+ >>> # Initializing a model (with random weights) from the nghuyong/ernie-3.0-base-zh style configuration
+ >>> model = ErnieModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "ernie"
+
+ def __init__(
+ self,
+ vocab_size=30522,
+ hidden_size=768,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ intermediate_size=3072,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ max_position_embeddings=512,
+ type_vocab_size=2,
+ task_type_vocab_size=3,
+ use_task_id=False,
+ initializer_range=0.02,
+ layer_norm_eps=1e-12,
+ pad_token_id=0,
+ position_embedding_type="absolute",
+ use_cache=True,
+ classifier_dropout=None,
+ **kwargs,
+ ):
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
+
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.hidden_act = hidden_act
+ self.intermediate_size = intermediate_size
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.max_position_embeddings = max_position_embeddings
+ self.type_vocab_size = type_vocab_size
+ self.task_type_vocab_size = task_type_vocab_size
+ self.use_task_id = use_task_id
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.position_embedding_type = position_embedding_type
+ self.use_cache = use_cache
+ self.classifier_dropout = classifier_dropout
+
+
+class ErnieOnnxConfig(OnnxConfig):
+ @property
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
+ if self.task == "multiple-choice":
+ dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
+ else:
+ dynamic_axis = {0: "batch", 1: "sequence"}
+ return OrderedDict(
+ [
+ ("input_ids", dynamic_axis),
+ ("attention_mask", dynamic_axis),
+ ("token_type_ids", dynamic_axis),
+ ("task_type_ids", dynamic_axis),
+ ]
+ )
+
+
+__all__ = ["ErnieConfig", "ErnieOnnxConfig"]
diff --git a/docs/transformers/build/lib/transformers/models/ernie/modeling_ernie.py b/docs/transformers/build/lib/transformers/models/ernie/modeling_ernie.py
new file mode 100644
index 0000000000000000000000000000000000000000..221dd57485880f6210930376dc536028f3463452
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/ernie/modeling_ernie.py
@@ -0,0 +1,1825 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch ERNIE model."""
+
+import math
+import warnings
+from dataclasses import dataclass
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from ...activations import ACT2FN
+from ...generation import GenerationMixin
+from ...modeling_outputs import (
+ BaseModelOutputWithPastAndCrossAttentions,
+ BaseModelOutputWithPoolingAndCrossAttentions,
+ CausalLMOutputWithCrossAttentions,
+ MaskedLMOutput,
+ MultipleChoiceModelOutput,
+ NextSentencePredictorOutput,
+ QuestionAnsweringModelOutput,
+ SequenceClassifierOutput,
+ TokenClassifierOutput,
+)
+from ...modeling_utils import PreTrainedModel
+from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
+from ...utils import (
+ ModelOutput,
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from .configuration_ernie import ErnieConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "nghuyong/ernie-1.0-base-zh"
+_CONFIG_FOR_DOC = "ErnieConfig"
+
+
+class ErnieEmbeddings(nn.Module):
+ """Construct the embeddings from word, position and token_type embeddings."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
+ self.use_task_id = config.use_task_id
+ if config.use_task_id:
+ self.task_type_embeddings = nn.Embedding(config.task_type_vocab_size, config.hidden_size)
+
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
+ # any TensorFlow checkpoint file
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
+ self.register_buffer(
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
+ )
+ self.register_buffer(
+ "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
+ )
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ task_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ past_key_values_length: int = 0,
+ ) -> torch.Tensor:
+ if input_ids is not None:
+ input_shape = input_ids.size()
+ else:
+ input_shape = inputs_embeds.size()[:-1]
+
+ seq_length = input_shape[1]
+
+ if position_ids is None:
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
+
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
+ # issue #5664
+ if token_type_ids is None:
+ if hasattr(self, "token_type_ids"):
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
+ token_type_ids = buffered_token_type_ids_expanded
+ else:
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.word_embeddings(input_ids)
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
+
+ embeddings = inputs_embeds + token_type_embeddings
+ if self.position_embedding_type == "absolute":
+ position_embeddings = self.position_embeddings(position_ids)
+ embeddings += position_embeddings
+
+ # add `task_type_id` for ERNIE model
+ if self.use_task_id:
+ if task_type_ids is None:
+ task_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
+ task_type_embeddings = self.task_type_embeddings(task_type_ids)
+ embeddings += task_type_embeddings
+
+ embeddings = self.LayerNorm(embeddings)
+ embeddings = self.dropout(embeddings)
+ return embeddings
+
+
+# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->Ernie
+class ErnieSelfAttention(nn.Module):
+ def __init__(self, config, position_embedding_type=None):
+ super().__init__()
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+ raise ValueError(
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
+ f"heads ({config.num_attention_heads})"
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+ self.position_embedding_type = position_embedding_type or getattr(
+ config, "position_embedding_type", "absolute"
+ )
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+ self.max_position_embeddings = config.max_position_embeddings
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
+
+ self.is_decoder = config.is_decoder
+
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
+ x = x.view(new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor]:
+ mixed_query_layer = self.query(hidden_states)
+
+ # If this is instantiated as a cross-attention module, the keys
+ # and values come from an encoder; the attention mask needs to be
+ # such that the encoder's padding tokens are not attended to.
+ is_cross_attention = encoder_hidden_states is not None
+
+ if is_cross_attention and past_key_value is not None:
+ # reuse k,v, cross_attentions
+ key_layer = past_key_value[0]
+ value_layer = past_key_value[1]
+ attention_mask = encoder_attention_mask
+ elif is_cross_attention:
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
+ attention_mask = encoder_attention_mask
+ elif past_key_value is not None:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
+ else:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+
+ use_cache = past_key_value is not None
+ if self.is_decoder:
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_layer, value_layer)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+ query_length, key_length = query_layer.shape[2], key_layer.shape[2]
+ if use_cache:
+ position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
+ -1, 1
+ )
+ else:
+ position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
+ position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
+ distance = position_ids_l - position_ids_r
+
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
+
+ if self.position_embedding_type == "relative_key":
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+ attention_scores = attention_scores + relative_position_scores
+ elif self.position_embedding_type == "relative_key_query":
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
+
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+ if attention_mask is not None:
+ # Apply the attention mask is (precomputed for all layers in ErnieModel forward() function)
+ attention_scores = attention_scores + attention_mask
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs = attention_probs * head_mask
+
+ context_layer = torch.matmul(attention_probs, value_layer)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(new_context_layer_shape)
+
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+ if self.is_decoder:
+ outputs = outputs + (past_key_value,)
+ return outputs
+
+
+# Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->Ernie
+class ErnieSelfOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+ERNIE_SELF_ATTENTION_CLASSES = {
+ "eager": ErnieSelfAttention,
+}
+
+
+# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Ernie,BERT->ERNIE
+class ErnieAttention(nn.Module):
+ def __init__(self, config, position_embedding_type=None):
+ super().__init__()
+ self.self = ERNIE_SELF_ATTENTION_CLASSES[config._attn_implementation](
+ config, position_embedding_type=position_embedding_type
+ )
+ self.output = ErnieSelfOutput(config)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
+ )
+
+ # Prune linear layers
+ self.self.query = prune_linear_layer(self.self.query, index)
+ self.self.key = prune_linear_layer(self.self.key, index)
+ self.self.value = prune_linear_layer(self.self.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor]:
+ self_outputs = self.self(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ )
+ attention_output = self.output(self_outputs[0], hidden_states)
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+ return outputs
+
+
+# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->Ernie
+class ErnieIntermediate(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->Ernie
+class ErnieOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Ernie
+class ErnieLayer(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.seq_len_dim = 1
+ self.attention = ErnieAttention(config)
+ self.is_decoder = config.is_decoder
+ self.add_cross_attention = config.add_cross_attention
+ if self.add_cross_attention:
+ if not self.is_decoder:
+ raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
+ self.crossattention = ErnieAttention(config, position_embedding_type="absolute")
+ self.intermediate = ErnieIntermediate(config)
+ self.output = ErnieOutput(config)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor]:
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
+ self_attention_outputs = self.attention(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ output_attentions=output_attentions,
+ past_key_value=self_attn_past_key_value,
+ )
+ attention_output = self_attention_outputs[0]
+
+ # if decoder, the last output is tuple of self-attn cache
+ if self.is_decoder:
+ outputs = self_attention_outputs[1:-1]
+ present_key_value = self_attention_outputs[-1]
+ else:
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
+
+ cross_attn_present_key_value = None
+ if self.is_decoder and encoder_hidden_states is not None:
+ if not hasattr(self, "crossattention"):
+ raise ValueError(
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
+ " by setting `config.add_cross_attention=True`"
+ )
+
+ # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
+ cross_attention_outputs = self.crossattention(
+ attention_output,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ cross_attn_past_key_value,
+ output_attentions,
+ )
+ attention_output = cross_attention_outputs[0]
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
+
+ # add cross-attn cache to positions 3,4 of present_key_value tuple
+ cross_attn_present_key_value = cross_attention_outputs[-1]
+ present_key_value = present_key_value + cross_attn_present_key_value
+
+ layer_output = apply_chunking_to_forward(
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
+ )
+ outputs = (layer_output,) + outputs
+
+ # if decoder, return the attn key/values as the last output
+ if self.is_decoder:
+ outputs = outputs + (present_key_value,)
+
+ return outputs
+
+ def feed_forward_chunk(self, attention_output):
+ intermediate_output = self.intermediate(attention_output)
+ layer_output = self.output(intermediate_output, attention_output)
+ return layer_output
+
+
+# Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->Ernie
+class ErnieEncoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.layer = nn.ModuleList([ErnieLayer(config) for _ in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = False,
+ output_hidden_states: Optional[bool] = False,
+ return_dict: Optional[bool] = True,
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ next_decoder_cache = () if use_cache else None
+ for i, layer_module in enumerate(self.layer):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+ past_key_value = past_key_values[i] if past_key_values is not None else None
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ layer_module.__call__,
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ )
+ else:
+ layer_outputs = layer_module(
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ )
+
+ hidden_states = layer_outputs[0]
+ if use_cache:
+ next_decoder_cache += (layer_outputs[-1],)
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+ if self.config.add_cross_attention:
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [
+ hidden_states,
+ next_decoder_cache,
+ all_hidden_states,
+ all_self_attentions,
+ all_cross_attentions,
+ ]
+ if v is not None
+ )
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=next_decoder_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+# Copied from transformers.models.bert.modeling_bert.BertPooler with Bert->Ernie
+class ErniePooler(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.activation = nn.Tanh()
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ # We "pool" the model by simply taking the hidden state corresponding
+ # to the first token.
+ first_token_tensor = hidden_states[:, 0]
+ pooled_output = self.dense(first_token_tensor)
+ pooled_output = self.activation(pooled_output)
+ return pooled_output
+
+
+# Copied from transformers.models.bert.modeling_bert.BertPredictionHeadTransform with Bert->Ernie
+class ErniePredictionHeadTransform(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ if isinstance(config.hidden_act, str):
+ self.transform_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.transform_act_fn = config.hidden_act
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.transform_act_fn(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead with Bert->Ernie
+class ErnieLMPredictionHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.transform = ErniePredictionHeadTransform(config)
+
+ # The output weights are the same as the input embeddings, but there is
+ # an output-only bias for each token.
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
+
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
+ self.decoder.bias = self.bias
+
+ def _tie_weights(self):
+ self.decoder.bias = self.bias
+
+ def forward(self, hidden_states):
+ hidden_states = self.transform(hidden_states)
+ hidden_states = self.decoder(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.bert.modeling_bert.BertOnlyMLMHead with Bert->Ernie
+class ErnieOnlyMLMHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.predictions = ErnieLMPredictionHead(config)
+
+ def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
+ prediction_scores = self.predictions(sequence_output)
+ return prediction_scores
+
+
+# Copied from transformers.models.bert.modeling_bert.BertOnlyNSPHead with Bert->Ernie
+class ErnieOnlyNSPHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
+
+ def forward(self, pooled_output):
+ seq_relationship_score = self.seq_relationship(pooled_output)
+ return seq_relationship_score
+
+
+# Copied from transformers.models.bert.modeling_bert.BertPreTrainingHeads with Bert->Ernie
+class ErniePreTrainingHeads(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.predictions = ErnieLMPredictionHead(config)
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
+
+ def forward(self, sequence_output, pooled_output):
+ prediction_scores = self.predictions(sequence_output)
+ seq_relationship_score = self.seq_relationship(pooled_output)
+ return prediction_scores, seq_relationship_score
+
+
+class ErniePreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = ErnieConfig
+ base_model_prefix = "ernie"
+ supports_gradient_checkpointing = True
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, nn.Linear):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+
+@dataclass
+# Copied from transformers.models.bert.modeling_bert.BertForPreTrainingOutput with Bert->Ernie
+class ErnieForPreTrainingOutput(ModelOutput):
+ """
+ Output type of [`ErnieForPreTraining`].
+
+ Args:
+ loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
+ Total loss as the sum of the masked language modeling loss and the next sequence prediction
+ (classification) loss.
+ prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ seq_relationship_logits (`torch.FloatTensor` of shape `(batch_size, 2)`):
+ Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
+ before SoftMax).
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
+ shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ prediction_logits: Optional[torch.FloatTensor] = None
+ seq_relationship_logits: Optional[torch.FloatTensor] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+
+
+ERNIE_START_DOCSTRING = r"""
+
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`ErnieConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+ERNIE_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `({0})`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
+ 1]`:
+
+ - 0 corresponds to a *sentence A* token,
+ - 1 corresponds to a *sentence B* token.
+
+ [What are token type IDs?](../glossary#token-type-ids)
+ task_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
+ Task type embedding is a special embedding to represent the characteristic of different tasks, such as
+ word-aware pre-training task, structure-aware pre-training task and semantic-aware pre-training task. We
+ assign a `task_type_id` to each task and the `task_type_id` is in the range `[0,
+ config.task_type_vocab_size-1]
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare Ernie Model transformer outputting raw hidden-states without any specific head on top.",
+ ERNIE_START_DOCSTRING,
+)
+class ErnieModel(ErniePreTrainedModel):
+ """
+
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
+ cross-attention is added between the self-attention layers, following the architecture described in [Attention is
+ all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
+
+ To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
+ to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
+ `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
+ """
+
+ # Copied from transformers.models.clap.modeling_clap.ClapTextModel.__init__ with ClapText->Ernie
+ def __init__(self, config, add_pooling_layer=True):
+ super().__init__(config)
+ self.config = config
+
+ self.embeddings = ErnieEmbeddings(config)
+ self.encoder = ErnieEncoder(config)
+
+ self.pooler = ErniePooler(config) if add_pooling_layer else None
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ # Copied from transformers.models.bert.modeling_bert.BertModel.get_input_embeddings
+ def get_input_embeddings(self):
+ return self.embeddings.word_embeddings
+
+ # Copied from transformers.models.bert.modeling_bert.BertModel.set_input_embeddings
+ def set_input_embeddings(self, value):
+ self.embeddings.word_embeddings = value
+
+ # Copied from transformers.models.bert.modeling_bert.BertModel._prune_heads
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ @add_start_docstrings_to_model_forward(ERNIE_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=BaseModelOutputWithPoolingAndCrossAttentions,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ task_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
+ r"""
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+ the model is configured as a decoder.
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if self.config.is_decoder:
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ else:
+ use_cache = False
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
+ input_shape = input_ids.size()
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ batch_size, seq_length = input_shape
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+
+ # past_key_values_length
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
+
+ if attention_mask is None:
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
+
+ if token_type_ids is None:
+ if hasattr(self.embeddings, "token_type_ids"):
+ buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
+ token_type_ids = buffered_token_type_ids_expanded
+ else:
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
+
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
+
+ # If a 2D or 3D attention mask is provided for the cross-attention
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if self.config.is_decoder and encoder_hidden_states is not None:
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
+ if encoder_attention_mask is None:
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+ else:
+ encoder_extended_attention_mask = None
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ embedding_output = self.embeddings(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ token_type_ids=token_type_ids,
+ task_type_ids=task_type_ids,
+ inputs_embeds=inputs_embeds,
+ past_key_values_length=past_key_values_length,
+ )
+ encoder_outputs = self.encoder(
+ embedding_output,
+ attention_mask=extended_attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_extended_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = encoder_outputs[0]
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
+
+ if not return_dict:
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPoolingAndCrossAttentions(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ past_key_values=encoder_outputs.past_key_values,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ cross_attentions=encoder_outputs.cross_attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ Ernie Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next
+ sentence prediction (classification)` head.
+ """,
+ ERNIE_START_DOCSTRING,
+)
+class ErnieForPreTraining(ErniePreTrainedModel):
+ _tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
+
+ # Copied from transformers.models.bert.modeling_bert.BertForPreTraining.__init__ with Bert->Ernie,bert->ernie
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.ernie = ErnieModel(config)
+ self.cls = ErniePreTrainingHeads(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ # Copied from transformers.models.bert.modeling_bert.BertForPreTraining.get_output_embeddings
+ def get_output_embeddings(self):
+ return self.cls.predictions.decoder
+
+ # Copied from transformers.models.bert.modeling_bert.BertForPreTraining.set_output_embeddings
+ def set_output_embeddings(self, new_embeddings):
+ self.cls.predictions.decoder = new_embeddings
+ self.cls.predictions.bias = new_embeddings.bias
+
+ @add_start_docstrings_to_model_forward(ERNIE_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @replace_return_docstrings(output_type=ErnieForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ task_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ next_sentence_label: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], ErnieForPreTrainingOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked),
+ the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
+ next_sentence_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the next sequence prediction (classification) loss. Input should be a sequence
+ pair (see `input_ids` docstring) Indices should be in `[0, 1]`:
+
+ - 0 indicates sequence B is a continuation of sequence A,
+ - 1 indicates sequence B is a random sequence.
+ kwargs (`Dict[str, any]`, *optional*, defaults to `{}`):
+ Used to hide legacy arguments that have been deprecated.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, ErnieForPreTraining
+ >>> import torch
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("nghuyong/ernie-1.0-base-zh")
+ >>> model = ErnieForPreTraining.from_pretrained("nghuyong/ernie-1.0-base-zh")
+
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
+ >>> outputs = model(**inputs)
+
+ >>> prediction_logits = outputs.prediction_logits
+ >>> seq_relationship_logits = outputs.seq_relationship_logits
+ ```
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.ernie(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ task_type_ids=task_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output, pooled_output = outputs[:2]
+ prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
+
+ total_loss = None
+ if labels is not None and next_sentence_label is not None:
+ loss_fct = CrossEntropyLoss()
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
+ next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
+ total_loss = masked_lm_loss + next_sentence_loss
+
+ if not return_dict:
+ output = (prediction_scores, seq_relationship_score) + outputs[2:]
+ return ((total_loss,) + output) if total_loss is not None else output
+
+ return ErnieForPreTrainingOutput(
+ loss=total_loss,
+ prediction_logits=prediction_scores,
+ seq_relationship_logits=seq_relationship_score,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """Ernie Model with a `language modeling` head on top for CLM fine-tuning.""", ERNIE_START_DOCSTRING
+)
+class ErnieForCausalLM(ErniePreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
+
+ # Copied from transformers.models.bert.modeling_bert.BertLMHeadModel.__init__ with BertLMHeadModel->ErnieForCausalLM,Bert->Ernie,bert->ernie
+ def __init__(self, config):
+ super().__init__(config)
+
+ if not config.is_decoder:
+ logger.warning("If you want to use `ErnieForCausalLM` as a standalone, add `is_decoder=True.`")
+
+ self.ernie = ErnieModel(config, add_pooling_layer=False)
+ self.cls = ErnieOnlyMLMHead(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ # Copied from transformers.models.bert.modeling_bert.BertLMHeadModel.get_output_embeddings
+ def get_output_embeddings(self):
+ return self.cls.predictions.decoder
+
+ # Copied from transformers.models.bert.modeling_bert.BertLMHeadModel.set_output_embeddings
+ def set_output_embeddings(self, new_embeddings):
+ self.cls.predictions.decoder = new_embeddings
+ self.cls.predictions.bias = new_embeddings.bias
+
+ @add_start_docstrings_to_model_forward(ERNIE_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=CausalLMOutputWithCrossAttentions,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ task_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ past_key_values: Optional[List[torch.Tensor]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ **kwargs,
+ ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
+ r"""
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+ the model is configured as a decoder.
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
+ `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
+ ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ if labels is not None:
+ use_cache = False
+
+ outputs = self.ernie(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ task_type_ids=task_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+ prediction_scores = self.cls(sequence_output)
+
+ lm_loss = None
+ if labels is not None:
+ lm_loss = self.loss_function(
+ prediction_scores,
+ labels,
+ vocab_size=self.config.vocab_size,
+ **kwargs,
+ )
+
+ if not return_dict:
+ output = (prediction_scores,) + outputs[2:]
+ return ((lm_loss,) + output) if lm_loss is not None else output
+
+ return CausalLMOutputWithCrossAttentions(
+ loss=lm_loss,
+ logits=prediction_scores,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ cross_attentions=outputs.cross_attentions,
+ )
+
+ # Copied from transformers.models.bert.modeling_bert.BertLMHeadModel._reorder_cache
+ def _reorder_cache(self, past_key_values, beam_idx):
+ reordered_past = ()
+ for layer_past in past_key_values:
+ reordered_past += (
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
+ )
+ return reordered_past
+
+
+@add_start_docstrings("""Ernie Model with a `language modeling` head on top.""", ERNIE_START_DOCSTRING)
+class ErnieForMaskedLM(ErniePreTrainedModel):
+ _tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
+
+ # Copied from transformers.models.bert.modeling_bert.BertForMaskedLM.__init__ with Bert->Ernie,bert->ernie
+ def __init__(self, config):
+ super().__init__(config)
+
+ if config.is_decoder:
+ logger.warning(
+ "If you want to use `ErnieForMaskedLM` make sure `config.is_decoder=False` for "
+ "bi-directional self-attention."
+ )
+
+ self.ernie = ErnieModel(config, add_pooling_layer=False)
+ self.cls = ErnieOnlyMLMHead(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ # Copied from transformers.models.bert.modeling_bert.BertForMaskedLM.get_output_embeddings
+ def get_output_embeddings(self):
+ return self.cls.predictions.decoder
+
+ # Copied from transformers.models.bert.modeling_bert.BertForMaskedLM.set_output_embeddings
+ def set_output_embeddings(self, new_embeddings):
+ self.cls.predictions.decoder = new_embeddings
+ self.cls.predictions.bias = new_embeddings.bias
+
+ @add_start_docstrings_to_model_forward(ERNIE_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=MaskedLMOutput,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output="'paris'",
+ expected_loss=0.88,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ task_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
+ """
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.ernie(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ task_type_ids=task_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+ prediction_scores = self.cls(sequence_output)
+
+ masked_lm_loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
+
+ if not return_dict:
+ output = (prediction_scores,) + outputs[2:]
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
+
+ return MaskedLMOutput(
+ loss=masked_lm_loss,
+ logits=prediction_scores,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ # Copied from transformers.models.bert.modeling_bert.BertForMaskedLM.prepare_inputs_for_generation
+ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
+ input_shape = input_ids.shape
+ effective_batch_size = input_shape[0]
+
+ # add a dummy token
+ if self.config.pad_token_id is None:
+ raise ValueError("The PAD token should be defined for generation")
+
+ attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
+ dummy_token = torch.full(
+ (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
+ )
+ input_ids = torch.cat([input_ids, dummy_token], dim=1)
+
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
+
+ @classmethod
+ def can_generate(cls) -> bool:
+ """
+ Legacy correction: ErnieForMaskedLM can't call `generate()` from `GenerationMixin`, even though it has a
+ `prepare_inputs_for_generation` method.
+ """
+ return False
+
+
+@add_start_docstrings(
+ """Ernie Model with a `next sentence prediction (classification)` head on top.""",
+ ERNIE_START_DOCSTRING,
+)
+class ErnieForNextSentencePrediction(ErniePreTrainedModel):
+ # Copied from transformers.models.bert.modeling_bert.BertForNextSentencePrediction.__init__ with Bert->Ernie,bert->ernie
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.ernie = ErnieModel(config)
+ self.cls = ErnieOnlyNSPHead(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(ERNIE_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ task_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ **kwargs,
+ ) -> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
+ (see `input_ids` docstring). Indices should be in `[0, 1]`:
+
+ - 0 indicates sequence B is a continuation of sequence A,
+ - 1 indicates sequence B is a random sequence.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, ErnieForNextSentencePrediction
+ >>> import torch
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("nghuyong/ernie-1.0-base-zh")
+ >>> model = ErnieForNextSentencePrediction.from_pretrained("nghuyong/ernie-1.0-base-zh")
+
+ >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
+ >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
+ >>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt")
+
+ >>> outputs = model(**encoding, labels=torch.LongTensor([1]))
+ >>> logits = outputs.logits
+ >>> assert logits[0, 0] < logits[0, 1] # next sentence was random
+ ```
+ """
+
+ if "next_sentence_label" in kwargs:
+ warnings.warn(
+ "The `next_sentence_label` argument is deprecated and will be removed in a future version, use"
+ " `labels` instead.",
+ FutureWarning,
+ )
+ labels = kwargs.pop("next_sentence_label")
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.ernie(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ task_type_ids=task_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ pooled_output = outputs[1]
+
+ seq_relationship_scores = self.cls(pooled_output)
+
+ next_sentence_loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1))
+
+ if not return_dict:
+ output = (seq_relationship_scores,) + outputs[2:]
+ return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output
+
+ return NextSentencePredictorOutput(
+ loss=next_sentence_loss,
+ logits=seq_relationship_scores,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ Ernie Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
+ output) e.g. for GLUE tasks.
+ """,
+ ERNIE_START_DOCSTRING,
+)
+class ErnieForSequenceClassification(ErniePreTrainedModel):
+ # Copied from transformers.models.bert.modeling_bert.BertForSequenceClassification.__init__ with Bert->Ernie,bert->ernie
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.config = config
+
+ self.ernie = ErnieModel(config)
+ classifier_dropout = (
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
+ )
+ self.dropout = nn.Dropout(classifier_dropout)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(ERNIE_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ task_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.ernie(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ task_type_ids=task_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ pooled_output = outputs[1]
+
+ pooled_output = self.dropout(pooled_output)
+ logits = self.classifier(pooled_output)
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ Ernie Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
+ softmax) e.g. for RocStories/SWAG tasks.
+ """,
+ ERNIE_START_DOCSTRING,
+)
+class ErnieForMultipleChoice(ErniePreTrainedModel):
+ # Copied from transformers.models.bert.modeling_bert.BertForMultipleChoice.__init__ with Bert->Ernie,bert->ernie
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.ernie = ErnieModel(config)
+ classifier_dropout = (
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
+ )
+ self.dropout = nn.Dropout(classifier_dropout)
+ self.classifier = nn.Linear(config.hidden_size, 1)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(ERNIE_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=MultipleChoiceModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ task_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
+ num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
+ `input_ids` above)
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
+
+ input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
+ token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
+ position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
+ inputs_embeds = (
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
+ if inputs_embeds is not None
+ else None
+ )
+
+ outputs = self.ernie(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ task_type_ids=task_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ pooled_output = outputs[1]
+
+ pooled_output = self.dropout(pooled_output)
+ logits = self.classifier(pooled_output)
+ reshaped_logits = logits.view(-1, num_choices)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(reshaped_logits, labels)
+
+ if not return_dict:
+ output = (reshaped_logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return MultipleChoiceModelOutput(
+ loss=loss,
+ logits=reshaped_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ Ernie Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
+ Named-Entity-Recognition (NER) tasks.
+ """,
+ ERNIE_START_DOCSTRING,
+)
+class ErnieForTokenClassification(ErniePreTrainedModel):
+ # Copied from transformers.models.bert.modeling_bert.BertForTokenClassification.__init__ with Bert->Ernie,bert->ernie
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.ernie = ErnieModel(config, add_pooling_layer=False)
+ classifier_dropout = (
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
+ )
+ self.dropout = nn.Dropout(classifier_dropout)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(ERNIE_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ task_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.ernie(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ task_type_ids=task_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ sequence_output = self.dropout(sequence_output)
+ logits = self.classifier(sequence_output)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ Ernie Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
+ layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
+ """,
+ ERNIE_START_DOCSTRING,
+)
+class ErnieForQuestionAnswering(ErniePreTrainedModel):
+ # Copied from transformers.models.bert.modeling_bert.BertForQuestionAnswering.__init__ with Bert->Ernie,bert->ernie
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.ernie = ErnieModel(config, add_pooling_layer=False)
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(ERNIE_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ task_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ start_positions: Optional[torch.Tensor] = None,
+ end_positions: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
+ r"""
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
+ are not taken into account for computing the loss.
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
+ are not taken into account for computing the loss.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.ernie(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ task_type_ids=task_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ logits = self.qa_outputs(sequence_output)
+ start_logits, end_logits = logits.split(1, dim=-1)
+ start_logits = start_logits.squeeze(-1).contiguous()
+ end_logits = end_logits.squeeze(-1).contiguous()
+
+ total_loss = None
+ if start_positions is not None and end_positions is not None:
+ # If we are on multi-GPU, split add a dimension
+ if len(start_positions.size()) > 1:
+ start_positions = start_positions.squeeze(-1)
+ if len(end_positions.size()) > 1:
+ end_positions = end_positions.squeeze(-1)
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
+ ignored_index = start_logits.size(1)
+ start_positions = start_positions.clamp(0, ignored_index)
+ end_positions = end_positions.clamp(0, ignored_index)
+
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
+ start_loss = loss_fct(start_logits, start_positions)
+ end_loss = loss_fct(end_logits, end_positions)
+ total_loss = (start_loss + end_loss) / 2
+
+ if not return_dict:
+ output = (start_logits, end_logits) + outputs[2:]
+ return ((total_loss,) + output) if total_loss is not None else output
+
+ return QuestionAnsweringModelOutput(
+ loss=total_loss,
+ start_logits=start_logits,
+ end_logits=end_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+__all__ = [
+ "ErnieForCausalLM",
+ "ErnieForMaskedLM",
+ "ErnieForMultipleChoice",
+ "ErnieForNextSentencePrediction",
+ "ErnieForPreTraining",
+ "ErnieForQuestionAnswering",
+ "ErnieForSequenceClassification",
+ "ErnieForTokenClassification",
+ "ErnieModel",
+ "ErniePreTrainedModel",
+]
diff --git a/docs/transformers/build/lib/transformers/models/esm/__init__.py b/docs/transformers/build/lib/transformers/models/esm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8eac54d6ddcbdae2b8ca3771ae5540522f6f29da
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/esm/__init__.py
@@ -0,0 +1,30 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_esm import *
+ from .modeling_esm import *
+ from .modeling_esmfold import *
+ from .modeling_tf_esm import *
+ from .tokenization_esm import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/docs/transformers/build/lib/transformers/models/esm/configuration_esm.py b/docs/transformers/build/lib/transformers/models/esm/configuration_esm.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac56bc8d783a5895496520fc62c7bada7fa5f61a
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/esm/configuration_esm.py
@@ -0,0 +1,365 @@
+# coding=utf-8
+# Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""ESM model configuration"""
+
+from dataclasses import asdict, dataclass
+from typing import Optional
+
+from ...configuration_utils import PretrainedConfig
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+# TODO Update this
+
+
+class EsmConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`ESMModel`]. It is used to instantiate a ESM model
+ according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the ESM
+ [facebook/esm-1b](https://huggingface.co/facebook/esm-1b) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*):
+ Vocabulary size of the ESM model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`ESMModel`].
+ mask_token_id (`int`, *optional*):
+ The index of the mask token in the vocabulary. This must be included in the config because of the
+ "mask-dropout" scaling trick, which will scale the inputs depending on the number of masked tokens.
+ pad_token_id (`int`, *optional*):
+ The index of the padding token in the vocabulary. This must be included in the config because certain parts
+ of the ESM code use this instead of the attention mask.
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ intermediate_size (`int`, *optional*, defaults to 3072):
+ Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for the attention probabilities.
+ max_position_embeddings (`int`, *optional*, defaults to 1026):
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
+ just in case (e.g., 512 or 1024 or 2048).
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
+ The epsilon used by the layer normalization layers.
+ position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
+ Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query", "rotary"`.
+ For positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
+ [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155).
+ For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
+ with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658).
+ is_decoder (`bool`, *optional*, defaults to `False`):
+ Whether the model is used as a decoder or not. If `False`, the model is used as an encoder.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ emb_layer_norm_before (`bool`, *optional*):
+ Whether to apply layer normalization after embeddings but before the main stem of the network.
+ token_dropout (`bool`, defaults to `False`):
+ When this is enabled, masked tokens are treated as if they had been dropped out by input dropout.
+
+ Examples:
+
+ ```python
+ >>> from transformers import EsmModel, EsmConfig
+
+ >>> # Initializing a ESM facebook/esm-1b style configuration
+ >>> configuration = EsmConfig(vocab_size=33)
+
+ >>> # Initializing a model from the configuration
+ >>> model = EsmModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "esm"
+
+ def __init__(
+ self,
+ vocab_size=None,
+ mask_token_id=None,
+ pad_token_id=None,
+ hidden_size=768,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ intermediate_size=3072,
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ max_position_embeddings=1026,
+ initializer_range=0.02,
+ layer_norm_eps=1e-12,
+ position_embedding_type="absolute",
+ use_cache=True,
+ emb_layer_norm_before=None,
+ token_dropout=False,
+ is_folding_model=False,
+ esmfold_config=None,
+ vocab_list=None,
+ **kwargs,
+ ):
+ super().__init__(pad_token_id=pad_token_id, mask_token_id=mask_token_id, **kwargs)
+
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.max_position_embeddings = max_position_embeddings
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.position_embedding_type = position_embedding_type
+ self.use_cache = use_cache
+ self.emb_layer_norm_before = emb_layer_norm_before
+ self.token_dropout = token_dropout
+ self.is_folding_model = is_folding_model
+ if is_folding_model:
+ if esmfold_config is None:
+ logger.info("No esmfold_config supplied for folding model, using default values.")
+ esmfold_config = EsmFoldConfig()
+ elif isinstance(esmfold_config, dict):
+ esmfold_config = EsmFoldConfig(**esmfold_config)
+ self.esmfold_config = esmfold_config
+ if vocab_list is None:
+ logger.warning("No vocab_list supplied for folding model, assuming the ESM-2 vocabulary!")
+ self.vocab_list = get_default_vocab_list()
+ else:
+ self.vocab_list = vocab_list
+ else:
+ self.esmfold_config = None
+ self.vocab_list = None
+ if self.esmfold_config is not None and getattr(self.esmfold_config, "use_esm_attn_map", False):
+ raise ValueError("The HuggingFace port of ESMFold does not support use_esm_attn_map at this time!")
+
+ def to_dict(self):
+ """
+ Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
+
+ Returns:
+ `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
+ """
+ output = super().to_dict()
+ if isinstance(self.esmfold_config, EsmFoldConfig):
+ output["esmfold_config"] = self.esmfold_config.to_dict()
+ return output
+
+
+@dataclass
+class EsmFoldConfig:
+ esm_type: Optional[str] = None
+ fp16_esm: bool = True
+ use_esm_attn_map: bool = False
+ esm_ablate_pairwise: bool = False
+ esm_ablate_sequence: bool = False
+ esm_input_dropout: float = 0
+
+ embed_aa: bool = True
+ bypass_lm: bool = False
+
+ lddt_head_hid_dim: int = 128
+ trunk: "TrunkConfig" = None
+
+ def __post_init__(self):
+ if self.trunk is None:
+ self.trunk = TrunkConfig()
+ elif isinstance(self.trunk, dict):
+ self.trunk = TrunkConfig(**self.trunk)
+
+ def to_dict(self):
+ """
+ Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
+
+ Returns:
+ `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
+ """
+ output = asdict(self)
+ output["trunk"] = self.trunk.to_dict()
+ return output
+
+
+@dataclass
+class TrunkConfig:
+ num_blocks: int = 48
+ sequence_state_dim: int = 1024
+ pairwise_state_dim: int = 128
+ sequence_head_width: int = 32
+ pairwise_head_width: int = 32
+ position_bins: int = 32
+ dropout: float = 0
+ layer_drop: float = 0
+ cpu_grad_checkpoint: bool = False
+ max_recycles: int = 4
+ chunk_size: Optional[int] = 128
+ structure_module: "StructureModuleConfig" = None
+
+ def __post_init__(self):
+ if self.structure_module is None:
+ self.structure_module = StructureModuleConfig()
+ elif isinstance(self.structure_module, dict):
+ self.structure_module = StructureModuleConfig(**self.structure_module)
+
+ if self.max_recycles <= 0:
+ raise ValueError(f"`max_recycles` should be positive, got {self.max_recycles}.")
+ if self.sequence_state_dim % self.sequence_state_dim != 0:
+ raise ValueError(
+ "`sequence_state_dim` should be a round multiple of `sequence_state_dim`, got"
+ f" {self.sequence_state_dim} and {self.sequence_state_dim}."
+ )
+ if self.pairwise_state_dim % self.pairwise_state_dim != 0:
+ raise ValueError(
+ "`pairwise_state_dim` should be a round multiple of `pairwise_state_dim`, got"
+ f" {self.pairwise_state_dim} and {self.pairwise_state_dim}."
+ )
+
+ sequence_num_heads = self.sequence_state_dim // self.sequence_head_width
+ pairwise_num_heads = self.pairwise_state_dim // self.pairwise_head_width
+
+ if self.sequence_state_dim != sequence_num_heads * self.sequence_head_width:
+ raise ValueError(
+ "`sequence_state_dim` should be equal to `sequence_num_heads * sequence_head_width, got"
+ f" {self.sequence_state_dim} != {sequence_num_heads} * {self.sequence_head_width}."
+ )
+ if self.pairwise_state_dim != pairwise_num_heads * self.pairwise_head_width:
+ raise ValueError(
+ "`pairwise_state_dim` should be equal to `pairwise_num_heads * pairwise_head_width, got"
+ f" {self.pairwise_state_dim} != {pairwise_num_heads} * {self.pairwise_head_width}."
+ )
+ if self.pairwise_state_dim % 2 != 0:
+ raise ValueError(f"`pairwise_state_dim` should be even, got {self.pairwise_state_dim}.")
+
+ if self.dropout >= 0.4:
+ raise ValueError(f"`dropout` should not be greater than 0.4, got {self.dropout}.")
+
+ def to_dict(self):
+ """
+ Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
+
+ Returns:
+ `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
+ """
+ output = asdict(self)
+ output["structure_module"] = self.structure_module.to_dict()
+ return output
+
+
+@dataclass
+class StructureModuleConfig:
+ """
+ Args:
+ sequence_dim:
+ Single representation channel dimension
+ pairwise_dim:
+ Pair representation channel dimension
+ ipa_dim:
+ IPA hidden channel dimension
+ resnet_dim:
+ Angle resnet (Alg. 23 lines 11-14) hidden channel dimension
+ num_heads_ipa:
+ Number of IPA heads
+ num_qk_points:
+ Number of query/key points to generate during IPA
+ num_v_points:
+ Number of value points to generate during IPA
+ dropout_rate:
+ Dropout rate used throughout the layer
+ num_blocks:
+ Number of structure module blocks
+ num_transition_layers:
+ Number of layers in the single representation transition (Alg. 23 lines 8-9)
+ num_resnet_blocks:
+ Number of blocks in the angle resnet
+ num_angles:
+ Number of angles to generate in the angle resnet
+ trans_scale_factor:
+ Scale of single representation transition hidden dimension
+ epsilon:
+ Small number used in angle resnet normalization
+ inf:
+ Large number used for attention masking
+ """
+
+ sequence_dim: int = 384
+ pairwise_dim: int = 128
+ ipa_dim: int = 16
+ resnet_dim: int = 128
+ num_heads_ipa: int = 12
+ num_qk_points: int = 4
+ num_v_points: int = 8
+ dropout_rate: float = 0.1
+ num_blocks: int = 8
+ num_transition_layers: int = 1
+ num_resnet_blocks: int = 2
+ num_angles: int = 7
+ trans_scale_factor: int = 10
+ epsilon: float = 1e-8
+ inf: float = 1e5
+
+ def to_dict(self):
+ return asdict(self)
+
+
+def get_default_vocab_list():
+ return (
+ "",
+ "",
+ "",
+ "",
+ "L",
+ "A",
+ "G",
+ "V",
+ "S",
+ "E",
+ "R",
+ "T",
+ "I",
+ "D",
+ "P",
+ "K",
+ "Q",
+ "N",
+ "F",
+ "Y",
+ "M",
+ "H",
+ "W",
+ "C",
+ "X",
+ "B",
+ "U",
+ "Z",
+ "O",
+ ".",
+ "-",
+ "",
+ "",
+ )
+
+
+__all__ = ["EsmConfig"]
diff --git a/docs/transformers/build/lib/transformers/models/esm/convert_esm.py b/docs/transformers/build/lib/transformers/models/esm/convert_esm.py
new file mode 100644
index 0000000000000000000000000000000000000000..020dd4e576639230565355d82e74ad6313f875b7
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/esm/convert_esm.py
@@ -0,0 +1,399 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Convert ESM checkpoint."""
+
+import argparse
+import pathlib
+from pathlib import Path
+from tempfile import TemporaryDirectory
+
+import esm as esm_module
+import torch
+from esm.esmfold.v1.misc import batch_encode_sequences as esmfold_encode_sequences
+from esm.esmfold.v1.pretrained import esmfold_v1
+
+from transformers.models.esm.configuration_esm import EsmConfig, EsmFoldConfig
+from transformers.models.esm.modeling_esm import (
+ EsmForMaskedLM,
+ EsmForSequenceClassification,
+ EsmIntermediate,
+ EsmLayer,
+ EsmOutput,
+ EsmSelfAttention,
+ EsmSelfOutput,
+)
+from transformers.models.esm.modeling_esmfold import EsmForProteinFolding
+from transformers.models.esm.tokenization_esm import EsmTokenizer
+from transformers.utils import logging
+
+
+logging.set_verbosity_info()
+logger = logging.get_logger(__name__)
+
+SAMPLE_DATA = [
+ (
+ "protein1",
+ "MNGTEGPNFYVPFSNATGVVRSPFEYPQYYLAEPWQFSMLAAYMFLLIVLGFPINFLTLYVTVQHKKLRTPLNYILLNLAVADLFMVLGGFTSTLYTSLHGYFVFGPTGCNLEGFFATLGGEIALWSLVVLAIERYVVVCKPMSNFRFGENHAIMGVAFTWVMALACAAPPLAGWSRYIPEGLQCSCGIDYYTLKPEVNNESFVIYMFVVHFTIPMIIIFFCYGQLVFTVKEAAAQQQESATTQKAEKEVTRMVIIMVIAFLICWVPYASVAFYIFTHQGSNFGPIFMTIPAFFAKSAAIYNPVIYIMMNKQFRNCMLTTICCGKNPLGDDEASATVSKTETSQVAPA",
+ ),
+ ("protein2", "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLA"),
+ ("protein3", "MKTVRQERLKSIRILERSKEPVSGAQLAEELSSRQVIVQDIAYLRSLGYNVATPRGYVLAGG"),
+ ("protein4", "MKTVRQERLKSIRILERSKEPVSGAQLAEELSSRQVIVQDIAYLRSLGYNVATPRGYVLA"),
+]
+
+MODEL_MAPPING = {
+ "esm1b_t33_650M_UR50S": esm_module.pretrained.esm1b_t33_650M_UR50S,
+ "esm1v_t33_650M_UR90S_1": esm_module.pretrained.esm1v_t33_650M_UR90S_1,
+ "esm1v_t33_650M_UR90S_2": esm_module.pretrained.esm1v_t33_650M_UR90S_2,
+ "esm1v_t33_650M_UR90S_3": esm_module.pretrained.esm1v_t33_650M_UR90S_3,
+ "esm1v_t33_650M_UR90S_4": esm_module.pretrained.esm1v_t33_650M_UR90S_4,
+ "esm1v_t33_650M_UR90S_5": esm_module.pretrained.esm1v_t33_650M_UR90S_5,
+ "esm2_t48_15B_UR50D": esm_module.pretrained.esm2_t48_15B_UR50D,
+ "esm2_t36_3B_UR50D": esm_module.pretrained.esm2_t36_3B_UR50D,
+ "esm2_t33_650M_UR50D": esm_module.pretrained.esm2_t33_650M_UR50D,
+ "esm2_t30_150M_UR50D": esm_module.pretrained.esm2_t30_150M_UR50D,
+ "esm2_t12_35M_UR50D": esm_module.pretrained.esm2_t12_35M_UR50D,
+ "esm2_t6_8M_UR50D": esm_module.pretrained.esm2_t6_8M_UR50D,
+ "esmfold_v1": esmfold_v1,
+}
+
+restypes = list("ARNDCQEGHILKMFPSTWYV")
+
+restypes_with_x = restypes + ["X"]
+restypes_with_extras = restypes_with_x + ["", "", "", "", ""]
+
+
+def get_esmfold_tokenizer():
+ with TemporaryDirectory() as tempdir:
+ vocab = "\n".join(restypes_with_extras)
+ vocab_file = Path(tempdir) / "vocab.txt"
+ vocab_file.write_text(vocab)
+ hf_tokenizer = EsmTokenizer(vocab_file=str(vocab_file))
+ hf_tokenizer.pad_token_id = 0 # Overlaps with 'A' but that seems to be what they want
+ return hf_tokenizer
+
+
+def transfer_and_check_weights(original_module, our_module):
+ status = our_module.load_state_dict(original_module.state_dict())
+ if status.missing_keys:
+ raise ValueError(f"Missing keys: {status.missing_keys}")
+ if status.unexpected_keys:
+ raise ValueError(f"Unexpected keys: {status.unexpected_keys}")
+
+
+def convert_esm_checkpoint_to_pytorch(
+ model: str, pytorch_dump_folder_path: str, classification_head: bool, push_to_repo: str, auth_token: str
+):
+ """
+ Copy/paste/tweak esm's weights to our BERT structure.
+ """
+ if model.startswith("esmfold"):
+ esm = MODEL_MAPPING[model]()
+ else:
+ esm, alphabet = MODEL_MAPPING[model]()
+ esm.eval() # disable dropout
+
+ if model.startswith("esmfold"):
+ embed_dim = esm.esm.embed_dim
+ num_layers = esm.esm.num_layers
+ num_attention_heads = esm.esm.attention_heads
+ intermediate_size = 4 * embed_dim
+ token_dropout = esm.esm.token_dropout
+ emb_layer_norm_before = False # This code path does not exist in ESM-2
+ position_embedding_type = "rotary"
+ is_folding_model = True
+ esmfold_config = EsmFoldConfig()
+ for key, val in esm.cfg.items():
+ if hasattr(esmfold_config, key) and key != "trunk":
+ setattr(esmfold_config, key, val)
+ for key, val in esm.cfg.trunk.items():
+ if hasattr(esmfold_config.trunk, key) and key != "structure_module":
+ setattr(esmfold_config.trunk, key, val)
+ for key, val in esm.cfg.trunk.structure_module.items():
+ if hasattr(esmfold_config.trunk.structure_module, key):
+ setattr(esmfold_config.trunk.structure_module, key, val)
+ elif hasattr(esm, "args"):
+ # Indicates an ESM-1b or ESM-1v model
+ embed_dim = esm.args.embed_dim
+ num_layers = esm.args.layers
+ num_attention_heads = esm.args.attention_heads
+ intermediate_size = esm.args.ffn_embed_dim
+ token_dropout = esm.args.token_dropout
+ emb_layer_norm_before = True if esm.emb_layer_norm_before else False
+ position_embedding_type = "absolute"
+ is_folding_model = False
+ esmfold_config = None
+ else:
+ # Indicates an ESM-2 model
+ embed_dim = esm.embed_dim
+ num_layers = esm.num_layers
+ num_attention_heads = esm.attention_heads
+ intermediate_size = 4 * embed_dim # This is hardcoded in ESM-2
+ token_dropout = esm.token_dropout
+ emb_layer_norm_before = False # This code path does not exist in ESM-2
+ position_embedding_type = "rotary"
+ is_folding_model = False
+ esmfold_config = None
+
+ if is_folding_model:
+ alphabet = esm.esm.alphabet
+ vocab_list = tuple(alphabet.all_toks)
+ mask_token_id = alphabet.mask_idx
+ pad_token_id = alphabet.padding_idx
+
+ if is_folding_model:
+ original_esm_model = esm.esm
+ else:
+ original_esm_model = esm
+
+ config = EsmConfig(
+ vocab_size=original_esm_model.embed_tokens.num_embeddings,
+ mask_token_id=mask_token_id,
+ hidden_size=embed_dim,
+ num_hidden_layers=num_layers,
+ num_attention_heads=num_attention_heads,
+ intermediate_size=intermediate_size,
+ max_position_embeddings=1026,
+ layer_norm_eps=1e-5, # PyTorch default used in fairseq
+ attention_probs_dropout_prob=0.0,
+ hidden_dropout_prob=0.0,
+ pad_token_id=pad_token_id,
+ emb_layer_norm_before=emb_layer_norm_before,
+ token_dropout=token_dropout,
+ position_embedding_type=position_embedding_type,
+ is_folding_model=is_folding_model,
+ esmfold_config=esmfold_config,
+ vocab_list=vocab_list,
+ )
+ if classification_head:
+ config.num_labels = esm.classification_heads["mnli"].out_proj.weight.shape[0]
+ print("Our ESM config:", config)
+
+ if model.startswith("esmfold"):
+ model_class = EsmForProteinFolding
+ elif classification_head:
+ model_class = EsmForSequenceClassification
+ else:
+ model_class = EsmForMaskedLM
+ model = model_class(config)
+ model.eval()
+
+ # Now let's copy all the weights.
+ # Embeddings
+ model.esm.embeddings.word_embeddings.weight = original_esm_model.embed_tokens.weight
+ if position_embedding_type == "absolute":
+ model.esm.embeddings.position_embeddings.weight = original_esm_model.embed_positions.weight
+
+ if config.emb_layer_norm_before:
+ model.esm.embeddings.layer_norm.weight = original_esm_model.emb_layer_norm_before.weight
+ model.esm.embeddings.layer_norm.bias = original_esm_model.emb_layer_norm_before.bias
+
+ model.esm.encoder.emb_layer_norm_after.weight = original_esm_model.emb_layer_norm_after.weight
+ model.esm.encoder.emb_layer_norm_after.bias = original_esm_model.emb_layer_norm_after.bias
+
+ for i in range(config.num_hidden_layers):
+ # Encoder: start of layer
+ layer: EsmLayer = model.esm.encoder.layer[i]
+ # esm_layer: TransformerSentenceEncoderLayer = original_esm_model.layers[i]
+ esm_layer = original_esm_model.layers[i]
+
+ # self attention
+ self_attn: EsmSelfAttention = layer.attention.self
+ assert (
+ esm_layer.self_attn.k_proj.weight.data.shape
+ == esm_layer.self_attn.q_proj.weight.data.shape
+ == esm_layer.self_attn.v_proj.weight.data.shape
+ == torch.Size((config.hidden_size, config.hidden_size))
+ )
+
+ self_attn.query.weight.data = esm_layer.self_attn.q_proj.weight
+ self_attn.query.bias.data = esm_layer.self_attn.q_proj.bias
+ self_attn.key.weight.data = esm_layer.self_attn.k_proj.weight
+ self_attn.key.bias.data = esm_layer.self_attn.k_proj.bias
+ self_attn.value.weight.data = esm_layer.self_attn.v_proj.weight
+ self_attn.value.bias.data = esm_layer.self_attn.v_proj.bias
+
+ if getattr(esm_layer.self_attn, "rot_emb", None) is not None:
+ # Matt: Although inv_freq is not a trainable weight, it is computed at model init and cached.
+ # During the training of ESM-2 the model was converted to float16 precision, which also converts
+ # the inv_freq tensor, and the loss of precision remains even if the model is loaded later as float32.
+ # If we recompute inv_freq without this loss of precision then we will get subtly different rotary
+ # embeddings, which are enough to cause significant discrepancies in model outputs. To avoid this,
+ # we make sure the new model copies the data from the old inv_freq.
+ self_attn.rotary_embeddings.inv_freq.data = esm_layer.self_attn.rot_emb.inv_freq
+
+ # LayerNorm changes for pre-activation
+ layer.attention.LayerNorm.weight = esm_layer.self_attn_layer_norm.weight
+ layer.attention.LayerNorm.bias = esm_layer.self_attn_layer_norm.bias
+ layer.LayerNorm.weight = esm_layer.final_layer_norm.weight
+ layer.LayerNorm.bias = esm_layer.final_layer_norm.bias
+
+ # self-attention output
+ self_output: EsmSelfOutput = layer.attention.output
+ assert self_output.dense.weight.shape == esm_layer.self_attn.out_proj.weight.shape
+ self_output.dense.weight = esm_layer.self_attn.out_proj.weight
+ self_output.dense.bias = esm_layer.self_attn.out_proj.bias
+
+ # intermediate
+ intermediate: EsmIntermediate = layer.intermediate
+ assert intermediate.dense.weight.shape == esm_layer.fc1.weight.shape
+ intermediate.dense.weight = esm_layer.fc1.weight
+ intermediate.dense.bias = esm_layer.fc1.bias
+
+ # output
+ bert_output: EsmOutput = layer.output
+ assert bert_output.dense.weight.shape == esm_layer.fc2.weight.shape
+ bert_output.dense.weight = esm_layer.fc2.weight
+ bert_output.dense.bias = esm_layer.fc2.bias
+ # end of layer
+
+ if is_folding_model:
+ model.esm_s_combine.data = esm.esm_s_combine.data
+ model.af2_to_esm.data = esm.af2_to_esm.data
+ transfer_and_check_weights(esm.embedding, model.embedding)
+ transfer_and_check_weights(esm.esm_s_mlp, model.esm_s_mlp)
+ transfer_and_check_weights(esm.trunk, model.trunk)
+ transfer_and_check_weights(esm.distogram_head, model.distogram_head)
+ transfer_and_check_weights(esm.ptm_head, model.ptm_head)
+ transfer_and_check_weights(esm.lm_head, model.lm_head)
+ transfer_and_check_weights(esm.lddt_head, model.lddt_head)
+
+ elif classification_head:
+ model.classifier.dense.weight = esm.esm.classification_heads["mnli"].dense.weight
+ model.classifier.dense.bias = esm.classification_heads["mnli"].dense.bias
+ model.classifier.out_proj.weight = esm.classification_heads["mnli"].out_proj.weight
+ model.classifier.out_proj.bias = esm.classification_heads["mnli"].out_proj.bias
+ else:
+ # LM Head
+ model.lm_head.dense.weight = esm.lm_head.dense.weight
+ model.lm_head.dense.bias = esm.lm_head.dense.bias
+ model.lm_head.layer_norm.weight = esm.lm_head.layer_norm.weight
+ model.lm_head.layer_norm.bias = esm.lm_head.layer_norm.bias
+ model.lm_head.decoder.weight = esm.lm_head.weight
+ model.lm_head.bias = esm.lm_head.bias
+
+ # Contact prediction head
+ transfer_and_check_weights(esm.contact_head, model.esm.contact_head)
+
+ # Prepare data (first 2 sequences from ESMStructuralSplitDataset superfamily / 4)
+ if is_folding_model:
+ # Folding models aren't trained on masked inputs and don't like mask tokens.
+ sample_data = SAMPLE_DATA[:2]
+ else:
+ sample_data = SAMPLE_DATA
+
+ if is_folding_model:
+ hf_tokenizer = get_esmfold_tokenizer()
+ hf_tokens = hf_tokenizer(
+ [row[1] for row in sample_data], return_tensors="pt", padding=True, add_special_tokens=False
+ )
+ esmfold_aas, esmfold_mask, _, _, _ = esmfold_encode_sequences([row[1] for row in sample_data])
+ success = torch.all(hf_tokens["input_ids"] == esmfold_aas) and torch.all(
+ hf_tokens["attention_mask"] == esmfold_mask
+ )
+ else:
+ # Let's check that we get the same results.
+ batch_converter = alphabet.get_batch_converter()
+ batch_labels, batch_strs, batch_tokens = batch_converter(sample_data)
+ # Prepare tokenizer and make sure it matches
+ with TemporaryDirectory() as tempdir:
+ vocab = "\n".join(alphabet.all_toks)
+ vocab_file = Path(tempdir) / "vocab.txt"
+ vocab_file.write_text(vocab)
+ hf_tokenizer = EsmTokenizer(vocab_file=str(vocab_file))
+
+ hf_tokens = hf_tokenizer([row[1] for row in sample_data], return_tensors="pt", padding=True)
+ success = torch.all(hf_tokens["input_ids"] == batch_tokens)
+
+ print("Do both models tokenizers output the same tokens?", "🔥" if success else "💩")
+ if not success:
+ raise Exception("Tokenization does not match!")
+
+ with torch.no_grad():
+ if is_folding_model:
+ # Let's test the model in parts
+ # ESMFold always converts the ESM stem to float16, which requires float16 ops
+ # that don't exist on CPU. Therefore, to test it we need to run it on GPU. However,
+ # ESMFold is what we in the community call a "big boy" and so we desperately avoid putting both the
+ # original and the converted model on the GPU at the same time.
+ their_output = esm.cuda().infer([row[1] for row in sample_data])
+ our_output = model.cuda()(
+ input_ids=hf_tokens["input_ids"].cuda(), attention_mask=hf_tokens["attention_mask"].cuda()
+ )
+ else:
+ our_output = model(**hf_tokens, output_hidden_states=True)
+ our_output = our_output["logits"]
+ if classification_head:
+ their_output = esm.model.classification_heads["mnli"](esm.extract_features(batch_tokens))
+ else:
+ their_output = esm(hf_tokens["input_ids"], repr_layers=list(range(999)))
+ their_output = their_output["logits"]
+
+ if is_folding_model:
+ max_absolute_diff = torch.max(torch.abs(our_output["positions"] - their_output["positions"])).item()
+ success = torch.allclose(our_output["positions"], their_output["positions"], atol=1e-5)
+ else:
+ max_absolute_diff = torch.max(torch.abs(our_output - their_output)).item()
+ success = torch.allclose(our_output, their_output, atol=1e-5)
+
+ print(f"max_absolute_diff = {max_absolute_diff}") # ~ 1e-5
+ print("Do both models output the same tensors?", "🔥" if success else "💩")
+
+ if not success:
+ raise Exception("Something went wRoNg")
+
+ if not is_folding_model:
+ # Let's check contact prediction too
+ our_output = model.predict_contacts(hf_tokens["input_ids"], hf_tokens["attention_mask"])
+ their_output = esm.predict_contacts(hf_tokens["input_ids"])
+ max_absolute_diff = torch.max(torch.abs(our_output - their_output)).item()
+ success = torch.allclose(our_output, their_output, atol=1e-5)
+
+ print("Contact prediction testing:")
+ print(f"max_absolute_diff = {max_absolute_diff}") # ~ 1e-5
+ print("Do both models output the same tensors?", "🔥" if success else "💩")
+
+ if not success:
+ raise Exception("Something went wRoNg")
+
+ pathlib.Path(pytorch_dump_folder_path).mkdir(parents=True, exist_ok=True)
+ print(f"Saving model to {pytorch_dump_folder_path}")
+ model.save_pretrained(pytorch_dump_folder_path)
+
+ del esm # Free up some memory before continuing
+
+ print(f"Saving tokenizer to {pytorch_dump_folder_path}")
+ hf_tokenizer.save_pretrained(pytorch_dump_folder_path)
+
+ if push_to_repo:
+ model.push_to_hub(repo_id=push_to_repo, token_token=auth_token)
+ hf_tokenizer.push_to_hub(repo_id=push_to_repo, token_token=auth_token)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ # Required parameters
+ parser.add_argument(
+ "--pytorch_dump_folder_path", type=str, required=True, help="Path to the output PyTorch model."
+ )
+ parser.add_argument(
+ "--classification_head", action="store_true", help="Whether to convert a final classification head."
+ )
+ parser.add_argument("--model", default=None, type=str, required=True, help="Name of model to convert.")
+ parser.add_argument("--push_to_repo", type=str, help="Repo to upload to (including username!).")
+ parser.add_argument("--auth_token", type=str, help="HuggingFace auth token.")
+ args = parser.parse_args()
+ convert_esm_checkpoint_to_pytorch(
+ args.model, args.pytorch_dump_folder_path, args.classification_head, args.push_to_repo, args.auth_token
+ )
diff --git a/docs/transformers/build/lib/transformers/models/esm/modeling_esm.py b/docs/transformers/build/lib/transformers/models/esm/modeling_esm.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f90d8d052f9bc9a3f0c5cdeb74aad823b44bf21
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/esm/modeling_esm.py
@@ -0,0 +1,1273 @@
+# coding=utf-8
+# Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch ESM model."""
+
+import math
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
+from ...modeling_outputs import (
+ BaseModelOutputWithPastAndCrossAttentions,
+ BaseModelOutputWithPoolingAndCrossAttentions,
+ MaskedLMOutput,
+ SequenceClassifierOutput,
+ TokenClassifierOutput,
+)
+from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
+from ...utils import logging
+from .configuration_esm import EsmConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "facebook/esm2_t6_8M_UR50D"
+_CONFIG_FOR_DOC = "EsmConfig"
+
+
+def rotate_half(x):
+ x1, x2 = x.chunk(2, dim=-1)
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(x, cos, sin):
+ cos = cos[:, :, : x.shape[-2], :]
+ sin = sin[:, :, : x.shape[-2], :]
+
+ return (x * cos) + (rotate_half(x) * sin)
+
+
+def gelu(x):
+ """
+ This is the gelu implementation from the original ESM repo. Using F.gelu yields subtly wrong results.
+ """
+ return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
+
+
+def symmetrize(x):
+ "Make layer symmetric in final two dimensions, used for contact prediction."
+ return x + x.transpose(-1, -2)
+
+
+def average_product_correct(x):
+ "Perform average product correct, used for contact prediction."
+ a1 = x.sum(-1, keepdims=True)
+ a2 = x.sum(-2, keepdims=True)
+ a12 = x.sum((-1, -2), keepdims=True)
+
+ avg = a1 * a2
+ avg.div_(a12) # in-place to reduce memory
+ normalized = x - avg
+ return normalized
+
+
+class RotaryEmbedding(torch.nn.Module):
+ """
+ Rotary position embeddings based on those in
+ [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation
+ matrices which depend on their relative positions.
+ """
+
+ def __init__(self, dim: int):
+ super().__init__()
+ # Generate and save the inverse frequency buffer (non trainable)
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
+ inv_freq = inv_freq
+ self.register_buffer("inv_freq", inv_freq)
+
+ self._seq_len_cached = None
+ self._cos_cached = None
+ self._sin_cached = None
+
+ def _update_cos_sin_tables(self, x, seq_dimension=2):
+ seq_len = x.shape[seq_dimension]
+
+ # Reset the tables if the sequence length has changed,
+ # or if we're on a new device (possibly due to tracing for instance)
+ if seq_len != self._seq_len_cached or self._cos_cached.device != x.device:
+ self._seq_len_cached = seq_len
+ t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq)
+ freqs = torch.outer(t, self.inv_freq)
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
+
+ self._cos_cached = emb.cos()[None, None, :, :]
+ self._sin_cached = emb.sin()[None, None, :, :]
+
+ return self._cos_cached, self._sin_cached
+
+ def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2)
+
+ return (
+ apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
+ apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
+ )
+
+
+class EsmContactPredictionHead(nn.Module):
+ """Performs symmetrization, apc, and computes a logistic regression on the output features"""
+
+ def __init__(
+ self,
+ in_features: int,
+ bias=True,
+ eos_idx: int = 2,
+ ):
+ super().__init__()
+ self.in_features = in_features
+ self.eos_idx = eos_idx
+ self.regression = nn.Linear(in_features, 1, bias)
+ self.activation = nn.Sigmoid()
+
+ def forward(self, tokens, attentions):
+ # remove eos token attentions
+ eos_mask = tokens.ne(self.eos_idx).to(attentions)
+ eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2)
+ attentions = attentions * eos_mask[:, None, None, :, :]
+ attentions = attentions[..., :-1, :-1]
+ # remove cls token attentions
+ attentions = attentions[..., 1:, 1:]
+ batch_size, layers, heads, seqlen, _ = attentions.size()
+ attentions = attentions.view(batch_size, layers * heads, seqlen, seqlen)
+
+ # features: batch x channels x tokens x tokens (symmetric)
+ attentions = attentions.to(
+ self.regression.weight.device
+ ) # attentions always float32, may need to convert to float16
+ attentions = average_product_correct(symmetrize(attentions))
+ attentions = attentions.permute(0, 2, 3, 1)
+ return self.activation(self.regression(attentions).squeeze(3))
+
+
+class EsmEmbeddings(nn.Module):
+ """
+ Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
+
+ if config.emb_layer_norm_before:
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ else:
+ self.layer_norm = None
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
+ self.register_buffer(
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
+ )
+
+ self.padding_idx = config.pad_token_id
+ self.position_embeddings = nn.Embedding(
+ config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
+ )
+ self.token_dropout = config.token_dropout
+ self.mask_token_id = config.mask_token_id
+
+ def forward(
+ self, input_ids=None, attention_mask=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
+ ):
+ if position_ids is None:
+ if input_ids is not None:
+ # Create the position ids from the input token ids. Any padded tokens remain padded.
+ position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length)
+ else:
+ position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.word_embeddings(input_ids)
+
+ # Note that if we want to support ESM-1 (not 1b!) in future then we need to support an
+ # embedding_scale factor here.
+ embeddings = inputs_embeds
+
+ # Matt: ESM has the option to handle masking in MLM in a slightly unusual way. If the token_dropout
+ # flag is False then it is handled in the same was as BERT/RoBERTa. If it is set to True, however,
+ # masked tokens are treated as if they were selected for input dropout and zeroed out.
+ # This "mask-dropout" is compensated for when masked tokens are not present, by scaling embeddings by
+ # a factor of (fraction of unmasked tokens during training) / (fraction of unmasked tokens in sample).
+ # This is analogous to the way that dropout layers scale down outputs during evaluation when not
+ # actually dropping out values (or, equivalently, scale up their un-dropped outputs in training).
+ if self.token_dropout:
+ embeddings = embeddings.masked_fill((input_ids == self.mask_token_id).unsqueeze(-1), 0.0)
+ mask_ratio_train = 0.15 * 0.8 # Hardcoded as the ratio used in all ESM model training runs
+ src_lengths = attention_mask.sum(-1)
+ mask_ratio_observed = (input_ids == self.mask_token_id).sum(-1).float() / src_lengths
+ embeddings = (embeddings * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]).to(
+ embeddings.dtype
+ )
+
+ if self.position_embedding_type == "absolute":
+ position_embeddings = self.position_embeddings(position_ids)
+ embeddings = embeddings + position_embeddings
+
+ if self.layer_norm is not None:
+ embeddings = self.layer_norm(embeddings)
+ if attention_mask is not None:
+ embeddings = (embeddings * attention_mask.unsqueeze(-1)).to(embeddings.dtype)
+ # Matt: I think this line was copied incorrectly from BERT, disabling it for now.
+ # embeddings = self.dropout(embeddings)
+ return embeddings
+
+ def create_position_ids_from_inputs_embeds(self, inputs_embeds):
+ """
+ We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
+
+ Args:
+ inputs_embeds: torch.Tensor
+
+ Returns: torch.Tensor
+ """
+ input_shape = inputs_embeds.size()[:-1]
+ sequence_length = input_shape[1]
+
+ position_ids = torch.arange(
+ self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
+ )
+ return position_ids.unsqueeze(0).expand(input_shape)
+
+
+class EsmSelfAttention(nn.Module):
+ def __init__(self, config, position_embedding_type=None):
+ super().__init__()
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+ raise ValueError(
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
+ f"heads ({config.num_attention_heads})"
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+ self.position_embedding_type = position_embedding_type or getattr(
+ config, "position_embedding_type", "absolute"
+ )
+ self.rotary_embeddings = None
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+ self.max_position_embeddings = config.max_position_embeddings
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
+ elif self.position_embedding_type == "rotary":
+ self.rotary_embeddings = RotaryEmbedding(dim=self.attention_head_size)
+
+ self.is_decoder = config.is_decoder
+
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
+ x = x.view(new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor]:
+ mixed_query_layer = self.query(hidden_states)
+
+ # If this is instantiated as a cross-attention module, the keys
+ # and values come from an encoder; the attention mask needs to be
+ # such that the encoder's padding tokens are not attended to.
+ is_cross_attention = encoder_hidden_states is not None
+
+ if is_cross_attention and past_key_value is not None:
+ # reuse k,v, cross_attentions
+ key_layer = past_key_value[0]
+ value_layer = past_key_value[1]
+ attention_mask = encoder_attention_mask
+ elif is_cross_attention:
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
+ attention_mask = encoder_attention_mask
+ elif past_key_value is not None:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
+ else:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+
+ # Matt: Our BERT model (which this code was derived from) scales attention logits down by sqrt(head_dim).
+ # ESM scales the query down by the same factor instead. Modulo numerical stability these are equivalent,
+ # but not when rotary embeddings get involved. Therefore, we scale the query here to match the original
+ # ESM code and fix rotary embeddings.
+ query_layer = query_layer * self.attention_head_size**-0.5
+
+ if self.is_decoder:
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_layer, value_layer)
+
+ if self.position_embedding_type == "rotary":
+ query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+ seq_length = hidden_states.size()[1]
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
+ distance = position_ids_l - position_ids_r
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
+
+ if self.position_embedding_type == "relative_key":
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+ attention_scores = attention_scores + relative_position_scores
+ elif self.position_embedding_type == "relative_key_query":
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
+
+ if attention_mask is not None:
+ # Apply the attention mask is (precomputed for all layers in EsmModel forward() function)
+ attention_scores = attention_scores + attention_mask
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs = attention_probs * head_mask
+
+ context_layer = torch.matmul(attention_probs.to(value_layer.dtype), value_layer)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(new_context_layer_shape)
+
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+ if self.is_decoder:
+ outputs = outputs + (past_key_value,)
+ return outputs
+
+
+class EsmSelfOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states, input_tensor):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = hidden_states + input_tensor
+ return hidden_states
+
+
+class EsmAttention(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.self = EsmSelfAttention(config)
+ self.output = EsmSelfOutput(config)
+ self.pruned_heads = set()
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
+ )
+
+ # Prune linear layers
+ self.self.query = prune_linear_layer(self.self.query, index)
+ self.self.key = prune_linear_layer(self.self.key, index)
+ self.self.value = prune_linear_layer(self.self.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=None,
+ output_attentions=False,
+ ):
+ hidden_states_ln = self.LayerNorm(hidden_states)
+ self_outputs = self.self(
+ hidden_states_ln,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ )
+ attention_output = self.output(self_outputs[0], hidden_states)
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+ return outputs
+
+
+class EsmIntermediate(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = gelu(hidden_states)
+ return hidden_states
+
+
+class EsmOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states, input_tensor):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = hidden_states + input_tensor
+ return hidden_states
+
+
+class EsmLayer(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.seq_len_dim = 1
+ self.attention = EsmAttention(config)
+ self.is_decoder = config.is_decoder
+ self.add_cross_attention = config.add_cross_attention
+ if self.add_cross_attention:
+ if not self.is_decoder:
+ raise RuntimeError(f"{self} should be used as a decoder model if cross attention is added")
+ self.crossattention = EsmAttention(config)
+ self.intermediate = EsmIntermediate(config)
+ self.output = EsmOutput(config)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=None,
+ output_attentions=False,
+ ):
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
+ self_attention_outputs = self.attention(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ output_attentions=output_attentions,
+ past_key_value=self_attn_past_key_value,
+ )
+ attention_output = self_attention_outputs[0]
+
+ # if decoder, the last output is tuple of self-attn cache
+ if self.is_decoder:
+ outputs = self_attention_outputs[1:-1]
+ present_key_value = self_attention_outputs[-1]
+ else:
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
+
+ cross_attn_present_key_value = None
+ if self.is_decoder and encoder_hidden_states is not None:
+ if not hasattr(self, "crossattention"):
+ raise AttributeError(
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated"
+ " with cross-attention layers by setting `config.add_cross_attention=True`"
+ )
+
+ # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
+ cross_attention_outputs = self.crossattention(
+ attention_output,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ cross_attn_past_key_value,
+ output_attentions,
+ )
+ attention_output = cross_attention_outputs[0]
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
+
+ # add cross-attn cache to positions 3,4 of present_key_value tuple
+ cross_attn_present_key_value = cross_attention_outputs[-1]
+ present_key_value = present_key_value + cross_attn_present_key_value
+
+ layer_output = self.feed_forward_chunk(attention_output)
+
+ outputs = (layer_output,) + outputs
+
+ # if decoder, return the attn key/values as the last output
+ if self.is_decoder:
+ outputs = outputs + (present_key_value,)
+ return outputs
+
+ def feed_forward_chunk(self, attention_output):
+ attention_output_ln = self.LayerNorm(attention_output)
+ intermediate_output = self.intermediate(attention_output_ln)
+ layer_output = self.output(intermediate_output, attention_output)
+ return layer_output
+
+
+class EsmEncoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.layer = nn.ModuleList([EsmLayer(config) for _ in range(config.num_hidden_layers)])
+ self.emb_layer_norm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_values=None,
+ use_cache=None,
+ output_attentions=False,
+ output_hidden_states=False,
+ return_dict=True,
+ ):
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
+ "`use_cache=False`..."
+ )
+ use_cache = False
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
+
+ next_decoder_cache = () if use_cache else None
+ for i, layer_module in enumerate(self.layer):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+ past_key_value = past_key_values[i] if past_key_values is not None else None
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ layer_module.__call__,
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ )
+ else:
+ layer_outputs = layer_module(
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ )
+
+ hidden_states = layer_outputs[0]
+ if use_cache:
+ next_decoder_cache = next_decoder_cache + (layer_outputs[-1],)
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+ if self.config.add_cross_attention:
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
+
+ if self.emb_layer_norm_after:
+ hidden_states = self.emb_layer_norm_after(hidden_states)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [
+ hidden_states,
+ next_decoder_cache,
+ all_hidden_states,
+ all_self_attentions,
+ all_cross_attentions,
+ ]
+ if v is not None
+ )
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=next_decoder_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+# Copied from transformers.models.bert.modeling_bert.BertPooler
+class EsmPooler(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.activation = nn.Tanh()
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ # We "pool" the model by simply taking the hidden state corresponding
+ # to the first token.
+ first_token_tensor = hidden_states[:, 0]
+ pooled_output = self.dense(first_token_tensor)
+ pooled_output = self.activation(pooled_output)
+ return pooled_output
+
+
+class EsmPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = EsmConfig
+ base_model_prefix = "esm"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["EsmLayer", "EsmFoldTriangularSelfAttentionBlock", "EsmEmbeddings"]
+
+ # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights with BertLMPredictionHead->EsmLMHead
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, nn.Linear):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, EsmLMHead):
+ module.bias.data.zero_()
+
+
+ESM_START_DOCSTRING = r"""
+
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`EsmConfig`]): Model configuration class with all the parameters of the
+ model. Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+ESM_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `({0})`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare ESM Model transformer outputting raw hidden-states without any specific head on top.",
+ ESM_START_DOCSTRING,
+)
+class EsmModel(EsmPreTrainedModel):
+ """
+
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
+ cross-attention is added between the self-attention layers, following the architecture described in [Attention is
+ all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
+
+ To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
+ to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
+ `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
+ """
+
+ def __init__(self, config, add_pooling_layer=True):
+ super().__init__(config)
+ self.config = config
+
+ self.embeddings = EsmEmbeddings(config)
+ self.encoder = EsmEncoder(config)
+
+ self.pooler = EsmPooler(config) if add_pooling_layer else None
+
+ self.contact_head = EsmContactPredictionHead(
+ in_features=config.num_hidden_layers * config.num_attention_heads, bias=True
+ )
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embeddings.word_embeddings
+
+ def set_input_embeddings(self, value):
+ self.embeddings.word_embeddings = value
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ @add_start_docstrings_to_model_forward(ESM_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=BaseModelOutputWithPoolingAndCrossAttentions,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
+ r"""
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+ the model is configured as a decoder.
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if self.config.is_decoder:
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ else:
+ use_cache = False
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
+ input_shape = input_ids.size()
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ batch_size, seq_length = input_shape
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+
+ # past_key_values_length
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
+
+ if attention_mask is None:
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
+
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
+
+ # If a 2D or 3D attention mask is provided for the cross-attention
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if self.config.is_decoder and encoder_hidden_states is not None:
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
+ if encoder_attention_mask is None:
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+ else:
+ encoder_extended_attention_mask = None
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ embedding_output = self.embeddings(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ past_key_values_length=past_key_values_length,
+ )
+ encoder_outputs = self.encoder(
+ embedding_output,
+ attention_mask=extended_attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_extended_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = encoder_outputs[0]
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
+
+ if not return_dict:
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPoolingAndCrossAttentions(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ past_key_values=encoder_outputs.past_key_values,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ cross_attentions=encoder_outputs.cross_attentions,
+ )
+
+ def predict_contacts(self, tokens, attention_mask):
+ attns = self(tokens, attention_mask=attention_mask, return_dict=True, output_attentions=True).attentions
+ attns = torch.stack(attns, dim=1) # Matches the original model layout
+ # In the original model, attentions for padding tokens are completely zeroed out.
+ # This makes no difference most of the time because the other tokens won't attend to them,
+ # but it does for the contact prediction task, which takes attentions as input,
+ # so we have to mimic that here.
+ attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(3)
+ attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(4)
+ return self.contact_head(tokens, attns)
+
+
+@add_start_docstrings("""ESM Model with a `language modeling` head on top.""", ESM_START_DOCSTRING)
+class EsmForMaskedLM(EsmPreTrainedModel):
+ _tied_weights_keys = ["lm_head.decoder.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ if config.is_decoder:
+ logger.warning(
+ "If you want to use `EsmForMaskedLM` make sure `config.is_decoder=False` for "
+ "bi-directional self-attention."
+ )
+
+ self.esm = EsmModel(config, add_pooling_layer=False)
+ self.lm_head = EsmLMHead(config)
+
+ self.init_weights()
+
+ def get_output_embeddings(self):
+ return self.lm_head.decoder
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head.decoder = new_embeddings
+
+ @add_start_docstrings_to_model_forward(ESM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=MaskedLMOutput,
+ config_class=_CONFIG_FOR_DOC,
+ mask="",
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, MaskedLMOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
+ kwargs (`Dict[str, any]`, *optional*, defaults to `{}`):
+ Used to hide legacy arguments that have been deprecated.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.esm(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = outputs[0]
+ prediction_scores = self.lm_head(sequence_output)
+
+ masked_lm_loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+
+ labels = labels.to(prediction_scores.device)
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
+
+ if not return_dict:
+ output = (prediction_scores,) + outputs[2:]
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
+
+ return MaskedLMOutput(
+ loss=masked_lm_loss,
+ logits=prediction_scores,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def predict_contacts(self, tokens, attention_mask):
+ return self.esm.predict_contacts(tokens, attention_mask=attention_mask)
+
+
+class EsmLMHead(nn.Module):
+ """ESM Head for masked language modeling."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
+
+ def forward(self, features, **kwargs):
+ x = self.dense(features)
+ x = gelu(x)
+ x = self.layer_norm(x)
+
+ # project back to size of vocabulary with bias
+ x = self.decoder(x) + self.bias
+ return x
+
+
+@add_start_docstrings(
+ """
+ ESM Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
+ output) e.g. for GLUE tasks.
+ """,
+ ESM_START_DOCSTRING,
+)
+class EsmForSequenceClassification(EsmPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.config = config
+
+ self.esm = EsmModel(config, add_pooling_layer=False)
+ self.classifier = EsmClassificationHead(config)
+
+ self.init_weights()
+
+ @add_start_docstrings_to_model_forward(ESM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=SequenceClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, SequenceClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.esm(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = outputs[0]
+ logits = self.classifier(sequence_output)
+
+ loss = None
+ if labels is not None:
+ labels = labels.to(logits.device)
+
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ ESM Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
+ Named-Entity-Recognition (NER) tasks.
+ """,
+ ESM_START_DOCSTRING,
+)
+class EsmForTokenClassification(EsmPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.esm = EsmModel(config, add_pooling_layer=False)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+
+ self.init_weights()
+
+ @add_start_docstrings_to_model_forward(ESM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TokenClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, TokenClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.esm(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ sequence_output = self.dropout(sequence_output)
+ logits = self.classifier(sequence_output)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+
+ labels = labels.to(logits.device)
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+class EsmClassificationHead(nn.Module):
+ """Head for sentence-level classification tasks."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
+
+ def forward(self, features, **kwargs):
+ x = features[:, 0, :] # take token (equiv. to [CLS])
+ x = self.dropout(x)
+ x = self.dense(x)
+ x = torch.tanh(x)
+ x = self.dropout(x)
+ x = self.out_proj(x)
+ return x
+
+
+def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
+ """
+ Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
+ are ignored. This is modified from fairseq's `utils.make_positions`.
+
+ Args:
+ x: torch.Tensor x:
+
+ Returns: torch.Tensor
+ """
+ # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
+ mask = input_ids.ne(padding_idx).int()
+ incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
+ return incremental_indices.long() + padding_idx
+
+
+__all__ = [
+ "EsmForMaskedLM",
+ "EsmForSequenceClassification",
+ "EsmForTokenClassification",
+ "EsmModel",
+ "EsmPreTrainedModel",
+]
diff --git a/docs/transformers/build/lib/transformers/models/esm/modeling_esmfold.py b/docs/transformers/build/lib/transformers/models/esm/modeling_esmfold.py
new file mode 100644
index 0000000000000000000000000000000000000000..645c9d16a5c5dcd012cf3d91ca784267ede9c654
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/esm/modeling_esmfold.py
@@ -0,0 +1,2325 @@
+# coding=utf-8
+# Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import math
+import sys
+from dataclasses import dataclass
+from functools import partial
+from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+from torch.nn import LayerNorm
+
+from ...integrations.deepspeed import is_deepspeed_available
+from ...modeling_outputs import ModelOutput
+from ...utils import (
+ ContextManagers,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ is_scipy_available,
+ logging,
+ replace_return_docstrings,
+)
+from .configuration_esm import EsmConfig
+from .modeling_esm import ESM_START_DOCSTRING, EsmModel, EsmPreTrainedModel
+from .openfold_utils import (
+ OFProtein,
+ Rigid,
+ Rotation,
+ atom14_to_atom37,
+ chunk_layer,
+ compute_predicted_aligned_error,
+ compute_tm,
+ frames_and_literature_positions_to_atom14_pos,
+ make_atom14_masks,
+ residue_constants,
+ to_pdb,
+ torsion_angles_to_frames,
+)
+
+
+logger = logging.get_logger(__name__)
+_CHECKPOINT_FOR_DOC = "facebook/esmfold_v1"
+_CONFIG_FOR_DOC = "EsmConfig"
+
+
+@dataclass
+class EsmForProteinFoldingOutput(ModelOutput):
+ """
+ Output type of [`EsmForProteinFoldingOutput`].
+
+ Args:
+ frames (`torch.FloatTensor`):
+ Output frames.
+ sidechain_frames (`torch.FloatTensor`):
+ Output sidechain frames.
+ unnormalized_angles (`torch.FloatTensor`):
+ Predicted unnormalized backbone and side chain torsion angles.
+ angles (`torch.FloatTensor`):
+ Predicted backbone and side chain torsion angles.
+ positions (`torch.FloatTensor`):
+ Predicted positions of the backbone and side chain atoms.
+ states (`torch.FloatTensor`):
+ Hidden states from the protein folding trunk.
+ s_s (`torch.FloatTensor`):
+ Per-residue embeddings derived by concatenating the hidden states of each layer of the ESM-2 LM stem.
+ s_z (`torch.FloatTensor`):
+ Pairwise residue embeddings.
+ distogram_logits (`torch.FloatTensor`):
+ Input logits to the distogram used to compute residue distances.
+ lm_logits (`torch.FloatTensor`):
+ Logits output by the ESM-2 protein language model stem.
+ aatype (`torch.FloatTensor`):
+ Input amino acids (AlphaFold2 indices).
+ atom14_atom_exists (`torch.FloatTensor`):
+ Whether each atom exists in the atom14 representation.
+ residx_atom14_to_atom37 (`torch.FloatTensor`):
+ Mapping between atoms in the atom14 and atom37 representations.
+ residx_atom37_to_atom14 (`torch.FloatTensor`):
+ Mapping between atoms in the atom37 and atom14 representations.
+ atom37_atom_exists (`torch.FloatTensor`):
+ Whether each atom exists in the atom37 representation.
+ residue_index (`torch.FloatTensor`):
+ The index of each residue in the protein chain. Unless internal padding tokens are used, this will just be
+ a sequence of integers from 0 to `sequence_length`.
+ lddt_head (`torch.FloatTensor`):
+ Raw outputs from the lddt head used to compute plddt.
+ plddt (`torch.FloatTensor`):
+ Per-residue confidence scores. Regions of low confidence may indicate areas where the model's prediction is
+ uncertain, or where the protein structure is disordered.
+ ptm_logits (`torch.FloatTensor`):
+ Raw logits used for computing ptm.
+ ptm (`torch.FloatTensor`):
+ TM-score output representing the model's high-level confidence in the overall structure.
+ aligned_confidence_probs (`torch.FloatTensor`):
+ Per-residue confidence scores for the aligned structure.
+ predicted_aligned_error (`torch.FloatTensor`):
+ Predicted error between the model's prediction and the ground truth.
+ max_predicted_aligned_error (`torch.FloatTensor`):
+ Per-sample maximum predicted error.
+ """
+
+ frames: Optional[torch.FloatTensor] = None
+ sidechain_frames: Optional[torch.FloatTensor] = None
+ unnormalized_angles: Optional[torch.FloatTensor] = None
+ angles: Optional[torch.FloatTensor] = None
+ positions: Optional[torch.FloatTensor] = None
+ states: Optional[torch.FloatTensor] = None
+ s_s: Optional[torch.FloatTensor] = None
+ s_z: Optional[torch.FloatTensor] = None
+ distogram_logits: Optional[torch.FloatTensor] = None
+ lm_logits: Optional[torch.FloatTensor] = None
+ aatype: Optional[torch.FloatTensor] = None
+ atom14_atom_exists: Optional[torch.FloatTensor] = None
+ residx_atom14_to_atom37: Optional[torch.FloatTensor] = None
+ residx_atom37_to_atom14: Optional[torch.FloatTensor] = None
+ atom37_atom_exists: Optional[torch.FloatTensor] = None
+ residue_index: Optional[torch.FloatTensor] = None
+ lddt_head: Optional[torch.FloatTensor] = None
+ plddt: Optional[torch.FloatTensor] = None
+ ptm_logits: Optional[torch.FloatTensor] = None
+ ptm: Optional[torch.FloatTensor] = None
+ aligned_confidence_probs: Optional[torch.FloatTensor] = None
+ predicted_aligned_error: Optional[torch.FloatTensor] = None
+ max_predicted_aligned_error: Optional[torch.FloatTensor] = None
+
+
+ESMFOLD_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `({0})`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ masking_pattern (`torch.LongTensor` of shape `({0})`, *optional*):
+ Locations of tokens to mask during training as a form of regularization. Mask values selected in `[0, 1]`.
+ num_recycles (`int`, *optional*, defaults to `None`):
+ Number of times to recycle the input sequence. If `None`, defaults to `config.num_recycles`. "Recycling"
+ consists of passing the output of the folding trunk back in as input to the trunk. During training, the
+ number of recycles should vary with each batch, to ensure that the model learns to output valid predictions
+ after each recycle. During inference, num_recycles should be set to the highest value that the model was
+ trained with for maximum accuracy. Accordingly, when this value is set to `None`, config.max_recycles is
+ used.
+"""
+
+
+def is_fp16_enabled():
+ # Autocast world
+ fp16_enabled = torch.get_autocast_gpu_dtype() == torch.float16
+ fp16_enabled = fp16_enabled and torch.is_autocast_enabled()
+
+ return fp16_enabled
+
+
+def is_deepspeed_initialized():
+ if is_deepspeed_available():
+ return False
+ else:
+ try:
+ import deepspeed
+
+ # This is not available in all DeepSpeed versions.
+ return deepspeed.utils.is_initialized()
+ except Exception:
+ return False
+
+
+def collate_dense_tensors(samples: List[torch.Tensor], pad_v: float = 0) -> torch.Tensor:
+ """
+ Takes a list of tensors with the following dimensions:
+ [(d_11, ..., d_1K),
+ (d_21, ..., d_2K), ..., (d_N1, ..., d_NK)]
+ and stack + pads them into a single tensor of:
+ (N, max_i=1,N { d_i1 }, ..., max_i=1,N {diK})
+ """
+ if len(samples) == 0:
+ return torch.Tensor()
+ if len({x.dim() for x in samples}) != 1:
+ raise RuntimeError(f"Samples has varying dimensions: {[x.dim() for x in samples]}")
+ (device,) = tuple({x.device for x in samples}) # assumes all on same device
+ max_shape = [max(lst) for lst in zip(*[x.shape for x in samples])]
+ result = torch.empty(len(samples), *max_shape, dtype=samples[0].dtype, device=device)
+ result.fill_(pad_v)
+ for i in range(len(samples)):
+ result_i = result[i]
+ t = samples[i]
+ result_i[tuple(slice(0, k) for k in t.shape)] = t
+ return result
+
+
+def flatten_final_dims(t: torch.Tensor, no_dims: int):
+ return t.reshape(t.shape[:-no_dims] + (-1,))
+
+
+def permute_final_dims(tensor: torch.Tensor, inds: List[int]):
+ zero_index = -1 * len(inds)
+ first_inds = list(range(len(tensor.shape[:zero_index])))
+ return tensor.permute(first_inds + [zero_index + i for i in inds])
+
+
+def dict_multimap(fn, dicts):
+ first = dicts[0]
+ new_dict = {}
+ for k, v in first.items():
+ all_v = [d[k] for d in dicts]
+ if isinstance(v, dict):
+ new_dict[k] = dict_multimap(fn, all_v)
+ else:
+ new_dict[k] = fn(all_v)
+
+ return new_dict
+
+
+def trunc_normal_init_(weights, scale=1.0, fan="fan_in"):
+ shape = weights.shape
+ scale = scale / max(1, shape[1])
+
+ if not is_scipy_available():
+ logger.warning(
+ "This init requires scipy, but scipy was not found, default to an approximation that might not be"
+ " equivalent."
+ )
+ std = math.sqrt(scale)
+ torch.nn.init.normal_(weights, std=std).clamp(min=0.0, max=2.0 * std)
+
+ else:
+ from scipy.stats import truncnorm
+
+ std = math.sqrt(scale) / truncnorm.std(a=-2, b=2, loc=0, scale=1)
+ samples = truncnorm.rvs(a=-2, b=2, loc=0, scale=std, size=weights.numel())
+ samples = np.reshape(samples, shape)
+ weights.copy_(torch.tensor(samples, device=weights.device))
+
+
+def ipa_point_weights_init_(weights):
+ with torch.no_grad():
+ softplus_inverse_1 = 0.541324854612918
+ weights.fill_(softplus_inverse_1)
+
+
+class EsmFoldLinear(nn.Linear):
+ """
+ A Linear layer with built-in nonstandard initializations. Called just like torch.nn.Linear.
+
+ Implements the initializers in 1.11.4, plus some additional ones found in the code.
+ """
+
+ def __init__(
+ self,
+ in_dim: int,
+ out_dim: int,
+ bias: bool = True,
+ init: str = "default",
+ init_fn: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None,
+ ):
+ """
+ Args:
+ in_dim:
+ The final dimension of inputs to the layer
+ out_dim:
+ The final dimension of layer outputs
+ bias:
+ Whether to learn an additive bias. True by default
+ init:
+ The initializer to use. Choose from:
+
+ "default": LeCun fan-in truncated normal initialization "relu": He initialization w/ truncated normal
+ distribution "glorot": Fan-average Glorot uniform initialization "gating": Weights=0, Bias=1 "normal":
+ Normal initialization with std=1/sqrt(fan_in) "final": Weights=0, Bias=0
+
+ Overridden by init_fn if the latter is not None.
+ init_fn:
+ A custom initializer taking weight and bias as inputs. Overrides init if not None.
+ """
+ super().__init__(in_dim, out_dim, bias=bias)
+
+ if bias:
+ with torch.no_grad():
+ self.bias.fill_(0)
+ self.init = init
+ self.init_fn = init_fn
+
+ if init not in ["default", "relu", "glorot", "gating", "normal", "final"]:
+ raise ValueError("Invalid init string.")
+
+
+class EsmFoldLayerNorm(nn.Module):
+ def __init__(self, c_in, eps=1e-5):
+ super().__init__()
+
+ self.c_in = (c_in,)
+ self.eps = eps
+
+ self.weight = nn.Parameter(torch.ones(c_in))
+ self.bias = nn.Parameter(torch.zeros(c_in))
+
+ def forward(self, x):
+ d = x.dtype
+ if d is torch.bfloat16 and not is_deepspeed_initialized():
+ with torch.cuda.amp.autocast(enabled=False):
+ out = nn.functional.layer_norm(x, self.c_in, self.weight.to(dtype=d), self.bias.to(dtype=d), self.eps)
+ else:
+ out = nn.functional.layer_norm(x, self.c_in, self.weight, self.bias, self.eps)
+
+ return out
+
+
+@torch.jit.ignore
+def softmax_no_cast(t: torch.Tensor, dim: int = -1) -> torch.Tensor:
+ """
+ Softmax, but without automatic casting to fp32 when the input is of type bfloat16
+ """
+ d = t.dtype
+ if d is torch.bfloat16 and not is_deepspeed_initialized():
+ with torch.cuda.amp.autocast(enabled=False):
+ s = torch.nn.functional.softmax(t, dim=dim)
+ else:
+ s = torch.nn.functional.softmax(t, dim=dim)
+
+ return s
+
+
+class EsmFoldAttention(nn.Module):
+ """
+ Standard multi-head attention using AlphaFold's default layer initialization. Allows multiple bias vectors.
+ """
+
+ def __init__(
+ self,
+ c_q: int,
+ c_k: int,
+ c_v: int,
+ c_hidden: int,
+ no_heads: int,
+ gating: bool = True,
+ ):
+ """
+ Args:
+ c_q:
+ Input dimension of query data
+ c_k:
+ Input dimension of key data
+ c_v:
+ Input dimension of value data
+ c_hidden:
+ Per-head hidden dimension
+ no_heads:
+ Number of attention heads
+ gating:
+ Whether the output should be gated using query data
+ """
+ super().__init__()
+
+ self.c_q = c_q
+ self.c_k = c_k
+ self.c_v = c_v
+ self.c_hidden = c_hidden
+ self.no_heads = no_heads
+ self.gating = gating
+
+ # DISCREPANCY: c_hidden is not the per-head channel dimension, as
+ # stated in the supplement, but the overall channel dimension.
+
+ self.linear_q = EsmFoldLinear(self.c_q, self.c_hidden * self.no_heads, bias=False, init="glorot")
+ self.linear_k = EsmFoldLinear(self.c_k, self.c_hidden * self.no_heads, bias=False, init="glorot")
+ self.linear_v = EsmFoldLinear(self.c_v, self.c_hidden * self.no_heads, bias=False, init="glorot")
+ self.linear_o = EsmFoldLinear(self.c_hidden * self.no_heads, self.c_q, init="final")
+
+ self.linear_g = None
+ if self.gating:
+ self.linear_g = EsmFoldLinear(self.c_q, self.c_hidden * self.no_heads, init="gating")
+
+ self.sigmoid = nn.Sigmoid()
+
+ def _prep_qkv(self, q_x: torch.Tensor, kv_x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ # [*, Q/K/V, H * C_hidden]
+ q = self.linear_q(q_x)
+ k = self.linear_k(kv_x)
+ v = self.linear_v(kv_x)
+
+ # [*, Q/K, H, C_hidden]
+ q = q.view(q.shape[:-1] + (self.no_heads, -1))
+ k = k.view(k.shape[:-1] + (self.no_heads, -1))
+ v = v.view(v.shape[:-1] + (self.no_heads, -1))
+
+ # [*, H, Q/K, C_hidden]
+ q = q.transpose(-2, -3)
+ k = k.transpose(-2, -3)
+ v = v.transpose(-2, -3)
+
+ q /= math.sqrt(self.c_hidden)
+
+ return q, k, v
+
+ def _wrap_up(self, o: torch.Tensor, q_x: torch.Tensor) -> torch.Tensor:
+ if self.linear_g is not None:
+ g = self.sigmoid(self.linear_g(q_x))
+
+ # [*, Q, H, C_hidden]
+ g = g.view(g.shape[:-1] + (self.no_heads, -1))
+ o = o * g
+
+ # [*, Q, H * C_hidden]
+ o = flatten_final_dims(o, 2)
+
+ # [*, Q, C_q]
+ o = self.linear_o(o)
+
+ return o
+
+ def forward(
+ self,
+ q_x: torch.Tensor,
+ kv_x: torch.Tensor,
+ biases: Optional[List[torch.Tensor]] = None,
+ use_memory_efficient_kernel: bool = False,
+ use_lma: bool = False,
+ lma_q_chunk_size: int = 1024,
+ lma_kv_chunk_size: int = 4096,
+ use_flash: bool = False,
+ flash_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ """
+ Args:
+ q_x:
+ [*, Q, C_q] query data
+ kv_x:
+ [*, K, C_k] key data
+ biases:
+ List of biases that broadcast to [*, H, Q, K]
+ use_memory_efficient_kernel:
+ Whether to use a custom memory-efficient attention kernel. This should be the default choice for most.
+ If none of the "use_<...>" flags are True, a stock PyTorch implementation is used instead
+ use_lma:
+ Whether to use low-memory attention (Staats & Rabe 2021). If none of the "use_<...>" flags are True, a
+ stock PyTorch implementation is used instead
+ lma_q_chunk_size:
+ Query chunk size (for LMA)
+ lma_kv_chunk_size:
+ Key/Value chunk size (for LMA)
+ Returns
+ [*, Q, C_q] attention update
+ """
+ if use_lma and (lma_q_chunk_size is None or lma_kv_chunk_size is None):
+ raise ValueError("If use_lma is specified, lma_q_chunk_size and lma_kv_chunk_size must be provided")
+
+ if use_flash and biases is not None:
+ raise ValueError("use_flash is incompatible with the bias option. For masking, use flash_mask instead")
+
+ attn_options = [use_memory_efficient_kernel, use_lma, use_flash]
+ if sum(attn_options) > 1:
+ raise ValueError("Choose at most one alternative attention algorithm")
+
+ if biases is None:
+ biases = []
+
+ # [*, H, Q/K, C_hidden]
+ query, key, value = self._prep_qkv(q_x, kv_x)
+ key = permute_final_dims(key, (1, 0))
+
+ # [*, H, Q, K]
+ output = torch.matmul(query, key)
+ for b in biases:
+ output += b
+ output = softmax_no_cast(output, -1)
+
+ # [*, H, Q, C_hidden]
+ output = torch.matmul(output, value)
+ output = output.transpose(-2, -3)
+ output = self._wrap_up(output, q_x)
+
+ return output
+
+
+class EsmFoldTriangleAttention(nn.Module):
+ def __init__(self, c_in, c_hidden, no_heads, starting=True, inf=1e9):
+ """
+ Args:
+ c_in:
+ Input channel dimension
+ c_hidden:
+ Overall hidden channel dimension (not per-head)
+ no_heads:
+ Number of attention heads
+ """
+ super().__init__()
+
+ self.c_in = c_in
+ self.c_hidden = c_hidden
+ self.no_heads = no_heads
+ self.starting = starting
+ self.inf = inf
+
+ self.layer_norm = LayerNorm(self.c_in)
+
+ self.linear = EsmFoldLinear(c_in, self.no_heads, bias=False, init="normal")
+
+ self.mha = EsmFoldAttention(self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads)
+
+ @torch.jit.ignore
+ def _chunk(
+ self,
+ x: torch.Tensor,
+ biases: List[torch.Tensor],
+ chunk_size: int,
+ use_memory_efficient_kernel: bool = False,
+ use_lma: bool = False,
+ inplace_safe: bool = False,
+ ) -> torch.Tensor:
+ "triangle! triangle!"
+ mha_inputs = {
+ "q_x": x,
+ "kv_x": x,
+ "biases": biases,
+ }
+
+ return chunk_layer(
+ partial(self.mha, use_memory_efficient_kernel=use_memory_efficient_kernel, use_lma=use_lma),
+ mha_inputs,
+ chunk_size=chunk_size,
+ no_batch_dims=len(x.shape[:-2]),
+ _out=x if inplace_safe else None,
+ )
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ mask: Optional[torch.Tensor] = None,
+ chunk_size: Optional[int] = None,
+ use_memory_efficient_kernel: bool = False,
+ use_lma: bool = False,
+ inplace_safe: bool = False,
+ ) -> torch.Tensor:
+ """
+ Args:
+ x:
+ [*, I, J, C_in] input tensor (e.g. the pair representation)
+ Returns:
+ [*, I, J, C_in] output tensor
+ """
+ if mask is None:
+ # [*, I, J]
+ mask = x.new_ones(
+ x.shape[:-1],
+ )
+
+ if not self.starting:
+ x = x.transpose(-2, -3)
+ mask = mask.transpose(-1, -2)
+
+ # [*, I, J, C_in]
+ x = self.layer_norm(x)
+
+ # [*, I, 1, 1, J]
+ mask_bias = (self.inf * (mask - 1))[..., :, None, None, :]
+
+ # [*, H, I, J]
+ triangle_bias = permute_final_dims(self.linear(x), (2, 0, 1))
+
+ # [*, 1, H, I, J]
+ triangle_bias = triangle_bias.unsqueeze(-4)
+
+ biases = [mask_bias, triangle_bias]
+
+ if chunk_size is not None:
+ x = self._chunk(
+ x,
+ biases,
+ chunk_size,
+ use_memory_efficient_kernel=use_memory_efficient_kernel,
+ use_lma=use_lma,
+ inplace_safe=inplace_safe,
+ )
+ else:
+ x = self.mha(
+ q_x=x, kv_x=x, biases=biases, use_memory_efficient_kernel=use_memory_efficient_kernel, use_lma=use_lma
+ )
+
+ if not self.starting:
+ x = x.transpose(-2, -3)
+
+ return x
+
+
+class EsmFoldTriangleMultiplicativeUpdate(nn.Module):
+ """
+ Implements Algorithms 11 and 12.
+ """
+
+ def __init__(self, config, _outgoing=True):
+ super().__init__()
+ c_hidden = config.pairwise_state_dim
+ self._outgoing = _outgoing
+
+ self.linear_a_p = EsmFoldLinear(c_hidden, c_hidden)
+ self.linear_a_g = EsmFoldLinear(c_hidden, c_hidden, init="gating")
+ self.linear_b_p = EsmFoldLinear(c_hidden, c_hidden)
+ self.linear_b_g = EsmFoldLinear(c_hidden, c_hidden, init="gating")
+ self.linear_g = EsmFoldLinear(c_hidden, c_hidden, init="gating")
+ self.linear_z = EsmFoldLinear(c_hidden, c_hidden, init="final")
+
+ self.layer_norm_in = LayerNorm(c_hidden)
+ self.layer_norm_out = LayerNorm(c_hidden)
+
+ self.sigmoid = nn.Sigmoid()
+
+ def _combine_projections(
+ self, a: torch.Tensor, b: torch.Tensor, _inplace_chunk_size: Optional[int] = None
+ ) -> torch.Tensor:
+ if self._outgoing:
+ a = permute_final_dims(a, (2, 0, 1))
+ b = permute_final_dims(b, (2, 1, 0))
+ else:
+ a = permute_final_dims(a, (2, 1, 0))
+ b = permute_final_dims(b, (2, 0, 1))
+
+ if _inplace_chunk_size is not None:
+ # To be replaced by torch vmap
+ for i in range(0, a.shape[-3], _inplace_chunk_size):
+ a_chunk = a[..., i : i + _inplace_chunk_size, :, :]
+ b_chunk = b[..., i : i + _inplace_chunk_size, :, :]
+ a[..., i : i + _inplace_chunk_size, :, :] = torch.matmul(
+ a_chunk,
+ b_chunk,
+ )
+
+ p = a
+ else:
+ p = torch.matmul(a, b)
+
+ return permute_final_dims(p, (1, 2, 0))
+
+ def _inference_forward(
+ self,
+ z: torch.Tensor,
+ mask: Optional[torch.Tensor] = None,
+ inplace_chunk_size: Optional[int] = None,
+ with_add: bool = True,
+ ):
+ """
+ Args:
+ z:
+ A [*, N, N, C_z] pair representation
+ mask:
+ A [*, N, N] pair mask
+ inplace_chunk_size:
+ Size of chunks used in the main computation. Increase to trade memory for speed.
+ with_add:
+ If True, z is overwritten with (z + update). Otherwise, it is overwritten with (update).
+ Returns:
+ A reference to the overwritten z
+
+ More memory-efficient, inference-only version of the forward function. Uses in-place operations, fusion of the
+ addition that happens after this module in the Evoformer, a smidge of recomputation, and a cache of overwritten
+ values to lower peak memory consumption of this module from 5x the size of the input tensor z to 2.5x its size.
+ Useful for inference on extremely long sequences.
+
+ It works as follows. We will make reference to variables used in the default forward implementation below.
+ Naively, triangle multiplication attention requires the manifestation of 5 tensors the size of z: 1) z, the
+ "square" input tensor, 2) a, the first projection of z, 3) b, the second projection of b, 4) g, a z-sized mask,
+ and 5) a z-sized tensor for intermediate computations. For large N, this is prohibitively expensive; for
+ N=4000, for example, z is more than 8GB alone. To avoid this problem, we compute b, g, and all intermediate
+ tensors in small chunks, noting that the chunks required to compute a chunk of the output depend only on the
+ tensor a and corresponding vertical and horizontal chunks of z. This suggests an algorithm that loops over
+ pairs of chunks of z: hereafter "columns" and "rows" of z, even though each "column" and "row" in fact contains
+ inplace_chunk_size contiguous true columns and rows of z. Writing output chunks to a new tensor would bring
+ total memory consumption down to 3x the size of z. However, more memory can be saved by writing output chunks
+ directly to z in-place. WLOG, we choose to write output chunks vertically, overwriting the ith "column" of z at
+ the end of the ith iteration of the main loop. Despite this overwriting, the ith column is always one column
+ ahead of previously overwritten columns and can be recovered directly from z. After the first iteration,
+ however, the ith row of z is always at least partially overwritten. For this reason, we introduce the z-cache,
+ a tensor one-half the size of z. The z-cache initially contains the left half (2nd and 3rd quadrants) of z. For
+ 0 < i < N/2, the missing left part of the ith row of z is recovered from this cache at the beginning of the ith
+ iteration. Once i exceeds n/2, the cache is "reoriented" to encompass the 3rd and 4th quadrants of z instead.
+ Though the 3rd quadrant of the original z is entirely overwritten at this point, it can be recovered from the
+ z-cache itself. Thereafter, the ith row of z can be recovered in its entirety from the reoriented z-cache.
+ After the final iteration, z has been completely overwritten and contains the triangular multiplicative update.
+ If with_add is True, it instead contains the sum of z and the triangular multiplicative update. In either case,
+ peak memory consumption is just 2.5x the size of z, disregarding memory used for chunks and other small
+ variables.
+ """
+ if mask is None:
+ mask = z.new_ones(z.shape[:-1])
+
+ mask = mask.unsqueeze(-1)
+
+ def compute_projection_helper(pair, mask, a=True):
+ if a:
+ linear_g = self.linear_a_g
+ linear_p = self.linear_a_p
+ else:
+ linear_g = self.linear_b_g
+ linear_p = self.linear_b_p
+
+ pair = self.layer_norm_in(pair)
+ p = linear_g(pair)
+ p.sigmoid_()
+ p *= linear_p(pair)
+ p *= mask
+ p = permute_final_dims(p, (2, 0, 1))
+ return p
+
+ def compute_projection(pair, mask, a=True, chunked=True):
+ need_transpose = self._outgoing ^ a
+ if not chunked:
+ p = compute_projection_helper(pair, mask, a)
+ if need_transpose:
+ p = p.transpose(-1, -2)
+ else:
+ # This computation is chunked so as not to exceed our 2.5x
+ # budget with a large intermediate tensor
+ linear_g = self.linear_a_g if a else self.linear_b_g
+ c = linear_g.bias.shape[-1]
+ out_shape = pair.shape[:-3] + (c,) + pair.shape[-3:-1]
+ p = pair.new_zeros(out_shape)
+ for i in range(0, pair.shape[-3], inplace_chunk_size):
+ pair_chunk = pair[..., i : i + inplace_chunk_size, :, :]
+ pair_chunk = compute_projection_helper(
+ pair[..., i : i + inplace_chunk_size, :, :],
+ mask[..., i : i + inplace_chunk_size, :, :],
+ a,
+ )
+ if need_transpose:
+ pair_chunk = pair_chunk.transpose(-1, -2)
+ p[..., i : i + inplace_chunk_size] = pair_chunk
+ else:
+ p[..., i : i + inplace_chunk_size, :] = pair_chunk
+
+ del pair_chunk
+
+ return p
+
+ # We start by fully manifesting a. In addition to the input, this
+ # brings total memory consumption to 2x z (disregarding size of chunks)
+ # [*, N, N, c]
+ a = compute_projection(z, mask, True, chunked=True)
+
+ if inplace_chunk_size is not None:
+ n = a.shape[-1]
+ half_n = n // 2 + n % 2
+ row_dim = -3
+ col_dim = -2
+ b_chunk_dim = row_dim if self._outgoing else col_dim
+
+ def empty_slicer(t):
+ return [slice(None) for _ in t.shape]
+
+ def slice_tensor(t, start, end, dim):
+ # Slices start:end from the dim dimension of t
+ s = empty_slicer(t)
+ s[dim] = slice(start, end)
+ return t[s]
+
+ def flip_z_cache_(z_cache, z):
+ # "Reorient" the z_cache (see below), filling it with quadrants
+ # 3---recovered from the z_cache---and 4---recovered from z---
+ # of the input tensor z.
+ quadrant_3 = slice_tensor(z_cache, half_n, None, row_dim)
+ z_cache = z_cache.transpose(row_dim, col_dim)
+
+ # If n is odd, we need to shrink the z_cache by one row
+ z_cache = z_cache[..., : (n // 2), :, :]
+
+ # Move the 3rd quadrant of z into the
+ first_half_slicer = empty_slicer(z_cache)
+ first_half_slicer[col_dim] = slice(0, half_n)
+ z_cache[first_half_slicer] = quadrant_3
+
+ # Get the fourth quadrant of z
+ quadrant_4 = slice_tensor(z, half_n, None, row_dim)
+ quadrant_4 = slice_tensor(quadrant_4, half_n, None, col_dim)
+
+ # Insert said quadrant into the rotated z-cache
+ quadrant_3_slicer = empty_slicer(z_cache)
+ quadrant_3_slicer[col_dim] = slice(half_n, None)
+
+ z_cache[quadrant_3_slicer] = quadrant_4
+
+ return z_cache
+
+ # Initialize the z cache to the left half of z.
+ z_cache_shape = list(z.shape)
+ z_cache_shape[col_dim] = half_n
+ z_cache = z.new_zeros(z_cache_shape)
+ z_cache_slicer = empty_slicer(z_cache)
+ z_cache_slicer[col_dim] = slice(0, half_n)
+ z_cache.copy_(z[z_cache_slicer])
+ z_cache_rotated = False
+
+ # We need to reorient the z-cache at the halfway point, and we
+ # don't want a single chunk to straddle that point. We contract one
+ # of the chunks in the middle to address that problem.
+ i_range = list(range(0, half_n, inplace_chunk_size))
+ initial_offsets = [i_2 - i_1 for i_1, i_2 in zip(i_range, i_range[1:] + [half_n])]
+ after_half = list(range(half_n, n, inplace_chunk_size))
+ after_half_offsets = [inplace_chunk_size for _ in after_half]
+ combined_range_with_offsets = zip(i_range + after_half, initial_offsets + after_half_offsets)
+ for i, offset in combined_range_with_offsets:
+ if not z_cache_rotated and i >= half_n:
+ z_cache = flip_z_cache_(z_cache, z)
+ z_cache_rotated = True
+
+ z_chunk_b = slice_tensor(z, i, i + offset, b_chunk_dim)
+ mask_chunk = slice_tensor(mask, i, i + offset, b_chunk_dim)
+
+ z_chunk_b = z_chunk_b.clone()
+ if b_chunk_dim == col_dim:
+ z_chunk_b = slice_tensor(z, i, i + offset, col_dim)
+ else: # b_chunk_dim == row_dim
+ # In this case, the b-dimension (b_chunk_dim) is partially
+ # overwritten at the end of each iteration. We need to
+ # restore the missing component from the z-cache.
+ if not z_cache_rotated:
+ z_chunk_slicer = empty_slicer(z_chunk_b)
+ z_chunk_slicer[col_dim] = slice(0, half_n)
+ z_chunk_b[z_chunk_slicer] = slice_tensor(z_cache, i, i + offset, row_dim)
+ else:
+ z_cache_offset = i - half_n
+ z_chunk_b = slice_tensor(z_cache, z_cache_offset, z_cache_offset + offset, row_dim)
+
+ b_chunk = compute_projection(z_chunk_b, mask_chunk, a=False, chunked=False)
+ del z_chunk_b
+
+ x_chunk = torch.matmul(a, b_chunk)
+ x_chunk = permute_final_dims(x_chunk, (1, 2, 0))
+ x_chunk = self.layer_norm_out(x_chunk)
+ x_chunk = self.linear_z(x_chunk)
+
+ # The g dimension (col_dim) is parallel to and ahead of the
+ # overwrites in z. We can extract the g chunk normally.
+ z_chunk_g = slice_tensor(z, i, i + offset, col_dim)
+ g_chunk = self.linear_g(self.layer_norm_in(z_chunk_g))
+ g_chunk.sigmoid_()
+ del z_chunk_g
+
+ x_chunk *= g_chunk
+
+ # Write the columns into z in-place
+ z_slicer = empty_slicer(z)
+ z_slicer[col_dim] = slice(i, i + offset)
+ if with_add:
+ z[z_slicer] += x_chunk
+ else:
+ z[z_slicer] = x_chunk
+ else:
+ b = compute_projection(z, mask, False, False)
+ x = torch.matmul(a, b)
+ x = self.layer_norm_out(x)
+ x = self.linear_z(x)
+ g = self.linear_g(z)
+ g.sigmoid_()
+ x *= g
+ if with_add:
+ z += x
+ else:
+ z = x
+
+ return z
+
+ def forward(
+ self,
+ z: torch.Tensor,
+ mask: Optional[torch.Tensor] = None,
+ inplace_safe: bool = False,
+ _add_with_inplace: bool = False,
+ _inplace_chunk_size: Optional[int] = 256,
+ ) -> torch.Tensor:
+ """
+ Args:
+ x:
+ [*, N_res, N_res, C_z] input tensor
+ mask:
+ [*, N_res, N_res] input mask
+ Returns:
+ [*, N_res, N_res, C_z] output tensor
+ """
+ if inplace_safe:
+ x = self._inference_forward(
+ z,
+ mask,
+ inplace_chunk_size=_inplace_chunk_size,
+ with_add=_add_with_inplace,
+ )
+ return x
+
+ if mask is None:
+ mask = z.new_ones(z.shape[:-1])
+
+ mask = mask.unsqueeze(-1)
+
+ z = self.layer_norm_in(z)
+ a = mask
+ a = a * self.sigmoid(self.linear_a_g(z))
+ a = a * self.linear_a_p(z)
+ b = mask
+ b = b * self.sigmoid(self.linear_b_g(z))
+ b = b * self.linear_b_p(z)
+
+ if is_fp16_enabled():
+ with torch.cuda.amp.autocast(enabled=False):
+ x = self._combine_projections(a.float(), b.float())
+ else:
+ x = self._combine_projections(a, b)
+
+ del a, b
+ x = self.layer_norm_out(x)
+ x = self.linear_z(x)
+ g = self.sigmoid(self.linear_g(z))
+ x = x * g
+
+ return x
+
+
+class EsmFoldPreTrainedModel(EsmPreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ # Subclass `EsMPreTrainedModel` to deal with special init
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, EsmFoldLinear):
+ with torch.no_grad():
+ if module.init_fn is not None:
+ module.init_fn(module.weight, module.bias)
+ elif module.init == "default":
+ trunc_normal_init_(module.weight, scale=1.0)
+ elif module.init == "relu":
+ trunc_normal_init_(module.weight, scale=2.0)
+ elif module.init == "glorot":
+ nn.init.xavier_uniform_(module.weight, gain=1)
+ elif module.init == "gating":
+ module.weight.fill_(0.0)
+ if module.bias:
+ module.bias.fill_(1.0)
+ elif module.init == "normal":
+ torch.nn.init.kaiming_normal_(module.weight, nonlinearity="linear")
+ elif module.init == "final":
+ module.weight.fill_(0.0)
+ elif isinstance(module, EsmFoldInvariantPointAttention):
+ ipa_point_weights_init_(module.head_weights)
+ elif isinstance(module, EsmFoldTriangularSelfAttentionBlock):
+ torch.nn.init.zeros_(module.tri_mul_in.linear_z.weight)
+ torch.nn.init.zeros_(module.tri_mul_in.linear_z.bias)
+ torch.nn.init.zeros_(module.tri_mul_out.linear_z.weight)
+ torch.nn.init.zeros_(module.tri_mul_out.linear_z.bias)
+ torch.nn.init.zeros_(module.tri_att_start.mha.linear_o.weight)
+ torch.nn.init.zeros_(module.tri_att_start.mha.linear_o.bias)
+ torch.nn.init.zeros_(module.tri_att_end.mha.linear_o.weight)
+ torch.nn.init.zeros_(module.tri_att_end.mha.linear_o.bias)
+
+ torch.nn.init.zeros_(module.sequence_to_pair.o_proj.weight)
+ torch.nn.init.zeros_(module.sequence_to_pair.o_proj.bias)
+ torch.nn.init.zeros_(module.pair_to_sequence.linear.weight)
+ torch.nn.init.zeros_(module.seq_attention.o_proj.weight)
+ torch.nn.init.zeros_(module.seq_attention.o_proj.bias)
+ torch.nn.init.zeros_(module.mlp_seq.mlp[-2].weight)
+ torch.nn.init.zeros_(module.mlp_seq.mlp[-2].bias)
+ torch.nn.init.zeros_(module.mlp_pair.mlp[-2].weight)
+ torch.nn.init.zeros_(module.mlp_pair.mlp[-2].bias)
+ else:
+ super()._init_weights(module)
+
+
+class EsmFoldSelfAttention(nn.Module):
+ def __init__(self, embed_dim, num_heads, head_width, gated=False):
+ super().__init__()
+ assert embed_dim == num_heads * head_width
+
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.head_width = head_width
+
+ self.proj = nn.Linear(embed_dim, embed_dim * 3, bias=False)
+ self.o_proj = nn.Linear(embed_dim, embed_dim, bias=True)
+ self.gated = gated
+ if gated:
+ self.g_proj = nn.Linear(embed_dim, embed_dim)
+ torch.nn.init.zeros_(self.g_proj.weight)
+ torch.nn.init.ones_(self.g_proj.bias)
+
+ self.rescale_factor = self.head_width**-0.5
+
+ torch.nn.init.zeros_(self.o_proj.bias)
+
+ def forward(self, x, mask=None, bias=None, indices=None):
+ """
+ Basic self attention with optional mask and external pairwise bias. To handle sequences of different lengths,
+ use mask.
+
+ Inputs:
+ x: batch of input sequences (.. x L x C) mask: batch of boolean masks where 1=valid, 0=padding position (..
+ x L_k) bias: batch of scalar pairwise attention biases (.. x Lq x Lk x num_heads)
+
+ Outputs:
+ sequence projection (B x L x embed_dim), attention maps (B x L x L x num_heads)
+ """
+
+ t = self.proj(x).view(*x.shape[:2], self.num_heads, -1)
+ t = t.permute(0, 2, 1, 3)
+ q, k, v = t.chunk(3, dim=-1)
+
+ q = self.rescale_factor * q
+ a = torch.einsum("...qc,...kc->...qk", q, k)
+
+ # Add external attention bias.
+ if bias is not None:
+ a = a + bias.permute(0, 3, 1, 2)
+
+ # Do not attend to padding tokens.
+ if mask is not None:
+ mask = mask[:, None, None]
+ a = a.masked_fill(mask == False, -np.inf) # noqa: E712
+
+ a = nn.functional.softmax(a, dim=-1)
+
+ y = torch.einsum("...hqk,...hkc->...qhc", a, v)
+ y = y.reshape(*y.shape[:2], -1)
+
+ if self.gated:
+ y = self.g_proj(x).sigmoid() * y
+ y = self.o_proj(y)
+
+ return y, a.permute(0, 3, 1, 2)
+
+
+class EsmFoldDropout(nn.Module):
+ """
+ Implementation of dropout with the ability to share the dropout mask along a particular dimension.
+ """
+
+ def __init__(self, r: float, batch_dim: Union[int, List[int]]):
+ super().__init__()
+
+ self.r = r
+ if isinstance(batch_dim, int):
+ batch_dim = [batch_dim]
+ self.batch_dim = batch_dim
+ self.dropout = nn.Dropout(self.r)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ shape = list(x.shape)
+ if self.batch_dim is not None:
+ for bd in self.batch_dim:
+ shape[bd] = 1
+ return x * self.dropout(x.new_ones(shape))
+
+
+class EsmFoldSequenceToPair(nn.Module):
+ def __init__(self, sequence_state_dim, inner_dim, pairwise_state_dim):
+ super().__init__()
+
+ self.layernorm = nn.LayerNorm(sequence_state_dim)
+ self.proj = nn.Linear(sequence_state_dim, inner_dim * 2, bias=True)
+ self.o_proj = nn.Linear(2 * inner_dim, pairwise_state_dim, bias=True)
+
+ torch.nn.init.zeros_(self.proj.bias)
+ torch.nn.init.zeros_(self.o_proj.bias)
+
+ def forward(self, sequence_state):
+ """
+ Inputs:
+ sequence_state: B x L x sequence_state_dim
+
+ Output:
+ pairwise_state: B x L x L x pairwise_state_dim
+
+ Intermediate state:
+ B x L x L x 2*inner_dim
+ """
+
+ assert len(sequence_state.shape) == 3
+
+ s = self.layernorm(sequence_state)
+ s = self.proj(s)
+ q, k = s.chunk(2, dim=-1)
+
+ prod = q[:, None, :, :] * k[:, :, None, :]
+ diff = q[:, None, :, :] - k[:, :, None, :]
+
+ x = torch.cat([prod, diff], dim=-1)
+ x = self.o_proj(x)
+
+ return x
+
+
+class EsmFoldPairToSequence(nn.Module):
+ def __init__(self, pairwise_state_dim, num_heads):
+ super().__init__()
+
+ self.layernorm = nn.LayerNorm(pairwise_state_dim)
+ self.linear = nn.Linear(pairwise_state_dim, num_heads, bias=False)
+
+ def forward(self, pairwise_state):
+ """
+ Inputs:
+ pairwise_state: B x L x L x pairwise_state_dim
+
+ Output:
+ pairwise_bias: B x L x L x num_heads
+ """
+ assert len(pairwise_state.shape) == 4
+ z = self.layernorm(pairwise_state)
+ pairwise_bias = self.linear(z)
+ return pairwise_bias
+
+
+class EsmFoldResidueMLP(nn.Module):
+ def __init__(self, embed_dim, inner_dim, dropout=0):
+ super().__init__()
+
+ self.mlp = nn.Sequential(
+ nn.LayerNorm(embed_dim),
+ nn.Linear(embed_dim, inner_dim),
+ nn.ReLU(),
+ nn.Linear(inner_dim, embed_dim),
+ nn.Dropout(dropout),
+ )
+
+ def forward(self, x):
+ return x + self.mlp(x)
+
+
+class EsmFoldTriangularSelfAttentionBlock(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+
+ sequence_state_dim = config.sequence_state_dim
+ pairwise_state_dim = config.pairwise_state_dim
+ sequence_num_heads = sequence_state_dim // config.sequence_head_width
+ pairwise_num_heads = pairwise_state_dim // config.pairwise_head_width
+
+ self.layernorm_1 = nn.LayerNorm(sequence_state_dim)
+
+ self.sequence_to_pair = EsmFoldSequenceToPair(sequence_state_dim, pairwise_state_dim // 2, pairwise_state_dim)
+ self.pair_to_sequence = EsmFoldPairToSequence(pairwise_state_dim, sequence_num_heads)
+
+ self.seq_attention = EsmFoldSelfAttention(
+ sequence_state_dim, sequence_num_heads, config.sequence_head_width, gated=True
+ )
+ self.tri_mul_out = EsmFoldTriangleMultiplicativeUpdate(config, _outgoing=True)
+ self.tri_mul_in = EsmFoldTriangleMultiplicativeUpdate(config, _outgoing=False)
+
+ self.tri_att_start = EsmFoldTriangleAttention(
+ pairwise_state_dim, config.pairwise_head_width, pairwise_num_heads, inf=1e9, starting=True
+ )
+ self.tri_att_end = EsmFoldTriangleAttention(
+ pairwise_state_dim, config.pairwise_head_width, pairwise_num_heads, inf=1e9, starting=False
+ )
+
+ self.mlp_seq = EsmFoldResidueMLP(sequence_state_dim, 4 * sequence_state_dim, dropout=config.dropout)
+ self.mlp_pair = EsmFoldResidueMLP(pairwise_state_dim, 4 * pairwise_state_dim, dropout=config.dropout)
+
+ self.drop = nn.Dropout(config.dropout)
+ self.row_drop = EsmFoldDropout(config.dropout * 2, 2)
+ self.col_drop = EsmFoldDropout(config.dropout * 2, 1)
+
+ def forward(self, sequence_state, pairwise_state, mask=None, chunk_size=None, **__kwargs):
+ """
+ Inputs:
+ sequence_state: B x L x sequence_state_dim pairwise_state: B x L x L x pairwise_state_dim mask: B x L boolean
+ tensor of valid positions
+
+ Output:
+ sequence_state: B x L x sequence_state_dim pairwise_state: B x L x L x pairwise_state_dim
+ """
+ if len(sequence_state.shape) != 3:
+ raise ValueError(f"`sequence_state` should be a 3d-tensor, got {len(sequence_state.shape)} dims.")
+ if len(pairwise_state.shape) != 4:
+ raise ValueError(f"`pairwise_state` should be a 4d-tensor, got {len(pairwise_state.shape)} dims.")
+ if mask is not None and len(mask.shape) != 2:
+ raise ValueError(f"`mask` should be a 2d-tensor, got {len(mask.shape)} dims.")
+
+ batch_dim, seq_dim, sequence_state_dim = sequence_state.shape
+ pairwise_state_dim = pairwise_state.shape[3]
+
+ if sequence_state_dim != self.config.sequence_state_dim:
+ raise ValueError(
+ "`sequence_state` last dimension should be equal to `self.sequence_state_dim`. Got "
+ f"{sequence_state_dim} != {self.config.sequence_state_dim}."
+ )
+ if pairwise_state_dim != self.config.pairwise_state_dim:
+ raise ValueError(
+ "`pairwise_state` last dimension should be equal to `self.pairwise_state_dim`. Got "
+ f"{pairwise_state_dim} != {self.config.pairwise_state_dim}."
+ )
+ if batch_dim != pairwise_state.shape[0]:
+ raise ValueError(
+ f"`sequence_state` and `pairwise_state` have inconsistent batch size: {batch_dim} != "
+ f"{pairwise_state.shape[0]}."
+ )
+ if seq_dim != pairwise_state.shape[1] or seq_dim != pairwise_state.shape[2]:
+ raise ValueError(
+ f"`sequence_state` and `pairwise_state` have inconsistent sequence length: {seq_dim} != "
+ f"{pairwise_state.shape[1]} or {pairwise_state.shape[2]}."
+ )
+
+ # Update sequence state
+ bias = self.pair_to_sequence(pairwise_state)
+
+ # Self attention with bias + mlp.
+ y = self.layernorm_1(sequence_state)
+ y, _ = self.seq_attention(y, mask=mask, bias=bias)
+ sequence_state = sequence_state + self.drop(y)
+ sequence_state = self.mlp_seq(sequence_state)
+
+ # Update pairwise state
+ pairwise_state = pairwise_state + self.sequence_to_pair(sequence_state)
+
+ # Axial attention with triangular bias.
+ tri_mask = mask.unsqueeze(2) * mask.unsqueeze(1) if mask is not None else None
+ pairwise_state = pairwise_state + self.row_drop(self.tri_mul_out(pairwise_state, mask=tri_mask))
+ pairwise_state = pairwise_state + self.col_drop(self.tri_mul_in(pairwise_state, mask=tri_mask))
+ pairwise_state = pairwise_state + self.row_drop(
+ self.tri_att_start(pairwise_state, mask=tri_mask, chunk_size=chunk_size)
+ )
+ pairwise_state = pairwise_state + self.col_drop(
+ self.tri_att_end(pairwise_state, mask=tri_mask, chunk_size=chunk_size)
+ )
+
+ # MLP over pairs.
+ pairwise_state = self.mlp_pair(pairwise_state)
+
+ return sequence_state, pairwise_state
+
+
+class EsmCategoricalMixture:
+ def __init__(self, param, bins=50, start=0, end=1):
+ # All tensors are of shape ..., bins.
+ self.logits = param
+ bins = torch.linspace(start, end, bins + 1, device=self.logits.device, dtype=self.logits.dtype)
+ self.v_bins = (bins[:-1] + bins[1:]) / 2
+
+ def log_prob(self, true):
+ # Shapes are:
+ # self.probs: ... x bins
+ # true : ...
+ true_index = (true.unsqueeze(-1) - self.v_bins[[None] * true.ndim]).abs().argmin(-1)
+ nll = self.logits.log_softmax(-1)
+ return torch.take_along_dim(nll, true_index.unsqueeze(-1), dim=-1).squeeze(-1)
+
+ def mean(self):
+ return (self.logits.softmax(-1) @ self.v_bins.unsqueeze(1)).squeeze(-1)
+
+
+def categorical_lddt(logits, bins=50):
+ # Logits are ..., 37, bins.
+ return EsmCategoricalMixture(logits, bins=bins).mean()
+
+
+def get_axial_mask(mask):
+ """
+ Helper to convert B x L mask of valid positions to axial mask used in row column attentions.
+
+ Input:
+ mask: B x L tensor of booleans
+
+ Output:
+ mask: B x L x L tensor of booleans
+ """
+
+ if mask is None:
+ return None
+
+ if len(mask.shape) != 2:
+ raise ValueError(f"`mask` should be a 2d-tensor, got {len(mask.shape)} dims.")
+ batch_dim, seq_dim = mask.shape
+ m = mask.unsqueeze(1).expand(batch_dim, seq_dim, seq_dim)
+ m = m.reshape(batch_dim * seq_dim, seq_dim)
+ return m
+
+
+class EsmFoldRelativePosition(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.bins = config.position_bins
+
+ # Note an additional offset is used so that the 0th position
+ # is reserved for masked pairs.
+ self.embedding = torch.nn.Embedding(2 * self.bins + 2, config.pairwise_state_dim)
+
+ def forward(self, residue_index, mask=None):
+ """
+ Input:
+ residue_index: B x L tensor of indices (dtype=torch.long) mask: B x L tensor of booleans
+
+ Output:
+ pairwise_state: B x L x L x pairwise_state_dim tensor of embeddings
+ """
+ if residue_index.dtype != torch.long:
+ raise ValueError(f"`residue_index` has dtype {residue_index.dtype}, it should be `torch.long`.")
+ if mask is not None and residue_index.shape != mask.shape:
+ raise ValueError(
+ f"`residue_index` and `mask` have inconsistent shapes: {residue_index.shape} != {mask.shape}."
+ )
+
+ diff = residue_index[:, None, :] - residue_index[:, :, None]
+ diff = diff.clamp(-self.bins, self.bins)
+ diff = diff + self.bins + 1 # Add 1 to adjust for padding index.
+
+ if mask is not None:
+ mask = mask[:, None, :] * mask[:, :, None]
+ diff[mask == False] = 0 # noqa: E712
+
+ output = self.embedding(diff)
+ return output
+
+
+class EsmFoldAngleResnetBlock(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+
+ self.linear_1 = EsmFoldLinear(config.resnet_dim, config.resnet_dim, init="relu")
+ self.linear_2 = EsmFoldLinear(config.resnet_dim, config.resnet_dim, init="final")
+
+ self.relu = nn.ReLU()
+
+ def forward(self, a: torch.Tensor) -> torch.Tensor:
+ s_initial = a
+
+ a = self.relu(a)
+ a = self.linear_1(a)
+ a = self.relu(a)
+ a = self.linear_2(a)
+
+ return a + s_initial
+
+
+class EsmFoldAngleResnet(nn.Module):
+ """
+ Implements Algorithm 20, lines 11-14
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+
+ self.linear_in = EsmFoldLinear(config.sequence_dim, config.resnet_dim)
+ self.linear_initial = EsmFoldLinear(config.sequence_dim, config.resnet_dim)
+
+ self.layers = nn.ModuleList()
+ for _ in range(config.num_resnet_blocks):
+ layer = EsmFoldAngleResnetBlock(config)
+ self.layers.append(layer)
+
+ self.linear_out = EsmFoldLinear(config.resnet_dim, config.num_angles * 2)
+
+ self.relu = nn.ReLU()
+
+ def forward(self, s: torch.Tensor, s_initial: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Args:
+ s:
+ [*, C_hidden] single embedding
+ s_initial:
+ [*, C_hidden] single embedding as of the start of the StructureModule
+ Returns:
+ [*, no_angles, 2] predicted angles
+ """
+ # NOTE: The ReLU's applied to the inputs are absent from the supplement
+ # pseudocode but present in the source. For maximal compatibility with
+ # the pretrained weights, I'm going with the source.
+
+ # [*, C_hidden]
+ s_initial = self.relu(s_initial)
+ s_initial = self.linear_initial(s_initial)
+ s = self.relu(s)
+ s = self.linear_in(s)
+ s = s + s_initial
+
+ for l in self.layers:
+ s = l(s)
+
+ s = self.relu(s)
+
+ # [*, no_angles * 2]
+ s = self.linear_out(s)
+
+ # [*, no_angles, 2]
+ s = s.view(s.shape[:-1] + (-1, 2))
+
+ unnormalized_s = s
+ norm_denom = torch.sqrt(
+ torch.clamp(
+ torch.sum(s**2, dim=-1, keepdim=True),
+ min=self.config.epsilon,
+ )
+ )
+ s = s / norm_denom
+
+ return unnormalized_s, s
+
+
+class EsmFoldInvariantPointAttention(nn.Module):
+ """
+ Implements Algorithm 22.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+
+ c_s = config.sequence_dim
+ c_z = config.pairwise_dim
+ self.hidden_dim = config.ipa_dim
+ self.num_heads = config.num_heads_ipa
+ self.num_qk_points = config.num_qk_points
+ self.num_v_points = config.num_v_points
+
+ # These linear layers differ from their specifications in the
+ # supplement. There, they lack bias and use Glorot initialization.
+ # Here as in the official source, they have bias and use the default
+ # Lecun initialization.
+ hc = config.ipa_dim * config.num_heads_ipa
+ self.linear_q = EsmFoldLinear(c_s, hc)
+ self.linear_kv = EsmFoldLinear(c_s, 2 * hc)
+
+ hpq = config.num_heads_ipa * config.num_qk_points * 3
+ self.linear_q_points = EsmFoldLinear(c_s, hpq)
+
+ hpkv = config.num_heads_ipa * (config.num_qk_points + config.num_v_points) * 3
+ self.linear_kv_points = EsmFoldLinear(c_s, hpkv)
+
+ self.linear_b = EsmFoldLinear(c_z, config.num_heads_ipa)
+
+ self.head_weights = nn.Parameter(torch.zeros((config.num_heads_ipa)))
+
+ concat_out_dim = config.num_heads_ipa * (c_z + config.ipa_dim + config.num_v_points * 4)
+ self.linear_out = EsmFoldLinear(concat_out_dim, c_s, init="final")
+
+ self.softmax = nn.Softmax(dim=-1)
+ self.softplus = nn.Softplus()
+
+ def forward(
+ self,
+ s: torch.Tensor,
+ z: Optional[torch.Tensor],
+ r: Rigid,
+ mask: torch.Tensor,
+ _offload_inference: bool = False,
+ _z_reference_list: Optional[Sequence[torch.Tensor]] = None,
+ ) -> torch.Tensor:
+ """
+ Args:
+ s:
+ [*, N_res, C_s] single representation
+ z:
+ [*, N_res, N_res, C_z] pair representation
+ r:
+ [*, N_res] transformation object
+ mask:
+ [*, N_res] mask
+ Returns:
+ [*, N_res, C_s] single representation update
+ """
+ z = [z]
+
+ #######################################
+ # Generate scalar and point activations
+ #######################################
+ # [*, N_res, H * C_hidden]
+ q = self.linear_q(s)
+ kv = self.linear_kv(s)
+
+ # [*, N_res, H, C_hidden]
+ q = q.view(q.shape[:-1] + (self.num_heads, -1))
+
+ # [*, N_res, H, 2 * C_hidden]
+ kv = kv.view(kv.shape[:-1] + (self.num_heads, -1))
+
+ # [*, N_res, H, C_hidden]
+ k, v = torch.split(kv, self.hidden_dim, dim=-1)
+
+ # [*, N_res, H * P_q * 3]
+ q_pts = self.linear_q_points(s)
+
+ # This is kind of clunky, but it's how the original does it
+ # [*, N_res, H * P_q, 3]
+ q_pts = torch.split(q_pts, q_pts.shape[-1] // 3, dim=-1)
+ q_pts = torch.stack(q_pts, dim=-1)
+ q_pts = r[..., None].apply(q_pts)
+
+ # [*, N_res, H, P_q, 3]
+ q_pts = q_pts.view(q_pts.shape[:-2] + (self.num_heads, self.num_qk_points, 3))
+
+ # [*, N_res, H * (P_q + P_v) * 3]
+ kv_pts = self.linear_kv_points(s)
+
+ # [*, N_res, H * (P_q + P_v), 3]
+ kv_pts = torch.split(kv_pts, kv_pts.shape[-1] // 3, dim=-1)
+ kv_pts = torch.stack(kv_pts, dim=-1)
+ kv_pts = r[..., None].apply(kv_pts)
+
+ # [*, N_res, H, (P_q + P_v), 3]
+ kv_pts = kv_pts.view(kv_pts.shape[:-2] + (self.num_heads, -1, 3))
+
+ # [*, N_res, H, P_q/P_v, 3]
+ k_pts, v_pts = torch.split(kv_pts, [self.num_qk_points, self.num_v_points], dim=-2)
+
+ ##########################
+ # Compute attention scores
+ ##########################
+ # [*, N_res, N_res, H]
+ b = self.linear_b(z[0])
+
+ if _offload_inference:
+ assert sys.getrefcount(z[0]) == 2
+ z[0] = z[0].cpu()
+
+ # [*, H, N_res, N_res]
+ if is_fp16_enabled():
+ with torch.cuda.amp.autocast(enabled=False):
+ a = torch.matmul(
+ permute_final_dims(q.float(), (1, 0, 2)), # [*, H, N_res, C_hidden]
+ permute_final_dims(k.float(), (1, 2, 0)), # [*, H, C_hidden, N_res]
+ )
+ else:
+ a = torch.matmul(
+ permute_final_dims(q, (1, 0, 2)), # [*, H, N_res, C_hidden]
+ permute_final_dims(k, (1, 2, 0)), # [*, H, C_hidden, N_res]
+ )
+
+ a *= math.sqrt(1.0 / (3 * self.hidden_dim))
+ a += math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1))
+
+ # [*, N_res, N_res, H, P_q, 3]
+ pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5)
+ pt_att = pt_att**2
+
+ # [*, N_res, N_res, H, P_q]
+ pt_att = sum(torch.unbind(pt_att, dim=-1))
+ head_weights = self.softplus(self.head_weights).view(*((1,) * len(pt_att.shape[:-2]) + (-1, 1)))
+ head_weights = head_weights * math.sqrt(1.0 / (3 * (self.num_qk_points * 9.0 / 2)))
+ pt_att = pt_att * head_weights
+
+ # [*, N_res, N_res, H]
+ pt_att = torch.sum(pt_att, dim=-1) * (-0.5)
+ # [*, N_res, N_res]
+ square_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2)
+ square_mask = self.config.inf * (square_mask - 1)
+
+ # [*, H, N_res, N_res]
+ pt_att = permute_final_dims(pt_att, (2, 0, 1))
+
+ a = a + pt_att
+ a = a + square_mask.unsqueeze(-3)
+ a = self.softmax(a)
+
+ ################
+ # Compute output
+ ################
+ # [*, N_res, H, C_hidden]
+ o = torch.matmul(a, v.transpose(-2, -3).to(dtype=a.dtype)).transpose(-2, -3)
+
+ # [*, N_res, H * C_hidden]
+ o = flatten_final_dims(o, 2)
+
+ # [*, H, 3, N_res, P_v]
+ o_pt = torch.sum(
+ (a[..., None, :, :, None] * permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :]),
+ dim=-2,
+ )
+
+ # [*, N_res, H, P_v, 3]
+ o_pt = permute_final_dims(o_pt, (2, 0, 3, 1))
+ o_pt = r[..., None, None].invert_apply(o_pt)
+
+ # [*, N_res, H * P_v]
+ o_pt_norm = flatten_final_dims(torch.sqrt(torch.sum(o_pt**2, dim=-1) + self.config.epsilon), 2)
+
+ # [*, N_res, H * P_v, 3]
+ o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3)
+
+ if _offload_inference:
+ z[0] = z[0].to(o_pt.device)
+
+ # [*, N_res, H, C_z]
+ o_pair = torch.matmul(a.transpose(-2, -3), z[0].to(dtype=a.dtype))
+
+ # [*, N_res, H * C_z]
+ o_pair = flatten_final_dims(o_pair, 2)
+
+ # [*, N_res, C_s]
+ s = self.linear_out(
+ torch.cat((o, *torch.unbind(o_pt, dim=-1), o_pt_norm, o_pair), dim=-1).to(dtype=z[0].dtype)
+ )
+
+ return s
+
+
+class EsmFoldBackboneUpdate(nn.Module):
+ """
+ Implements part of Algorithm 23.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+
+ self.linear = EsmFoldLinear(config.sequence_dim, 6, init="final")
+
+ def forward(self, s: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Args:
+ [*, N_res, C_s] single representation
+ Returns:
+ [*, N_res, 6] update vector
+ """
+ # [*, 6]
+ update = self.linear(s)
+
+ return update
+
+
+class EsmFoldStructureModuleTransitionLayer(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+
+ self.linear_1 = EsmFoldLinear(config.sequence_dim, config.sequence_dim, init="relu")
+ self.linear_2 = EsmFoldLinear(config.sequence_dim, config.sequence_dim, init="relu")
+ self.linear_3 = EsmFoldLinear(config.sequence_dim, config.sequence_dim, init="final")
+
+ self.relu = nn.ReLU()
+
+ def forward(self, s):
+ s_initial = s
+ s = self.linear_1(s)
+ s = self.relu(s)
+ s = self.linear_2(s)
+ s = self.relu(s)
+ s = self.linear_3(s)
+
+ s = s + s_initial
+
+ return s
+
+
+class EsmFoldStructureModuleTransition(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+
+ self.layers = nn.ModuleList()
+ for _ in range(config.num_transition_layers):
+ l = EsmFoldStructureModuleTransitionLayer(config)
+ self.layers.append(l)
+
+ self.dropout = nn.Dropout(config.dropout_rate)
+ self.layer_norm = LayerNorm(config.sequence_dim)
+
+ def forward(self, s):
+ for l in self.layers:
+ s = l(s)
+
+ s = self.dropout(s)
+ s = self.layer_norm(s)
+
+ return s
+
+
+class EsmFoldStructureModule(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+
+ # Buffers to be lazily initialized later
+ # self.default_frames
+ # self.group_idx
+ # self.atom_mask
+ # self.lit_positions
+
+ self.layer_norm_s = LayerNorm(config.sequence_dim)
+ self.layer_norm_z = LayerNorm(config.pairwise_dim)
+
+ self.linear_in = EsmFoldLinear(config.sequence_dim, config.sequence_dim)
+
+ self.ipa = EsmFoldInvariantPointAttention(config)
+
+ self.ipa_dropout = nn.Dropout(config.dropout_rate)
+ self.layer_norm_ipa = LayerNorm(config.sequence_dim)
+
+ self.transition = EsmFoldStructureModuleTransition(config)
+ self.bb_update = EsmFoldBackboneUpdate(config)
+ self.angle_resnet = EsmFoldAngleResnet(config)
+
+ def forward(
+ self,
+ evoformer_output_dict,
+ aatype,
+ mask=None,
+ _offload_inference=False,
+ ):
+ """
+ Args:
+ evoformer_output_dict:
+ Dictionary containing:
+ "single":
+ [*, N_res, C_s] single representation
+ "pair":
+ [*, N_res, N_res, C_z] pair representation
+ aatype:
+ [*, N_res] amino acid indices
+ mask:
+ Optional [*, N_res] sequence mask
+ Returns:
+ A dictionary of outputs
+ """
+ s = evoformer_output_dict["single"]
+
+ if mask is None:
+ # [*, N]
+ mask = s.new_ones(s.shape[:-1])
+
+ # [*, N, C_s]
+ s = self.layer_norm_s(s)
+
+ # [*, N, N, C_z]
+ z = self.layer_norm_z(evoformer_output_dict["pair"])
+
+ z_reference_list = None
+ if _offload_inference:
+ assert sys.getrefcount(evoformer_output_dict["pair"]) == 2
+ evoformer_output_dict["pair"] = evoformer_output_dict["pair"].cpu()
+ z_reference_list = [z]
+ z = None
+
+ # [*, N, C_s]
+ s_initial = s
+ s = self.linear_in(s)
+
+ # [*, N]
+ rigids = Rigid.identity(
+ s.shape[:-1],
+ s.dtype,
+ s.device,
+ self.training,
+ fmt="quat",
+ )
+ outputs = []
+ for i in range(self.config.num_blocks):
+ # [*, N, C_s]
+ s = s + self.ipa(
+ s,
+ z,
+ rigids,
+ mask,
+ _offload_inference=_offload_inference,
+ _z_reference_list=z_reference_list,
+ )
+ s = self.ipa_dropout(s)
+ s = self.layer_norm_ipa(s)
+ s = self.transition(s)
+
+ # [*, N]
+ rigids = rigids.compose_q_update_vec(self.bb_update(s))
+
+ # To hew as closely as possible to AlphaFold, we convert our
+ # quaternion-based transformations to rotation-matrix ones
+ # here
+ backb_to_global = Rigid(
+ Rotation(rot_mats=rigids.get_rots().get_rot_mats(), quats=None),
+ rigids.get_trans(),
+ )
+
+ backb_to_global = backb_to_global.scale_translation(self.config.trans_scale_factor)
+
+ # [*, N, 7, 2]
+ unnormalized_angles, angles = self.angle_resnet(s, s_initial)
+
+ all_frames_to_global = self.torsion_angles_to_frames(backb_to_global, angles, aatype)
+
+ pred_xyz = self.frames_and_literature_positions_to_atom14_pos(all_frames_to_global, aatype)
+
+ scaled_rigids = rigids.scale_translation(self.config.trans_scale_factor)
+
+ preds = {
+ "frames": scaled_rigids.to_tensor_7(),
+ "sidechain_frames": all_frames_to_global.to_tensor_4x4(),
+ "unnormalized_angles": unnormalized_angles,
+ "angles": angles,
+ "positions": pred_xyz,
+ "states": s,
+ }
+
+ outputs.append(preds)
+
+ rigids = rigids.stop_rot_gradient()
+
+ del z, z_reference_list
+
+ if _offload_inference:
+ evoformer_output_dict["pair"] = evoformer_output_dict["pair"].to(s.device)
+
+ outputs = dict_multimap(torch.stack, outputs)
+ outputs["single"] = s
+
+ return outputs
+
+ def _init_residue_constants(self, float_dtype, device):
+ if not hasattr(self, "default_frames"):
+ self.register_buffer(
+ "default_frames",
+ torch.tensor(
+ residue_constants.restype_rigid_group_default_frame,
+ dtype=float_dtype,
+ device=device,
+ requires_grad=False,
+ ),
+ persistent=False,
+ )
+ if not hasattr(self, "group_idx"):
+ self.register_buffer(
+ "group_idx",
+ torch.tensor(
+ residue_constants.restype_atom14_to_rigid_group,
+ device=device,
+ requires_grad=False,
+ ),
+ persistent=False,
+ )
+ if not hasattr(self, "atom_mask"):
+ self.register_buffer(
+ "atom_mask",
+ torch.tensor(
+ residue_constants.restype_atom14_mask,
+ dtype=float_dtype,
+ device=device,
+ requires_grad=False,
+ ),
+ persistent=False,
+ )
+ if not hasattr(self, "lit_positions"):
+ self.register_buffer(
+ "lit_positions",
+ torch.tensor(
+ residue_constants.restype_atom14_rigid_group_positions,
+ dtype=float_dtype,
+ device=device,
+ requires_grad=False,
+ ),
+ persistent=False,
+ )
+
+ def torsion_angles_to_frames(self, r, alpha, f):
+ # Lazily initialize the residue constants on the correct device
+ self._init_residue_constants(alpha.dtype, alpha.device)
+ # Separated purely to make testing less annoying
+ return torsion_angles_to_frames(r, alpha, f, self.default_frames)
+
+ def frames_and_literature_positions_to_atom14_pos(self, r, f): # [*, N, 8] # [*, N]
+ # Lazily initialize the residue constants on the correct device
+ self._init_residue_constants(r.get_rots().dtype, r.get_rots().device)
+ return frames_and_literature_positions_to_atom14_pos(
+ r,
+ f,
+ self.default_frames,
+ self.group_idx,
+ self.atom_mask,
+ self.lit_positions,
+ )
+
+
+class EsmFoldingTrunk(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+
+ c_s = config.sequence_state_dim
+ c_z = config.pairwise_state_dim
+
+ self.pairwise_positional_embedding = EsmFoldRelativePosition(config)
+
+ self.blocks = nn.ModuleList([EsmFoldTriangularSelfAttentionBlock(config) for _ in range(config.num_blocks)])
+
+ self.recycle_bins = 15
+ self.recycle_s_norm = nn.LayerNorm(c_s)
+ self.recycle_z_norm = nn.LayerNorm(c_z)
+ self.recycle_disto = nn.Embedding(self.recycle_bins, c_z)
+ self.recycle_disto.weight[0].detach().zero_()
+
+ self.structure_module = EsmFoldStructureModule(config.structure_module)
+ self.trunk2sm_s = nn.Linear(c_s, config.structure_module.sequence_dim)
+ self.trunk2sm_z = nn.Linear(c_z, config.structure_module.pairwise_dim)
+
+ self.chunk_size = config.chunk_size
+
+ def set_chunk_size(self, chunk_size):
+ # This parameter means the axial attention will be computed
+ # in a chunked manner. This should make the memory used more or less O(L) instead of O(L^2).
+ # It's equivalent to running a for loop over chunks of the dimension we're iterative over,
+ # where the chunk_size is the size of the chunks, so 128 would mean to parse 128-length chunks.
+ self.chunk_size = chunk_size
+
+ def forward(self, seq_feats, pair_feats, true_aa, residx, mask, no_recycles):
+ """
+ Inputs:
+ seq_feats: B x L x C tensor of sequence features pair_feats: B x L x L x C tensor of pair features residx: B
+ x L long tensor giving the position in the sequence mask: B x L boolean tensor indicating valid residues
+
+ Output:
+ predicted_structure: B x L x (num_atoms_per_residue * 3) tensor wrapped in a Coordinates object
+ """
+
+ device = seq_feats.device
+ s_s_0 = seq_feats
+ s_z_0 = pair_feats
+
+ if no_recycles is None:
+ no_recycles = self.config.max_recycles
+ else:
+ if no_recycles < 0:
+ raise ValueError("Number of recycles must not be negative.")
+ no_recycles += 1 # First 'recycle' is just the standard forward pass through the model.
+
+ def trunk_iter(s, z, residx, mask):
+ z = z + self.pairwise_positional_embedding(residx, mask=mask)
+
+ for block in self.blocks:
+ s, z = block(s, z, mask=mask, residue_index=residx, chunk_size=self.chunk_size)
+ return s, z
+
+ s_s = s_s_0
+ s_z = s_z_0
+ recycle_s = torch.zeros_like(s_s)
+ recycle_z = torch.zeros_like(s_z)
+ recycle_bins = torch.zeros(*s_z.shape[:-1], device=device, dtype=torch.int64)
+
+ for recycle_idx in range(no_recycles):
+ with ContextManagers([] if recycle_idx == no_recycles - 1 else [torch.no_grad()]):
+ # === Recycling ===
+ recycle_s = self.recycle_s_norm(recycle_s.detach()).to(device)
+ recycle_z = self.recycle_z_norm(recycle_z.detach()).to(device)
+ recycle_z += self.recycle_disto(recycle_bins.detach()).to(device)
+
+ s_s, s_z = trunk_iter(s_s_0 + recycle_s, s_z_0 + recycle_z, residx, mask)
+
+ # === Structure module ===
+ structure = self.structure_module(
+ {"single": self.trunk2sm_s(s_s), "pair": self.trunk2sm_z(s_z)},
+ true_aa,
+ mask.float(),
+ )
+
+ recycle_s = s_s
+ recycle_z = s_z
+ # Distogram needs the N, CA, C coordinates, and bin constants same as alphafold.
+ recycle_bins = EsmFoldingTrunk.distogram(
+ structure["positions"][-1][:, :, :3],
+ 3.375,
+ 21.375,
+ self.recycle_bins,
+ )
+
+ structure["s_s"] = s_s
+ structure["s_z"] = s_z
+
+ return structure
+
+ @staticmethod
+ def distogram(coords, min_bin, max_bin, num_bins):
+ # Coords are [... L x 3 x 3], where it's [N, CA, C] x 3 coordinates.
+ boundaries = torch.linspace(
+ min_bin,
+ max_bin,
+ num_bins - 1,
+ device=coords.device,
+ )
+ boundaries = boundaries**2
+ N, CA, C = [x.squeeze(-2) for x in coords.chunk(3, dim=-2)]
+ # Infer CB coordinates.
+ b = CA - N
+ c = C - CA
+ a = b.cross(c, dim=-1)
+ CB = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + CA
+ dists = (CB[..., None, :, :] - CB[..., :, None, :]).pow(2).sum(dim=-1, keepdims=True)
+ bins = torch.sum(dists > boundaries, dim=-1) # [..., L, L]
+ return bins
+
+
+# TODO Add information to the docstring about any methods that convert to PDB format, or otherwise prepare
+# the outputs for downstream use.
+
+
+@add_start_docstrings(
+ """
+ ESMForProteinFolding is the HuggingFace port of the original ESMFold model. It consists of an ESM-2 "stem" followed
+ by a protein folding "head", although unlike most other output heads, this "head" is similar in size and runtime to
+ the rest of the model combined! It outputs a dictionary containing predicted structural information about the input
+ protein(s).
+ """,
+ ESM_START_DOCSTRING,
+)
+class EsmForProteinFolding(EsmPreTrainedModel):
+ _no_split_modules = ["EsmFoldStructureModule", "EsmFoldTriangularSelfAttentionBlock"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.config = config
+
+ self.distogram_bins = 64
+
+ self.esm = EsmModel(config, add_pooling_layer=False)
+
+ self.esm.requires_grad_(False)
+ if self.config.esmfold_config.fp16_esm:
+ self.esm.half()
+
+ self.esm_feats = self.config.hidden_size
+ self.esm_attns = self.config.num_hidden_layers * self.config.num_attention_heads
+ self.esm_layers = self.config.num_hidden_layers
+ self.register_buffer("af2_to_esm", self._af2_to_esm_from_vocab_list(config.vocab_list))
+ self.esm_s_combine = nn.Parameter(torch.zeros(self.esm_layers + 1))
+
+ trunk_config = self.config.esmfold_config.trunk
+ c_s = trunk_config.sequence_state_dim
+ c_z = trunk_config.pairwise_state_dim
+ self.esm_s_mlp = nn.Sequential(
+ LayerNorm(self.esm_feats),
+ nn.Linear(self.esm_feats, c_s),
+ nn.ReLU(),
+ nn.Linear(c_s, c_s),
+ )
+
+ # 0 is padding, N is unknown residues, N + 1 is mask.
+ self.n_tokens_embed = residue_constants.restype_num + 3
+ self.pad_idx = 0
+ self.unk_idx = self.n_tokens_embed - 2
+ self.mask_idx = self.n_tokens_embed - 1
+ self.esm_dict_cls_idx = self.config.vocab_list.index("")
+ self.esm_dict_mask_idx = self.config.vocab_list.index("")
+ self.esm_dict_eos_idx = self.config.vocab_list.index("")
+ self.esm_dict_padding_idx = self.config.vocab_list.index("")
+ if self.config.esmfold_config.embed_aa:
+ self.embedding = nn.Embedding(self.n_tokens_embed, c_s, padding_idx=0)
+
+ self.trunk = EsmFoldingTrunk(trunk_config)
+
+ self.distogram_head = nn.Linear(c_z, self.distogram_bins)
+ self.ptm_head = nn.Linear(c_z, self.distogram_bins)
+ self.lm_head = nn.Linear(c_s, self.n_tokens_embed)
+ self.lddt_bins = 50
+ structure_module_config = trunk_config.structure_module
+ self.lddt_head = nn.Sequential(
+ nn.LayerNorm(structure_module_config.sequence_dim),
+ nn.Linear(structure_module_config.sequence_dim, self.config.esmfold_config.lddt_head_hid_dim),
+ nn.Linear(self.config.esmfold_config.lddt_head_hid_dim, self.config.esmfold_config.lddt_head_hid_dim),
+ nn.Linear(self.config.esmfold_config.lddt_head_hid_dim, 37 * self.lddt_bins),
+ )
+
+ @staticmethod
+ def _af2_to_esm_from_vocab_list(vocab_list: List[str]) -> torch.Tensor:
+ # Remember that t is shifted from residue_constants by 1 (0 is padding).
+ esm_reorder = [vocab_list.index("")] + [vocab_list.index(v) for v in residue_constants.restypes_with_x]
+ return torch.tensor(esm_reorder)
+
+ @add_start_docstrings_to_model_forward(ESMFOLD_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @replace_return_docstrings(output_type=EsmForProteinFoldingOutput, config_class=EsmConfig)
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ masking_pattern: Optional[torch.Tensor] = None,
+ num_recycles: Optional[int] = None,
+ ) -> EsmForProteinFoldingOutput:
+ r"""
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, EsmForProteinFolding
+
+ >>> model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1")
+ >>> tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1")
+ >>> inputs = tokenizer(["MLKNVQVQLV"], return_tensors="pt", add_special_tokens=False) # A tiny random peptide
+ >>> outputs = model(**inputs)
+ >>> folded_positions = outputs.positions
+ ```
+
+ """
+ cfg = self.config.esmfold_config
+
+ aa = input_ids # B x L
+ B = aa.shape[0]
+ L = aa.shape[1]
+ device = input_ids.device
+ if attention_mask is None:
+ attention_mask = torch.ones_like(aa, device=device)
+ if position_ids is None:
+ position_ids = torch.arange(L, device=device).expand_as(input_ids)
+
+ # === ESM ===
+ esmaa = self.af2_idx_to_esm_idx(aa, attention_mask)
+
+ if masking_pattern is not None:
+ masked_aa, esmaa, mlm_targets = self.bert_mask(aa, esmaa, attention_mask, masking_pattern)
+ else:
+ masked_aa = aa
+ mlm_targets = None
+
+ # We get sequence and pair representations from whatever version of ESM /
+ # configuration we are using. The sequence representation esm_s is always
+ # present. The pair embedding esm_z may be present depending on the
+ # configuration of the model. If esm_z is not used by the model then it
+ # is returned as None here.
+ esm_s = self.compute_language_model_representations(esmaa)
+
+ # Convert esm_s and esm_z, if present, to the precision used by the trunk and
+ # the structure module. These tensors may be a lower precision if, for example,
+ # we're running the language model in fp16 precision.
+ esm_s = esm_s.to(self.esm_s_combine.dtype)
+
+ if cfg.esm_ablate_sequence:
+ esm_s = esm_s * 0
+
+ esm_s = esm_s.detach()
+
+ # === preprocessing ===
+ esm_s = (self.esm_s_combine.softmax(0).unsqueeze(0) @ esm_s).squeeze(2)
+ s_s_0 = self.esm_s_mlp(esm_s)
+
+ s_z_0 = s_s_0.new_zeros(B, L, L, cfg.trunk.pairwise_state_dim)
+
+ if self.config.esmfold_config.embed_aa:
+ s_s_0 += self.embedding(masked_aa)
+
+ structure: dict = self.trunk(s_s_0, s_z_0, aa, position_ids, attention_mask, no_recycles=num_recycles)
+ # Documenting what we expect:
+ structure = {
+ k: v
+ for k, v in structure.items()
+ if k
+ in [
+ "s_z",
+ "s_s",
+ "frames",
+ "sidechain_frames",
+ "unnormalized_angles",
+ "angles",
+ "positions",
+ "states",
+ ]
+ }
+
+ # Add BERT mask for the loss to use, if available.
+ if mlm_targets:
+ structure["mlm_targets"] = mlm_targets
+
+ disto_logits = self.distogram_head(structure["s_z"])
+ disto_logits = (disto_logits + disto_logits.transpose(1, 2)) / 2
+ structure["distogram_logits"] = disto_logits
+
+ lm_logits = self.lm_head(structure["s_s"])
+ structure["lm_logits"] = lm_logits
+
+ structure["aatype"] = aa
+ make_atom14_masks(structure)
+ # Of course, this doesn't respect the true mask because it doesn't know about it...
+ # We're not going to properly mask change of index tensors:
+ # "residx_atom14_to_atom37",
+ # "residx_atom37_to_atom14",
+ for k in [
+ "atom14_atom_exists",
+ "atom37_atom_exists",
+ ]:
+ structure[k] *= attention_mask.unsqueeze(-1)
+ structure["residue_index"] = position_ids
+
+ lddt_head = self.lddt_head(structure["states"]).reshape(structure["states"].shape[0], B, L, -1, self.lddt_bins)
+ structure["lddt_head"] = lddt_head
+ plddt = categorical_lddt(lddt_head[-1], bins=self.lddt_bins)
+ structure["plddt"] = plddt
+
+ ptm_logits = self.ptm_head(structure["s_z"])
+ structure["ptm_logits"] = ptm_logits
+ structure["ptm"] = compute_tm(ptm_logits, max_bin=31, no_bins=self.distogram_bins)
+ structure.update(compute_predicted_aligned_error(ptm_logits, max_bin=31, no_bins=self.distogram_bins))
+
+ return EsmForProteinFoldingOutput(**structure)
+
+ def af2_idx_to_esm_idx(self, aa, mask):
+ # avoid indexing on different devices
+ if self.af2_to_esm.device != aa.device:
+ self.af2_to_esm = self.af2_to_esm.to(aa.device)
+ aa = (aa + 1).masked_fill(mask != 1, 0)
+ return self.af2_to_esm[aa]
+
+ def compute_language_model_representations(self, esmaa: torch.Tensor) -> torch.Tensor:
+ device = next(self.parameters()).device
+ B, L = esmaa.shape # B = batch size, L = sequence length.
+
+ if self.config.esmfold_config.bypass_lm:
+ esm_s = torch.zeros(B, L, self.esm_s_combine.size[0], -1, self.esm_feats, device=device)
+ return esm_s
+
+ bosi, eosi = self.esm_dict_cls_idx, self.esm_dict_eos_idx
+ bos = esmaa.new_full((B, 1), bosi)
+ eos = esmaa.new_full((B, 1), self.esm_dict_padding_idx)
+ esmaa = torch.cat([bos, esmaa, eos], dim=1)
+ # Use the first padding index as eos during inference.
+ esmaa[range(B), (esmaa != 1).sum(1)] = eosi
+
+ # _, esm_z, esm_s = self.esm(esmaa, return_pairs=self.config.esmfold_config.use_esm_attn_map)
+ # Because we do not support use_esm_attn_map in the HF port as it is not used in any public models,
+ # esm_z is always None
+ esm_hidden_states = self.esm(esmaa, attention_mask=esmaa != 1, output_hidden_states=True)["hidden_states"]
+ esm_s = torch.stack(esm_hidden_states, dim=2)
+
+ esm_s = esm_s[:, 1:-1] # B, L, nLayers, C
+
+ return esm_s
+
+ def bert_mask(self, aa, esmaa, mask, pattern):
+ new_aa = aa.clone()
+ target = aa.clone()
+ new_esmaa = esmaa.clone()
+ new_aa[pattern == 1] = self.mask_idx
+ target[pattern != 1] = 0
+ new_esmaa[pattern == 1] = self.esm_dict_mask_idx
+ return new_aa, new_esmaa, target
+
+ @torch.no_grad()
+ def infer(
+ self,
+ seqs: Union[str, List[str]],
+ position_ids=None,
+ ):
+ if isinstance(seqs, str):
+ lst = [seqs]
+ else:
+ lst = seqs
+ # Returns the raw outputs of the model given an input sequence.
+ device = next(self.parameters()).device
+ aatype = collate_dense_tensors(
+ [
+ torch.from_numpy(
+ residue_constants.sequence_to_onehot(
+ sequence=seq,
+ mapping=residue_constants.restype_order_with_x,
+ map_unknown_to_x=True,
+ )
+ )
+ .to(device)
+ .argmax(dim=1)
+ for seq in lst
+ ]
+ ) # B=1 x L
+ mask = collate_dense_tensors([aatype.new_ones(len(seq)) for seq in lst])
+ position_ids = (
+ torch.arange(aatype.shape[1], device=device).expand(len(lst), -1)
+ if position_ids is None
+ else position_ids.to(device)
+ )
+ if position_ids.ndim == 1:
+ position_ids = position_ids.unsqueeze(0)
+ return self.forward(
+ aatype,
+ mask,
+ position_ids=position_ids,
+ )
+
+ @staticmethod
+ def output_to_pdb(output: Dict) -> List[str]:
+ """Returns the pbd (file) string from the model given the model output."""
+ output = {k: v.to("cpu").numpy() for k, v in output.items()}
+ pdbs = []
+ final_atom_positions = atom14_to_atom37(output["positions"][-1], output)
+ final_atom_mask = output["atom37_atom_exists"]
+ for i in range(output["aatype"].shape[0]):
+ aa = output["aatype"][i]
+ pred_pos = final_atom_positions[i]
+ mask = final_atom_mask[i]
+ resid = output["residue_index"][i] + 1
+ pred = OFProtein(
+ aatype=aa,
+ atom_positions=pred_pos,
+ atom_mask=mask,
+ residue_index=resid,
+ b_factors=output["plddt"][i],
+ )
+ pdbs.append(to_pdb(pred))
+ return pdbs
+
+ def infer_pdb(self, seqs, *args, **kwargs) -> str:
+ """Returns the pdb (file) string from the model given an input sequence."""
+ assert isinstance(seqs, str)
+ output = self.infer(seqs, *args, **kwargs)
+ return self.output_to_pdb(output)[0]
+
+ def infer_pdbs(self, seqs: List[str], *args, **kwargs) -> List[str]:
+ """Returns the pdb (file) string from the model given an input sequence."""
+ output = self.infer(seqs, *args, **kwargs)
+ return self.output_to_pdb(output)
+
+
+__all__ = ["EsmForProteinFolding", "EsmFoldPreTrainedModel"]
diff --git a/docs/transformers/build/lib/transformers/models/esm/modeling_tf_esm.py b/docs/transformers/build/lib/transformers/models/esm/modeling_tf_esm.py
new file mode 100644
index 0000000000000000000000000000000000000000..71698486dab0adfece1feb14ce91b4500e24382c
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/esm/modeling_tf_esm.py
@@ -0,0 +1,1575 @@
+# coding=utf-8
+# Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch ESM model."""
+
+from __future__ import annotations
+
+import os
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import tensorflow as tf
+
+from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
+from ...modeling_tf_outputs import (
+ TFBaseModelOutputWithPastAndCrossAttentions,
+ TFBaseModelOutputWithPoolingAndCrossAttentions,
+ TFMaskedLMOutput,
+ TFSequenceClassifierOutput,
+ TFTokenClassifierOutput,
+)
+from ...modeling_tf_utils import (
+ TFMaskedLanguageModelingLoss,
+ TFModelInputType,
+ TFPreTrainedModel,
+ TFSequenceClassificationLoss,
+ TFTokenClassificationLoss,
+ get_initializer,
+ keras,
+ shape_list,
+ unpack_inputs,
+)
+from ...tf_utils import check_embeddings_within_bounds, stable_softmax
+from ...utils import logging
+from .configuration_esm import EsmConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "facebook/esm2_t6_8M_UR50D"
+_CONFIG_FOR_DOC = "EsmConfig"
+
+
+def rotate_half(x):
+ x1, x2 = tf.split(x, 2, axis=-1)
+ return tf.concat((-x2, x1), axis=-1)
+
+
+def apply_rotary_pos_emb(x, cos, sin):
+ cos = cos[:, :, : tf.shape(x)[-2], :]
+ sin = sin[:, :, : tf.shape(x)[-2], :]
+
+ return (x * cos) + (rotate_half(x) * sin)
+
+
+def symmetrize(x):
+ "Make layer symmetric in final two dimensions, used for contact prediction."
+ return x + tf.linalg.matrix_transpose(x) # Transposes last two dimensions only
+
+
+def average_product_correct(x):
+ "Perform average product correct, used for contact prediction."
+ a1 = tf.reduce_sum(x, -1, keepdims=True)
+ a2 = tf.reduce_sum(x, -2, keepdims=True)
+ a12 = tf.reduce_sum(x, (-1, -2), keepdims=True)
+
+ avg = a1 * a2
+ avg = avg / a12
+ normalized = x - avg
+ return normalized
+
+
+class TFRotaryEmbedding(keras.layers.Layer):
+ """
+ Rotary position embeddings based on those in
+ [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation
+ matrices which depend on their relative positions.
+ """
+
+ def __init__(self, dim: int, name=None):
+ super().__init__(name=name)
+ # Matt: The PyTorch version of this layer does a lot of work to cache values, but we just rely on TF compilation
+ # and/or XLA to sort out constants like that. It actually may not seem like this layer needs to be stateful at
+ # all when we benefit from TF compilation, but it does. The reason is that self.inv_freq is a buffer in the
+ # original implementation, but all the shared ESM checkpoints were trained with fp16 params. This means that
+ # the inv_freq tensor was stored as a float16, and we need to replicate those lower-precision values or our
+ # models give different outputs from the original.
+ self.dim = dim
+
+ def build(self, input_shape):
+ super().build(input_shape)
+ self.inv_freq = self.add_weight(
+ "inv_freq", shape=(self.dim // 2,), dtype=tf.float32, initializer=get_initializer(1.0), trainable=False
+ )
+ self.inv_freq.assign(
+ 1.0 / (10000 ** (tf.range(start=0, limit=self.dim, delta=2, dtype=tf.float32) / self.dim))
+ )
+
+ def _compute_cos_sin(self, x, seq_dimension=2):
+ seq_len = tf.shape(x)[seq_dimension]
+
+ t = tf.range(seq_len, dtype=self.inv_freq.dtype)
+ freqs = tf.einsum("i, j -> ij", t, self.inv_freq) # Outer multiplication
+ emb = tf.concat((freqs, freqs), axis=-1)[None, None, :, :]
+
+ return tf.cos(emb), tf.sin(emb)
+
+ def call(self, q: tf.Tensor, k: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
+ cos_emb, sin_emb = self._compute_cos_sin(k, seq_dimension=-2)
+
+ return (
+ apply_rotary_pos_emb(q, cos_emb, sin_emb),
+ apply_rotary_pos_emb(k, cos_emb, sin_emb),
+ )
+
+
+class TFEsmContactPredictionHead(keras.layers.Layer):
+ """Performs symmetrization, apc, and computes a logistic regression on the output features"""
+
+ def __init__(
+ self,
+ in_features: int,
+ bias=True,
+ eos_idx: int = 2,
+ name=None,
+ ):
+ super().__init__(name=name)
+ self.eos_idx = eos_idx
+ self.in_features = in_features
+ self.regression = keras.layers.Dense(1, use_bias=bias, activation="sigmoid", name="regression")
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "regression", None) is not None:
+ with tf.name_scope(self.regression.name):
+ self.regression.build((None, self.in_features))
+
+ def call(self, tokens, attentions):
+ # remove eos token attentions
+ eos_mask = tf.cast(tokens != self.eos_idx, attentions.dtype)
+ eos_mask = tf.expand_dims(eos_mask, 1) * tf.expand_dims(eos_mask, 2)
+ attentions = attentions * eos_mask[:, None, None, :, :]
+ attentions = attentions[..., :-1, :-1]
+ # remove cls token attentions
+ attentions = attentions[..., 1:, 1:]
+ batch_size, layers, heads, seqlen, _ = shape_list(attentions)
+ attentions = tf.reshape(attentions, (batch_size, layers * heads, seqlen, seqlen))
+
+ # features: batch x channels x tokens x tokens (symmetric)
+ attentions = average_product_correct(symmetrize(attentions))
+ attentions = tf.transpose(attentions, perm=(0, 2, 3, 1))
+ return tf.squeeze(self.regression(attentions), 3)
+
+
+class TFEsmEmbeddings(keras.layers.Layer):
+ """
+ Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
+ """
+
+ def __init__(self, config, name=None):
+ super().__init__(name=name)
+ self.word_embeddings = keras.layers.Embedding(
+ config.vocab_size,
+ config.hidden_size,
+ embeddings_initializer=get_initializer(config.initializer_range),
+ name="word_embeddings",
+ )
+ self.position_embeddings = keras.layers.Embedding(
+ config.max_position_embeddings,
+ config.hidden_size,
+ embeddings_initializer=get_initializer(config.initializer_range),
+ name="position_embeddings",
+ )
+
+ if config.emb_layer_norm_before:
+ self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
+ else:
+ self.layer_norm = None
+ # Matt: I think this line was copied incorrectly from BERT, disabling for now
+ # self.dropout = Dropout(config.hidden_dropout_prob)
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
+
+ self.position_ids = tf.range(config.max_position_embeddings)[None, :]
+
+ self.padding_idx = config.pad_token_id
+ self.token_dropout = config.token_dropout
+ self.mask_token_id = config.mask_token_id
+ self.config = config
+
+ def call(
+ self, input_ids=None, attention_mask=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
+ ):
+ if position_ids is None:
+ if input_ids is not None:
+ # Create the position ids from the input token ids. Any padded tokens remain padded.
+ position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length)
+ else:
+ position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
+
+ if inputs_embeds is None:
+ check_embeddings_within_bounds(input_ids, self.config.vocab_size)
+ inputs_embeds = self.word_embeddings(input_ids)
+
+ # Note that if we want to support ESM-1 (not 1b!) in future then we need to support an
+ # embedding_scale factor here.
+ embeddings = inputs_embeds
+
+ # Matt: ESM has the option to handle masking in MLM in a slightly unusual way. If the token_dropout
+ # flag is False then it is handled in the same was as BERT/RoBERTa. If it is set to True, however,
+ # masked tokens are treated as if they were selected for input dropout and zeroed out.
+ # This "mask-dropout" is compensated for when masked tokens are not present, by scaling embeddings by
+ # a factor of (fraction of unmasked tokens during training) / (fraction of unmasked tokens in sample).
+ # This is analogous to the way that dropout layers scale down outputs during evaluation when not
+ # actually dropping out values (or, equivalently, scale up their un-dropped outputs in training).
+ if self.token_dropout:
+ embeddings = tf.where((input_ids == self.mask_token_id)[:, :, None], 0.0, embeddings)
+ mask_ratio_train = 0.15 * 0.8 # Hardcoded as the ratio used in all ESM model training runs
+ src_lengths = tf.cast(tf.reduce_sum(attention_mask, axis=-1), tf.float32)
+ masked_tokens = input_ids == self.mask_token_id
+ mask_ratio_observed = tf.math.count_nonzero(masked_tokens, dtype=tf.float32, axis=-1) / src_lengths
+ embeddings = embeddings * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]
+
+ if self.position_embedding_type == "absolute":
+ position_embeddings = self.position_embeddings(position_ids)
+ embeddings += position_embeddings
+
+ if self.layer_norm is not None:
+ embeddings = self.layer_norm(embeddings)
+ if attention_mask is not None:
+ embeddings = embeddings * tf.cast(tf.expand_dims(attention_mask, -1), embeddings.dtype)
+ # Matt: I think this line was copied incorrectly from BERT, disabling it for now.
+ # embeddings = self.dropout(embeddings)
+ return embeddings
+
+ def create_position_ids_from_inputs_embeds(self, inputs_embeds):
+ """
+ We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
+
+ Args:
+ inputs_embeds: tf.Tensor
+
+ Returns: tf.Tensor
+ """
+ input_shape = shape_list(inputs_embeds)[:-1]
+ sequence_length = input_shape[1]
+
+ position_ids = tf.range(
+ start=self.padding_idx + 1, limit=sequence_length + self.padding_idx + 1, dtype=tf.int64
+ )
+ return tf.broadcast_to(tf.expand_dims(position_ids, 0), input_shape)
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "word_embeddings", None) is not None:
+ with tf.name_scope(self.word_embeddings.name):
+ self.word_embeddings.build(None)
+ if getattr(self, "position_embeddings", None) is not None:
+ with tf.name_scope(self.position_embeddings.name):
+ self.position_embeddings.build(None)
+ if getattr(self, "layer_norm", None) is not None:
+ with tf.name_scope(self.layer_norm.name):
+ self.layer_norm.build([None, None, self.config.hidden_size])
+
+
+class TFEsmSelfAttention(keras.layers.Layer):
+ def __init__(self, config, position_embedding_type=None, name=None):
+ super().__init__(name=name)
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+ raise ValueError(
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
+ f"heads ({config.num_attention_heads})"
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.query = keras.layers.Dense(
+ self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
+ )
+ self.key = keras.layers.Dense(
+ self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
+ )
+ self.value = keras.layers.Dense(
+ self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
+ )
+
+ self.dropout = keras.layers.Dropout(config.attention_probs_dropout_prob)
+ self.position_embedding_type = position_embedding_type or getattr(
+ config, "position_embedding_type", "absolute"
+ )
+ self.rotary_embeddings = None
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+ self.max_position_embeddings = config.max_position_embeddings
+ self.distance_embedding = keras.layers.Embedding(
+ 2 * config.max_position_embeddings - 1,
+ self.attention_head_size,
+ embeddings_initializer=get_initializer(config.initializer_range),
+ )
+ elif self.position_embedding_type == "rotary":
+ self.rotary_embeddings = TFRotaryEmbedding(dim=self.attention_head_size, name="rotary_embeddings")
+
+ self.is_decoder = config.is_decoder
+ self.config = config
+
+ def transpose_for_scores(self, x: tf.Tensor) -> tf.Tensor:
+ new_x_shape = shape_list(x)[:-1] + [self.num_attention_heads, self.attention_head_size]
+ x = tf.reshape(x, new_x_shape)
+ return tf.transpose(x, perm=(0, 2, 1, 3))
+
+ def call(
+ self,
+ hidden_states: tf.Tensor,
+ attention_mask: tf.Tensor | None = None,
+ head_mask: tf.Tensor | None = None,
+ encoder_hidden_states: tf.Tensor | None = None,
+ encoder_attention_mask: tf.Tensor | None = None,
+ past_key_value: Tuple[Tuple[tf.Tensor]] | None = None,
+ output_attentions: Optional[bool] = False,
+ training: bool = False,
+ ) -> Tuple[tf.Tensor]:
+ mixed_query_layer = self.query(hidden_states)
+
+ # If this is instantiated as a cross-attention module, the keys
+ # and values come from an encoder; the attention mask needs to be
+ # such that the encoder's padding tokens are not attended to.
+ is_cross_attention = encoder_hidden_states is not None
+
+ if is_cross_attention and past_key_value is not None:
+ # reuse k,v, cross_attentions
+ key_layer = past_key_value[0]
+ value_layer = past_key_value[1]
+ attention_mask = encoder_attention_mask
+ elif is_cross_attention:
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
+ attention_mask = encoder_attention_mask
+ elif past_key_value is not None:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+ key_layer = tf.concat([past_key_value[0], key_layer], axis=2)
+ value_layer = tf.concat([past_key_value[1], value_layer], axis=2)
+ else:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+
+ # Matt: Our BERT model (which this code was derived from) scales attention logits down by sqrt(head_dim).
+ # ESM scales the query down by the same factor instead. Modulo numerical stability these are equivalent,
+ # but not when rotary embeddings get involved. Therefore, we scale the query here to match the original
+ # ESM code and fix rotary embeddings.
+ query_layer = query_layer * self.attention_head_size**-0.5
+
+ if self.is_decoder:
+ # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_layer, value_layer)
+
+ if self.position_embedding_type == "rotary":
+ query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
+
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+ seq_length = shape_list(hidden_states)[1]
+ position_ids_l = tf.expand_dims(tf.range(seq_length, dtype=tf.int64), -1)
+ position_ids_r = tf.expand_dims(tf.range(seq_length, dtype=tf.int64), 0)
+ distance = position_ids_l - position_ids_r
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
+ positional_embedding = tf.cast(positional_embedding, query_layer.dtype) # fp16 compatibility
+
+ if self.position_embedding_type == "relative_key":
+ relative_position_scores = tf.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+ attention_scores = attention_scores + relative_position_scores
+ elif self.position_embedding_type == "relative_key_query":
+ relative_position_scores_query = tf.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+ relative_position_scores_key = tf.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
+
+ if attention_mask is not None:
+ # Apply the attention mask is (precomputed for all layers in EsmModel forward() function)
+ attention_scores = attention_scores + attention_mask
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = stable_softmax(attention_scores, axis=-1)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs, training=training)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs = attention_probs * head_mask
+
+ context_layer = attention_probs @ value_layer
+
+ context_layer = tf.transpose(context_layer, perm=(0, 2, 1, 3))
+ new_context_layer_shape = shape_list(context_layer)[:-2] + [self.all_head_size]
+ context_layer = tf.reshape(context_layer, new_context_layer_shape)
+
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+ if self.is_decoder:
+ outputs = outputs + (past_key_value,)
+ return outputs
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "query", None) is not None:
+ with tf.name_scope(self.query.name):
+ self.query.build([None, None, self.config.hidden_size])
+ if getattr(self, "key", None) is not None:
+ with tf.name_scope(self.key.name):
+ self.key.build([None, None, self.config.hidden_size])
+ if getattr(self, "value", None) is not None:
+ with tf.name_scope(self.value.name):
+ self.value.build([None, None, self.config.hidden_size])
+ if getattr(self, "rotary_embeddings", None) is not None:
+ with tf.name_scope(self.rotary_embeddings.name):
+ self.rotary_embeddings.build(None)
+
+
+class TFEsmSelfOutput(keras.layers.Layer):
+ def __init__(self, config, name=None):
+ super().__init__(name=name)
+ self.dense = keras.layers.Dense(
+ config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
+ )
+ self.dropout = keras.layers.Dropout(config.hidden_dropout_prob)
+ self.config = config
+
+ def call(self, hidden_states, input_tensor, training=False):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states, training=training)
+ hidden_states += input_tensor
+ return hidden_states
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "dense", None) is not None:
+ with tf.name_scope(self.dense.name):
+ self.dense.build([None, None, self.config.hidden_size])
+
+
+class TFEsmAttention(keras.layers.Layer):
+ def __init__(self, config, name=None):
+ super().__init__(name=name)
+ self.self = TFEsmSelfAttention(config, name="self")
+ self.output_layer = TFEsmSelfOutput(config, name="output")
+ self.pruned_heads = set()
+ self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
+ self.config = config
+
+ def prune_heads(self, heads):
+ raise NotImplementedError
+
+ def call(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=None,
+ output_attentions=False,
+ training=False,
+ ):
+ hidden_states_ln = self.LayerNorm(hidden_states)
+ self_outputs = self.self(
+ hidden_states_ln,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ training,
+ )
+ attention_output = self.output_layer(self_outputs[0], hidden_states)
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+ return outputs
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "self", None) is not None:
+ with tf.name_scope(self.self.name):
+ self.self.build(None)
+ if getattr(self, "output_layer", None) is not None:
+ with tf.name_scope(self.output_layer.name):
+ self.output_layer.build(None)
+ if getattr(self, "LayerNorm", None) is not None:
+ with tf.name_scope(self.LayerNorm.name):
+ self.LayerNorm.build([None, None, self.config.hidden_size])
+
+
+class TFEsmIntermediate(keras.layers.Layer):
+ def __init__(self, config: EsmConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.dense = keras.layers.Dense(
+ units=config.intermediate_size,
+ kernel_initializer=get_initializer(config.initializer_range),
+ name="dense",
+ )
+ self.config = config
+
+ def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
+ hidden_states = self.dense(inputs=hidden_states)
+ hidden_states = tf.nn.gelu(hidden_states)
+ return hidden_states
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "dense", None) is not None:
+ with tf.name_scope(self.dense.name):
+ self.dense.build([None, None, self.config.hidden_size])
+
+
+class TFEsmOutput(keras.layers.Layer):
+ def __init__(self, config, name=None):
+ super().__init__(name=name)
+ self.dense = keras.layers.Dense(
+ config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
+ )
+ self.dropout = keras.layers.Dropout(config.hidden_dropout_prob)
+ self.config = config
+
+ def call(self, hidden_states, input_tensor, training=False):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states, training=training)
+ hidden_states += input_tensor
+ return hidden_states
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "dense", None) is not None:
+ with tf.name_scope(self.dense.name):
+ self.dense.build([None, None, self.config.intermediate_size])
+
+
+class TFEsmLayer(keras.layers.Layer):
+ def __init__(self, config, name=None):
+ super().__init__(name=name)
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.seq_len_dim = 1
+ self.attention = TFEsmAttention(config, name="attention")
+ self.is_decoder = config.is_decoder
+ self.add_cross_attention = config.add_cross_attention
+ if self.add_cross_attention:
+ if not self.is_decoder:
+ raise RuntimeError(f"{self} should be used as a decoder model if cross attention is added")
+ self.crossattention = TFEsmAttention(config)
+ self.intermediate = TFEsmIntermediate(config, name="intermediate")
+ self.output_layer = TFEsmOutput(config, name="output")
+ self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
+ self.config = config
+
+ def call(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=None,
+ output_attentions=False,
+ training=False,
+ ):
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
+ self_attention_outputs = self.attention(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ output_attentions=output_attentions,
+ past_key_value=self_attn_past_key_value,
+ training=training,
+ )
+ attention_output = self_attention_outputs[0]
+
+ # if decoder, the last output is tuple of self-attn cache
+ if self.is_decoder:
+ outputs = self_attention_outputs[1:-1]
+ present_key_value = self_attention_outputs[-1]
+ else:
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
+
+ cross_attn_present_key_value = None
+ if self.is_decoder and encoder_hidden_states is not None:
+ if not hasattr(self, "crossattention"):
+ raise AttributeError(
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated"
+ " with cross-attention layers by setting `config.add_cross_attention=True`"
+ )
+
+ # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
+ cross_attention_outputs = self.crossattention(
+ attention_output,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ cross_attn_past_key_value,
+ output_attentions,
+ training=training,
+ )
+ attention_output = cross_attention_outputs[0]
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
+
+ # add cross-attn cache to positions 3,4 of present_key_value tuple
+ cross_attn_present_key_value = cross_attention_outputs[-1]
+ present_key_value = present_key_value + cross_attn_present_key_value
+
+ layernorm_output = self.LayerNorm(attention_output)
+ intermediate_output = self.intermediate(hidden_states=layernorm_output)
+ layer_output = self.output_layer(
+ hidden_states=intermediate_output, input_tensor=attention_output, training=training
+ )
+ outputs = (layer_output,) + outputs # add attentions if we output them
+
+ # if decoder, return the attn key/values as the last output
+ if self.is_decoder:
+ outputs = outputs + (present_key_value,)
+
+ return outputs
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "attention", None) is not None:
+ with tf.name_scope(self.attention.name):
+ self.attention.build(None)
+ if getattr(self, "intermediate", None) is not None:
+ with tf.name_scope(self.intermediate.name):
+ self.intermediate.build(None)
+ if getattr(self, "output_layer", None) is not None:
+ with tf.name_scope(self.output_layer.name):
+ self.output_layer.build(None)
+ if getattr(self, "LayerNorm", None) is not None:
+ with tf.name_scope(self.LayerNorm.name):
+ self.LayerNorm.build([None, None, self.config.hidden_size])
+
+
+class TFEsmEncoder(keras.layers.Layer):
+ def __init__(self, config, name=None):
+ super().__init__(name=name)
+ self.config = config
+ self.layer = [TFEsmLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)]
+ self.emb_layer_norm_after = keras.layers.LayerNormalization(
+ epsilon=config.layer_norm_eps, name="emb_layer_norm_after"
+ )
+
+ def call(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_values=None,
+ use_cache=None,
+ output_attentions=False,
+ output_hidden_states=False,
+ return_dict=True,
+ training=False,
+ ):
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
+
+ next_decoder_cache = () if use_cache else None
+ for i, layer_module in enumerate(self.layer):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+ past_key_value = past_key_values[i] if past_key_values is not None else None
+
+ layer_outputs = layer_module(
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ training,
+ )
+
+ hidden_states = layer_outputs[0]
+ if use_cache:
+ next_decoder_cache += (layer_outputs[-1],)
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+ if self.config.add_cross_attention:
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
+
+ if self.emb_layer_norm_after:
+ hidden_states = self.emb_layer_norm_after(hidden_states)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [
+ hidden_states,
+ next_decoder_cache,
+ all_hidden_states,
+ all_self_attentions,
+ all_cross_attentions,
+ ]
+ if v is not None
+ )
+ return TFBaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=next_decoder_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ cross_attentions=all_cross_attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "emb_layer_norm_after", None) is not None:
+ with tf.name_scope(self.emb_layer_norm_after.name):
+ self.emb_layer_norm_after.build([None, None, self.config.hidden_size])
+ if getattr(self, "layer", None) is not None:
+ for layer in self.layer:
+ with tf.name_scope(layer.name):
+ layer.build(None)
+
+
+# Copied from transformers.models.bert.modeling_tf_bert.TFBertPooler with Bert->Esm
+class TFEsmPooler(keras.layers.Layer):
+ def __init__(self, config: EsmConfig, **kwargs):
+ super().__init__(**kwargs)
+
+ self.dense = keras.layers.Dense(
+ units=config.hidden_size,
+ kernel_initializer=get_initializer(config.initializer_range),
+ activation="tanh",
+ name="dense",
+ )
+ self.config = config
+
+ def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
+ # We "pool" the model by simply taking the hidden state corresponding
+ # to the first token.
+ first_token_tensor = hidden_states[:, 0]
+ pooled_output = self.dense(inputs=first_token_tensor)
+
+ return pooled_output
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "dense", None) is not None:
+ with tf.name_scope(self.dense.name):
+ self.dense.build([None, None, self.config.hidden_size])
+
+
+class TFEsmPreTrainedModel(TFPreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = EsmConfig
+ base_model_prefix = "esm"
+
+
+ESM_START_DOCSTRING = r"""
+
+ This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a Keras [Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it as a
+ regular Keras model and refer to the TF/Keras documentation for all matters related to general usage and behavior.
+
+ Parameters:
+ config ([`EsmConfig`]): Model configuration class with all the parameters of the
+ model. Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+ESM_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`tf.Tensor` of shape `({0})`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`tf.Tensor` of shape `({0})`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ position_ids (`tf.Tensor` of shape `({0})`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ head_mask (`tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ inputs_embeds (`tf.Tensor` of shape `({0}, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare ESM Model transformer outputting raw hidden-states without any specific head on top.",
+ ESM_START_DOCSTRING,
+)
+class TFEsmMainLayer(keras.layers.Layer):
+ """
+
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
+ cross-attention is added between the self-attention layers, following the architecture described in [Attention is
+ all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
+
+ To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
+ to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
+ `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
+ """
+
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
+
+ def __init__(self, config, add_pooling_layer=True, name=None, **kwargs):
+ super().__init__(name=name, **kwargs)
+
+ self.config = config
+ self.is_decoder = config.is_decoder
+
+ self.embeddings = TFEsmEmbeddings(config, name="embeddings")
+ self.encoder = TFEsmEncoder(config, name="encoder")
+ self.pooler = TFEsmPooler(config, name="pooler") if add_pooling_layer else None
+
+ self.contact_head = TFEsmContactPredictionHead(
+ in_features=self.config.num_hidden_layers * self.config.num_attention_heads, bias=True, name="contact_head"
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "embeddings", None) is not None:
+ with tf.name_scope(self.embeddings.name):
+ self.embeddings.build(None)
+ if getattr(self, "encoder", None) is not None:
+ with tf.name_scope(self.encoder.name):
+ self.encoder.build(None)
+ if getattr(self, "pooler", None) is not None:
+ with tf.name_scope(self.pooler.name):
+ self.pooler.build(None)
+ if getattr(self, "contact_head", None) is not None:
+ with tf.name_scope(self.contact_head.name):
+ self.contact_head.build(None)
+
+ def get_input_embeddings(self):
+ return self.embeddings.word_embeddings
+
+ def set_input_embeddings(self, value: tf.Variable):
+ self.embeddings.word_embeddings.weight = value
+ self.embeddings.vocab_size = shape_list(value)[0]
+
+ def _prune_heads(self, heads_to_prune):
+ raise NotImplementedError
+
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ position_ids: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ encoder_hidden_states: np.ndarray | tf.Tensor | None = None,
+ encoder_attention_mask: np.ndarray | tf.Tensor | None = None,
+ past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ training: bool = False,
+ ) -> Union[TFBaseModelOutputWithPoolingAndCrossAttentions, Tuple[tf.Tensor]]:
+ if not self.config.is_decoder:
+ use_cache = False
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = shape_list(input_ids)
+ elif inputs_embeds is not None:
+ input_shape = shape_list(inputs_embeds)[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ batch_size, seq_length = input_shape
+
+ if past_key_values is None:
+ past_key_values_length = 0
+ past_key_values = [None] * len(self.encoder.layer)
+ else:
+ past_key_values_length = shape_list(past_key_values[0][0])[-2]
+
+ if attention_mask is None:
+ attention_mask = tf.fill(dims=(batch_size, seq_length + past_key_values_length), value=1)
+
+ embedding_output = self.embeddings(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ inputs_embeds=inputs_embeds,
+ past_key_values_length=past_key_values_length,
+ training=training,
+ )
+
+ # We create a 3D attention mask from a 2D tensor mask.
+ # Sizes are [batch_size, 1, 1, to_seq_length]
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
+ # this attention mask is more simple than the triangular masking of causal attention
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
+ attention_mask_shape = shape_list(attention_mask)
+
+ mask_seq_length = seq_length + past_key_values_length
+ # Copied from `modeling_tf_t5.py`
+ # Provided a padding mask of dimensions [batch_size, mask_seq_length]
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
+ if self.is_decoder:
+ seq_ids = tf.range(mask_seq_length)
+ causal_mask = tf.less_equal(
+ tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)),
+ seq_ids[None, :, None],
+ )
+ causal_mask = tf.cast(causal_mask, dtype=attention_mask.dtype)
+ extended_attention_mask = causal_mask * attention_mask[:, None, :]
+ attention_mask_shape = shape_list(extended_attention_mask)
+ extended_attention_mask = tf.reshape(
+ extended_attention_mask, (attention_mask_shape[0], 1, attention_mask_shape[1], attention_mask_shape[2])
+ )
+ if past_key_values[0] is not None:
+ # attention_mask needs to be sliced to the shape `[batch_size, 1, from_seq_length - cached_seq_length, to_seq_length]
+ extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :]
+ else:
+ extended_attention_mask = tf.reshape(
+ attention_mask, (attention_mask_shape[0], 1, 1, attention_mask_shape[1])
+ )
+
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
+ # masked positions, this operation will create a tensor which is 0.0 for
+ # positions we want to attend and -10000.0 for masked positions.
+ # Since we are adding it to the raw scores before the softmax, this is
+ # effectively the same as removing these entirely.
+ extended_attention_mask = tf.cast(extended_attention_mask, dtype=embedding_output.dtype)
+ one_cst = tf.constant(1.0, dtype=embedding_output.dtype)
+ ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype)
+ extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst)
+
+ # Copied from `modeling_tf_t5.py` with -1e9 -> -10000
+ if self.is_decoder and encoder_attention_mask is not None:
+ # If a 2D ou 3D attention mask is provided for the cross-attention
+ # we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ encoder_attention_mask = tf.cast(encoder_attention_mask, dtype=extended_attention_mask.dtype)
+ num_dims_encoder_attention_mask = len(shape_list(encoder_attention_mask))
+ if num_dims_encoder_attention_mask == 3:
+ encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
+ if num_dims_encoder_attention_mask == 2:
+ encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
+
+ # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
+ # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270
+ # encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask,
+ # tf.transpose(encoder_extended_attention_mask, perm=(-1, -2)))
+
+ encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0
+ else:
+ encoder_extended_attention_mask = None
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ if head_mask is not None:
+ raise NotImplementedError
+ else:
+ head_mask = [None] * self.config.num_hidden_layers
+
+ encoder_outputs = self.encoder(
+ hidden_states=embedding_output,
+ attention_mask=extended_attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_extended_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ sequence_output = encoder_outputs[0]
+ pooled_output = self.pooler(hidden_states=sequence_output) if self.pooler is not None else None
+
+ if not return_dict:
+ return (
+ sequence_output,
+ pooled_output,
+ ) + encoder_outputs[1:]
+
+ return TFBaseModelOutputWithPoolingAndCrossAttentions(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ past_key_values=encoder_outputs.past_key_values,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ cross_attentions=encoder_outputs.cross_attentions,
+ )
+
+ def predict_contacts(self, tokens, attention_mask):
+ attns = self(tokens, attention_mask=attention_mask, return_dict=True, output_attentions=True).attentions
+ attns = tf.stack(attns, axis=1) # Matches the original model layout
+ # In the original model, attentions for padding tokens are completely zeroed out.
+ # This makes no difference most of the time because the other tokens won't attend to them,
+ # but it does for the contact prediction task, which takes attentions as input,
+ # so we have to mimic that here.
+ attention_mask = tf.cast(attention_mask, attns.dtype)
+ attns *= attention_mask[:, None, None, None]
+ attns *= attention_mask[:, None, None, :, None]
+ return self.contact_head(tokens, attns)
+
+
+@add_start_docstrings(
+ "The bare ESM Model transformer outputting raw hidden-states without any specific head on top.",
+ ESM_START_DOCSTRING,
+)
+class TFEsmModel(TFEsmPreTrainedModel):
+ def __init__(self, config: EsmConfig, add_pooling_layer=True, *inputs, **kwargs):
+ super().__init__(config, *inputs, **kwargs)
+
+ self.esm = TFEsmMainLayer(config, add_pooling_layer=add_pooling_layer, name="esm")
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(ESM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TFBaseModelOutputWithPoolingAndCrossAttentions,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ position_ids: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ encoder_hidden_states: np.ndarray | tf.Tensor | None = None,
+ encoder_attention_mask: np.ndarray | tf.Tensor | None = None,
+ past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ training: Optional[bool] = False,
+ ) -> Union[TFBaseModelOutputWithPoolingAndCrossAttentions, Tuple[tf.Tensor]]:
+ r"""
+ encoder_hidden_states (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+ the model is configured as a decoder.
+ encoder_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ past_key_values (`Tuple[Tuple[tf.Tensor]]` of length `config.n_layers`)
+ contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`). Set to `False` during training, `True` during generation
+ """
+ outputs = self.esm(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ return outputs
+
+ def predict_contacts(self, tokens, attention_mask):
+ return self.esm.predict_contacts(tokens, attention_mask)
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "esm", None) is not None:
+ with tf.name_scope(self.esm.name):
+ self.esm.build(None)
+
+
+@add_start_docstrings("""ESM Model with a `language modeling` head on top.""", ESM_START_DOCSTRING)
+class TFEsmForMaskedLM(TFEsmPreTrainedModel, TFMaskedLanguageModelingLoss):
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ if config.is_decoder:
+ logger.warning(
+ "If you want to use `EsmForMaskedLM` make sure `config.is_decoder=False` for "
+ "bi-directional self-attention."
+ )
+
+ self.esm = TFEsmMainLayer(config, add_pooling_layer=False, name="esm")
+ self.lm_head = TFEsmLMHead(config, name="lm_head")
+ if config.tie_word_embeddings:
+ # Ensure word embeddings are built so that we actually have something to tie
+ with tf.name_scope(os.path.join(self._name_scope(), "esm", "embeddings", "word_embeddings")):
+ self.esm.embeddings.word_embeddings.build((None, None))
+ self.lm_head.decoder = self.esm.embeddings.word_embeddings.weights[0]
+
+ def get_output_embeddings(self):
+ return self.lm_head.decoder
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head.decoder = new_embeddings
+
+ def get_lm_head(self):
+ return self.lm_head
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(ESM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TFMaskedLMOutput,
+ config_class=_CONFIG_FOR_DOC,
+ mask="",
+ )
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ position_ids: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ encoder_hidden_states: np.ndarray | tf.Tensor | None = None,
+ encoder_attention_mask: np.ndarray | tf.Tensor | None = None,
+ labels: np.ndarray | tf.Tensor | None = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ training: bool = False,
+ ) -> Union[TFMaskedLMOutput, Tuple[tf.Tensor]]:
+ r"""
+ labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
+ kwargs (`Dict[str, any]`, *optional*, defaults to `{}`):
+ Used to hide legacy arguments that have been deprecated.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.esm(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ sequence_output = outputs[0]
+ prediction_scores = self.lm_head(sequence_output)
+
+ masked_lm_loss = None
+ if labels is not None:
+ masked_lm_loss = self.hf_compute_loss(labels=labels, logits=prediction_scores)
+
+ if not return_dict:
+ output = (prediction_scores,) + outputs[2:]
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
+
+ return TFMaskedLMOutput(
+ loss=masked_lm_loss,
+ logits=prediction_scores,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def predict_contacts(self, tokens, attention_mask):
+ return self.esm.predict_contacts(tokens, attention_mask)
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "esm", None) is not None:
+ with tf.name_scope(self.esm.name):
+ self.esm.build(None)
+ if getattr(self, "lm_head", None) is not None:
+ with tf.name_scope(self.lm_head.name):
+ self.lm_head.build(None)
+
+
+class TFEsmLMHead(keras.layers.Layer):
+ """ESM Head for masked language modeling."""
+
+ def __init__(self, config, name=None):
+ super().__init__(name=name)
+ self.dense = keras.layers.Dense(
+ config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
+ )
+
+ self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
+ if config.tie_word_embeddings:
+ self.decoder = None
+ else:
+ self.decoder = keras.layers.Dense(
+ config.vocab_size,
+ kernel_initializer=get_initializer(config.initializer_range),
+ name="decoder",
+ use_bias=False,
+ )
+ self.config = config
+
+ def build(self, input_shape=None):
+ # Separate bias to match the PT model and allow weight cross-loading to work
+ # Put it in the build so it gets the right name when adding it as a weight
+ if self.built:
+ return
+ self.built = True
+ self.bias = self.add_weight("bias", shape=(self.config.vocab_size,), initializer="zeros", trainable=True)
+ if getattr(self, "dense", None) is not None:
+ with tf.name_scope(self.dense.name):
+ self.dense.build([None, None, self.config.hidden_size])
+ if getattr(self, "layer_norm", None) is not None:
+ with tf.name_scope(self.layer_norm.name):
+ self.layer_norm.build([None, None, self.config.hidden_size])
+ if getattr(self, "decoder", None) is not None and not self.config.tie_word_embeddings:
+ with tf.name_scope(self.decoder.name):
+ self.decoder.build([None, None, self.config.hidden_size])
+
+ def get_bias(self):
+ return {"bias": self.bias}
+
+ def call(self, features):
+ x = self.dense(features)
+ x = tf.nn.gelu(x)
+ x = self.layer_norm(x)
+
+ # project back to size of vocabulary with bias
+ if self.config.tie_word_embeddings:
+ x = tf.matmul(x, self.decoder, transpose_b=True) + self.bias
+ else:
+ x = self.decoder(x) + self.bias
+ return x
+
+
+@add_start_docstrings(
+ """
+ ESM Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
+ output) e.g. for GLUE tasks.
+ """,
+ ESM_START_DOCSTRING,
+)
+class TFEsmForSequenceClassification(TFEsmPreTrainedModel, TFSequenceClassificationLoss):
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.config = config
+
+ self.esm = TFEsmMainLayer(config, add_pooling_layer=False, name="esm")
+ self.classifier = TFEsmClassificationHead(config, name="classifier")
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(ESM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TFSequenceClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ position_ids: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ labels: np.ndarray | tf.Tensor | None = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ training: bool = False,
+ ) -> Union[TFSequenceClassifierOutput, Tuple[tf.Tensor]]:
+ r"""
+ labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.esm(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+ sequence_output = outputs[0]
+ logits = self.classifier(sequence_output)
+
+ loss = None if labels is None else self.hf_compute_loss(labels, logits)
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TFSequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "esm", None) is not None:
+ with tf.name_scope(self.esm.name):
+ self.esm.build(None)
+ if getattr(self, "classifier", None) is not None:
+ with tf.name_scope(self.classifier.name):
+ self.classifier.build(None)
+
+
+@add_start_docstrings(
+ """
+ ESM Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
+ Named-Entity-Recognition (NER) tasks.
+ """,
+ ESM_START_DOCSTRING,
+)
+class TFEsmForTokenClassification(TFEsmPreTrainedModel, TFTokenClassificationLoss):
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.esm = TFEsmMainLayer(config, add_pooling_layer=False, name="esm")
+ self.dropout = keras.layers.Dropout(config.hidden_dropout_prob)
+ self.classifier = keras.layers.Dense(config.num_labels, name="classifier")
+ self.config = config
+
+ @unpack_inputs
+ @add_start_docstrings_to_model_forward(ESM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TFTokenClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def call(
+ self,
+ input_ids: TFModelInputType | None = None,
+ attention_mask: np.ndarray | tf.Tensor | None = None,
+ position_ids: np.ndarray | tf.Tensor | None = None,
+ head_mask: np.ndarray | tf.Tensor | None = None,
+ inputs_embeds: np.ndarray | tf.Tensor | None = None,
+ labels: np.ndarray | tf.Tensor | None = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ training: bool = False,
+ ) -> Union[TFTokenClassifierOutput, Tuple[tf.Tensor]]:
+ r"""
+ labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.esm(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ training=training,
+ )
+
+ sequence_output = outputs[0]
+
+ sequence_output = self.dropout(sequence_output, training=training)
+ logits = self.classifier(sequence_output)
+
+ loss = None if labels is None else self.hf_compute_loss(labels, logits)
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TFTokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "esm", None) is not None:
+ with tf.name_scope(self.esm.name):
+ self.esm.build(None)
+ if getattr(self, "classifier", None) is not None:
+ with tf.name_scope(self.classifier.name):
+ self.classifier.build([None, None, self.config.hidden_size])
+
+
+class TFEsmClassificationHead(keras.layers.Layer):
+ """Head for sentence-level classification tasks."""
+
+ def __init__(self, config, name=None):
+ super().__init__(name=name)
+ self.dense = keras.layers.Dense(
+ config.hidden_size,
+ kernel_initializer=get_initializer(config.initializer_range),
+ activation="tanh",
+ name="dense",
+ )
+ self.dropout = keras.layers.Dropout(config.hidden_dropout_prob)
+ self.out_proj = keras.layers.Dense(
+ config.num_labels,
+ kernel_initializer=get_initializer(config.initializer_range),
+ activation="linear",
+ name="out_proj",
+ )
+ self.config = config
+
+ def call(self, features, training=False):
+ x = features[:, 0, :] # take token (equiv. to [CLS])
+ x = self.dropout(x, training=training)
+ x = self.dense(x)
+ x = self.dropout(x, training=training)
+ x = self.out_proj(x)
+ return x
+
+ def build(self, input_shape=None):
+ if self.built:
+ return
+ self.built = True
+ if getattr(self, "dense", None) is not None:
+ with tf.name_scope(self.dense.name):
+ self.dense.build([None, None, self.config.hidden_size])
+ if getattr(self, "out_proj", None) is not None:
+ with tf.name_scope(self.out_proj.name):
+ self.out_proj.build([None, None, self.config.hidden_size])
+
+
+def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
+ """
+ Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
+ are ignored. This is modified from fairseq's `utils.make_positions`.
+
+ Args:
+ x: tf.Tensor x:
+
+ Returns: tf.Tensor
+ """
+ # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
+ mask = tf.cast(input_ids != padding_idx, tf.int64)
+ incremental_indices = (tf.cumsum(mask, axis=1) + past_key_values_length) * mask
+ return incremental_indices + padding_idx
+
+
+__all__ = [
+ "TFEsmForMaskedLM",
+ "TFEsmForSequenceClassification",
+ "TFEsmForTokenClassification",
+ "TFEsmModel",
+ "TFEsmPreTrainedModel",
+]
diff --git a/docs/transformers/build/lib/transformers/models/esm/openfold_utils/__init__.py b/docs/transformers/build/lib/transformers/models/esm/openfold_utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..02a8c149ae320dd9b045edc5df31760a4eebefd9
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/esm/openfold_utils/__init__.py
@@ -0,0 +1,8 @@
+from .chunk_utils import chunk_layer
+from .data_transforms import make_atom14_masks
+from .feats import atom14_to_atom37, frames_and_literature_positions_to_atom14_pos, torsion_angles_to_frames
+from .loss import compute_predicted_aligned_error, compute_tm
+from .protein import Protein as OFProtein
+from .protein import to_pdb
+from .rigid_utils import Rigid, Rotation
+from .tensor_utils import dict_multimap, flatten_final_dims, permute_final_dims
diff --git a/docs/transformers/build/lib/transformers/models/esm/openfold_utils/chunk_utils.py b/docs/transformers/build/lib/transformers/models/esm/openfold_utils/chunk_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..51ff6b74d6c3f59ac23b68e2be0999b5f8dbb4ef
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/esm/openfold_utils/chunk_utils.py
@@ -0,0 +1,397 @@
+# Copyright 2021 AlQuraishi Laboratory
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+import math
+from functools import partial
+from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
+
+import torch
+
+from .tensor_utils import tensor_tree_map, tree_map
+
+
+def _fetch_dims(tree: Union[dict, list, tuple, torch.Tensor]) -> List[Tuple[int, ...]]:
+ shapes = []
+ if isinstance(tree, dict):
+ for v in tree.values():
+ shapes.extend(_fetch_dims(v))
+ elif isinstance(tree, (list, tuple)):
+ for t in tree:
+ shapes.extend(_fetch_dims(t))
+ elif isinstance(tree, torch.Tensor):
+ shapes.append(tree.shape)
+ else:
+ raise TypeError("Not supported")
+
+ return shapes
+
+
+@torch.jit.ignore
+def _flat_idx_to_idx(flat_idx: int, dims: Tuple[int, ...]) -> Tuple[int, ...]:
+ idx = []
+ for d in reversed(dims):
+ idx.append(flat_idx % d)
+ flat_idx = flat_idx // d
+
+ return tuple(reversed(idx))
+
+
+@torch.jit.ignore
+def _get_minimal_slice_set(
+ start: Sequence[int],
+ end: Sequence[int],
+ dims: Sequence[int],
+ start_edges: Optional[Sequence[bool]] = None,
+ end_edges: Optional[Sequence[bool]] = None,
+) -> List[Tuple[slice, ...]]:
+ """
+ Produces an ordered sequence of tensor slices that, when used in sequence on a tensor with shape dims, yields
+ tensors that contain every leaf in the contiguous range [start, end]. Care is taken to yield a short sequence of
+ slices, and perhaps even the shortest possible (I'm pretty sure it's the latter).
+
+ end is INCLUSIVE.
+ """
+
+ # start_edges and end_edges both indicate whether, starting from any given
+ # dimension, the start/end index is at the top/bottom edge of the
+ # corresponding tensor, modeled as a tree
+ def reduce_edge_list(l: List[bool]) -> None:
+ tally = True
+ for i in range(len(l)):
+ reversed_idx = -1 * (i + 1)
+ l[reversed_idx] &= tally
+ tally = l[reversed_idx]
+
+ if start_edges is None:
+ start_edges = [s == 0 for s in start]
+ reduce_edge_list(start_edges)
+ if end_edges is None:
+ end_edges = [e == (d - 1) for e, d in zip(end, dims)]
+ reduce_edge_list(end_edges)
+
+ # Base cases. Either start/end are empty and we're done, or the final,
+ # one-dimensional tensor can be simply sliced
+ if len(start) == 0:
+ return [()]
+ elif len(start) == 1:
+ return [(slice(start[0], end[0] + 1),)]
+
+ slices: List[Tuple[slice, ...]] = []
+ path_list: List[slice] = []
+
+ # Dimensions common to start and end can be selected directly
+ for s, e in zip(start, end):
+ if s == e:
+ path_list.append(slice(s, s + 1))
+ else:
+ break
+
+ path: Tuple[slice, ...] = tuple(path_list)
+ divergence_idx = len(path)
+
+ # start == end, and we're done
+ if divergence_idx == len(dims):
+ return [path]
+
+ def upper() -> Tuple[Tuple[slice, ...], ...]:
+ assert start_edges is not None
+ assert end_edges is not None
+
+ sdi = start[divergence_idx]
+ return tuple(
+ path + (slice(sdi, sdi + 1),) + s
+ for s in _get_minimal_slice_set(
+ start[divergence_idx + 1 :],
+ [d - 1 for d in dims[divergence_idx + 1 :]],
+ dims[divergence_idx + 1 :],
+ start_edges=start_edges[divergence_idx + 1 :],
+ end_edges=[True for _ in end_edges[divergence_idx + 1 :]],
+ )
+ )
+
+ def lower() -> Tuple[Tuple[slice, ...], ...]:
+ assert start_edges is not None
+ assert end_edges is not None
+
+ edi = end[divergence_idx]
+ return tuple(
+ path + (slice(edi, edi + 1),) + s
+ for s in _get_minimal_slice_set(
+ [0 for _ in start[divergence_idx + 1 :]],
+ end[divergence_idx + 1 :],
+ dims[divergence_idx + 1 :],
+ start_edges=[True for _ in start_edges[divergence_idx + 1 :]],
+ end_edges=end_edges[divergence_idx + 1 :],
+ )
+ )
+
+ # If both start and end are at the edges of the subtree rooted at
+ # divergence_idx, we can just select the whole subtree at once
+ if start_edges[divergence_idx] and end_edges[divergence_idx]:
+ slices.append(path + (slice(start[divergence_idx], end[divergence_idx] + 1),))
+ # If just start is at the edge, we can grab almost all of the subtree,
+ # treating only the ragged bottom edge as an edge case
+ elif start_edges[divergence_idx]:
+ slices.append(path + (slice(start[divergence_idx], end[divergence_idx]),))
+ slices.extend(lower())
+ # Analogous to the previous case, but the top is ragged this time
+ elif end_edges[divergence_idx]:
+ slices.extend(upper())
+ slices.append(path + (slice(start[divergence_idx] + 1, end[divergence_idx] + 1),))
+ # If both sides of the range are ragged, we need to handle both sides
+ # separately. If there's contiguous meat in between them, we can index it
+ # in one big chunk
+ else:
+ slices.extend(upper())
+ middle_ground = end[divergence_idx] - start[divergence_idx]
+ if middle_ground > 1:
+ slices.append(path + (slice(start[divergence_idx] + 1, end[divergence_idx]),))
+ slices.extend(lower())
+
+ return slices
+
+
+@torch.jit.ignore
+def _chunk_slice(t: torch.Tensor, flat_start: int, flat_end: int, no_batch_dims: int) -> torch.Tensor:
+ """
+ Equivalent to
+
+ t.reshape((-1,) + t.shape[no_batch_dims:])[flat_start:flat_end]
+
+ but without the need for the initial reshape call, which can be memory-intensive in certain situations. The only
+ reshape operations in this function are performed on sub-tensors that scale with (flat_end - flat_start), the chunk
+ size.
+ """
+
+ batch_dims = t.shape[:no_batch_dims]
+ start_idx = list(_flat_idx_to_idx(flat_start, batch_dims))
+ # _get_minimal_slice_set is inclusive
+ end_idx = list(_flat_idx_to_idx(flat_end - 1, batch_dims))
+
+ # Get an ordered list of slices to perform
+ slices = _get_minimal_slice_set(
+ start_idx,
+ end_idx,
+ batch_dims,
+ )
+
+ sliced_tensors = [t[s] for s in slices]
+
+ return torch.cat([s.view((-1,) + t.shape[no_batch_dims:]) for s in sliced_tensors])
+
+
+def chunk_layer(
+ layer: Callable,
+ inputs: Dict[str, Any],
+ chunk_size: int,
+ no_batch_dims: int,
+ low_mem: bool = False,
+ _out: Any = None,
+ _add_into_out: bool = False,
+) -> Any:
+ """
+ Implements the "chunking" procedure described in section 1.11.8.
+
+ Layer outputs and inputs are assumed to be simple "pytrees," consisting only of (arbitrarily nested) lists, tuples,
+ and dicts with torch.Tensor leaves.
+
+ Args:
+ layer:
+ The layer to be applied chunk-wise
+ inputs:
+ A (non-nested) dictionary of keyworded inputs. All leaves must be tensors and must share the same batch
+ dimensions.
+ chunk_size:
+ The number of sub-batches per chunk. If multiple batch dimensions are specified, a "sub-batch" is defined
+ as a single indexing of all batch dimensions simultaneously (s.t. the number of sub-batches is the product
+ of the batch dimensions).
+ no_batch_dims:
+ How many of the initial dimensions of each input tensor can be considered batch dimensions.
+ low_mem:
+ Avoids flattening potentially large input tensors. Unnecessary in most cases, and is ever so slightly
+ slower than the default setting.
+ Returns:
+ The reassembled output of the layer on the inputs.
+ """
+ if not (len(inputs) > 0):
+ raise ValueError("Must provide at least one input")
+
+ initial_dims = [shape[:no_batch_dims] for shape in _fetch_dims(inputs)]
+ orig_batch_dims = tuple([max(s) for s in zip(*initial_dims)])
+
+ def _prep_inputs(t: torch.Tensor) -> torch.Tensor:
+ if not low_mem:
+ if not sum(t.shape[:no_batch_dims]) == no_batch_dims:
+ t = t.expand(orig_batch_dims + t.shape[no_batch_dims:])
+ t = t.reshape(-1, *t.shape[no_batch_dims:])
+ else:
+ t = t.expand(orig_batch_dims + t.shape[no_batch_dims:])
+ return t
+
+ prepped_inputs: Dict[str, Any] = tensor_tree_map(_prep_inputs, inputs)
+ prepped_outputs = None
+ if _out is not None:
+ prepped_outputs = tensor_tree_map(lambda t: t.view([-1] + list(t.shape[no_batch_dims:])), _out)
+
+ flat_batch_dim = 1
+ for d in orig_batch_dims:
+ flat_batch_dim *= d
+
+ no_chunks = flat_batch_dim // chunk_size + (flat_batch_dim % chunk_size != 0)
+
+ def _select_chunk(t: torch.Tensor) -> torch.Tensor:
+ return t[i : i + chunk_size] if t.shape[0] != 1 else t
+
+ i = 0
+ out = prepped_outputs
+ for _ in range(no_chunks):
+ # Chunk the input
+ if not low_mem:
+ select_chunk = _select_chunk
+ else:
+ select_chunk = partial(
+ _chunk_slice,
+ flat_start=i,
+ flat_end=min(flat_batch_dim, i + chunk_size),
+ no_batch_dims=len(orig_batch_dims),
+ )
+
+ chunks: Dict[str, Any] = tensor_tree_map(select_chunk, prepped_inputs)
+
+ # Run the layer on the chunk
+ output_chunk = layer(**chunks)
+
+ # Allocate space for the output
+ if out is None:
+ out = tensor_tree_map(lambda t: t.new_zeros((flat_batch_dim,) + t.shape[1:]), output_chunk)
+
+ # Put the chunk in its pre-allocated space
+ if isinstance(output_chunk, dict):
+
+ def assign(d1: dict, d2: dict) -> None:
+ for k, v in d1.items():
+ if isinstance(v, dict):
+ assign(v, d2[k])
+ else:
+ if _add_into_out:
+ v[i : i + chunk_size] += d2[k]
+ else:
+ v[i : i + chunk_size] = d2[k]
+
+ assign(out, output_chunk)
+ elif isinstance(output_chunk, tuple):
+ for x1, x2 in zip(out, output_chunk):
+ if _add_into_out:
+ x1[i : i + chunk_size] += x2
+ else:
+ x1[i : i + chunk_size] = x2
+ elif isinstance(output_chunk, torch.Tensor):
+ if _add_into_out:
+ out[i : i + chunk_size] += output_chunk
+ else:
+ out[i : i + chunk_size] = output_chunk
+ else:
+ raise TypeError("Not supported")
+
+ i += chunk_size
+
+ out = tensor_tree_map(lambda t: t.view(orig_batch_dims + t.shape[1:]), out)
+
+ return out
+
+
+class ChunkSizeTuner:
+ def __init__(
+ self,
+ # Heuristically, runtimes for most of the modules in the network
+ # plateau earlier than this on all GPUs I've run the model on.
+ max_chunk_size: int = 512,
+ ):
+ self.max_chunk_size = max_chunk_size
+ self.cached_chunk_size: Optional[int] = None
+ self.cached_arg_data: Optional[tuple] = None
+
+ def _determine_favorable_chunk_size(self, fn: Callable, args: tuple, min_chunk_size: int) -> int:
+ logging.info("Tuning chunk size...")
+
+ if min_chunk_size >= self.max_chunk_size:
+ return min_chunk_size
+
+ candidates: List[int] = [2**l for l in range(int(math.log(self.max_chunk_size, 2)) + 1)]
+ candidates = [c for c in candidates if c > min_chunk_size]
+ candidates = [min_chunk_size] + candidates
+ candidates[-1] += 4
+
+ def test_chunk_size(chunk_size: int) -> bool:
+ try:
+ with torch.no_grad():
+ fn(*args, chunk_size=chunk_size)
+ return True
+ except RuntimeError:
+ return False
+
+ min_viable_chunk_size_index = 0
+ i = len(candidates) - 1
+ while i > min_viable_chunk_size_index:
+ viable = test_chunk_size(candidates[i])
+ if not viable:
+ i = (min_viable_chunk_size_index + i) // 2
+ else:
+ min_viable_chunk_size_index = i
+ i = (i + len(candidates) - 1) // 2
+
+ return candidates[min_viable_chunk_size_index]
+
+ def _compare_arg_caches(self, ac1: Iterable, ac2: Iterable) -> bool:
+ consistent = True
+ for a1, a2 in zip(ac1, ac2):
+ assert type(ac1) is type(ac2)
+ if isinstance(ac1, (list, tuple)):
+ consistent &= self._compare_arg_caches(a1, a2)
+ elif isinstance(ac1, dict):
+ a1_items = [v for _, v in sorted(a1.items(), key=lambda x: x[0])]
+ a2_items = [v for _, v in sorted(a2.items(), key=lambda x: x[0])]
+ consistent &= self._compare_arg_caches(a1_items, a2_items)
+ else:
+ consistent &= a1 == a2
+
+ return consistent
+
+ def tune_chunk_size(
+ self,
+ representative_fn: Callable,
+ args: tuple,
+ min_chunk_size: int,
+ ) -> int:
+ consistent = True
+ arg_data: tuple = tree_map(lambda a: a.shape if isinstance(a, torch.Tensor) else a, args, object)
+ if self.cached_arg_data is not None:
+ # If args have changed shape/value, we need to re-tune
+ assert len(self.cached_arg_data) == len(arg_data)
+ consistent = self._compare_arg_caches(self.cached_arg_data, arg_data)
+ else:
+ # Otherwise, we can reuse the precomputed value
+ consistent = False
+
+ if not consistent:
+ self.cached_chunk_size = self._determine_favorable_chunk_size(
+ representative_fn,
+ args,
+ min_chunk_size,
+ )
+ self.cached_arg_data = arg_data
+
+ assert self.cached_chunk_size is not None
+
+ return self.cached_chunk_size
diff --git a/docs/transformers/build/lib/transformers/models/esm/openfold_utils/data_transforms.py b/docs/transformers/build/lib/transformers/models/esm/openfold_utils/data_transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d4c17589ae66df2a8fd0ccfe8d6e335004eed9a
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/esm/openfold_utils/data_transforms.py
@@ -0,0 +1,93 @@
+# Copyright 2021 AlQuraishi Laboratory
+# Copyright 2021 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Dict
+
+import numpy as np
+import torch
+
+from . import residue_constants as rc
+from .tensor_utils import tensor_tree_map, tree_map
+
+
+def make_atom14_masks(protein: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
+ """Construct denser atom positions (14 dimensions instead of 37)."""
+ restype_atom14_to_atom37_list = []
+ restype_atom37_to_atom14_list = []
+ restype_atom14_mask_list = []
+
+ for rt in rc.restypes:
+ atom_names = rc.restype_name_to_atom14_names[rc.restype_1to3[rt]]
+ restype_atom14_to_atom37_list.append([(rc.atom_order[name] if name else 0) for name in atom_names])
+ atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)}
+ restype_atom37_to_atom14_list.append(
+ [(atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0) for name in rc.atom_types]
+ )
+
+ restype_atom14_mask_list.append([(1.0 if name else 0.0) for name in atom_names])
+
+ # Add dummy mapping for restype 'UNK'
+ restype_atom14_to_atom37_list.append([0] * 14)
+ restype_atom37_to_atom14_list.append([0] * 37)
+ restype_atom14_mask_list.append([0.0] * 14)
+
+ restype_atom14_to_atom37 = torch.tensor(
+ restype_atom14_to_atom37_list,
+ dtype=torch.int32,
+ device=protein["aatype"].device,
+ )
+ restype_atom37_to_atom14 = torch.tensor(
+ restype_atom37_to_atom14_list,
+ dtype=torch.int32,
+ device=protein["aatype"].device,
+ )
+ restype_atom14_mask = torch.tensor(
+ restype_atom14_mask_list,
+ dtype=torch.float32,
+ device=protein["aatype"].device,
+ )
+ protein_aatype = protein["aatype"].to(torch.long)
+
+ # create the mapping for (residx, atom14) --> atom37, i.e. an array
+ # with shape (num_res, 14) containing the atom37 indices for this protein
+ residx_atom14_to_atom37 = restype_atom14_to_atom37[protein_aatype]
+ residx_atom14_mask = restype_atom14_mask[protein_aatype]
+
+ protein["atom14_atom_exists"] = residx_atom14_mask
+ protein["residx_atom14_to_atom37"] = residx_atom14_to_atom37.long()
+
+ # create the gather indices for mapping back
+ residx_atom37_to_atom14 = restype_atom37_to_atom14[protein_aatype]
+ protein["residx_atom37_to_atom14"] = residx_atom37_to_atom14.long()
+
+ # create the corresponding mask
+ restype_atom37_mask = torch.zeros([21, 37], dtype=torch.float32, device=protein["aatype"].device)
+ for restype, restype_letter in enumerate(rc.restypes):
+ restype_name = rc.restype_1to3[restype_letter]
+ atom_names = rc.residue_atoms[restype_name]
+ for atom_name in atom_names:
+ atom_type = rc.atom_order[atom_name]
+ restype_atom37_mask[restype, atom_type] = 1
+
+ residx_atom37_mask = restype_atom37_mask[protein_aatype]
+ protein["atom37_atom_exists"] = residx_atom37_mask
+
+ return protein
+
+
+def make_atom14_masks_np(batch: Dict[str, torch.Tensor]) -> Dict[str, np.ndarray]:
+ batch = tree_map(lambda n: torch.tensor(n, device=batch["aatype"].device), batch, np.ndarray)
+ out = tensor_tree_map(lambda t: np.array(t), make_atom14_masks(batch))
+ return out
diff --git a/docs/transformers/build/lib/transformers/models/esm/openfold_utils/feats.py b/docs/transformers/build/lib/transformers/models/esm/openfold_utils/feats.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac7b90dfe79b24b852cb26fca998bda831f36a6f
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/esm/openfold_utils/feats.py
@@ -0,0 +1,253 @@
+# Copyright 2021 AlQuraishi Laboratory
+# Copyright 2021 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Dict, Tuple, overload
+
+import torch
+import torch.types
+from torch import nn
+
+from . import residue_constants as rc
+from .rigid_utils import Rigid, Rotation
+from .tensor_utils import batched_gather
+
+
+@overload
+def pseudo_beta_fn(aatype: torch.Tensor, all_atom_positions: torch.Tensor, all_atom_masks: None) -> torch.Tensor: ...
+
+
+@overload
+def pseudo_beta_fn(
+ aatype: torch.Tensor, all_atom_positions: torch.Tensor, all_atom_masks: torch.Tensor
+) -> Tuple[torch.Tensor, torch.Tensor]: ...
+
+
+def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks):
+ is_gly = aatype == rc.restype_order["G"]
+ ca_idx = rc.atom_order["CA"]
+ cb_idx = rc.atom_order["CB"]
+ pseudo_beta = torch.where(
+ is_gly[..., None].expand(*((-1,) * len(is_gly.shape)), 3),
+ all_atom_positions[..., ca_idx, :],
+ all_atom_positions[..., cb_idx, :],
+ )
+
+ if all_atom_masks is not None:
+ pseudo_beta_mask = torch.where(
+ is_gly,
+ all_atom_masks[..., ca_idx],
+ all_atom_masks[..., cb_idx],
+ )
+ return pseudo_beta, pseudo_beta_mask
+ else:
+ return pseudo_beta
+
+
+def atom14_to_atom37(atom14: torch.Tensor, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
+ atom37_data = batched_gather(
+ atom14,
+ batch["residx_atom37_to_atom14"],
+ dim=-2,
+ no_batch_dims=len(atom14.shape[:-2]),
+ )
+
+ atom37_data = atom37_data * batch["atom37_atom_exists"][..., None]
+
+ return atom37_data
+
+
+def build_template_angle_feat(template_feats: Dict[str, torch.Tensor]) -> torch.Tensor:
+ template_aatype = template_feats["template_aatype"]
+ torsion_angles_sin_cos = template_feats["template_torsion_angles_sin_cos"]
+ alt_torsion_angles_sin_cos = template_feats["template_alt_torsion_angles_sin_cos"]
+ torsion_angles_mask = template_feats["template_torsion_angles_mask"]
+ template_angle_feat = torch.cat(
+ [
+ nn.functional.one_hot(template_aatype, 22),
+ torsion_angles_sin_cos.reshape(*torsion_angles_sin_cos.shape[:-2], 14),
+ alt_torsion_angles_sin_cos.reshape(*alt_torsion_angles_sin_cos.shape[:-2], 14),
+ torsion_angles_mask,
+ ],
+ dim=-1,
+ )
+
+ return template_angle_feat
+
+
+def build_template_pair_feat(
+ batch: Dict[str, torch.Tensor],
+ min_bin: torch.types.Number,
+ max_bin: torch.types.Number,
+ no_bins: int,
+ use_unit_vector: bool = False,
+ eps: float = 1e-20,
+ inf: float = 1e8,
+) -> torch.Tensor:
+ template_mask = batch["template_pseudo_beta_mask"]
+ template_mask_2d = template_mask[..., None] * template_mask[..., None, :]
+
+ # Compute distogram (this seems to differ slightly from Alg. 5)
+ tpb = batch["template_pseudo_beta"]
+ dgram = torch.sum((tpb[..., None, :] - tpb[..., None, :, :]) ** 2, dim=-1, keepdim=True)
+ lower = torch.linspace(min_bin, max_bin, no_bins, device=tpb.device) ** 2
+ upper = torch.cat([lower[1:], lower.new_tensor([inf])], dim=-1)
+ dgram = ((dgram > lower) * (dgram < upper)).type(dgram.dtype)
+
+ to_concat = [dgram, template_mask_2d[..., None]]
+
+ aatype_one_hot: torch.LongTensor = nn.functional.one_hot(
+ batch["template_aatype"],
+ rc.restype_num + 2,
+ )
+
+ n_res = batch["template_aatype"].shape[-1]
+ to_concat.append(aatype_one_hot[..., None, :, :].expand(*aatype_one_hot.shape[:-2], n_res, -1, -1))
+ to_concat.append(aatype_one_hot[..., None, :].expand(*aatype_one_hot.shape[:-2], -1, n_res, -1))
+
+ n, ca, c = [rc.atom_order[a] for a in ["N", "CA", "C"]]
+ rigids = Rigid.make_transform_from_reference(
+ n_xyz=batch["template_all_atom_positions"][..., n, :],
+ ca_xyz=batch["template_all_atom_positions"][..., ca, :],
+ c_xyz=batch["template_all_atom_positions"][..., c, :],
+ eps=eps,
+ )
+ points = rigids.get_trans()[..., None, :, :]
+ rigid_vec = rigids[..., None].invert_apply(points)
+
+ inv_distance_scalar = torch.rsqrt(eps + torch.sum(rigid_vec**2, dim=-1))
+
+ t_aa_masks = batch["template_all_atom_mask"]
+ template_mask = t_aa_masks[..., n] * t_aa_masks[..., ca] * t_aa_masks[..., c]
+ template_mask_2d = template_mask[..., None] * template_mask[..., None, :]
+
+ inv_distance_scalar = inv_distance_scalar * template_mask_2d
+ unit_vector = rigid_vec * inv_distance_scalar[..., None]
+
+ if not use_unit_vector:
+ unit_vector = unit_vector * 0.0
+
+ to_concat.extend(torch.unbind(unit_vector[..., None, :], dim=-1))
+ to_concat.append(template_mask_2d[..., None])
+
+ act = torch.cat(to_concat, dim=-1)
+ act = act * template_mask_2d[..., None]
+
+ return act
+
+
+def build_extra_msa_feat(batch: Dict[str, torch.Tensor]) -> torch.Tensor:
+ msa_1hot: torch.LongTensor = nn.functional.one_hot(batch["extra_msa"], 23)
+ msa_feat = [
+ msa_1hot,
+ batch["extra_has_deletion"].unsqueeze(-1),
+ batch["extra_deletion_value"].unsqueeze(-1),
+ ]
+ return torch.cat(msa_feat, dim=-1)
+
+
+def torsion_angles_to_frames(
+ r: Rigid,
+ alpha: torch.Tensor,
+ aatype: torch.Tensor,
+ rrgdf: torch.Tensor,
+) -> Rigid:
+ # [*, N, 8, 4, 4]
+ default_4x4 = rrgdf[aatype, ...]
+
+ # [*, N, 8] transformations, i.e.
+ # One [*, N, 8, 3, 3] rotation matrix and
+ # One [*, N, 8, 3] translation matrix
+ default_r = r.from_tensor_4x4(default_4x4)
+
+ bb_rot = alpha.new_zeros((*((1,) * len(alpha.shape[:-1])), 2))
+ bb_rot[..., 1] = 1
+
+ # [*, N, 8, 2]
+ alpha = torch.cat([bb_rot.expand(*alpha.shape[:-2], -1, -1), alpha], dim=-2)
+
+ # [*, N, 8, 3, 3]
+ # Produces rotation matrices of the form:
+ # [
+ # [1, 0 , 0 ],
+ # [0, a_2,-a_1],
+ # [0, a_1, a_2]
+ # ]
+ # This follows the original code rather than the supplement, which uses
+ # different indices.
+
+ all_rots = alpha.new_zeros(default_r.get_rots().get_rot_mats().shape)
+ all_rots[..., 0, 0] = 1
+ all_rots[..., 1, 1] = alpha[..., 1]
+ all_rots[..., 1, 2] = -alpha[..., 0]
+ all_rots[..., 2, 1:] = alpha
+
+ all_frames = default_r.compose(Rigid(Rotation(rot_mats=all_rots), None))
+
+ chi2_frame_to_frame = all_frames[..., 5]
+ chi3_frame_to_frame = all_frames[..., 6]
+ chi4_frame_to_frame = all_frames[..., 7]
+
+ chi1_frame_to_bb = all_frames[..., 4]
+ chi2_frame_to_bb = chi1_frame_to_bb.compose(chi2_frame_to_frame)
+ chi3_frame_to_bb = chi2_frame_to_bb.compose(chi3_frame_to_frame)
+ chi4_frame_to_bb = chi3_frame_to_bb.compose(chi4_frame_to_frame)
+
+ all_frames_to_bb = Rigid.cat(
+ [
+ all_frames[..., :5],
+ chi2_frame_to_bb.unsqueeze(-1),
+ chi3_frame_to_bb.unsqueeze(-1),
+ chi4_frame_to_bb.unsqueeze(-1),
+ ],
+ dim=-1,
+ )
+
+ all_frames_to_global = r[..., None].compose(all_frames_to_bb)
+
+ return all_frames_to_global
+
+
+def frames_and_literature_positions_to_atom14_pos(
+ r: Rigid,
+ aatype: torch.Tensor,
+ default_frames: torch.Tensor,
+ group_idx: torch.Tensor,
+ atom_mask: torch.Tensor,
+ lit_positions: torch.Tensor,
+) -> torch.Tensor:
+ # [*, N, 14]
+ group_mask = group_idx[aatype, ...]
+
+ # [*, N, 14, 8]
+ group_mask_one_hot: torch.LongTensor = nn.functional.one_hot(
+ group_mask,
+ num_classes=default_frames.shape[-3],
+ )
+
+ # [*, N, 14, 8]
+ t_atoms_to_global = r[..., None, :] * group_mask_one_hot
+
+ # [*, N, 14]
+ t_atoms_to_global = t_atoms_to_global.map_tensor_fn(lambda x: torch.sum(x, dim=-1))
+
+ # [*, N, 14, 1]
+ atom_mask = atom_mask[aatype, ...].unsqueeze(-1)
+
+ # [*, N, 14, 3]
+ lit_positions = lit_positions[aatype, ...]
+ pred_positions = t_atoms_to_global.apply(lit_positions)
+ pred_positions = pred_positions * atom_mask
+
+ return pred_positions
diff --git a/docs/transformers/build/lib/transformers/models/esm/openfold_utils/loss.py b/docs/transformers/build/lib/transformers/models/esm/openfold_utils/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c442786dc82ba2ebe243923509ed76a40de2a01
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/esm/openfold_utils/loss.py
@@ -0,0 +1,105 @@
+# Copyright 2021 AlQuraishi Laboratory
+# Copyright 2021 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Dict, Optional, Tuple
+
+import torch
+
+
+def _calculate_bin_centers(boundaries: torch.Tensor) -> torch.Tensor:
+ step = boundaries[1] - boundaries[0]
+ bin_centers = boundaries + step / 2
+ bin_centers = torch.cat([bin_centers, (bin_centers[-1] + step).unsqueeze(-1)], dim=0)
+ return bin_centers
+
+
+def _calculate_expected_aligned_error(
+ alignment_confidence_breaks: torch.Tensor,
+ aligned_distance_error_probs: torch.Tensor,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ bin_centers = _calculate_bin_centers(alignment_confidence_breaks)
+ return (
+ torch.sum(aligned_distance_error_probs * bin_centers, dim=-1),
+ bin_centers[-1],
+ )
+
+
+def compute_predicted_aligned_error(
+ logits: torch.Tensor,
+ max_bin: int = 31,
+ no_bins: int = 64,
+ **kwargs,
+) -> Dict[str, torch.Tensor]:
+ """Computes aligned confidence metrics from logits.
+
+ Args:
+ logits: [*, num_res, num_res, num_bins] the logits output from
+ PredictedAlignedErrorHead.
+ max_bin: Maximum bin value
+ no_bins: Number of bins
+ Returns:
+ aligned_confidence_probs: [*, num_res, num_res, num_bins] the predicted
+ aligned error probabilities over bins for each residue pair.
+ predicted_aligned_error: [*, num_res, num_res] the expected aligned distance
+ error for each pair of residues.
+ max_predicted_aligned_error: [*] the maximum predicted error possible.
+ """
+ boundaries = torch.linspace(0, max_bin, steps=(no_bins - 1), device=logits.device)
+
+ aligned_confidence_probs = torch.nn.functional.softmax(logits, dim=-1)
+ predicted_aligned_error, max_predicted_aligned_error = _calculate_expected_aligned_error(
+ alignment_confidence_breaks=boundaries,
+ aligned_distance_error_probs=aligned_confidence_probs,
+ )
+
+ return {
+ "aligned_confidence_probs": aligned_confidence_probs,
+ "predicted_aligned_error": predicted_aligned_error,
+ "max_predicted_aligned_error": max_predicted_aligned_error,
+ }
+
+
+def compute_tm(
+ logits: torch.Tensor,
+ residue_weights: Optional[torch.Tensor] = None,
+ max_bin: int = 31,
+ no_bins: int = 64,
+ eps: float = 1e-8,
+ **kwargs,
+) -> torch.Tensor:
+ if residue_weights is None:
+ residue_weights = logits.new_ones(logits.shape[-2])
+
+ boundaries = torch.linspace(0, max_bin, steps=(no_bins - 1), device=logits.device)
+
+ bin_centers = _calculate_bin_centers(boundaries)
+ torch.sum(residue_weights)
+ n = logits.shape[-2]
+ clipped_n = max(n, 19)
+
+ d0 = 1.24 * (clipped_n - 15) ** (1.0 / 3) - 1.8
+
+ probs = torch.nn.functional.softmax(logits, dim=-1)
+
+ tm_per_bin = 1.0 / (1 + (bin_centers**2) / (d0**2))
+ predicted_tm_term = torch.sum(probs * tm_per_bin, dim=-1)
+
+ normed_residue_mask = residue_weights / (eps + residue_weights.sum())
+ per_alignment = torch.sum(predicted_tm_term * normed_residue_mask, dim=-1)
+
+ weighted = per_alignment * residue_weights
+
+ argmax = (weighted == torch.max(weighted)).nonzero()[0]
+ return per_alignment[tuple(argmax)]
diff --git a/docs/transformers/build/lib/transformers/models/esm/openfold_utils/protein.py b/docs/transformers/build/lib/transformers/models/esm/openfold_utils/protein.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae9d8c13277bd8c5e7dd152f50d549b6f7286af3
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/esm/openfold_utils/protein.py
@@ -0,0 +1,330 @@
+# Copyright 2021 AlQuraishi Laboratory
+# Copyright 2021 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Protein data type."""
+
+import dataclasses
+import re
+import string
+from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence, Tuple
+
+import numpy as np
+
+from . import residue_constants
+
+
+FeatureDict = Mapping[str, np.ndarray]
+ModelOutput = Mapping[str, Any] # Is a nested dict.
+PICO_TO_ANGSTROM = 0.01
+
+
+@dataclasses.dataclass(frozen=True)
+class Protein:
+ """Protein structure representation."""
+
+ # Cartesian coordinates of atoms in angstroms. The atom types correspond to
+ # residue_constants.atom_types, i.e. the first three are N, CA, CB.
+ atom_positions: np.ndarray # [num_res, num_atom_type, 3]
+
+ # Amino-acid type for each residue represented as an integer between 0 and
+ # 20, where 20 is 'X'.
+ aatype: np.ndarray # [num_res]
+
+ # Binary float mask to indicate presence of a particular atom. 1.0 if an atom
+ # is present and 0.0 if not. This should be used for loss masking.
+ atom_mask: np.ndarray # [num_res, num_atom_type]
+
+ # Residue index as used in PDB. It is not necessarily continuous or 0-indexed.
+ residue_index: np.ndarray # [num_res]
+
+ # B-factors, or temperature factors, of each residue (in sq. angstroms units),
+ # representing the displacement of the residue from its ground truth mean
+ # value.
+ b_factors: np.ndarray # [num_res, num_atom_type]
+
+ # Chain indices for multi-chain predictions
+ chain_index: Optional[np.ndarray] = None
+
+ # Optional remark about the protein. Included as a comment in output PDB
+ # files
+ remark: Optional[str] = None
+
+ # Templates used to generate this protein (prediction-only)
+ parents: Optional[Sequence[str]] = None
+
+ # Chain corresponding to each parent
+ parents_chain_index: Optional[Sequence[int]] = None
+
+
+def from_proteinnet_string(proteinnet_str: str) -> Protein:
+ tag_re = r"(\[[A-Z]+\]\n)"
+ tags: List[str] = [tag.strip() for tag in re.split(tag_re, proteinnet_str) if len(tag) > 0]
+ groups: Iterator[Tuple[str, List[str]]] = zip(tags[0::2], [l.split("\n") for l in tags[1::2]])
+
+ atoms: List[str] = ["N", "CA", "C"]
+ aatype = None
+ atom_positions = None
+ atom_mask = None
+ for g in groups:
+ if "[PRIMARY]" == g[0]:
+ seq = g[1][0].strip()
+ for i in range(len(seq)):
+ if seq[i] not in residue_constants.restypes:
+ seq[i] = "X" # FIXME: strings are immutable
+ aatype = np.array(
+ [residue_constants.restype_order.get(res_symbol, residue_constants.restype_num) for res_symbol in seq]
+ )
+ elif "[TERTIARY]" == g[0]:
+ tertiary: List[List[float]] = []
+ for axis in range(3):
+ tertiary.append(list(map(float, g[1][axis].split())))
+ tertiary_np = np.array(tertiary)
+ atom_positions = np.zeros((len(tertiary[0]) // 3, residue_constants.atom_type_num, 3)).astype(np.float32)
+ for i, atom in enumerate(atoms):
+ atom_positions[:, residue_constants.atom_order[atom], :] = np.transpose(tertiary_np[:, i::3])
+ atom_positions *= PICO_TO_ANGSTROM
+ elif "[MASK]" == g[0]:
+ mask = np.array(list(map({"-": 0, "+": 1}.get, g[1][0].strip())))
+ atom_mask = np.zeros(
+ (
+ len(mask),
+ residue_constants.atom_type_num,
+ )
+ ).astype(np.float32)
+ for i, atom in enumerate(atoms):
+ atom_mask[:, residue_constants.atom_order[atom]] = 1
+ atom_mask *= mask[..., None]
+
+ assert aatype is not None
+
+ return Protein(
+ atom_positions=atom_positions,
+ atom_mask=atom_mask,
+ aatype=aatype,
+ residue_index=np.arange(len(aatype)),
+ b_factors=None,
+ )
+
+
+def get_pdb_headers(prot: Protein, chain_id: int = 0) -> List[str]:
+ pdb_headers: List[str] = []
+
+ remark = prot.remark
+ if remark is not None:
+ pdb_headers.append(f"REMARK {remark}")
+
+ parents = prot.parents
+ parents_chain_index = prot.parents_chain_index
+ if parents is not None and parents_chain_index is not None:
+ parents = [p for i, p in zip(parents_chain_index, parents) if i == chain_id]
+
+ if parents is None or len(parents) == 0:
+ parents = ["N/A"]
+
+ pdb_headers.append(f"PARENT {' '.join(parents)}")
+
+ return pdb_headers
+
+
+def add_pdb_headers(prot: Protein, pdb_str: str) -> str:
+ """Add pdb headers to an existing PDB string. Useful during multi-chain
+ recycling
+ """
+ out_pdb_lines: List[str] = []
+ lines = pdb_str.split("\n")
+
+ remark = prot.remark
+ if remark is not None:
+ out_pdb_lines.append(f"REMARK {remark}")
+
+ parents_per_chain: List[List[str]]
+ if prot.parents is not None and len(prot.parents) > 0:
+ parents_per_chain = []
+ if prot.parents_chain_index is not None:
+ parent_dict: Dict[str, List[str]] = {}
+ for p, i in zip(prot.parents, prot.parents_chain_index):
+ parent_dict.setdefault(str(i), [])
+ parent_dict[str(i)].append(p)
+
+ max_idx = max([int(chain_idx) for chain_idx in parent_dict])
+ for i in range(max_idx + 1):
+ chain_parents = parent_dict.get(str(i), ["N/A"])
+ parents_per_chain.append(chain_parents)
+ else:
+ parents_per_chain.append(list(prot.parents))
+ else:
+ parents_per_chain = [["N/A"]]
+
+ def make_parent_line(p: Sequence[str]) -> str:
+ return f"PARENT {' '.join(p)}"
+
+ out_pdb_lines.append(make_parent_line(parents_per_chain[0]))
+
+ chain_counter = 0
+ for i, l in enumerate(lines):
+ if "PARENT" not in l and "REMARK" not in l:
+ out_pdb_lines.append(l)
+ if "TER" in l and "END" not in lines[i + 1]:
+ chain_counter += 1
+ if not chain_counter >= len(parents_per_chain):
+ chain_parents = parents_per_chain[chain_counter]
+ else:
+ chain_parents = ["N/A"]
+
+ out_pdb_lines.append(make_parent_line(chain_parents))
+
+ return "\n".join(out_pdb_lines)
+
+
+def to_pdb(prot: Protein) -> str:
+ """Converts a `Protein` instance to a PDB string.
+
+ Args:
+ prot: The protein to convert to PDB.
+
+ Returns:
+ PDB string.
+ """
+ restypes = residue_constants.restypes + ["X"]
+
+ def res_1to3(r: int) -> str:
+ return residue_constants.restype_1to3.get(restypes[r], "UNK")
+
+ atom_types = residue_constants.atom_types
+
+ pdb_lines: List[str] = []
+
+ atom_mask = prot.atom_mask
+ aatype = prot.aatype
+ atom_positions = prot.atom_positions
+ residue_index = prot.residue_index.astype(np.int32)
+ b_factors = prot.b_factors
+ chain_index = prot.chain_index
+
+ if np.any(aatype > residue_constants.restype_num):
+ raise ValueError("Invalid aatypes.")
+
+ headers = get_pdb_headers(prot)
+ if len(headers) > 0:
+ pdb_lines.extend(headers)
+
+ n = aatype.shape[0]
+ atom_index = 1
+ prev_chain_index = 0
+ chain_tags = string.ascii_uppercase
+ chain_tag = None
+ # Add all atom sites.
+ for i in range(n):
+ res_name_3 = res_1to3(aatype[i])
+ for atom_name, pos, mask, b_factor in zip(atom_types, atom_positions[i], atom_mask[i], b_factors[i]):
+ if mask < 0.5:
+ continue
+
+ record_type = "ATOM"
+ name = atom_name if len(atom_name) == 4 else f" {atom_name}"
+ alt_loc = ""
+ insertion_code = ""
+ occupancy = 1.00
+ element = atom_name[0] # Protein supports only C, N, O, S, this works.
+ charge = ""
+
+ chain_tag = "A"
+ if chain_index is not None:
+ chain_tag = chain_tags[chain_index[i]]
+
+ # PDB is a columnar format, every space matters here!
+ atom_line = (
+ f"{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}"
+ f"{res_name_3:>3} {chain_tag:>1}"
+ f"{residue_index[i]:>4}{insertion_code:>1} "
+ f"{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}"
+ f"{occupancy:>6.2f}{b_factor:>6.2f} "
+ f"{element:>2}{charge:>2}"
+ )
+ pdb_lines.append(atom_line)
+ atom_index += 1
+
+ should_terminate = i == n - 1
+ if chain_index is not None:
+ if i != n - 1 and chain_index[i + 1] != prev_chain_index:
+ should_terminate = True
+ prev_chain_index = chain_index[i + 1]
+
+ if should_terminate:
+ # Close the chain.
+ chain_end = "TER"
+ chain_termination_line = (
+ f"{chain_end:<6}{atom_index:>5} {res_1to3(aatype[i]):>3} {chain_tag:>1}{residue_index[i]:>4}"
+ )
+ pdb_lines.append(chain_termination_line)
+ atom_index += 1
+
+ if i != n - 1:
+ # "prev" is a misnomer here. This happens at the beginning of
+ # each new chain.
+ pdb_lines.extend(get_pdb_headers(prot, prev_chain_index))
+
+ pdb_lines.append("END")
+ pdb_lines.append("")
+ return "\n".join(pdb_lines)
+
+
+def ideal_atom_mask(prot: Protein) -> np.ndarray:
+ """Computes an ideal atom mask.
+
+ `Protein.atom_mask` typically is defined according to the atoms that are reported in the PDB. This function
+ computes a mask according to heavy atoms that should be present in the given sequence of amino acids.
+
+ Args:
+ prot: `Protein` whose fields are `numpy.ndarray` objects.
+
+ Returns:
+ An ideal atom mask.
+ """
+ return residue_constants.STANDARD_ATOM_MASK[prot.aatype]
+
+
+def from_prediction(
+ features: FeatureDict,
+ result: ModelOutput,
+ b_factors: Optional[np.ndarray] = None,
+ chain_index: Optional[np.ndarray] = None,
+ remark: Optional[str] = None,
+ parents: Optional[Sequence[str]] = None,
+ parents_chain_index: Optional[Sequence[int]] = None,
+) -> Protein:
+ """Assembles a protein from a prediction.
+
+ Args:
+ features: Dictionary holding model inputs.
+ result: Dictionary holding model outputs.
+ b_factors: (Optional) B-factors to use for the protein.
+ chain_index: (Optional) Chain indices for multi-chain predictions
+ remark: (Optional) Remark about the prediction
+ parents: (Optional) List of template names
+ Returns:
+ A protein instance.
+ """
+ return Protein(
+ aatype=features["aatype"],
+ atom_positions=result["final_atom_positions"],
+ atom_mask=result["final_atom_mask"],
+ residue_index=features["residue_index"] + 1,
+ b_factors=b_factors if b_factors is not None else np.zeros_like(result["final_atom_mask"]),
+ chain_index=chain_index,
+ remark=remark,
+ parents=parents,
+ parents_chain_index=parents_chain_index,
+ )
diff --git a/docs/transformers/build/lib/transformers/models/esm/openfold_utils/residue_constants.py b/docs/transformers/build/lib/transformers/models/esm/openfold_utils/residue_constants.py
new file mode 100644
index 0000000000000000000000000000000000000000..b05a603fb29f51e2b870c73fa28899cd91936bdb
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/esm/openfold_utils/residue_constants.py
@@ -0,0 +1,981 @@
+# Copyright 2021 AlQuraishi Laboratory
+# Copyright 2021 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Constants used in AlphaFold."""
+
+import collections
+import copy
+import functools
+from importlib import resources
+from typing import Dict, List, Mapping, Sequence, Tuple
+
+import numpy as np
+
+
+# Internal import (35fd).
+
+
+# Distance from one CA to next CA [trans configuration: omega = 180].
+ca_ca = 3.80209737096
+
+# Format: The list for each AA type contains chi1, chi2, chi3, chi4 in
+# this order (or a relevant subset from chi1 onwards). ALA and GLY don't have
+# chi angles so their chi angle lists are empty.
+chi_angles_atoms: Dict[str, List[List[str]]] = {
+ "ALA": [],
+ # Chi5 in arginine is always 0 +- 5 degrees, so ignore it.
+ "ARG": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD"], ["CB", "CG", "CD", "NE"], ["CG", "CD", "NE", "CZ"]],
+ "ASN": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "OD1"]],
+ "ASP": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "OD1"]],
+ "CYS": [["N", "CA", "CB", "SG"]],
+ "GLN": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD"], ["CB", "CG", "CD", "OE1"]],
+ "GLU": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD"], ["CB", "CG", "CD", "OE1"]],
+ "GLY": [],
+ "HIS": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "ND1"]],
+ "ILE": [["N", "CA", "CB", "CG1"], ["CA", "CB", "CG1", "CD1"]],
+ "LEU": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
+ "LYS": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD"], ["CB", "CG", "CD", "CE"], ["CG", "CD", "CE", "NZ"]],
+ "MET": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "SD"], ["CB", "CG", "SD", "CE"]],
+ "PHE": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
+ "PRO": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD"]],
+ "SER": [["N", "CA", "CB", "OG"]],
+ "THR": [["N", "CA", "CB", "OG1"]],
+ "TRP": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
+ "TYR": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]],
+ "VAL": [["N", "CA", "CB", "CG1"]],
+}
+
+# If chi angles given in fixed-length array, this matrix determines how to mask
+# them for each AA type. The order is as per restype_order (see below).
+chi_angles_mask: List[List[float]] = [
+ [0.0, 0.0, 0.0, 0.0], # ALA
+ [1.0, 1.0, 1.0, 1.0], # ARG
+ [1.0, 1.0, 0.0, 0.0], # ASN
+ [1.0, 1.0, 0.0, 0.0], # ASP
+ [1.0, 0.0, 0.0, 0.0], # CYS
+ [1.0, 1.0, 1.0, 0.0], # GLN
+ [1.0, 1.0, 1.0, 0.0], # GLU
+ [0.0, 0.0, 0.0, 0.0], # GLY
+ [1.0, 1.0, 0.0, 0.0], # HIS
+ [1.0, 1.0, 0.0, 0.0], # ILE
+ [1.0, 1.0, 0.0, 0.0], # LEU
+ [1.0, 1.0, 1.0, 1.0], # LYS
+ [1.0, 1.0, 1.0, 0.0], # MET
+ [1.0, 1.0, 0.0, 0.0], # PHE
+ [1.0, 1.0, 0.0, 0.0], # PRO
+ [1.0, 0.0, 0.0, 0.0], # SER
+ [1.0, 0.0, 0.0, 0.0], # THR
+ [1.0, 1.0, 0.0, 0.0], # TRP
+ [1.0, 1.0, 0.0, 0.0], # TYR
+ [1.0, 0.0, 0.0, 0.0], # VAL
+]
+
+# The following chi angles are pi periodic: they can be rotated by a multiple
+# of pi without affecting the structure.
+chi_pi_periodic: List[List[float]] = [
+ [0.0, 0.0, 0.0, 0.0], # ALA
+ [0.0, 0.0, 0.0, 0.0], # ARG
+ [0.0, 0.0, 0.0, 0.0], # ASN
+ [0.0, 1.0, 0.0, 0.0], # ASP
+ [0.0, 0.0, 0.0, 0.0], # CYS
+ [0.0, 0.0, 0.0, 0.0], # GLN
+ [0.0, 0.0, 1.0, 0.0], # GLU
+ [0.0, 0.0, 0.0, 0.0], # GLY
+ [0.0, 0.0, 0.0, 0.0], # HIS
+ [0.0, 0.0, 0.0, 0.0], # ILE
+ [0.0, 0.0, 0.0, 0.0], # LEU
+ [0.0, 0.0, 0.0, 0.0], # LYS
+ [0.0, 0.0, 0.0, 0.0], # MET
+ [0.0, 1.0, 0.0, 0.0], # PHE
+ [0.0, 0.0, 0.0, 0.0], # PRO
+ [0.0, 0.0, 0.0, 0.0], # SER
+ [0.0, 0.0, 0.0, 0.0], # THR
+ [0.0, 0.0, 0.0, 0.0], # TRP
+ [0.0, 1.0, 0.0, 0.0], # TYR
+ [0.0, 0.0, 0.0, 0.0], # VAL
+ [0.0, 0.0, 0.0, 0.0], # UNK
+]
+
+# Atoms positions relative to the 8 rigid groups, defined by the pre-omega, phi,
+# psi and chi angles:
+# 0: 'backbone group',
+# 1: 'pre-omega-group', (empty)
+# 2: 'phi-group', (currently empty, because it defines only hydrogens)
+# 3: 'psi-group',
+# 4,5,6,7: 'chi1,2,3,4-group'
+# The atom positions are relative to the axis-end-atom of the corresponding
+# rotation axis. The x-axis is in direction of the rotation axis, and the y-axis
+# is defined such that the dihedral-angle-definiting atom (the last entry in
+# chi_angles_atoms above) is in the xy-plane (with a positive y-coordinate).
+# format: [atomname, group_idx, rel_position]
+rigid_group_atom_positions: Dict[str, List[Tuple[str, int, Tuple[float, float, float]]]] = {
+ "ALA": [
+ ("N", 0, (-0.525, 1.363, 0.000)),
+ ("CA", 0, (0.000, 0.000, 0.000)),
+ ("C", 0, (1.526, -0.000, -0.000)),
+ ("CB", 0, (-0.529, -0.774, -1.205)),
+ ("O", 3, (0.627, 1.062, 0.000)),
+ ],
+ "ARG": [
+ ("N", 0, (-0.524, 1.362, -0.000)),
+ ("CA", 0, (0.000, 0.000, 0.000)),
+ ("C", 0, (1.525, -0.000, -0.000)),
+ ("CB", 0, (-0.524, -0.778, -1.209)),
+ ("O", 3, (0.626, 1.062, 0.000)),
+ ("CG", 4, (0.616, 1.390, -0.000)),
+ ("CD", 5, (0.564, 1.414, 0.000)),
+ ("NE", 6, (0.539, 1.357, -0.000)),
+ ("NH1", 7, (0.206, 2.301, 0.000)),
+ ("NH2", 7, (2.078, 0.978, -0.000)),
+ ("CZ", 7, (0.758, 1.093, -0.000)),
+ ],
+ "ASN": [
+ ("N", 0, (-0.536, 1.357, 0.000)),
+ ("CA", 0, (0.000, 0.000, 0.000)),
+ ("C", 0, (1.526, -0.000, -0.000)),
+ ("CB", 0, (-0.531, -0.787, -1.200)),
+ ("O", 3, (0.625, 1.062, 0.000)),
+ ("CG", 4, (0.584, 1.399, 0.000)),
+ ("ND2", 5, (0.593, -1.188, 0.001)),
+ ("OD1", 5, (0.633, 1.059, 0.000)),
+ ],
+ "ASP": [
+ ("N", 0, (-0.525, 1.362, -0.000)),
+ ("CA", 0, (0.000, 0.000, 0.000)),
+ ("C", 0, (1.527, 0.000, -0.000)),
+ ("CB", 0, (-0.526, -0.778, -1.208)),
+ ("O", 3, (0.626, 1.062, -0.000)),
+ ("CG", 4, (0.593, 1.398, -0.000)),
+ ("OD1", 5, (0.610, 1.091, 0.000)),
+ ("OD2", 5, (0.592, -1.101, -0.003)),
+ ],
+ "CYS": [
+ ("N", 0, (-0.522, 1.362, -0.000)),
+ ("CA", 0, (0.000, 0.000, 0.000)),
+ ("C", 0, (1.524, 0.000, 0.000)),
+ ("CB", 0, (-0.519, -0.773, -1.212)),
+ ("O", 3, (0.625, 1.062, -0.000)),
+ ("SG", 4, (0.728, 1.653, 0.000)),
+ ],
+ "GLN": [
+ ("N", 0, (-0.526, 1.361, -0.000)),
+ ("CA", 0, (0.000, 0.000, 0.000)),
+ ("C", 0, (1.526, 0.000, 0.000)),
+ ("CB", 0, (-0.525, -0.779, -1.207)),
+ ("O", 3, (0.626, 1.062, -0.000)),
+ ("CG", 4, (0.615, 1.393, 0.000)),
+ ("CD", 5, (0.587, 1.399, -0.000)),
+ ("NE2", 6, (0.593, -1.189, -0.001)),
+ ("OE1", 6, (0.634, 1.060, 0.000)),
+ ],
+ "GLU": [
+ ("N", 0, (-0.528, 1.361, 0.000)),
+ ("CA", 0, (0.000, 0.000, 0.000)),
+ ("C", 0, (1.526, -0.000, -0.000)),
+ ("CB", 0, (-0.526, -0.781, -1.207)),
+ ("O", 3, (0.626, 1.062, 0.000)),
+ ("CG", 4, (0.615, 1.392, 0.000)),
+ ("CD", 5, (0.600, 1.397, 0.000)),
+ ("OE1", 6, (0.607, 1.095, -0.000)),
+ ("OE2", 6, (0.589, -1.104, -0.001)),
+ ],
+ "GLY": [
+ ("N", 0, (-0.572, 1.337, 0.000)),
+ ("CA", 0, (0.000, 0.000, 0.000)),
+ ("C", 0, (1.517, -0.000, -0.000)),
+ ("O", 3, (0.626, 1.062, -0.000)),
+ ],
+ "HIS": [
+ ("N", 0, (-0.527, 1.360, 0.000)),
+ ("CA", 0, (0.000, 0.000, 0.000)),
+ ("C", 0, (1.525, 0.000, 0.000)),
+ ("CB", 0, (-0.525, -0.778, -1.208)),
+ ("O", 3, (0.625, 1.063, 0.000)),
+ ("CG", 4, (0.600, 1.370, -0.000)),
+ ("CD2", 5, (0.889, -1.021, 0.003)),
+ ("ND1", 5, (0.744, 1.160, -0.000)),
+ ("CE1", 5, (2.030, 0.851, 0.002)),
+ ("NE2", 5, (2.145, -0.466, 0.004)),
+ ],
+ "ILE": [
+ ("N", 0, (-0.493, 1.373, -0.000)),
+ ("CA", 0, (0.000, 0.000, 0.000)),
+ ("C", 0, (1.527, -0.000, -0.000)),
+ ("CB", 0, (-0.536, -0.793, -1.213)),
+ ("O", 3, (0.627, 1.062, -0.000)),
+ ("CG1", 4, (0.534, 1.437, -0.000)),
+ ("CG2", 4, (0.540, -0.785, -1.199)),
+ ("CD1", 5, (0.619, 1.391, 0.000)),
+ ],
+ "LEU": [
+ ("N", 0, (-0.520, 1.363, 0.000)),
+ ("CA", 0, (0.000, 0.000, 0.000)),
+ ("C", 0, (1.525, -0.000, -0.000)),
+ ("CB", 0, (-0.522, -0.773, -1.214)),
+ ("O", 3, (0.625, 1.063, -0.000)),
+ ("CG", 4, (0.678, 1.371, 0.000)),
+ ("CD1", 5, (0.530, 1.430, -0.000)),
+ ("CD2", 5, (0.535, -0.774, 1.200)),
+ ],
+ "LYS": [
+ ("N", 0, (-0.526, 1.362, -0.000)),
+ ("CA", 0, (0.000, 0.000, 0.000)),
+ ("C", 0, (1.526, 0.000, 0.000)),
+ ("CB", 0, (-0.524, -0.778, -1.208)),
+ ("O", 3, (0.626, 1.062, -0.000)),
+ ("CG", 4, (0.619, 1.390, 0.000)),
+ ("CD", 5, (0.559, 1.417, 0.000)),
+ ("CE", 6, (0.560, 1.416, 0.000)),
+ ("NZ", 7, (0.554, 1.387, 0.000)),
+ ],
+ "MET": [
+ ("N", 0, (-0.521, 1.364, -0.000)),
+ ("CA", 0, (0.000, 0.000, 0.000)),
+ ("C", 0, (1.525, 0.000, 0.000)),
+ ("CB", 0, (-0.523, -0.776, -1.210)),
+ ("O", 3, (0.625, 1.062, -0.000)),
+ ("CG", 4, (0.613, 1.391, -0.000)),
+ ("SD", 5, (0.703, 1.695, 0.000)),
+ ("CE", 6, (0.320, 1.786, -0.000)),
+ ],
+ "PHE": [
+ ("N", 0, (-0.518, 1.363, 0.000)),
+ ("CA", 0, (0.000, 0.000, 0.000)),
+ ("C", 0, (1.524, 0.000, -0.000)),
+ ("CB", 0, (-0.525, -0.776, -1.212)),
+ ("O", 3, (0.626, 1.062, -0.000)),
+ ("CG", 4, (0.607, 1.377, 0.000)),
+ ("CD1", 5, (0.709, 1.195, -0.000)),
+ ("CD2", 5, (0.706, -1.196, 0.000)),
+ ("CE1", 5, (2.102, 1.198, -0.000)),
+ ("CE2", 5, (2.098, -1.201, -0.000)),
+ ("CZ", 5, (2.794, -0.003, -0.001)),
+ ],
+ "PRO": [
+ ("N", 0, (-0.566, 1.351, -0.000)),
+ ("CA", 0, (0.000, 0.000, 0.000)),
+ ("C", 0, (1.527, -0.000, 0.000)),
+ ("CB", 0, (-0.546, -0.611, -1.293)),
+ ("O", 3, (0.621, 1.066, 0.000)),
+ ("CG", 4, (0.382, 1.445, 0.0)),
+ # ('CD', 5, (0.427, 1.440, 0.0)),
+ ("CD", 5, (0.477, 1.424, 0.0)), # manually made angle 2 degrees larger
+ ],
+ "SER": [
+ ("N", 0, (-0.529, 1.360, -0.000)),
+ ("CA", 0, (0.000, 0.000, 0.000)),
+ ("C", 0, (1.525, -0.000, -0.000)),
+ ("CB", 0, (-0.518, -0.777, -1.211)),
+ ("O", 3, (0.626, 1.062, -0.000)),
+ ("OG", 4, (0.503, 1.325, 0.000)),
+ ],
+ "THR": [
+ ("N", 0, (-0.517, 1.364, 0.000)),
+ ("CA", 0, (0.000, 0.000, 0.000)),
+ ("C", 0, (1.526, 0.000, -0.000)),
+ ("CB", 0, (-0.516, -0.793, -1.215)),
+ ("O", 3, (0.626, 1.062, 0.000)),
+ ("CG2", 4, (0.550, -0.718, -1.228)),
+ ("OG1", 4, (0.472, 1.353, 0.000)),
+ ],
+ "TRP": [
+ ("N", 0, (-0.521, 1.363, 0.000)),
+ ("CA", 0, (0.000, 0.000, 0.000)),
+ ("C", 0, (1.525, -0.000, 0.000)),
+ ("CB", 0, (-0.523, -0.776, -1.212)),
+ ("O", 3, (0.627, 1.062, 0.000)),
+ ("CG", 4, (0.609, 1.370, -0.000)),
+ ("CD1", 5, (0.824, 1.091, 0.000)),
+ ("CD2", 5, (0.854, -1.148, -0.005)),
+ ("CE2", 5, (2.186, -0.678, -0.007)),
+ ("CE3", 5, (0.622, -2.530, -0.007)),
+ ("NE1", 5, (2.140, 0.690, -0.004)),
+ ("CH2", 5, (3.028, -2.890, -0.013)),
+ ("CZ2", 5, (3.283, -1.543, -0.011)),
+ ("CZ3", 5, (1.715, -3.389, -0.011)),
+ ],
+ "TYR": [
+ ("N", 0, (-0.522, 1.362, 0.000)),
+ ("CA", 0, (0.000, 0.000, 0.000)),
+ ("C", 0, (1.524, -0.000, -0.000)),
+ ("CB", 0, (-0.522, -0.776, -1.213)),
+ ("O", 3, (0.627, 1.062, -0.000)),
+ ("CG", 4, (0.607, 1.382, -0.000)),
+ ("CD1", 5, (0.716, 1.195, -0.000)),
+ ("CD2", 5, (0.713, -1.194, -0.001)),
+ ("CE1", 5, (2.107, 1.200, -0.002)),
+ ("CE2", 5, (2.104, -1.201, -0.003)),
+ ("OH", 5, (4.168, -0.002, -0.005)),
+ ("CZ", 5, (2.791, -0.001, -0.003)),
+ ],
+ "VAL": [
+ ("N", 0, (-0.494, 1.373, -0.000)),
+ ("CA", 0, (0.000, 0.000, 0.000)),
+ ("C", 0, (1.527, -0.000, -0.000)),
+ ("CB", 0, (-0.533, -0.795, -1.213)),
+ ("O", 3, (0.627, 1.062, -0.000)),
+ ("CG1", 4, (0.540, 1.429, -0.000)),
+ ("CG2", 4, (0.533, -0.776, 1.203)),
+ ],
+}
+
+# A list of atoms (excluding hydrogen) for each AA type. PDB naming convention.
+residue_atoms: Dict[str, List[str]] = {
+ "ALA": ["C", "CA", "CB", "N", "O"],
+ "ARG": ["C", "CA", "CB", "CG", "CD", "CZ", "N", "NE", "O", "NH1", "NH2"],
+ "ASP": ["C", "CA", "CB", "CG", "N", "O", "OD1", "OD2"],
+ "ASN": ["C", "CA", "CB", "CG", "N", "ND2", "O", "OD1"],
+ "CYS": ["C", "CA", "CB", "N", "O", "SG"],
+ "GLU": ["C", "CA", "CB", "CG", "CD", "N", "O", "OE1", "OE2"],
+ "GLN": ["C", "CA", "CB", "CG", "CD", "N", "NE2", "O", "OE1"],
+ "GLY": ["C", "CA", "N", "O"],
+ "HIS": ["C", "CA", "CB", "CG", "CD2", "CE1", "N", "ND1", "NE2", "O"],
+ "ILE": ["C", "CA", "CB", "CG1", "CG2", "CD1", "N", "O"],
+ "LEU": ["C", "CA", "CB", "CG", "CD1", "CD2", "N", "O"],
+ "LYS": ["C", "CA", "CB", "CG", "CD", "CE", "N", "NZ", "O"],
+ "MET": ["C", "CA", "CB", "CG", "CE", "N", "O", "SD"],
+ "PHE": ["C", "CA", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "N", "O"],
+ "PRO": ["C", "CA", "CB", "CG", "CD", "N", "O"],
+ "SER": ["C", "CA", "CB", "N", "O", "OG"],
+ "THR": ["C", "CA", "CB", "CG2", "N", "O", "OG1"],
+ "TRP": ["C", "CA", "CB", "CG", "CD1", "CD2", "CE2", "CE3", "CZ2", "CZ3", "CH2", "N", "NE1", "O"],
+ "TYR": ["C", "CA", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "N", "O", "OH"],
+ "VAL": ["C", "CA", "CB", "CG1", "CG2", "N", "O"],
+}
+
+# Naming swaps for ambiguous atom names.
+# Due to symmetries in the amino acids the naming of atoms is ambiguous in
+# 4 of the 20 amino acids.
+# (The LDDT paper lists 7 amino acids as ambiguous, but the naming ambiguities
+# in LEU, VAL and ARG can be resolved by using the 3d constellations of
+# the 'ambiguous' atoms and their neighbours)
+# TODO: ^ interpret this
+residue_atom_renaming_swaps: Dict[str, Dict[str, str]] = {
+ "ASP": {"OD1": "OD2"},
+ "GLU": {"OE1": "OE2"},
+ "PHE": {"CD1": "CD2", "CE1": "CE2"},
+ "TYR": {"CD1": "CD2", "CE1": "CE2"},
+}
+
+# Van der Waals radii [Angstroem] of the atoms (from Wikipedia)
+van_der_waals_radius: Dict[str, float] = {
+ "C": 1.7,
+ "N": 1.55,
+ "O": 1.52,
+ "S": 1.8,
+}
+
+Bond = collections.namedtuple("Bond", ["atom1_name", "atom2_name", "length", "stddev"])
+BondAngle = collections.namedtuple(
+ "BondAngle",
+ ["atom1_name", "atom2_name", "atom3name", "angle_rad", "stddev"],
+)
+
+
+def map_structure_with_atom_order(in_list: list, first_call: bool = True) -> list:
+ # Maps strings in a nested list structure to their corresponding index in atom_order
+ if first_call:
+ in_list = copy.deepcopy(in_list)
+ for i in range(len(in_list)):
+ if isinstance(in_list[i], list):
+ in_list[i] = map_structure_with_atom_order(in_list[i], first_call=False)
+ elif isinstance(in_list[i], str):
+ in_list[i] = atom_order[in_list[i]]
+ else:
+ raise TypeError("Unexpected type when mapping nested lists!")
+ return in_list
+
+
+@functools.lru_cache(maxsize=None)
+def load_stereo_chemical_props() -> Tuple[
+ Mapping[str, List[Bond]],
+ Mapping[str, List[Bond]],
+ Mapping[str, List[BondAngle]],
+]:
+ """Load stereo_chemical_props.txt into a nice structure.
+
+ Load literature values for bond lengths and bond angles and translate bond angles into the length of the opposite
+ edge of the triangle ("residue_virtual_bonds").
+
+ Returns:
+ residue_bonds: dict that maps resname --> list of Bond tuples residue_virtual_bonds: dict that maps resname -->
+ list of Bond tuples residue_bond_angles: dict that maps resname --> list of BondAngle tuples
+ """
+ # TODO: this file should be downloaded in a setup script
+ stereo_chemical_props = resources.read_text("openfold.resources", "stereo_chemical_props.txt")
+
+ lines_iter = iter(stereo_chemical_props.splitlines())
+ # Load bond lengths.
+ residue_bonds: Dict[str, List[Bond]] = {}
+ next(lines_iter) # Skip header line.
+ for line in lines_iter:
+ if line.strip() == "-":
+ break
+ bond, resname, bond_length, stddev = line.split()
+ atom1, atom2 = bond.split("-")
+ if resname not in residue_bonds:
+ residue_bonds[resname] = []
+ residue_bonds[resname].append(Bond(atom1, atom2, float(bond_length), float(stddev)))
+ residue_bonds["UNK"] = []
+
+ # Load bond angles.
+ residue_bond_angles: Dict[str, List[BondAngle]] = {}
+ next(lines_iter) # Skip empty line.
+ next(lines_iter) # Skip header line.
+ for line in lines_iter:
+ if line.strip() == "-":
+ break
+ bond, resname, angle_degree, stddev_degree = line.split()
+ atom1, atom2, atom3 = bond.split("-")
+ if resname not in residue_bond_angles:
+ residue_bond_angles[resname] = []
+ residue_bond_angles[resname].append(
+ BondAngle(
+ atom1,
+ atom2,
+ atom3,
+ float(angle_degree) / 180.0 * np.pi,
+ float(stddev_degree) / 180.0 * np.pi,
+ )
+ )
+ residue_bond_angles["UNK"] = []
+
+ def make_bond_key(atom1_name: str, atom2_name: str) -> str:
+ """Unique key to lookup bonds."""
+ return "-".join(sorted([atom1_name, atom2_name]))
+
+ # Translate bond angles into distances ("virtual bonds").
+ residue_virtual_bonds: Dict[str, List[Bond]] = {}
+ for resname, bond_angles in residue_bond_angles.items():
+ # Create a fast lookup dict for bond lengths.
+ bond_cache: Dict[str, Bond] = {}
+ for b in residue_bonds[resname]:
+ bond_cache[make_bond_key(b.atom1_name, b.atom2_name)] = b
+ residue_virtual_bonds[resname] = []
+ for ba in bond_angles:
+ bond1 = bond_cache[make_bond_key(ba.atom1_name, ba.atom2_name)]
+ bond2 = bond_cache[make_bond_key(ba.atom2_name, ba.atom3name)]
+
+ # Compute distance between atom1 and atom3 using the law of cosines
+ # c^2 = a^2 + b^2 - 2ab*cos(gamma).
+ gamma = ba.angle_rad
+ length = np.sqrt(bond1.length**2 + bond2.length**2 - 2 * bond1.length * bond2.length * np.cos(gamma))
+
+ # Propagation of uncertainty assuming uncorrelated errors.
+ dl_outer = 0.5 / length
+ dl_dgamma = (2 * bond1.length * bond2.length * np.sin(gamma)) * dl_outer
+ dl_db1 = (2 * bond1.length - 2 * bond2.length * np.cos(gamma)) * dl_outer
+ dl_db2 = (2 * bond2.length - 2 * bond1.length * np.cos(gamma)) * dl_outer
+ stddev = np.sqrt(
+ (dl_dgamma * ba.stddev) ** 2 + (dl_db1 * bond1.stddev) ** 2 + (dl_db2 * bond2.stddev) ** 2
+ )
+ residue_virtual_bonds[resname].append(Bond(ba.atom1_name, ba.atom3name, length, stddev))
+
+ return (residue_bonds, residue_virtual_bonds, residue_bond_angles)
+
+
+# Between-residue bond lengths for general bonds (first element) and for Proline
+# (second element).
+between_res_bond_length_c_n: Tuple[float, float] = (1.329, 1.341)
+between_res_bond_length_stddev_c_n: Tuple[float, float] = (0.014, 0.016)
+
+# Between-residue cos_angles.
+between_res_cos_angles_c_n_ca: Tuple[float, float] = (-0.5203, 0.0353) # degrees: 121.352 +- 2.315
+between_res_cos_angles_ca_c_n: Tuple[float, float] = (-0.4473, 0.0311) # degrees: 116.568 +- 1.995
+
+# This mapping is used when we need to store atom data in a format that requires
+# fixed atom data size for every residue (e.g. a numpy array).
+atom_types: List[str] = [
+ "N",
+ "CA",
+ "C",
+ "CB",
+ "O",
+ "CG",
+ "CG1",
+ "CG2",
+ "OG",
+ "OG1",
+ "SG",
+ "CD",
+ "CD1",
+ "CD2",
+ "ND1",
+ "ND2",
+ "OD1",
+ "OD2",
+ "SD",
+ "CE",
+ "CE1",
+ "CE2",
+ "CE3",
+ "NE",
+ "NE1",
+ "NE2",
+ "OE1",
+ "OE2",
+ "CH2",
+ "NH1",
+ "NH2",
+ "OH",
+ "CZ",
+ "CZ2",
+ "CZ3",
+ "NZ",
+ "OXT",
+]
+atom_order: Dict[str, int] = {atom_type: i for i, atom_type in enumerate(atom_types)}
+atom_type_num = len(atom_types) # := 37.
+
+# A compact atom encoding with 14 columns
+# pylint: disable=line-too-long
+# pylint: disable=bad-whitespace
+restype_name_to_atom14_names: Dict[str, List[str]] = {
+ "ALA": ["N", "CA", "C", "O", "CB", "", "", "", "", "", "", "", "", ""],
+ "ARG": ["N", "CA", "C", "O", "CB", "CG", "CD", "NE", "CZ", "NH1", "NH2", "", "", ""],
+ "ASN": ["N", "CA", "C", "O", "CB", "CG", "OD1", "ND2", "", "", "", "", "", ""],
+ "ASP": ["N", "CA", "C", "O", "CB", "CG", "OD1", "OD2", "", "", "", "", "", ""],
+ "CYS": ["N", "CA", "C", "O", "CB", "SG", "", "", "", "", "", "", "", ""],
+ "GLN": ["N", "CA", "C", "O", "CB", "CG", "CD", "OE1", "NE2", "", "", "", "", ""],
+ "GLU": ["N", "CA", "C", "O", "CB", "CG", "CD", "OE1", "OE2", "", "", "", "", ""],
+ "GLY": ["N", "CA", "C", "O", "", "", "", "", "", "", "", "", "", ""],
+ "HIS": ["N", "CA", "C", "O", "CB", "CG", "ND1", "CD2", "CE1", "NE2", "", "", "", ""],
+ "ILE": ["N", "CA", "C", "O", "CB", "CG1", "CG2", "CD1", "", "", "", "", "", ""],
+ "LEU": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "", "", "", "", "", ""],
+ "LYS": ["N", "CA", "C", "O", "CB", "CG", "CD", "CE", "NZ", "", "", "", "", ""],
+ "MET": ["N", "CA", "C", "O", "CB", "CG", "SD", "CE", "", "", "", "", "", ""],
+ "PHE": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "", "", ""],
+ "PRO": ["N", "CA", "C", "O", "CB", "CG", "CD", "", "", "", "", "", "", ""],
+ "SER": ["N", "CA", "C", "O", "CB", "OG", "", "", "", "", "", "", "", ""],
+ "THR": ["N", "CA", "C", "O", "CB", "OG1", "CG2", "", "", "", "", "", "", ""],
+ "TRP": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "NE1", "CE2", "CE3", "CZ2", "CZ3", "CH2"],
+ "TYR": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "OH", "", ""],
+ "VAL": ["N", "CA", "C", "O", "CB", "CG1", "CG2", "", "", "", "", "", "", ""],
+ "UNK": ["", "", "", "", "", "", "", "", "", "", "", "", "", ""],
+}
+# pylint: enable=line-too-long
+# pylint: enable=bad-whitespace
+
+
+# This is the standard residue order when coding AA type as a number.
+# Reproduce it by taking 3-letter AA codes and sorting them alphabetically.
+restypes: List[str] = [
+ "A",
+ "R",
+ "N",
+ "D",
+ "C",
+ "Q",
+ "E",
+ "G",
+ "H",
+ "I",
+ "L",
+ "K",
+ "M",
+ "F",
+ "P",
+ "S",
+ "T",
+ "W",
+ "Y",
+ "V",
+]
+restype_order: Dict[str, int] = {restype: i for i, restype in enumerate(restypes)}
+restype_num = len(restypes) # := 20.
+unk_restype_index = restype_num # Catch-all index for unknown restypes.
+
+restypes_with_x: List[str] = restypes + ["X"]
+restype_order_with_x: Dict[str, int] = {restype: i for i, restype in enumerate(restypes_with_x)}
+
+
+def sequence_to_onehot(sequence: str, mapping: Mapping[str, int], map_unknown_to_x: bool = False) -> np.ndarray:
+ """Maps the given sequence into a one-hot encoded matrix.
+
+ Args:
+ sequence: An amino acid sequence.
+ mapping: A dictionary mapping amino acids to integers.
+ map_unknown_to_x: If True, any amino acid that is not in the mapping will be
+ mapped to the unknown amino acid 'X'. If the mapping doesn't contain amino acid 'X', an error will be thrown.
+ If False, any amino acid not in the mapping will throw an error.
+
+ Returns:
+ A numpy array of shape (seq_len, num_unique_aas) with one-hot encoding of the sequence.
+
+ Raises:
+ ValueError: If the mapping doesn't contain values from 0 to
+ num_unique_aas - 1 without any gaps.
+ """
+ num_entries = max(mapping.values()) + 1
+
+ if sorted(set(mapping.values())) != list(range(num_entries)):
+ raise ValueError(
+ "The mapping must have values from 0 to num_unique_aas-1 without any gaps. Got: %s"
+ % sorted(mapping.values())
+ )
+
+ one_hot_arr = np.zeros((len(sequence), num_entries), dtype=np.int32)
+
+ for aa_index, aa_type in enumerate(sequence):
+ if map_unknown_to_x:
+ if aa_type.isalpha() and aa_type.isupper():
+ aa_id = mapping.get(aa_type, mapping["X"])
+ else:
+ raise ValueError(f"Invalid character in the sequence: {aa_type}")
+ else:
+ aa_id = mapping[aa_type]
+ one_hot_arr[aa_index, aa_id] = 1
+
+ return one_hot_arr
+
+
+restype_1to3: Dict[str, str] = {
+ "A": "ALA",
+ "R": "ARG",
+ "N": "ASN",
+ "D": "ASP",
+ "C": "CYS",
+ "Q": "GLN",
+ "E": "GLU",
+ "G": "GLY",
+ "H": "HIS",
+ "I": "ILE",
+ "L": "LEU",
+ "K": "LYS",
+ "M": "MET",
+ "F": "PHE",
+ "P": "PRO",
+ "S": "SER",
+ "T": "THR",
+ "W": "TRP",
+ "Y": "TYR",
+ "V": "VAL",
+}
+
+
+# NB: restype_3to1 differs from Bio.PDB.protein_letters_3to1 by being a simple
+# 1-to-1 mapping of 3 letter names to one letter names. The latter contains
+# many more, and less common, three letter names as keys and maps many of these
+# to the same one letter name (including 'X' and 'U' which we don't use here).
+restype_3to1: Dict[str, str] = {v: k for k, v in restype_1to3.items()}
+
+# Define a restype name for all unknown residues.
+unk_restype = "UNK"
+
+resnames: List[str] = [restype_1to3[r] for r in restypes] + [unk_restype]
+resname_to_idx: Dict[str, int] = {resname: i for i, resname in enumerate(resnames)}
+
+
+# The mapping here uses hhblits convention, so that B is mapped to D, J and O
+# are mapped to X, U is mapped to C, and Z is mapped to E. Other than that the
+# remaining 20 amino acids are kept in alphabetical order.
+# There are 2 non-amino acid codes, X (representing any amino acid) and
+# "-" representing a missing amino acid in an alignment. The id for these
+# codes is put at the end (20 and 21) so that they can easily be ignored if
+# desired.
+HHBLITS_AA_TO_ID: Dict[str, int] = {
+ "A": 0,
+ "B": 2,
+ "C": 1,
+ "D": 2,
+ "E": 3,
+ "F": 4,
+ "G": 5,
+ "H": 6,
+ "I": 7,
+ "J": 20,
+ "K": 8,
+ "L": 9,
+ "M": 10,
+ "N": 11,
+ "O": 20,
+ "P": 12,
+ "Q": 13,
+ "R": 14,
+ "S": 15,
+ "T": 16,
+ "U": 1,
+ "V": 17,
+ "W": 18,
+ "X": 20,
+ "Y": 19,
+ "Z": 3,
+ "-": 21,
+}
+
+# Partial inversion of HHBLITS_AA_TO_ID.
+ID_TO_HHBLITS_AA: Dict[int, str] = {
+ 0: "A",
+ 1: "C", # Also U.
+ 2: "D", # Also B.
+ 3: "E", # Also Z.
+ 4: "F",
+ 5: "G",
+ 6: "H",
+ 7: "I",
+ 8: "K",
+ 9: "L",
+ 10: "M",
+ 11: "N",
+ 12: "P",
+ 13: "Q",
+ 14: "R",
+ 15: "S",
+ 16: "T",
+ 17: "V",
+ 18: "W",
+ 19: "Y",
+ 20: "X", # Includes J and O.
+ 21: "-",
+}
+
+restypes_with_x_and_gap: List[str] = restypes + ["X", "-"]
+MAP_HHBLITS_AATYPE_TO_OUR_AATYPE: Tuple[int, ...] = tuple(
+ restypes_with_x_and_gap.index(ID_TO_HHBLITS_AA[i]) for i in range(len(restypes_with_x_and_gap))
+)
+
+
+def _make_standard_atom_mask() -> np.ndarray:
+ """Returns [num_res_types, num_atom_types] mask array."""
+ # +1 to account for unknown (all 0s).
+ mask = np.zeros([restype_num + 1, atom_type_num], dtype=np.int32)
+ for restype, restype_letter in enumerate(restypes):
+ restype_name = restype_1to3[restype_letter]
+ atom_names = residue_atoms[restype_name]
+ for atom_name in atom_names:
+ atom_type = atom_order[atom_name]
+ mask[restype, atom_type] = 1
+ return mask
+
+
+STANDARD_ATOM_MASK = _make_standard_atom_mask()
+
+
+# A one hot representation for the first and second atoms defining the axis
+# of rotation for each chi-angle in each residue.
+def chi_angle_atom(atom_index: int) -> np.ndarray:
+ """Define chi-angle rigid groups via one-hot representations."""
+ chi_angles_index = {}
+ one_hots = []
+
+ for k, v in chi_angles_atoms.items():
+ indices = [atom_types.index(s[atom_index]) for s in v]
+ indices.extend([-1] * (4 - len(indices)))
+ chi_angles_index[k] = indices
+
+ for r in restypes:
+ res3 = restype_1to3[r]
+ one_hot = np.eye(atom_type_num)[chi_angles_index[res3]]
+ one_hots.append(one_hot)
+
+ one_hots.append(np.zeros([4, atom_type_num])) # Add zeros for residue `X`.
+ one_hot = np.stack(one_hots, axis=0)
+ one_hot = np.transpose(one_hot, [0, 2, 1])
+
+ return one_hot
+
+
+chi_atom_1_one_hot = chi_angle_atom(1)
+chi_atom_2_one_hot = chi_angle_atom(2)
+
+# An array like chi_angles_atoms but using indices rather than names.
+chi_angles_atom_indices_list: List[List[List[str]]] = [chi_angles_atoms[restype_1to3[r]] for r in restypes]
+chi_angles_atom_indices_ours: list = map_structure_with_atom_order(chi_angles_atom_indices_list)
+chi_angles_atom_indices = np.array(
+ [chi_atoms + ([[0, 0, 0, 0]] * (4 - len(chi_atoms))) for chi_atoms in chi_angles_atom_indices_list]
+)
+
+# Mapping from (res_name, atom_name) pairs to the atom's chi group index
+# and atom index within that group.
+chi_groups_for_atom: Dict[Tuple[str, str], List[Tuple[int, int]]] = collections.defaultdict(list)
+for res_name, chi_angle_atoms_for_res in chi_angles_atoms.items():
+ for chi_group_i, chi_group in enumerate(chi_angle_atoms_for_res):
+ for atom_i, atom in enumerate(chi_group):
+ chi_groups_for_atom[(res_name, atom)].append((chi_group_i, atom_i))
+chi_groups_for_atom = dict(chi_groups_for_atom)
+
+
+def _make_rigid_transformation_4x4(ex: np.ndarray, ey: np.ndarray, translation: np.ndarray) -> np.ndarray:
+ """Create a rigid 4x4 transformation matrix from two axes and transl."""
+ # Normalize ex.
+ ex_normalized = ex / np.linalg.norm(ex)
+
+ # make ey perpendicular to ex
+ ey_normalized = ey - np.dot(ey, ex_normalized) * ex_normalized
+ ey_normalized /= np.linalg.norm(ey_normalized)
+
+ # compute ez as cross product
+ eznorm = np.cross(ex_normalized, ey_normalized)
+ m = np.stack([ex_normalized, ey_normalized, eznorm, translation]).transpose()
+ m = np.concatenate([m, [[0.0, 0.0, 0.0, 1.0]]], axis=0)
+ return m
+
+
+# create an array with (restype, atomtype) --> rigid_group_idx
+# and an array with (restype, atomtype, coord) for the atom positions
+# and compute affine transformation matrices (4,4) from one rigid group to the
+# previous group
+restype_atom37_to_rigid_group = np.zeros([21, 37], dtype=int)
+restype_atom37_mask = np.zeros([21, 37], dtype=np.float32)
+restype_atom37_rigid_group_positions = np.zeros([21, 37, 3], dtype=np.float32)
+restype_atom14_to_rigid_group = np.zeros([21, 14], dtype=int)
+restype_atom14_mask = np.zeros([21, 14], dtype=np.float32)
+restype_atom14_rigid_group_positions = np.zeros([21, 14, 3], dtype=np.float32)
+restype_rigid_group_default_frame = np.zeros([21, 8, 4, 4], dtype=np.float32)
+
+
+def _make_rigid_group_constants() -> None:
+ """Fill the arrays above."""
+ for restype, restype_letter in enumerate(restypes):
+ resname = restype_1to3[restype_letter]
+ for atomname, group_idx, atom_position in rigid_group_atom_positions[resname]:
+ atomtype = atom_order[atomname]
+ restype_atom37_to_rigid_group[restype, atomtype] = group_idx
+ restype_atom37_mask[restype, atomtype] = 1
+ restype_atom37_rigid_group_positions[restype, atomtype, :] = atom_position
+
+ atom14idx = restype_name_to_atom14_names[resname].index(atomname)
+ restype_atom14_to_rigid_group[restype, atom14idx] = group_idx
+ restype_atom14_mask[restype, atom14idx] = 1
+ restype_atom14_rigid_group_positions[restype, atom14idx, :] = atom_position
+
+ for restype, restype_letter in enumerate(restypes):
+ resname = restype_1to3[restype_letter]
+ atom_positions: Dict[str, np.ndarray] = {
+ name: np.array(pos) for name, _, pos in rigid_group_atom_positions[resname]
+ }
+
+ # backbone to backbone is the identity transform
+ restype_rigid_group_default_frame[restype, 0, :, :] = np.eye(4)
+
+ # pre-omega-frame to backbone (currently dummy identity matrix)
+ restype_rigid_group_default_frame[restype, 1, :, :] = np.eye(4)
+
+ # phi-frame to backbone
+ mat = _make_rigid_transformation_4x4(
+ ex=atom_positions["N"] - atom_positions["CA"],
+ ey=np.array([1.0, 0.0, 0.0]),
+ translation=atom_positions["N"],
+ )
+ restype_rigid_group_default_frame[restype, 2, :, :] = mat
+
+ # psi-frame to backbone
+ mat = _make_rigid_transformation_4x4(
+ ex=atom_positions["C"] - atom_positions["CA"],
+ ey=atom_positions["CA"] - atom_positions["N"],
+ translation=atom_positions["C"],
+ )
+ restype_rigid_group_default_frame[restype, 3, :, :] = mat
+
+ # chi1-frame to backbone
+ if chi_angles_mask[restype][0]:
+ base_atom_names = chi_angles_atoms[resname][0]
+ base_atom_positions = [atom_positions[name] for name in base_atom_names]
+ mat = _make_rigid_transformation_4x4(
+ ex=base_atom_positions[2] - base_atom_positions[1],
+ ey=base_atom_positions[0] - base_atom_positions[1],
+ translation=base_atom_positions[2],
+ )
+ restype_rigid_group_default_frame[restype, 4, :, :] = mat
+
+ # chi2-frame to chi1-frame
+ # chi3-frame to chi2-frame
+ # chi4-frame to chi3-frame
+ # luckily all rotation axes for the next frame start at (0,0,0) of the
+ # previous frame
+ for chi_idx in range(1, 4):
+ if chi_angles_mask[restype][chi_idx]:
+ axis_end_atom_name = chi_angles_atoms[resname][chi_idx][2]
+ axis_end_atom_position = atom_positions[axis_end_atom_name]
+ mat = _make_rigid_transformation_4x4(
+ ex=axis_end_atom_position,
+ ey=np.array([-1.0, 0.0, 0.0]),
+ translation=axis_end_atom_position,
+ )
+ restype_rigid_group_default_frame[restype, 4 + chi_idx, :, :] = mat
+
+
+_make_rigid_group_constants()
+
+
+def make_atom14_dists_bounds(
+ overlap_tolerance: float = 1.5,
+ bond_length_tolerance_factor: int = 15,
+) -> Dict[str, np.ndarray]:
+ """compute upper and lower bounds for bonds to assess violations."""
+ restype_atom14_bond_lower_bound = np.zeros([21, 14, 14], np.float32)
+ restype_atom14_bond_upper_bound = np.zeros([21, 14, 14], np.float32)
+ restype_atom14_bond_stddev = np.zeros([21, 14, 14], np.float32)
+ residue_bonds, residue_virtual_bonds, _ = load_stereo_chemical_props()
+ for restype, restype_letter in enumerate(restypes):
+ resname = restype_1to3[restype_letter]
+ atom_list = restype_name_to_atom14_names[resname]
+
+ # create lower and upper bounds for clashes
+ for atom1_idx, atom1_name in enumerate(atom_list):
+ if not atom1_name:
+ continue
+ atom1_radius = van_der_waals_radius[atom1_name[0]]
+ for atom2_idx, atom2_name in enumerate(atom_list):
+ if (not atom2_name) or atom1_idx == atom2_idx:
+ continue
+ atom2_radius = van_der_waals_radius[atom2_name[0]]
+ lower = atom1_radius + atom2_radius - overlap_tolerance
+ upper = 1e10
+ restype_atom14_bond_lower_bound[restype, atom1_idx, atom2_idx] = lower
+ restype_atom14_bond_lower_bound[restype, atom2_idx, atom1_idx] = lower
+ restype_atom14_bond_upper_bound[restype, atom1_idx, atom2_idx] = upper
+ restype_atom14_bond_upper_bound[restype, atom2_idx, atom1_idx] = upper
+
+ # overwrite lower and upper bounds for bonds and angles
+ for b in residue_bonds[resname] + residue_virtual_bonds[resname]:
+ atom1_idx = atom_list.index(b.atom1_name)
+ atom2_idx = atom_list.index(b.atom2_name)
+ lower = b.length - bond_length_tolerance_factor * b.stddev
+ upper = b.length + bond_length_tolerance_factor * b.stddev
+ restype_atom14_bond_lower_bound[restype, atom1_idx, atom2_idx] = lower
+ restype_atom14_bond_lower_bound[restype, atom2_idx, atom1_idx] = lower
+ restype_atom14_bond_upper_bound[restype, atom1_idx, atom2_idx] = upper
+ restype_atom14_bond_upper_bound[restype, atom2_idx, atom1_idx] = upper
+ restype_atom14_bond_stddev[restype, atom1_idx, atom2_idx] = b.stddev
+ restype_atom14_bond_stddev[restype, atom2_idx, atom1_idx] = b.stddev
+ return {
+ "lower_bound": restype_atom14_bond_lower_bound, # shape (21,14,14)
+ "upper_bound": restype_atom14_bond_upper_bound, # shape (21,14,14)
+ "stddev": restype_atom14_bond_stddev, # shape (21,14,14)
+ }
+
+
+restype_atom14_ambiguous_atoms = np.zeros((21, 14), dtype=np.float32)
+restype_atom14_ambiguous_atoms_swap_idx: np.ndarray = np.tile(np.arange(14, dtype=int), (21, 1))
+
+
+def _make_atom14_ambiguity_feats() -> None:
+ for res, pairs in residue_atom_renaming_swaps.items():
+ res_idx = restype_order[restype_3to1[res]]
+ for atom1, atom2 in pairs.items():
+ atom1_idx = restype_name_to_atom14_names[res].index(atom1)
+ atom2_idx = restype_name_to_atom14_names[res].index(atom2)
+ restype_atom14_ambiguous_atoms[res_idx, atom1_idx] = 1
+ restype_atom14_ambiguous_atoms[res_idx, atom2_idx] = 1
+ restype_atom14_ambiguous_atoms_swap_idx[res_idx, atom1_idx] = atom2_idx
+ restype_atom14_ambiguous_atoms_swap_idx[res_idx, atom2_idx] = atom1_idx
+
+
+_make_atom14_ambiguity_feats()
+
+
+def aatype_to_str_sequence(aatype: Sequence[int]) -> str:
+ return "".join([restypes_with_x[aatype[i]] for i in range(len(aatype))])
diff --git a/docs/transformers/build/lib/transformers/models/esm/openfold_utils/tensor_utils.py b/docs/transformers/build/lib/transformers/models/esm/openfold_utils/tensor_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..efe72e4905b81faee5fe44131f1ba1b75856bd1c
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/esm/openfold_utils/tensor_utils.py
@@ -0,0 +1,140 @@
+# Copyright 2021 AlQuraishi Laboratory
+# Copyright 2021 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from functools import partial
+from typing import Any, Callable, Dict, List, Type, TypeVar, Union, overload
+
+import torch
+import torch.nn as nn
+import torch.types
+
+
+def add(m1: torch.Tensor, m2: torch.Tensor, inplace: bool) -> torch.Tensor:
+ # The first operation in a checkpoint can't be in-place, but it's
+ # nice to have in-place addition during inference. Thus...
+ if not inplace:
+ m1 = m1 + m2
+ else:
+ m1 += m2
+
+ return m1
+
+
+def permute_final_dims(tensor: torch.Tensor, inds: List[int]) -> torch.Tensor:
+ zero_index = -1 * len(inds)
+ first_inds = list(range(len(tensor.shape[:zero_index])))
+ return tensor.permute(first_inds + [zero_index + i for i in inds])
+
+
+def flatten_final_dims(t: torch.Tensor, no_dims: int) -> torch.Tensor:
+ return t.reshape(t.shape[:-no_dims] + (-1,))
+
+
+def masked_mean(mask: torch.Tensor, value: torch.Tensor, dim: int, eps: float = 1e-4) -> torch.Tensor:
+ mask = mask.expand(*value.shape)
+ return torch.sum(mask * value, dim=dim) / (eps + torch.sum(mask, dim=dim))
+
+
+def pts_to_distogram(
+ pts: torch.Tensor, min_bin: torch.types.Number = 2.3125, max_bin: torch.types.Number = 21.6875, no_bins: int = 64
+) -> torch.Tensor:
+ boundaries = torch.linspace(min_bin, max_bin, no_bins - 1, device=pts.device)
+ dists = torch.sqrt(torch.sum((pts.unsqueeze(-2) - pts.unsqueeze(-3)) ** 2, dim=-1))
+ return torch.bucketize(dists, boundaries)
+
+
+def dict_multimap(fn: Callable[[list], Any], dicts: List[dict]) -> dict:
+ first = dicts[0]
+ new_dict = {}
+ for k, v in first.items():
+ all_v = [d[k] for d in dicts]
+ if isinstance(v, dict):
+ new_dict[k] = dict_multimap(fn, all_v)
+ else:
+ new_dict[k] = fn(all_v)
+
+ return new_dict
+
+
+def one_hot(x: torch.Tensor, v_bins: torch.Tensor) -> torch.Tensor:
+ reshaped_bins = v_bins.view(((1,) * len(x.shape)) + (len(v_bins),))
+ diffs = x[..., None] - reshaped_bins
+ am = torch.argmin(torch.abs(diffs), dim=-1)
+ return nn.functional.one_hot(am, num_classes=len(v_bins)).float()
+
+
+def batched_gather(data: torch.Tensor, inds: torch.Tensor, dim: int = 0, no_batch_dims: int = 0) -> torch.Tensor:
+ ranges: List[Union[slice, torch.Tensor]] = []
+ for i, s in enumerate(data.shape[:no_batch_dims]):
+ r = torch.arange(s)
+ r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1))))
+ ranges.append(r)
+
+ remaining_dims: List[Union[slice, torch.Tensor]] = [slice(None) for _ in range(len(data.shape) - no_batch_dims)]
+ remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds
+ ranges.extend(remaining_dims)
+ # Matt note: Editing this to get around the behaviour of using a list as an array index changing
+ # in recent Numpy versions
+ return data[tuple(ranges)]
+
+
+T = TypeVar("T")
+
+
+# With tree_map, a poor man's JAX tree_map
+def dict_map(
+ fn: Callable[[T], Any], dic: Dict[Any, Union[dict, list, tuple, T]], leaf_type: Type[T]
+) -> Dict[Any, Union[dict, list, tuple, Any]]:
+ new_dict: Dict[Any, Union[dict, list, tuple, Any]] = {}
+ for k, v in dic.items():
+ if isinstance(v, dict):
+ new_dict[k] = dict_map(fn, v, leaf_type)
+ else:
+ new_dict[k] = tree_map(fn, v, leaf_type)
+
+ return new_dict
+
+
+@overload
+def tree_map(fn: Callable[[T], Any], tree: T, leaf_type: Type[T]) -> Any: ...
+
+
+@overload
+def tree_map(fn: Callable[[T], Any], tree: dict, leaf_type: Type[T]) -> dict: ...
+
+
+@overload
+def tree_map(fn: Callable[[T], Any], tree: list, leaf_type: Type[T]) -> list: ...
+
+
+@overload
+def tree_map(fn: Callable[[T], Any], tree: tuple, leaf_type: Type[T]) -> tuple: ...
+
+
+def tree_map(fn, tree, leaf_type):
+ if isinstance(tree, dict):
+ return dict_map(fn, tree, leaf_type)
+ elif isinstance(tree, list):
+ return [tree_map(fn, x, leaf_type) for x in tree]
+ elif isinstance(tree, tuple):
+ return tuple(tree_map(fn, x, leaf_type) for x in tree)
+ elif isinstance(tree, leaf_type):
+ return fn(tree)
+ else:
+ print(type(tree))
+ raise TypeError("Not supported")
+
+
+tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor)
diff --git a/docs/transformers/build/lib/transformers/models/esm/tokenization_esm.py b/docs/transformers/build/lib/transformers/models/esm/tokenization_esm.py
new file mode 100644
index 0000000000000000000000000000000000000000..4bc433e350e13c33fc779092baf52197d8aa5e0d
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/esm/tokenization_esm.py
@@ -0,0 +1,147 @@
+# coding=utf-8
+# Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization classes for ESM."""
+
+import os
+from typing import List, Optional
+
+from ...tokenization_utils import PreTrainedTokenizer
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
+
+
+def load_vocab_file(vocab_file):
+ with open(vocab_file, "r") as f:
+ lines = f.read().splitlines()
+ return [l.strip() for l in lines]
+
+
+class EsmTokenizer(PreTrainedTokenizer):
+ """
+ Constructs an ESM tokenizer.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ model_input_names = ["input_ids", "attention_mask"]
+
+ def __init__(
+ self,
+ vocab_file,
+ unk_token="",
+ cls_token="",
+ pad_token="",
+ mask_token="",
+ eos_token="",
+ **kwargs,
+ ):
+ self.all_tokens = load_vocab_file(vocab_file)
+ self._id_to_token = dict(enumerate(self.all_tokens))
+ self._token_to_id = {tok: ind for ind, tok in enumerate(self.all_tokens)}
+ super().__init__(
+ unk_token=unk_token,
+ cls_token=cls_token,
+ pad_token=pad_token,
+ mask_token=mask_token,
+ eos_token=eos_token,
+ **kwargs,
+ )
+
+ # TODO, all the tokens are added? But they are also part of the vocab... bit strange.
+ # none of them are special, but they all need special splitting.
+
+ self.unique_no_split_tokens = self.all_tokens
+ self._update_trie(self.unique_no_split_tokens)
+
+ def _convert_id_to_token(self, index: int) -> str:
+ return self._id_to_token.get(index, self.unk_token)
+
+ def _convert_token_to_id(self, token: str) -> int:
+ return self._token_to_id.get(token, self._token_to_id.get(self.unk_token))
+
+ def _tokenize(self, text, **kwargs):
+ return text.split()
+
+ def get_vocab(self):
+ base_vocab = self._token_to_id.copy()
+ base_vocab.update(self.added_tokens_encoder)
+ return base_vocab
+
+ def token_to_id(self, token: str) -> int:
+ return self._token_to_id.get(token, self._token_to_id.get(self.unk_token))
+
+ def id_to_token(self, index: int) -> str:
+ return self._id_to_token.get(index, self.unk_token)
+
+ def build_inputs_with_special_tokens(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ cls = [self.cls_token_id]
+ sep = [self.eos_token_id] # No sep token in ESM vocabulary
+ if token_ids_1 is None:
+ if self.eos_token_id is None:
+ return cls + token_ids_0
+ else:
+ return cls + token_ids_0 + sep
+ elif self.eos_token_id is None:
+ raise ValueError("Cannot tokenize multiple sequences when EOS token is not set!")
+ return cls + token_ids_0 + sep + token_ids_1 + sep # Multiple inputs always have an EOS token
+
+ def get_special_tokens_mask(
+ self, token_ids_0: List, token_ids_1: Optional[List] = None, already_has_special_tokens: bool = False
+ ) -> List[int]:
+ """
+ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
+ special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods.
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of ids of the first sequence.
+ token_ids_1 (`List[int]`, *optional*):
+ List of ids of the second sequence.
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
+ Whether or not the token list is already formatted with special tokens for the model.
+
+ Returns:
+ A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
+ """
+ if already_has_special_tokens:
+ if token_ids_1 is not None:
+ raise ValueError(
+ "You should not supply a second sequence if the provided sequence of "
+ "ids is already formatted with special tokens for the model."
+ )
+
+ return [1 if token in self.all_special_ids else 0 for token in token_ids_0]
+ mask = [1] + ([0] * len(token_ids_0)) + [1]
+ if token_ids_1 is not None:
+ mask += [0] * len(token_ids_1) + [1]
+ return mask
+
+ def save_vocabulary(self, save_directory, filename_prefix):
+ vocab_file = os.path.join(save_directory, (filename_prefix + "-" if filename_prefix else "") + "vocab.txt")
+ with open(vocab_file, "w") as f:
+ f.write("\n".join(self.all_tokens))
+ return (vocab_file,)
+
+ @property
+ def vocab_size(self) -> int:
+ return len(self.all_tokens)
+
+
+__all__ = ["EsmTokenizer"]
diff --git a/docs/transformers/build/lib/transformers/models/falcon/__init__.py b/docs/transformers/build/lib/transformers/models/falcon/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9789767f11402264660b5dec0b5cae2466ee9d8
--- /dev/null
+++ b/docs/transformers/build/lib/transformers/models/falcon/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_falcon import *
+ from .modeling_falcon import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)