diff --git a/.gitattributes b/.gitattributes index f8d2f31cf3853000a83449a2a5c3ac5bdc2b219f..375c7e662659a9b787545e4aeb0ec894bba5e3ed 100644 --- a/.gitattributes +++ b/.gitattributes @@ -53,3 +53,5 @@ docs/resources/grpo_geoqa.png filter=lfs diff=lfs merge=lfs -text docs/resources/grpo_openr1_multimodal.png filter=lfs diff=lfs merge=lfs -text docs/resources/web-ui-en.jpg filter=lfs diff=lfs merge=lfs -text docs/resources/kto_data.png filter=lfs diff=lfs merge=lfs -text +docs/resources/grpo_countdown_1.png filter=lfs diff=lfs merge=lfs -text +docs/resources/grpo_clevr_count.png filter=lfs diff=lfs merge=lfs -text diff --git a/docs/resources/grpo_clevr_count.png b/docs/resources/grpo_clevr_count.png new file mode 100644 index 0000000000000000000000000000000000000000..2c814e28e31908a7250220fde18b5bd024d82b35 --- /dev/null +++ b/docs/resources/grpo_clevr_count.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7192dc4f04801dbdff30bed098a16a7e21212a773ba7b6dc1424b261feca366f +size 671176 diff --git a/docs/resources/grpo_countdown_1.png b/docs/resources/grpo_countdown_1.png new file mode 100644 index 0000000000000000000000000000000000000000..819ab3d992619b077d75e6946d4637b030b8d213 --- /dev/null +++ b/docs/resources/grpo_countdown_1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b78dc3ce1cd541e76f2c557dea3aff06b278bb3b5413946a92c584cf42c1369f +size 785044 diff --git a/docs/transformers/build/lib/transformers/models/deprecated/efficientformer/configuration_efficientformer.py b/docs/transformers/build/lib/transformers/models/deprecated/efficientformer/configuration_efficientformer.py new file mode 100644 index 0000000000000000000000000000000000000000..5e2dca8bcd01c853025ec4364c10b4250421cf2d --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/efficientformer/configuration_efficientformer.py @@ -0,0 +1,172 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""EfficientFormer model configuration""" + +from typing import List + +from ....configuration_utils import PretrainedConfig +from ....utils import logging + + +logger = logging.get_logger(__name__) + + +class EfficientFormerConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of an [`EfficientFormerModel`]. It is used to + instantiate an EfficientFormer model according to the specified arguments, defining the model architecture. + Instantiating a configuration with the defaults will yield a similar configuration to that of the EfficientFormer + [snap-research/efficientformer-l1](https://huggingface.co/snap-research/efficientformer-l1) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + depths (`List(int)`, *optional*, defaults to `[3, 2, 6, 4]`) + Depth of each stage. + hidden_sizes (`List(int)`, *optional*, defaults to `[48, 96, 224, 448]`) + Dimensionality of each stage. + downsamples (`List(bool)`, *optional*, defaults to `[True, True, True, True]`) + Whether or not to downsample inputs between two stages. + dim (`int`, *optional*, defaults to 448): + Number of channels in Meta3D layers + key_dim (`int`, *optional*, defaults to 32): + The size of the key in meta3D block. + attention_ratio (`int`, *optional*, defaults to 4): + Ratio of the dimension of the query and value to the dimension of the key in MSHA block + resolution (`int`, *optional*, defaults to 7) + Size of each patch + num_hidden_layers (`int`, *optional*, defaults to 5): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the 3D MetaBlock. + mlp_expansion_ratio (`int`, *optional*, defaults to 4): + Ratio of size of the hidden dimensionality of an MLP to the dimensionality of its input. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings and encoder. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch. + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + pool_size (`int`, *optional*, defaults to 3): + Kernel size of pooling layers. + downsample_patch_size (`int`, *optional*, defaults to 3): + The size of patches in downsampling layers. + downsample_stride (`int`, *optional*, defaults to 2): + The stride of convolution kernels in downsampling layers. + downsample_pad (`int`, *optional*, defaults to 1): + Padding in downsampling layers. + drop_path_rate (`int`, *optional*, defaults to 0): + Rate at which to increase dropout probability in DropPath. + num_meta3d_blocks (`int`, *optional*, defaults to 1): + The number of 3D MetaBlocks in the last stage. + distillation (`bool`, *optional*, defaults to `True`): + Whether to add a distillation head. + use_layer_scale (`bool`, *optional*, defaults to `True`): + Whether to scale outputs from token mixers. + layer_scale_init_value (`float`, *optional*, defaults to 1e-5): + Factor by which outputs from token mixers are scaled. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + image_size (`int`, *optional*, defaults to `224`): + The size (resolution) of each image. + + Example: + + ```python + >>> from transformers import EfficientFormerConfig, EfficientFormerModel + + >>> # Initializing a EfficientFormer efficientformer-l1 style configuration + >>> configuration = EfficientFormerConfig() + + >>> # Initializing a EfficientFormerModel (with random weights) from the efficientformer-l3 style configuration + >>> model = EfficientFormerModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "efficientformer" + + def __init__( + self, + depths: List[int] = [3, 2, 6, 4], + hidden_sizes: List[int] = [48, 96, 224, 448], + downsamples: List[bool] = [True, True, True, True], + dim: int = 448, + key_dim: int = 32, + attention_ratio: int = 4, + resolution: int = 7, + num_hidden_layers: int = 5, + num_attention_heads: int = 8, + mlp_expansion_ratio: int = 4, + hidden_dropout_prob: float = 0.0, + patch_size: int = 16, + num_channels: int = 3, + pool_size: int = 3, + downsample_patch_size: int = 3, + downsample_stride: int = 2, + downsample_pad: int = 1, + drop_path_rate: float = 0.0, + num_meta3d_blocks: int = 1, + distillation: bool = True, + use_layer_scale: bool = True, + layer_scale_init_value: float = 1e-5, + hidden_act: str = "gelu", + initializer_range: float = 0.02, + layer_norm_eps: float = 1e-12, + image_size: int = 224, + batch_norm_eps: float = 1e-05, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.hidden_sizes = hidden_sizes + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.patch_size = patch_size + self.num_channels = num_channels + self.depths = depths + self.mlp_expansion_ratio = mlp_expansion_ratio + self.downsamples = downsamples + self.dim = dim + self.key_dim = key_dim + self.attention_ratio = attention_ratio + self.resolution = resolution + self.pool_size = pool_size + self.downsample_patch_size = downsample_patch_size + self.downsample_stride = downsample_stride + self.downsample_pad = downsample_pad + self.drop_path_rate = drop_path_rate + self.num_meta3d_blocks = num_meta3d_blocks + self.distillation = distillation + self.use_layer_scale = use_layer_scale + self.layer_scale_init_value = layer_scale_init_value + self.image_size = image_size + self.batch_norm_eps = batch_norm_eps + + +__all__ = [ + "EfficientFormerConfig", +] diff --git a/docs/transformers/build/lib/transformers/models/deprecated/efficientformer/image_processing_efficientformer.py b/docs/transformers/build/lib/transformers/models/deprecated/efficientformer/image_processing_efficientformer.py new file mode 100644 index 0000000000000000000000000000000000000000..74d16a048de1e59a6d6e0f7fe001bb7194273abd --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/efficientformer/image_processing_efficientformer.py @@ -0,0 +1,324 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for EfficientFormer.""" + +from typing import Dict, List, Optional, Union + +import numpy as np + +from ....image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ....image_transforms import ( + get_resize_output_image_size, + resize, + to_channel_dimension_format, +) +from ....image_utils import ( + IMAGENET_DEFAULT_MEAN, + IMAGENET_DEFAULT_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + infer_channel_dimension_format, + is_batched, + is_scaled_image, + to_numpy_array, + valid_images, + validate_kwargs, + validate_preprocess_arguments, +) +from ....utils import TensorType, logging + + +logger = logging.get_logger(__name__) + + +class EfficientFormerImageProcessor(BaseImageProcessor): + r""" + Constructs a EfficientFormer image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `(size["height"], + size["width"])`. Can be overridden by the `do_resize` parameter in the `preprocess` method. + size (`dict`, *optional*, defaults to `{"height": 224, "width": 224}`): + Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess` + method. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`): + Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the + `preprocess` method. + do_center_crop (`bool`, *optional*, defaults to `True`): + Whether to center crop the image to the specified `crop_size`. Can be overridden by `do_center_crop` in the + `preprocess` method. + crop_size (`Dict[str, int]` *optional*, defaults to 224): + Size of the output image after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess` + method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale` + parameter in the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the + `preprocess` method. + do_normalize: + Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess` + method. + image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize: bool = True, + size: Optional[Dict[str, int]] = None, + resample: PILImageResampling = PILImageResampling.BICUBIC, + do_center_crop: bool = True, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + crop_size: Dict[str, int] = None, + do_normalize: bool = True, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"height": 224, "width": 224} + size = get_size_dict(size) + crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224} + crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size") + + self.do_resize = do_resize + self.do_rescale = do_rescale + self.do_normalize = do_normalize + self.do_center_crop = do_center_crop + self.crop_size = crop_size + self.size = size + self.resample = resample + self.rescale_factor = rescale_factor + self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN + self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD + self._valid_processor_keys = [ + "images", + "do_resize", + "size", + "resample", + "do_center_crop", + "crop_size", + "do_rescale", + "rescale_factor", + "do_normalize", + "image_mean", + "image_std", + "return_tensors", + "data_format", + "input_data_format", + ] + + def resize( + self, + image: np.ndarray, + size: Dict[str, int], + resample: PILImageResampling = PILImageResampling.BILINEAR, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize an image to `(size["height"], size["width"])`. + + Args: + image (`np.ndarray`): + Image to resize. + size (`Dict[str, int]`): + Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image. + resample: + `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`. + data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the output image. If unset, the channel dimension format of the input + image is used. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + + Returns: + `np.ndarray`: The resized image. + """ + size = get_size_dict(size) + + if "shortest_edge" in size: + size = get_resize_output_image_size( + image, size=size["shortest_edge"], default_to_square=False, input_data_format=input_data_format + ) + # size = get_resize_output_image_size(image, size["shortest_edge"], size["longest_edge"]) + elif "height" in size and "width" in size: + size = (size["height"], size["width"]) + else: + raise ValueError(f"Size must contain 'height' and 'width' keys or 'shortest_edge' key. Got {size.keys()}") + return resize( + image, size=size, resample=resample, data_format=data_format, input_data_format=input_data_format, **kwargs + ) + + def preprocess( + self, + images: ImageInput, + do_resize: Optional[bool] = None, + size: Dict[str, int] = None, + resample: PILImageResampling = None, + do_center_crop: Optional[bool] = None, + crop_size: Optional[int] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> BatchFeature: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Dictionary in the format `{"height": h, "width": w}` specifying the size of the output image after + resizing. + resample (`PILImageResampling` filter, *optional*, defaults to `self.resample`): + `PILImageResampling` filter to use if resizing the image e.g. `PILImageResampling.BILINEAR`. Only has + an effect if `do_resize` is set to `True`. + do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`): + Whether to center crop the image. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image values between [0 - 1]. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`): + Size of the center crop. Only has an effect if `do_center_crop` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use if `do_normalize` is set to `True`. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use if `do_normalize` is set to `True`. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + do_resize = do_resize if do_resize is not None else self.do_resize + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop + crop_size = crop_size if crop_size is not None else self.crop_size + crop_size = get_size_dict(crop_size, param_name="crop_size", default_to_square=True) + resample = resample if resample is not None else self.resample + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + + size = size if size is not None else self.size + size_dict = get_size_dict(size) + + validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys) + + if not is_batched(images): + images = [images] + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_resize=do_resize, + size=size, + resample=resample, + ) + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if do_rescale and is_scaled_image(images[0]): + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + if do_resize: + images = [ + self.resize(image=image, size=size_dict, resample=resample, input_data_format=input_data_format) + for image in images + ] + + if do_center_crop: + images = [ + self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) for image in images + ] + + if do_rescale: + images = [ + self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + for image in images + ] + + if do_normalize: + images = [ + self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) + for image in images + ] + + images = [ + to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images + ] + + data = {"pixel_values": images} + return BatchFeature(data=data, tensor_type=return_tensors) + + +__all__ = ["EfficientFormerImageProcessor"] diff --git a/docs/transformers/build/lib/transformers/models/deprecated/efficientformer/modeling_efficientformer.py b/docs/transformers/build/lib/transformers/models/deprecated/efficientformer/modeling_efficientformer.py new file mode 100644 index 0000000000000000000000000000000000000000..a45fe7da5da3f3077b48b3aa23ca557cc60dd5e9 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/efficientformer/modeling_efficientformer.py @@ -0,0 +1,807 @@ +# coding=utf-8 +# Copyright 2022 Snapchat Research and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch EfficientFormer model.""" + +import itertools +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ....activations import ACT2FN +from ....modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput +from ....modeling_utils import PreTrainedModel +from ....utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, +) +from .configuration_efficientformer import EfficientFormerConfig + + +logger = logging.get_logger(__name__) + +# General docstring +_CONFIG_FOR_DOC = "EfficientFormerConfig" + +# Base docstring +_CHECKPOINT_FOR_DOC = "snap-research/efficientformer-l1-300" +_EXPECTED_OUTPUT_SHAPE = [1, 49, 448] + +# Image classification docstring +_IMAGE_CLASS_CHECKPOINT = "snap-research/efficientformer-l1-300" +_IMAGE_CLASS_EXPECTED_OUTPUT = "Egyptian cat" + + +class EfficientFormerPatchEmbeddings(nn.Module): + """ + This class performs downsampling between two stages. For the input tensor with the shape [batch_size, num_channels, + height, width] it produces output tensor with the shape [batch_size, num_channels, height/stride, width/stride] + """ + + def __init__(self, config: EfficientFormerConfig, num_channels: int, embed_dim: int, apply_norm: bool = True): + super().__init__() + self.num_channels = num_channels + + self.projection = nn.Conv2d( + num_channels, + embed_dim, + kernel_size=config.downsample_patch_size, + stride=config.downsample_stride, + padding=config.downsample_pad, + ) + self.norm = nn.BatchNorm2d(embed_dim, eps=config.batch_norm_eps) if apply_norm else nn.Identity() + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, height, width = pixel_values.shape + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + + embeddings = self.projection(pixel_values) + embeddings = self.norm(embeddings) + + return embeddings + + +class EfficientFormerSelfAttention(nn.Module): + def __init__(self, dim: int, key_dim: int, num_heads: int, attention_ratio: int, resolution: int): + super().__init__() + + self.num_heads = num_heads + self.key_dim = key_dim + self.attention_ratio = attention_ratio + self.scale = key_dim**-0.5 + self.total_key_dim = key_dim * num_heads + self.expanded_key_dim = int(attention_ratio * key_dim) + self.total_expanded_key_dim = int(self.expanded_key_dim * num_heads) + hidden_size = self.total_expanded_key_dim + self.total_key_dim * 2 + self.qkv = nn.Linear(dim, hidden_size) + self.projection = nn.Linear(self.total_expanded_key_dim, dim) + points = list(itertools.product(range(resolution), range(resolution))) + num_points = len(points) + attention_offsets = {} + idxs = [] + for point_1 in points: + for point_2 in points: + offset = (abs(point_1[0] - point_2[0]), abs(point_1[1] - point_2[1])) + if offset not in attention_offsets: + attention_offsets[offset] = len(attention_offsets) + idxs.append(attention_offsets[offset]) + self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets))) + self.register_buffer("attention_bias_idxs", torch.LongTensor(idxs).view(num_points, num_points)) + + @torch.no_grad() + def train(self, mode=True): + super().train(mode) + if mode and hasattr(self, "ab"): + del self.ab + else: + self.ab = self.attention_biases[:, self.attention_bias_idxs] + + def forward(self, hidden_states: torch.Tensor, output_attentions: bool = False) -> Tuple[torch.Tensor]: + batch_size, sequence_length, num_channels = hidden_states.shape + qkv = self.qkv(hidden_states) + query_layer, key_layer, value_layer = qkv.reshape(batch_size, sequence_length, self.num_heads, -1).split( + [self.key_dim, self.key_dim, self.expanded_key_dim], dim=3 + ) + query_layer = query_layer.permute(0, 2, 1, 3) + key_layer = key_layer.permute(0, 2, 1, 3) + value_layer = value_layer.permute(0, 2, 1, 3) + + # set `model.to(torch_device)` won't change `self.ab.device`, if there is no follow-up `train` or `eval` call. + # Let's do it manually here, so users won't have to do this everytime. + if not self.training: + self.ab = self.ab.to(self.attention_biases.device) + attention_probs = (torch.matmul(query_layer, key_layer.transpose(-2, -1))) * self.scale + ( + self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab + ) + + attention_probs = attention_probs.softmax(dim=-1) + + context_layer = torch.matmul(attention_probs, value_layer).transpose(1, 2) + context_layer = context_layer.reshape(batch_size, sequence_length, self.total_expanded_key_dim) + context_layer = self.projection(context_layer) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +class EfficientFormerConvStem(nn.Module): + def __init__(self, config: EfficientFormerConfig, out_channels: int): + super().__init__() + + self.convolution1 = nn.Conv2d(config.num_channels, out_channels // 2, kernel_size=3, stride=2, padding=1) + self.batchnorm_before = nn.BatchNorm2d(out_channels // 2, eps=config.batch_norm_eps) + + self.convolution2 = nn.Conv2d(out_channels // 2, out_channels, kernel_size=3, stride=2, padding=1) + self.batchnorm_after = nn.BatchNorm2d(out_channels, eps=config.batch_norm_eps) + + self.activation = nn.ReLU() + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + features = self.batchnorm_before(self.convolution1(pixel_values)) + features = self.activation(features) + features = self.batchnorm_after(self.convolution2(features)) + features = self.activation(features) + + return features + + +class EfficientFormerPooling(nn.Module): + def __init__(self, pool_size: int): + super().__init__() + self.pool = nn.AvgPool2d(pool_size, stride=1, padding=pool_size // 2, count_include_pad=False) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + output = self.pool(hidden_states) - hidden_states + return output + + +class EfficientFormerDenseMlp(nn.Module): + def __init__( + self, + config: EfficientFormerConfig, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + + self.linear_in = nn.Linear(in_features, hidden_features) + self.activation = ACT2FN[config.hidden_act] + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.linear_out = nn.Linear(hidden_features, out_features) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.linear_in(hidden_states) + hidden_states = self.activation(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.linear_out(hidden_states) + hidden_states = self.dropout(hidden_states) + + return hidden_states + + +class EfficientFormerConvMlp(nn.Module): + def __init__( + self, + config: EfficientFormerConfig, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + drop: float = 0.0, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + + self.convolution1 = nn.Conv2d(in_features, hidden_features, 1) + self.activation = ACT2FN[config.hidden_act] + self.convolution2 = nn.Conv2d(hidden_features, out_features, 1) + self.dropout = nn.Dropout(drop) + + self.batchnorm_before = nn.BatchNorm2d(hidden_features, eps=config.batch_norm_eps) + self.batchnorm_after = nn.BatchNorm2d(out_features, eps=config.batch_norm_eps) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + hidden_state = self.convolution1(hidden_state) + hidden_state = self.batchnorm_before(hidden_state) + + hidden_state = self.activation(hidden_state) + hidden_state = self.dropout(hidden_state) + hidden_state = self.convolution2(hidden_state) + + hidden_state = self.batchnorm_after(hidden_state) + hidden_state = self.dropout(hidden_state) + + return hidden_state + + +def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, + however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the + layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the + argument. + """ + if drop_prob == 0.0 or not training: + return input + keep_prob = 1 - drop_prob + shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) + random_tensor.floor_() # binarize + output = input.div(keep_prob) * random_tensor + return output + + +class EfficientFormerDropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: Optional[float] = None) -> None: + super().__init__() + self.drop_prob = drop_prob + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return drop_path(hidden_states, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return "p={}".format(self.drop_prob) + + +class EfficientFormerFlat(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor]: + hidden_states = hidden_states.flatten(2).transpose(1, 2) + return hidden_states + + +class EfficientFormerMeta3D(nn.Module): + def __init__(self, config: EfficientFormerConfig, dim: int, drop_path: float = 0.0): + super().__init__() + + self.token_mixer = EfficientFormerSelfAttention( + dim=config.dim, + key_dim=config.key_dim, + num_heads=config.num_attention_heads, + attention_ratio=config.attention_ratio, + resolution=config.resolution, + ) + + self.layernorm1 = nn.LayerNorm(dim, eps=config.layer_norm_eps) + self.layernorm2 = nn.LayerNorm(dim, eps=config.layer_norm_eps) + + mlp_hidden_dim = int(dim * config.mlp_expansion_ratio) + self.mlp = EfficientFormerDenseMlp(config, in_features=dim, hidden_features=mlp_hidden_dim) + + self.drop_path = EfficientFormerDropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.use_layer_scale = config.use_layer_scale + if config.use_layer_scale: + self.layer_scale_1 = nn.Parameter(config.layer_scale_init_value * torch.ones((dim)), requires_grad=True) + self.layer_scale_2 = nn.Parameter(config.layer_scale_init_value * torch.ones((dim)), requires_grad=True) + + def forward(self, hidden_states: torch.Tensor, output_attentions: bool = False) -> Tuple[torch.Tensor]: + self_attention_outputs = self.token_mixer(self.layernorm1(hidden_states), output_attentions) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + if self.use_layer_scale: + layer_output = hidden_states + self.drop_path( + self.layer_scale_1.unsqueeze(0).unsqueeze(0) * attention_output + ) + layer_output = layer_output + self.drop_path( + self.layer_scale_2.unsqueeze(0).unsqueeze(0) * self.mlp(self.layernorm2(layer_output)) + ) + else: + layer_output = hidden_states + self.drop_path(attention_output) + layer_output = layer_output + self.drop_path(self.mlp(self.layernorm2(layer_output))) + + outputs = (layer_output,) + outputs + + return outputs + + +class EfficientFormerMeta3DLayers(nn.Module): + def __init__(self, config: EfficientFormerConfig): + super().__init__() + drop_paths = [ + config.drop_path_rate * (block_idx + sum(config.depths[:-1])) + for block_idx in range(config.num_meta3d_blocks) + ] + self.blocks = nn.ModuleList( + [EfficientFormerMeta3D(config, config.hidden_sizes[-1], drop_path=drop_path) for drop_path in drop_paths] + ) + + def forward(self, hidden_states: torch.Tensor, output_attentions: bool = False) -> Tuple[torch.Tensor]: + all_attention_outputs = () if output_attentions else None + + for layer_module in self.blocks: + if isinstance(hidden_states, tuple): + hidden_states = hidden_states[0] + + hidden_states = layer_module(hidden_states, output_attentions) + + if output_attentions: + all_attention_outputs = all_attention_outputs + (hidden_states[1],) + + if output_attentions: + outputs = (hidden_states[0],) + all_attention_outputs + return outputs + + return hidden_states + + +class EfficientFormerMeta4D(nn.Module): + def __init__(self, config: EfficientFormerConfig, dim: int, drop_path: float = 0.0): + super().__init__() + pool_size = config.pool_size if config.pool_size is not None else 3 + self.token_mixer = EfficientFormerPooling(pool_size=pool_size) + mlp_hidden_dim = int(dim * config.mlp_expansion_ratio) + self.mlp = EfficientFormerConvMlp( + config, in_features=dim, hidden_features=mlp_hidden_dim, drop=config.hidden_dropout_prob + ) + + self.drop_path = EfficientFormerDropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.use_layer_scale = config.use_layer_scale + if config.use_layer_scale: + self.layer_scale_1 = nn.Parameter(config.layer_scale_init_value * torch.ones((dim)), requires_grad=True) + self.layer_scale_2 = nn.Parameter(config.layer_scale_init_value * torch.ones((dim)), requires_grad=True) + + def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor]: + outputs = self.token_mixer(hidden_states) + + if self.use_layer_scale: + layer_output = hidden_states + self.drop_path(self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * outputs) + + layer_output = layer_output + self.drop_path( + self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * self.mlp(layer_output) + ) + else: + layer_output = hidden_states + self.drop_path(outputs) + layer_output = layer_output + self.drop_path(self.mlp(layer_output)) + + return layer_output + + +class EfficientFormerMeta4DLayers(nn.Module): + def __init__(self, config: EfficientFormerConfig, stage_idx: int): + super().__init__() + num_layers = ( + config.depths[stage_idx] if stage_idx != -1 else config.depths[stage_idx] - config.num_meta3d_blocks + ) + drop_paths = [ + config.drop_path_rate * (block_idx + sum(config.depths[:stage_idx])) for block_idx in range(num_layers) + ] + + self.blocks = nn.ModuleList( + [ + EfficientFormerMeta4D(config, config.hidden_sizes[stage_idx], drop_path=drop_path) + for drop_path in drop_paths + ] + ) + + def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor]: + for layer_module in self.blocks: + hidden_states = layer_module(hidden_states) + return hidden_states + + +class EfficientFormerIntermediateStage(nn.Module): + def __init__(self, config: EfficientFormerConfig, index: int): + super().__init__() + self.meta4D_layers = EfficientFormerMeta4DLayers(config, index) + + def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor]: + hidden_states = self.meta4D_layers(hidden_states) + return hidden_states + + +class EfficientFormerLastStage(nn.Module): + def __init__(self, config: EfficientFormerConfig): + super().__init__() + self.meta4D_layers = EfficientFormerMeta4DLayers(config, -1) + self.flat = EfficientFormerFlat() + self.meta3D_layers = EfficientFormerMeta3DLayers(config) + + def forward(self, hidden_states: torch.Tensor, output_attentions: bool = False) -> Tuple[torch.Tensor]: + hidden_states = self.meta4D_layers(hidden_states) + hidden_states = self.flat(hidden_states) + hidden_states = self.meta3D_layers(hidden_states, output_attentions) + + return hidden_states + + +class EfficientFormerEncoder(nn.Module): + def __init__(self, config: EfficientFormerConfig): + super().__init__() + self.config = config + num_intermediate_stages = len(config.depths) - 1 + downsamples = [ + config.downsamples[i] or config.hidden_sizes[i] != config.hidden_sizes[i + 1] + for i in range(num_intermediate_stages) + ] + intermediate_stages = [] + + for i in range(num_intermediate_stages): + intermediate_stages.append(EfficientFormerIntermediateStage(config, i)) + if downsamples[i]: + intermediate_stages.append( + EfficientFormerPatchEmbeddings(config, config.hidden_sizes[i], config.hidden_sizes[i + 1]) + ) + + self.intermediate_stages = nn.ModuleList(intermediate_stages) + self.last_stage = EfficientFormerLastStage(config) + + def forward( + self, + hidden_states: torch.Tensor, + output_hidden_states: bool = False, + output_attentions: bool = False, + return_dict: bool = True, + ) -> BaseModelOutput: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + for layer_module in self.intermediate_stages: + hidden_states = layer_module(hidden_states) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_output = self.last_stage(hidden_states, output_attentions=output_attentions) + + if output_attentions: + all_self_attentions = all_self_attentions + layer_output[1:] + + if output_hidden_states: + all_hidden_states = all_hidden_states + (layer_output[0],) + + if not return_dict: + return tuple(v for v in [layer_output[0], all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutput( + last_hidden_state=layer_output[0], + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class EfficientFormerPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = EfficientFormerConfig + base_model_prefix = "efficientformer" + main_input_name = "pixel_values" + supports_gradient_checkpointing = False + + def _init_weights(self, module: nn.Module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +EFFICIENTFORMER_START_DOCSTRING = r""" + This model is a PyTorch [nn.Module](https://pytorch.org/docs/stable/nn.html#nn.Module) subclass. Use it as a + regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. + + Parameters: + config ([`EfficientFormerConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +EFFICIENTFORMER_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`ViTImageProcessor`]. See + [`ViTImageProcessor.preprocess`] for details. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare EfficientFormer Model transformer outputting raw hidden-states without any specific head on top.", + EFFICIENTFORMER_START_DOCSTRING, +) +class EfficientFormerModel(EfficientFormerPreTrainedModel): + def __init__(self, config: EfficientFormerConfig): + super().__init__(config) + self.config = config + _no_split_modules = ["EfficientFormerMeta4D"] + + self.patch_embed = EfficientFormerConvStem(config, config.hidden_sizes[0]) + self.encoder = EfficientFormerEncoder(config) + self.layernorm = nn.LayerNorm(config.hidden_sizes[-1], eps=config.layer_norm_eps) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(EFFICIENTFORMER_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPooling, + config_class=_CONFIG_FOR_DOC, + modality="vision", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, BaseModelOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + embedding_output = self.patch_embed(pixel_values) + encoder_outputs = self.encoder( + embedding_output, output_attentions=output_attentions, output_hidden_states=output_hidden_states + ) + + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + + if not return_dict: + head_outputs = (sequence_output,) + return head_outputs + encoder_outputs[1:] + + return BaseModelOutput( + last_hidden_state=sequence_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + """ + EfficientFormer Model transformer with an image classification head on top (a linear layer on top of the final + hidden state of the [CLS] token) e.g. for ImageNet. + """, + EFFICIENTFORMER_START_DOCSTRING, +) +class EfficientFormerForImageClassification(EfficientFormerPreTrainedModel): + def __init__(self, config: EfficientFormerConfig): + super().__init__(config) + + self.num_labels = config.num_labels + self.efficientformer = EfficientFormerModel(config) + + # Classifier head + self.classifier = ( + nn.Linear(config.hidden_sizes[-1], config.num_labels) if config.num_labels > 0 else nn.Identity() + ) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(EFFICIENTFORMER_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_IMAGE_CLASS_CHECKPOINT, + output_type=ImageClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, + ) + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, ImageClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.efficientformer( + pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.classifier(sequence_output.mean(-2)) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@dataclass +class EfficientFormerForImageClassificationWithTeacherOutput(ModelOutput): + """ + Output type of [`EfficientFormerForImageClassificationWithTeacher`]. + + Args: + logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): + Prediction scores as the average of the cls_logits and distillation logits. + cls_logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): + Prediction scores of the classification head (i.e. the linear layer on top of the final hidden state of the + class token). + distillation_logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): + Prediction scores of the distillation head (i.e. the linear layer on top of the final hidden state of the + distillation token). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer + plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in + the self-attention heads. + """ + + logits: Optional[torch.FloatTensor] = None + cls_logits: Optional[torch.FloatTensor] = None + distillation_logits: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@add_start_docstrings( + """ + EfficientFormer Model transformer with image classification heads on top (a linear layer on top of the final hidden + state of the [CLS] token and a linear layer on top of the final hidden state of the distillation token) e.g. for + ImageNet. + + + + This model supports inference-only. Fine-tuning with distillation (i.e. with a teacher) is not yet + supported. + + + """, + EFFICIENTFORMER_START_DOCSTRING, +) +class EfficientFormerForImageClassificationWithTeacher(EfficientFormerPreTrainedModel): + def __init__(self, config: EfficientFormerConfig): + super().__init__(config) + + self.num_labels = config.num_labels + self.efficientformer = EfficientFormerModel(config) + + # Classifier head + self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity() + # Distillation head + self.distillation_classifier = ( + nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity() + ) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(EFFICIENTFORMER_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_IMAGE_CLASS_CHECKPOINT, + output_type=EfficientFormerForImageClassificationWithTeacherOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, + ) + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, EfficientFormerForImageClassificationWithTeacherOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + outputs = self.efficientformer( + pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + cls_logits = self.classifier(sequence_output.mean(-2)) + distillation_logits = self.distillation_classifier(sequence_output.mean(-2)) + + # during inference, return the average of both classifier predictions + logits = (cls_logits + distillation_logits) / 2 + + if not return_dict: + output = (logits, cls_logits, distillation_logits) + outputs[1:] + return output + + return EfficientFormerForImageClassificationWithTeacherOutput( + logits=logits, + cls_logits=cls_logits, + distillation_logits=distillation_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "EfficientFormerForImageClassification", + "EfficientFormerForImageClassificationWithTeacher", + "EfficientFormerModel", + "EfficientFormerPreTrainedModel", +] diff --git a/docs/transformers/build/lib/transformers/models/deprecated/efficientformer/modeling_tf_efficientformer.py b/docs/transformers/build/lib/transformers/models/deprecated/efficientformer/modeling_tf_efficientformer.py new file mode 100644 index 0000000000000000000000000000000000000000..e11fa1edf9d08f6ae773a025a9f3fbf7039cb1c9 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/efficientformer/modeling_tf_efficientformer.py @@ -0,0 +1,1198 @@ +# coding=utf-8 +# Copyright 2023 Snapchat Research and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TensorFlow EfficientFormer model.""" + +import itertools +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import tensorflow as tf + +from ....activations_tf import ACT2FN +from ....modeling_tf_outputs import ( + TFBaseModelOutput, + TFBaseModelOutputWithPooling, + TFImageClassifierOutput, +) +from ....modeling_tf_utils import ( + TFPreTrainedModel, + TFSequenceClassificationLoss, + get_initializer, + keras, + keras_serializable, + unpack_inputs, +) +from ....tf_utils import shape_list, stable_softmax +from ....utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, +) +from .configuration_efficientformer import EfficientFormerConfig + + +logger = logging.get_logger(__name__) + +# General docstring +_CONFIG_FOR_DOC = "EfficientFormerConfig" + +# Base docstring +_CHECKPOINT_FOR_DOC = "snap-research/efficientformer-l1-300" +_EXPECTED_OUTPUT_SHAPE = [1, 49, 448] + +# Image classification docstring +_IMAGE_CLASS_CHECKPOINT = "snap-research/efficientformer-l1-300" +_IMAGE_CLASS_EXPECTED_OUTPUT = "LABEL_281" + + +class TFEfficientFormerPatchEmbeddings(keras.layers.Layer): + """ + This class performs downsampling between two stages. For the input tensor with the shape [batch_size, num_channels, + height, width] it produces output tensor with the shape [batch_size, num_channels, height/stride, width/stride] + """ + + def __init__( + self, config: EfficientFormerConfig, num_channels: int, embed_dim: int, apply_norm: bool = True, **kwargs + ) -> None: + super().__init__(**kwargs) + self.num_channels = num_channels + + self.padding = keras.layers.ZeroPadding2D(padding=config.downsample_pad) + self.projection = keras.layers.Conv2D( + filters=embed_dim, + kernel_size=config.downsample_patch_size, + strides=config.downsample_stride, + padding="valid", + name="projection", + ) + # Use same default momentum and epsilon as PyTorch equivalent for BatchNormalization + self.norm = ( + keras.layers.BatchNormalization(axis=-1, epsilon=config.batch_norm_eps, momentum=0.9, name="norm") + if apply_norm + else tf.identity + ) + self.embed_dim = embed_dim + + def call(self, pixel_values: tf.Tensor, training: bool = False) -> tf.Tensor: + tf.debugging.assert_shapes( + [(pixel_values, (..., None, None, self.num_channels))], + message="Make sure that the channel dimension of the pixel values match with the one set in the configuration.", + ) + embeddings = self.projection(self.padding(pixel_values)) + embeddings = self.norm(embeddings, training=training) + return embeddings + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "projection", None) is not None: + with tf.name_scope(self.projection.name): + self.projection.build([None, None, None, self.num_channels]) + if getattr(self, "norm", None) is not None: + if hasattr(self.norm, "name"): + with tf.name_scope(self.norm.name): + self.norm.build([None, None, None, self.embed_dim]) + + +class TFEfficientFormerSelfAttention(keras.layers.Layer): + def __init__( + self, + dim: int, + key_dim: int, + num_heads: int, + attention_ratio: int, + resolution: int, + config: EfficientFormerConfig, + **kwargs, + ): + super().__init__(**kwargs) + + self.num_heads = num_heads + self.key_dim = key_dim + self.attention_ratio = attention_ratio + self.scale = key_dim**-0.5 + self.total_key_dim = key_dim * num_heads + self.expanded_key_dim = int(attention_ratio * key_dim) + self.total_expanded_key_dim = int(self.expanded_key_dim * num_heads) + hidden_size = self.total_expanded_key_dim + self.total_key_dim * 2 + + self.qkv = keras.layers.Dense( + units=hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="qkv" + ) + self.projection = keras.layers.Dense( + units=dim, kernel_initializer=get_initializer(config.initializer_range), name="projection" + ) + self.resolution = resolution + self.dim = dim + + def build(self, input_shape: tf.TensorShape) -> None: + points = list(itertools.product(range(self.resolution), range(self.resolution))) + num_points = len(points) + attention_offsets = {} + + idxs = [] + + for point_1 in points: + for point_2 in points: + offset = (abs(point_1[0] - point_2[0]), abs(point_1[1] - point_2[1])) + if offset not in attention_offsets: + attention_offsets[offset] = len(attention_offsets) + idxs.append(attention_offsets[offset]) + + self.attention_biases = self.add_weight( + shape=(self.num_heads, len(attention_offsets)), + initializer=keras.initializers.zeros(), + trainable=True, + name="attention_biases", + ) + self.attention_bias_idxs = self.add_weight( + shape=(num_points, num_points), + trainable=False, + dtype=tf.int32, + name="attention_bias_idxs", + ) + + self.attention_bias_idxs.assign(tf.reshape(tf.cast(idxs, dtype=tf.int32), (num_points, num_points))) + + if self.built: + return + self.built = True + if getattr(self, "qkv", None) is not None: + with tf.name_scope(self.qkv.name): + self.qkv.build([None, None, self.dim]) + if getattr(self, "projection", None) is not None: + with tf.name_scope(self.projection.name): + self.projection.build([None, None, self.total_expanded_key_dim]) + + def call( + self, hidden_states: tf.Tensor, output_attentions: bool = False, training: bool = False + ) -> Tuple[tf.Tensor]: + batch_size, sequence_length, *_ = shape_list(hidden_states) + qkv = self.qkv(inputs=hidden_states) + + query_layer, key_layer, value_layer = tf.split( + tf.reshape(tensor=qkv, shape=(batch_size, sequence_length, self.num_heads, -1)), + num_or_size_splits=[self.key_dim, self.key_dim, self.expanded_key_dim], + axis=3, + ) + + query_layer = tf.transpose(query_layer, perm=[0, 2, 1, 3]) + key_layer = tf.transpose(key_layer, perm=[0, 2, 1, 3]) + value_layer = tf.transpose(value_layer, perm=[0, 2, 1, 3]) + + attention_probs = tf.matmul(query_layer, tf.transpose(key_layer, perm=[0, 1, 3, 2])) + scale = tf.cast(self.scale, dtype=attention_probs.dtype) + attention_probs = tf.multiply(attention_probs, scale) + + attention_biases = tf.gather(params=self.attention_biases, indices=self.attention_bias_idxs, axis=1) + attention_probs = attention_probs + attention_biases + attention_probs = stable_softmax(logits=attention_probs, axis=-1) + + context_layer = tf.matmul(attention_probs, value_layer) + context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3]) + + context_layer = tf.reshape( + tensor=context_layer, shape=(batch_size, sequence_length, self.total_expanded_key_dim) + ) + context_layer = self.projection(context_layer) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +class TFEfficientFormerConvStem(keras.layers.Layer): + def __init__(self, config: EfficientFormerConfig, out_channels: int, **kwargs): + super().__init__(**kwargs) + + self.padding = keras.layers.ZeroPadding2D(padding=1) + self.convolution1 = keras.layers.Conv2D( + filters=out_channels // 2, kernel_size=3, strides=2, padding="valid", name="convolution1" + ) + # Use same default momentum and epsilon as PyTorch equivalent for BatchNormalization + self.batchnorm_before = keras.layers.BatchNormalization( + axis=-1, epsilon=config.batch_norm_eps, momentum=0.9, name="batchnorm_before" + ) + + self.convolution2 = keras.layers.Conv2D( + filters=out_channels, + kernel_size=3, + strides=2, + padding="valid", + name="convolution2", + ) + # Use same default momentum and epsilon as PyTorch equivalent for BatchNormalization + self.batchnorm_after = keras.layers.BatchNormalization( + axis=-1, epsilon=config.batch_norm_eps, momentum=0.9, name="batchnorm_after" + ) + + self.activation = keras.layers.Activation(activation=keras.activations.relu, name="activation") + self.out_channels = out_channels + self.config = config + + def call(self, pixel_values: tf.Tensor, training: bool = False) -> tf.Tensor: + features = self.batchnorm_before(self.convolution1(self.padding(pixel_values)), training=training) + features = self.activation(features) + features = self.batchnorm_after(self.convolution2(self.padding(features)), training=training) + features = self.activation(features) + return features + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "convolution1", None) is not None: + with tf.name_scope(self.convolution1.name): + self.convolution1.build([None, None, None, self.config.num_channels]) + if getattr(self, "batchnorm_before", None) is not None: + with tf.name_scope(self.batchnorm_before.name): + self.batchnorm_before.build([None, None, None, self.out_channels // 2]) + if getattr(self, "convolution2", None) is not None: + with tf.name_scope(self.convolution2.name): + self.convolution2.build([None, None, None, self.out_channels // 2]) + if getattr(self, "batchnorm_after", None) is not None: + with tf.name_scope(self.batchnorm_after.name): + self.batchnorm_after.build([None, None, None, self.out_channels]) + if getattr(self, "activation", None) is not None: + with tf.name_scope(self.activation.name): + self.activation.build(None) + + +class TFEfficientFormerPooling(keras.layers.Layer): + def __init__(self, pool_size: int, **kwargs): + super().__init__(**kwargs) + self.pool = keras.layers.AveragePooling2D(pool_size=pool_size, strides=1, padding="same") + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + output = self.pool(hidden_states) + output = output - hidden_states + return output + + +class TFEfficientFormerDenseMlp(keras.layers.Layer): + def __init__( + self, + config: EfficientFormerConfig, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + **kwargs, + ): + super().__init__(**kwargs) + out_features = out_features or in_features + hidden_features = hidden_features or in_features + + self.linear_in = keras.layers.Dense( + units=hidden_features, kernel_initializer=get_initializer(config.initializer_range), name="linear_in" + ) + self.activation = ACT2FN[config.hidden_act] + self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) + + self.linear_out = keras.layers.Dense( + units=out_features, kernel_initializer=get_initializer(config.initializer_range), name="linear_out" + ) + self.hidden_features = hidden_features + self.in_features = in_features + + def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor: + hidden_states = self.linear_in(inputs=hidden_states) + hidden_states = self.activation(hidden_states) + hidden_states = self.dropout(inputs=hidden_states, training=training) + hidden_states = self.linear_out(inputs=hidden_states) + hidden_states = self.dropout(inputs=hidden_states, training=training) + + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "linear_in", None) is not None: + with tf.name_scope(self.linear_in.name): + self.linear_in.build([None, None, self.in_features]) + if getattr(self, "linear_out", None) is not None: + with tf.name_scope(self.linear_out.name): + self.linear_out.build([None, None, self.hidden_features]) + + +class TFEfficientFormerConvMlp(keras.layers.Layer): + def __init__( + self, + config: EfficientFormerConfig, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + drop: float = 0.0, + **kwargs, + ): + super().__init__(**kwargs) + out_features = out_features or in_features + hidden_features = hidden_features or in_features + + self.convolution1 = keras.layers.Conv2D( + filters=hidden_features, + kernel_size=1, + name="convolution1", + padding="valid", + ) + + self.activation = ACT2FN[config.hidden_act] + + self.convolution2 = keras.layers.Conv2D( + filters=out_features, + kernel_size=1, + name="convolution2", + padding="valid", + ) + + self.dropout = keras.layers.Dropout(rate=drop) + + # Use same default momentum and epsilon as PyTorch equivalent for BatchNormalization + self.batchnorm_before = keras.layers.BatchNormalization( + axis=-1, epsilon=config.batch_norm_eps, momentum=0.9, name="batchnorm_before" + ) + # Use same default momentum and epsilon as PyTorch equivalent for BatchNormalization + self.batchnorm_after = keras.layers.BatchNormalization( + axis=-1, epsilon=config.batch_norm_eps, momentum=0.9, name="batchnorm_after" + ) + self.hidden_features = hidden_features + self.in_features = in_features + self.out_features = out_features + + def call(self, hidden_state: tf.Tensor, training: bool = False) -> tf.Tensor: + hidden_state = self.convolution1(hidden_state) + hidden_state = self.batchnorm_before(hidden_state, training=training) + hidden_state = self.activation(hidden_state) + hidden_state = self.dropout(hidden_state, training=training) + hidden_state = self.convolution2(hidden_state) + hidden_state = self.batchnorm_after(hidden_state, training=training) + hidden_state = self.dropout(hidden_state, training=training) + return hidden_state + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "convolution1", None) is not None: + with tf.name_scope(self.convolution1.name): + self.convolution1.build([None, None, None, self.in_features]) + if getattr(self, "convolution2", None) is not None: + with tf.name_scope(self.convolution2.name): + self.convolution2.build([None, None, None, self.hidden_features]) + if getattr(self, "batchnorm_before", None) is not None: + with tf.name_scope(self.batchnorm_before.name): + self.batchnorm_before.build([None, None, None, self.hidden_features]) + if getattr(self, "batchnorm_after", None) is not None: + with tf.name_scope(self.batchnorm_after.name): + self.batchnorm_after.build([None, None, None, self.out_features]) + + +# Copied from transformers.models.convnext.modeling_tf_convnext.TFConvNextDropPath with ConvNext->EfficientFormer +class TFEfficientFormerDropPath(keras.layers.Layer): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + References: + (1) github.com:rwightman/pytorch-image-models + """ + + def __init__(self, drop_path: float, **kwargs): + super().__init__(**kwargs) + self.drop_path = drop_path + + def call(self, x: tf.Tensor, training=None): + if training: + keep_prob = 1 - self.drop_path + shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1) + random_tensor = keep_prob + tf.random.uniform(shape, 0, 1) + random_tensor = tf.floor(random_tensor) + return (x / keep_prob) * random_tensor + return x + + +class TFEfficientFormerFlat(keras.layers.Layer): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def call(self, hidden_states: tf.Tensor) -> Tuple[tf.Tensor]: + batch_size, _, _, in_channels = shape_list(hidden_states) + hidden_states = tf.reshape(hidden_states, shape=[batch_size, -1, in_channels]) + return hidden_states + + +class TFEfficientFormerMeta3D(keras.layers.Layer): + def __init__(self, config: EfficientFormerConfig, dim: int, drop_path: float = 0.0, **kwargs): + super().__init__(**kwargs) + + self.token_mixer = TFEfficientFormerSelfAttention( + dim=config.dim, + key_dim=config.key_dim, + num_heads=config.num_attention_heads, + attention_ratio=config.attention_ratio, + resolution=config.resolution, + name="token_mixer", + config=config, + ) + self.dim = dim + self.config = config + + self.layernorm1 = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm1") + self.layernorm2 = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm2") + mlp_hidden_dim = int(dim * config.mlp_expansion_ratio) + self.mlp = TFEfficientFormerDenseMlp(config, in_features=dim, hidden_features=mlp_hidden_dim, name="mlp") + + # Using `layers.Activation` instead of `tf.identity` to better control `training' behavior. + self.drop_path = ( + TFEfficientFormerDropPath(drop_path) + if drop_path > 0.0 + else keras.layers.Activation("linear", name="drop_path") + ) + self.config = config + + def build(self, input_shape=None): + self.layer_scale_1 = None + self.layer_scale_2 = None + + if self.config.use_layer_scale: + self.layer_scale_1 = self.add_weight( + shape=(self.dim,), + initializer=keras.initializers.Constant(value=self.config.layer_scale_init_value), + trainable=True, + name="layer_scale_1", + ) + self.layer_scale_2 = self.add_weight( + shape=(self.dim,), + initializer=keras.initializers.Constant(value=self.config.layer_scale_init_value), + trainable=True, + name="layer_scale_2", + ) + + if self.built: + return + self.built = True + if getattr(self, "token_mixer", None) is not None: + with tf.name_scope(self.token_mixer.name): + self.token_mixer.build(None) + if getattr(self, "layernorm1", None) is not None: + with tf.name_scope(self.layernorm1.name): + self.layernorm1.build([None, None, self.dim]) + if getattr(self, "layernorm2", None) is not None: + with tf.name_scope(self.layernorm2.name): + self.layernorm2.build([None, None, self.dim]) + if getattr(self, "mlp", None) is not None: + with tf.name_scope(self.mlp.name): + self.mlp.build(None) + if getattr(self, "drop_path", None) is not None: + with tf.name_scope(self.drop_path.name): + self.drop_path.build(None) + + def call( + self, hidden_states: tf.Tensor, output_attentions: bool = False, training: bool = False + ) -> Tuple[tf.Tensor]: + self_attention_outputs = self.token_mixer( + hidden_states=self.layernorm1(hidden_states, training=training), + output_attentions=output_attentions, + training=training, + ) + + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + if self.config.use_layer_scale: + layer_output = hidden_states + self.drop_path( + tf.expand_dims(tf.expand_dims(self.layer_scale_1, 0), 0) * attention_output, + training=training, + ) + layer_output = layer_output + self.drop_path( + tf.expand_dims(tf.expand_dims(self.layer_scale_2, 0), 0) + * self.mlp(hidden_states=self.layernorm2(inputs=layer_output, training=training), training=training), + training=training, + ) + else: + layer_output = hidden_states + self.drop_path(attention_output, training=training) + layer_output = layer_output + self.drop_path( + self.mlp(hidden_states=self.layernorm2(inputs=layer_output, training=training), training=training), + training=training, + ) + + outputs = (layer_output,) + outputs + + return outputs + + +class TFEfficientFormerMeta3DLayers(keras.layers.Layer): + def __init__(self, config: EfficientFormerConfig, **kwargs): + super().__init__(**kwargs) + drop_paths = [ + config.drop_path_rate * (block_idx + sum(config.depths[:-1])) + for block_idx in range(config.num_meta3d_blocks) + ] + self.blocks = [ + TFEfficientFormerMeta3D(config, config.hidden_sizes[-1], drop_path=drop_path, name=f"blocks.{i}") + for i, drop_path in enumerate(drop_paths) + ] + + def call( + self, hidden_states: tf.Tensor, output_attentions: bool = False, training: bool = False + ) -> Tuple[tf.Tensor]: + all_attention_outputs = () if output_attentions else None + + for i, layer_module in enumerate(self.blocks): + if isinstance(hidden_states, tuple): + hidden_states = hidden_states[0] + + hidden_states = layer_module( + hidden_states=hidden_states, output_attentions=output_attentions, training=training + ) + if output_attentions: + all_attention_outputs = all_attention_outputs + (hidden_states[1],) + + if output_attentions: + outputs = (hidden_states[0],) + all_attention_outputs + return outputs + + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "blocks", None) is not None: + for layer in self.blocks: + with tf.name_scope(layer.name): + layer.build(None) + + +class TFEfficientFormerMeta4D(keras.layers.Layer): + def __init__(self, config: EfficientFormerConfig, dim: int, drop_path: float = 0.0, **kwargs): + super().__init__(**kwargs) + pool_size = config.pool_size if config.pool_size is not None else 3 + self.token_mixer = TFEfficientFormerPooling(pool_size=pool_size, name="token_mixer") + self.dim = dim + mlp_hidden_dim = int(dim * config.mlp_expansion_ratio) + self.mlp = TFEfficientFormerConvMlp( + config=config, in_features=dim, hidden_features=mlp_hidden_dim, drop=config.hidden_dropout_prob, name="mlp" + ) + + self.drop_path = ( + TFEfficientFormerDropPath(drop_path, name="drop_path") + if drop_path > 0.0 + else keras.layers.Activation("linear", name="drop_path") + ) + self.config = config + + def build(self, input_shape=None): + self.layer_scale_1 = None + self.layer_scale_2 = None + + if self.config.use_layer_scale: + self.layer_scale_1 = self.add_weight( + shape=(self.dim), + initializer=keras.initializers.Constant(value=self.config.layer_scale_init_value), + trainable=True, + name="layer_scale_1", + ) + self.layer_scale_2 = self.add_weight( + shape=(self.dim), + initializer=keras.initializers.Constant(value=self.config.layer_scale_init_value), + trainable=True, + name="layer_scale_2", + ) + + if self.built: + return + self.built = True + if getattr(self, "token_mixer", None) is not None: + with tf.name_scope(self.token_mixer.name): + self.token_mixer.build(None) + if getattr(self, "mlp", None) is not None: + with tf.name_scope(self.mlp.name): + self.mlp.build(None) + if getattr(self, "drop_path", None) is not None: + with tf.name_scope(self.drop_path.name): + self.drop_path.build(None) + + def call(self, hidden_states: tf.Tensor, training: bool = False) -> Tuple[tf.Tensor]: + outputs = self.token_mixer(hidden_states) + + if self.config.use_layer_scale: + layer_output = hidden_states + self.drop_path( + tf.expand_dims(tf.expand_dims(self.layer_scale_1, 0), 0) * outputs, + training=training, + ) + + layer_output = layer_output + self.drop_path( + tf.expand_dims(tf.expand_dims(self.layer_scale_2, 0), 0) + * self.mlp(hidden_state=layer_output, training=training), + training=training, + ) + + else: + layer_output = hidden_states + self.drop_path(outputs, training=training) + layer_output = layer_output + self.drop_path( + self.mlp(hidden_state=layer_output, training=training), training=training + ) + + return layer_output + + +class TFEfficientFormerMeta4DLayers(keras.layers.Layer): + def __init__(self, config: EfficientFormerConfig, stage_idx: int, **kwargs): + super().__init__(**kwargs) + num_layers = ( + config.depths[stage_idx] if stage_idx != -1 else config.depths[stage_idx] - config.num_meta3d_blocks + ) + drop_paths = [ + config.drop_path_rate * (block_idx + sum(config.depths[:stage_idx])) for block_idx in range(num_layers) + ] + + self.blocks = [ + TFEfficientFormerMeta4D( + config=config, dim=config.hidden_sizes[stage_idx], drop_path=drop_paths[i], name=f"blocks.{i}" + ) + for i in range(len(drop_paths)) + ] + + def call(self, hidden_states: tf.Tensor, training: bool = False) -> Tuple[tf.Tensor]: + for layer_module in self.blocks: + hidden_states = layer_module(hidden_states=hidden_states, training=training) + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "blocks", None) is not None: + for layer in self.blocks: + with tf.name_scope(layer.name): + layer.build(None) + + +class TFEfficientFormerIntermediateStage(keras.layers.Layer): + def __init__(self, config: EfficientFormerConfig, index: int, **kwargs): + super().__init__(**kwargs) + self.meta4D_layers = TFEfficientFormerMeta4DLayers(config=config, stage_idx=index, name="meta4D_layers") + + def call(self, hidden_states: tf.Tensor, training: bool = False) -> Tuple[tf.Tensor]: + hidden_states = self.meta4D_layers(hidden_states=hidden_states, training=training) + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "meta4D_layers", None) is not None: + with tf.name_scope(self.meta4D_layers.name): + self.meta4D_layers.build(None) + + +class TFEfficientFormerLastStage(keras.layers.Layer): + def __init__(self, config: EfficientFormerConfig, **kwargs): + super().__init__(**kwargs) + self.meta4D_layers = TFEfficientFormerMeta4DLayers(config=config, stage_idx=-1, name="meta4D_layers") + self.flat = TFEfficientFormerFlat(name="flat") + self.meta3D_layers = TFEfficientFormerMeta3DLayers(config, name="meta3D_layers") + + def call( + self, hidden_states: tf.Tensor, output_attentions: bool = False, training: bool = False + ) -> Tuple[tf.Tensor]: + hidden_states = self.meta4D_layers(hidden_states=hidden_states, training=training) + hidden_states = self.flat(hidden_states=hidden_states) + hidden_states = self.meta3D_layers( + hidden_states=hidden_states, output_attentions=output_attentions, training=training + ) + + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "meta4D_layers", None) is not None: + with tf.name_scope(self.meta4D_layers.name): + self.meta4D_layers.build(None) + if getattr(self, "flat", None) is not None: + with tf.name_scope(self.flat.name): + self.flat.build(None) + if getattr(self, "meta3D_layers", None) is not None: + with tf.name_scope(self.meta3D_layers.name): + self.meta3D_layers.build(None) + + +class TFEfficientFormerEncoder(keras.layers.Layer): + def __init__(self, config: EfficientFormerConfig, **kwargs): + super().__init__(**kwargs) + + self.config = config + num_intermediate_stages = len(config.depths) - 1 + downsamples = [ + config.downsamples[i] or config.hidden_sizes[i] != config.hidden_sizes[i + 1] + for i in range(num_intermediate_stages) + ] + + intermediate_stages = [] + layer_count = -1 + for i in range(num_intermediate_stages): + layer_count += 1 + intermediate_stages.append( + TFEfficientFormerIntermediateStage(config, i, name=f"intermediate_stages.{layer_count}") + ) + if downsamples[i]: + layer_count += 1 + intermediate_stages.append( + TFEfficientFormerPatchEmbeddings( + config, + config.hidden_sizes[i], + config.hidden_sizes[i + 1], + name=f"intermediate_stages.{layer_count}", + ) + ) + self.intermediate_stages = intermediate_stages + self.last_stage = TFEfficientFormerLastStage(config, name="last_stage") + + def call( + self, + hidden_states: tf.Tensor, + output_hidden_states: bool, + output_attentions: bool, + return_dict: bool, + training: bool = False, + ) -> TFBaseModelOutput: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + for layer_module in self.intermediate_stages: + hidden_states = layer_module(hidden_states, training=training) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_output = self.last_stage(hidden_states, output_attentions=output_attentions, training=training) + + if output_attentions: + all_self_attentions = all_self_attentions + layer_output[1:] + + if output_hidden_states: + all_hidden_states = all_hidden_states + (layer_output[0],) + + if not return_dict: + return tuple(v for v in [layer_output[0], all_hidden_states, all_self_attentions] if v is not None) + + return TFBaseModelOutput( + last_hidden_state=layer_output[0], + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "last_stage", None) is not None: + with tf.name_scope(self.last_stage.name): + self.last_stage.build(None) + for layer in self.intermediate_stages: + with tf.name_scope(layer.name): + layer.build(None) + + +@keras_serializable +class TFEfficientFormerMainLayer(keras.layers.Layer): + config_class = EfficientFormerConfig + + def __init__(self, config: EfficientFormerConfig, **kwargs) -> None: + super().__init__(**kwargs) + self.config = config + + self.patch_embed = TFEfficientFormerConvStem(config, config.hidden_sizes[0], name="patch_embed") + self.encoder = TFEfficientFormerEncoder(config, name="encoder") + self.layernorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm") + + @unpack_inputs + def call( + self, + pixel_values: Optional[tf.Tensor] = None, + output_attentions: Optional[tf.Tensor] = None, + output_hidden_states: Optional[tf.Tensor] = None, + return_dict: Optional[bool] = None, + training: bool = False, + ) -> Union[TFBaseModelOutput, Tuple[tf.Tensor, ...]]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + # When running on CPU, keras.layers.Conv2D and keras.layers.AveragePool2D do not + # support channels first NCHW format. A number of blocks contain both. + # So change the input format from (batch_size, num_channels, height, width) to + # (batch_size, height, width, num_channels) here. + # shape = (batch_size, in_height, in_width, in_channels=num_channels) + pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1)) + embedding_output = self.patch_embed(pixel_values, training=training) + + encoder_outputs = self.encoder( + hidden_states=embedding_output, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output, training=training) + + # Change the hidden states from (batch_size, height, width, num_channels) to + # (batch_size, num_channels, height, width). + # The hidden states are in (batch_size, height, width, num_channels) + # shape after all stages except the MB3D blocks. + if output_hidden_states: + hidden_states = tuple([tf.transpose(h, perm=(0, 3, 1, 2)) for h in encoder_outputs[1][:-1]]) + ( + encoder_outputs[1][-1], + ) + + if not return_dict: + head_outputs = (sequence_output,) + return head_outputs + encoder_outputs[1:] + + return TFBaseModelOutput( + last_hidden_state=sequence_output, + hidden_states=hidden_states if output_hidden_states else encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "patch_embed", None) is not None: + with tf.name_scope(self.patch_embed.name): + self.patch_embed.build(None) + if getattr(self, "encoder", None) is not None: + with tf.name_scope(self.encoder.name): + self.encoder.build(None) + if getattr(self, "layernorm", None) is not None: + with tf.name_scope(self.layernorm.name): + self.layernorm.build([None, None, self.config.hidden_sizes[-1]]) + + +class TFEfficientFormerPreTrainedModel(TFPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = EfficientFormerConfig + base_model_prefix = "efficientformer" + main_input_name = "pixel_values" + + +EFFICIENTFORMER_START_DOCSTRING = r""" + This model is a TensorFlow + [keras.layers.Layer](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer). Use it as a regular + TensorFlow Module and refer to the TensorFlow documentation for all matter related to general usage and behavior. + + + Parameters: + config ([`EfficientFormerConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +EFFICIENTFORMER_INPUTS_DOCSTRING = r""" + Args: + pixel_values ((`tf.Tensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See + [`EfficientFormerImageProcessor.__call__`] for details. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare EfficientFormer Model transformer outputting raw hidden-states without any specific head on top.", + EFFICIENTFORMER_START_DOCSTRING, +) +class TFEfficientFormerModel(TFEfficientFormerPreTrainedModel): + def __init__(self, config: EfficientFormerConfig, **kwargs) -> None: + super().__init__(config, **kwargs) + + self.efficientformer = TFEfficientFormerMainLayer(config, name="efficientformer") + + @unpack_inputs + @add_start_docstrings_to_model_forward(EFFICIENTFORMER_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFBaseModelOutputWithPooling, + config_class=_CONFIG_FOR_DOC, + modality="vision", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def call( + self, + pixel_values: Optional[tf.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: bool = False, + ) -> Union[Tuple, TFBaseModelOutput]: + outputs = self.efficientformer( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "efficientformer", None) is not None: + with tf.name_scope(self.efficientformer.name): + self.efficientformer.build(None) + + +@add_start_docstrings( + """ + EfficientFormer Model transformer with an image classification head on top of pooled last hidden state, e.g. for + ImageNet. + """, + EFFICIENTFORMER_START_DOCSTRING, +) +class TFEfficientFormerForImageClassification(TFEfficientFormerPreTrainedModel, TFSequenceClassificationLoss): + def __init__(self, config: EfficientFormerConfig): + super().__init__(config) + + self.num_labels = config.num_labels + self.efficientformer = TFEfficientFormerMainLayer(config, name="efficientformer") + + # Classifier head + self.classifier = ( + keras.layers.Dense(config.num_labels, name="classifier") + if config.num_labels > 0 + else keras.layers.Activation("linear", name="classifier") + ) + self.config = config + + @unpack_inputs + @add_start_docstrings_to_model_forward(EFFICIENTFORMER_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_IMAGE_CLASS_CHECKPOINT, + output_type=TFImageClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, + ) + def call( + self, + pixel_values: Optional[tf.Tensor] = None, + labels: Optional[tf.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: bool = False, + ) -> Union[tf.Tensor, TFImageClassifierOutput]: + r""" + labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.efficientformer( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + sequence_output = outputs[0] + + logits = self.classifier(tf.reduce_mean(sequence_output, axis=-2)) + + loss = None if labels is None else self.hf_compute_loss(labels, logits) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return TFImageClassifierOutput( + loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "efficientformer", None) is not None: + with tf.name_scope(self.efficientformer.name): + self.efficientformer.build(None) + if getattr(self, "classifier", None) is not None: + if hasattr(self.classifier, "name"): + with tf.name_scope(self.classifier.name): + self.classifier.build([None, None, self.config.hidden_sizes[-1]]) + + +@dataclass +class TFEfficientFormerForImageClassificationWithTeacherOutput(ModelOutput): + """ + Args: + Output type of [`EfficientFormerForImageClassificationWithTeacher`]. + logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`): + Prediction scores as the average of the cls_logits and distillation logits. + cls_logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`): + Prediction scores of the classification head (i.e. the linear layer on top of the final hidden state of the + class token). + distillation_logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`): + Prediction scores of the distillation head (i.e. the linear layer on top of the final hidden state of the + distillation token). + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when + `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus + the initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when + `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in + the self-attention heads. + """ + + logits: Optional[tf.Tensor] = None + cls_logits: Optional[tf.Tensor] = None + distillation_logits: Optional[tf.Tensor] = None + hidden_states: Optional[Tuple[tf.Tensor]] = None + attentions: Optional[Tuple[tf.Tensor]] = None + + +@add_start_docstrings( + """ + EfficientFormer Model transformer with image classification heads on top (a linear layer on top of the final hidden + state and a linear layer on top of the final hidden state of the distillation token) e.g. for ImageNet. + + .. warning:: + This model supports inference-only. Fine-tuning with distillation (i.e. with a teacher) is not yet + supported. + """, + EFFICIENTFORMER_START_DOCSTRING, +) +class TFEfficientFormerForImageClassificationWithTeacher(TFEfficientFormerPreTrainedModel): + def __init__(self, config: EfficientFormerConfig) -> None: + super().__init__(config) + + self.num_labels = config.num_labels + self.efficientformer = TFEfficientFormerMainLayer(config, name="efficientformer") + + # Classifier heads + self.classifier = ( + keras.layers.Dense(config.num_labels, name="classifier") + if config.num_labels > 0 + else keras.layers.Activation("linear", name="classifier") + ) + self.distillation_classifier = ( + keras.layers.Dense(config.num_labels, name="distillation_classifier") + if config.num_labels > 0 + else keras.layers.Activation("linear", name="distillation_classifier") + ) + + @unpack_inputs + @add_start_docstrings_to_model_forward(EFFICIENTFORMER_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_IMAGE_CLASS_CHECKPOINT, + output_type=TFEfficientFormerForImageClassificationWithTeacherOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, + ) + def call( + self, + pixel_values: Optional[tf.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: bool = False, + ) -> Union[tuple, TFEfficientFormerForImageClassificationWithTeacherOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if training: + raise Exception( + "This model supports inference-only. Fine-tuning with distillation (i.e. with a teacher) is not yet supported." + ) + + outputs = self.efficientformer( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + sequence_output = outputs[0] + + cls_logits = self.classifier(tf.reduce_mean(sequence_output, axis=-2)) + distillation_logits = self.distillation_classifier(tf.reduce_mean(sequence_output, axis=-2)) + logits = (cls_logits + distillation_logits) / 2 + + if not return_dict: + output = (logits, cls_logits, distillation_logits) + outputs[1:] + return output + + return TFEfficientFormerForImageClassificationWithTeacherOutput( + logits=logits, + cls_logits=cls_logits, + distillation_logits=distillation_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "efficientformer", None) is not None: + with tf.name_scope(self.efficientformer.name): + self.efficientformer.build(None) + if getattr(self, "classifier", None) is not None: + if hasattr(self.classifier, "name"): + with tf.name_scope(self.classifier.name): + self.classifier.build([None, None, self.config.hidden_sizes[-1]]) + if getattr(self, "distillation_classifier", None) is not None: + if hasattr(self.distillation_classifier, "name"): + with tf.name_scope(self.distillation_classifier.name): + self.distillation_classifier.build([None, None, self.config.hidden_sizes[-1]]) + + +__all__ = [ + "TFEfficientFormerForImageClassification", + "TFEfficientFormerForImageClassificationWithTeacher", + "TFEfficientFormerModel", + "TFEfficientFormerPreTrainedModel", +] diff --git a/docs/transformers/build/lib/transformers/models/deprecated/ernie_m/__init__.py b/docs/transformers/build/lib/transformers/models/deprecated/ernie_m/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2beb8f463ff10af33ea980599ea4d5fc05888acb --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/ernie_m/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2023 The HuggingFace and Baidu Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ....utils import _LazyModule +from ....utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_ernie_m import * + from .modeling_ernie_m import * + from .tokenization_ernie_m import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/build/lib/transformers/models/deprecated/ernie_m/configuration_ernie_m.py b/docs/transformers/build/lib/transformers/models/deprecated/ernie_m/configuration_ernie_m.py new file mode 100644 index 0000000000000000000000000000000000000000..7a45106131850a82cb169fb01cecf96ff5f3e1a2 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/ernie_m/configuration_ernie_m.py @@ -0,0 +1,114 @@ +# coding=utf-8 +# Copyright 2023 Xuan Ouyang, Shuohuan Wang, Chao Pang, Yu Sun, Hao Tian, Hua Wu, Haifeng Wang and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""ErnieM model configuration""" +# Adapted from original paddlenlp repository.(https://github.com/PaddlePaddle/PaddleNLP/blob/develop/paddlenlp/transformers/ernie_m/configuration.py) + +from __future__ import annotations + +from typing import Dict + +from ....configuration_utils import PretrainedConfig + + +class ErnieMConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`ErnieMModel`]. It is used to instantiate a + Ernie-M model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the `Ernie-M` + [susnato/ernie-m-base_pytorch](https://huggingface.co/susnato/ernie-m-base_pytorch) architecture. + + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 250002): + Vocabulary size of `inputs_ids` in [`ErnieMModel`]. Also is the vocab size of token embedding matrix. + Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling + [`ErnieMModel`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the embedding layer, encoder layers and pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the feed-forward (ff) layer in the encoder. Input tensors to feed-forward layers are + firstly projected from hidden_size to intermediate_size, and then projected back to hidden_size. Typically + intermediate_size is larger than hidden_size. + hidden_act (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function in the feed-forward layer. `"gelu"`, `"relu"` and any other torch + supported activation functions are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings and encoder. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability used in `MultiHeadAttention` in all encoder layers to drop some attention target. + max_position_embeddings (`int`, *optional*, defaults to 514): + The maximum value of the dimensionality of position encoding, which dictates the maximum supported length + of an input sequence. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the normal initializer for initializing all weight matrices. The index of padding + token in the token vocabulary. + pad_token_id (`int`, *optional*, defaults to 1): + Padding token id. + layer_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the layer normalization layers. + classifier_dropout (`float`, *optional*): + The dropout ratio for the classification head. + act_dropout (`float`, *optional*, defaults to 0.0): + This dropout probability is used in `ErnieMEncoderLayer` after activation. + + A normal_initializer initializes weight matrices as normal distributions. See + `ErnieMPretrainedModel._init_weights()` for how weights are initialized in `ErnieMModel`. + """ + + model_type = "ernie_m" + attribute_map: Dict[str, str] = {"dropout": "classifier_dropout", "num_classes": "num_labels"} + + def __init__( + self, + vocab_size: int = 250002, + hidden_size: int = 768, + num_hidden_layers: int = 12, + num_attention_heads: int = 12, + intermediate_size: int = 3072, + hidden_act: str = "gelu", + hidden_dropout_prob: float = 0.1, + attention_probs_dropout_prob: float = 0.1, + max_position_embeddings: int = 514, + initializer_range: float = 0.02, + pad_token_id: int = 1, + layer_norm_eps: float = 1e-05, + classifier_dropout=None, + act_dropout=0.0, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, **kwargs) + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.classifier_dropout = classifier_dropout + self.act_dropout = act_dropout + + +__all__ = ["ErnieMConfig"] diff --git a/docs/transformers/build/lib/transformers/models/deprecated/ernie_m/modeling_ernie_m.py b/docs/transformers/build/lib/transformers/models/deprecated/ernie_m/modeling_ernie_m.py new file mode 100644 index 0000000000000000000000000000000000000000..28c17afa3f7a2ca10d340794ac0b0a6b28e27915 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/ernie_m/modeling_ernie_m.py @@ -0,0 +1,1058 @@ +# coding=utf-8 +# Copyright 2023 Xuan Ouyang, Shuohuan Wang, Chao Pang, Yu Sun, Hao Tian, Hua Wu, Haifeng Wang The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch ErnieM model.""" + +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn, tensor +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ....activations import ACT2FN +from ....modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + MultipleChoiceModelOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from ....modeling_utils import PreTrainedModel +from ....pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ....utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging +from .configuration_ernie_m import ErnieMConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "susnato/ernie-m-base_pytorch" +_CONFIG_FOR_DOC = "ErnieMConfig" +_TOKENIZER_FOR_DOC = "ErnieMTokenizer" + + +# Adapted from paddlenlp.transformers.ernie_m.modeling.ErnieEmbeddings +class ErnieMEmbeddings(nn.Module): + """Construct the embeddings from word and position embeddings.""" + + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, config.hidden_size, padding_idx=config.pad_token_id + ) + self.layer_norm = nn.LayerNorm(normalized_shape=config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(p=config.hidden_dropout_prob) + self.padding_idx = config.pad_token_id + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + past_key_values_length: int = 0, + ) -> torch.Tensor: + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + if position_ids is None: + input_shape = inputs_embeds.size()[:-1] + ones = torch.ones(input_shape, dtype=torch.int64, device=inputs_embeds.device) + seq_length = torch.cumsum(ones, dim=1) + position_ids = seq_length - ones + + if past_key_values_length > 0: + position_ids = position_ids + past_key_values_length + # to mimic paddlenlp implementation + position_ids += 2 + position_embeddings = self.position_embeddings(position_ids) + embeddings = inputs_embeds + position_embeddings + embeddings = self.layer_norm(embeddings) + embeddings = self.dropout(embeddings) + + return embeddings + + +class ErnieMSelfAttention(nn.Module): + def __init__(self, config, position_embedding_type=None): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.q_proj = nn.Linear(config.hidden_size, self.all_head_size) + self.k_proj = nn.Linear(config.hidden_size, self.all_head_size) + self.v_proj = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = position_embedding_type or getattr( + config, "position_embedding_type", "absolute" + ) + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) + + self.is_decoder = config.is_decoder + + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + mixed_query_layer = self.q_proj(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_layer = past_key_value[0] + value_layer = past_key_value[1] + attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.transpose_for_scores(self.k_proj(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.v_proj(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.k_proj(hidden_states)) + value_layer = self.transpose_for_scores(self.v_proj(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.k_proj(hidden_states)) + value_layer = self.transpose_for_scores(self.v_proj(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + use_cache = past_key_value is not None + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + query_length, key_length = query_layer.shape[2], key_layer.shape[2] + if use_cache: + position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( + -1, 1 + ) + else: + position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in ErnieMModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + +class ErnieMAttention(nn.Module): + def __init__(self, config, position_embedding_type=None): + super().__init__() + self.self_attn = ErnieMSelfAttention(config, position_embedding_type=position_embedding_type) + self.out_proj = nn.Linear(config.hidden_size, config.hidden_size) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self_attn.num_attention_heads, self.self_attn.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self_attn.q_proj = prune_linear_layer(self.self_attn.q_proj, index) + self.self_attn.k_proj = prune_linear_layer(self.self_attn.k_proj, index) + self.self_attn.v_proj = prune_linear_layer(self.self_attn.v_proj, index) + self.out_proj = prune_linear_layer(self.out_proj, index, dim=1) + + # Update hyper params and store pruned heads + self.self_attn.num_attention_heads = self.self_attn.num_attention_heads - len(heads) + self.self_attn.all_head_size = self.self_attn.attention_head_size * self.self_attn.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + self_outputs = self.self_attn( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.out_proj(self_outputs[0]) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class ErnieMEncoderLayer(nn.Module): + def __init__(self, config): + super().__init__() + # to mimic paddlenlp implementation + dropout = 0.1 if config.hidden_dropout_prob is None else config.hidden_dropout_prob + act_dropout = config.hidden_dropout_prob if config.act_dropout is None else config.act_dropout + + self.self_attn = ErnieMAttention(config) + self.linear1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.dropout = nn.Dropout(act_dropout) + self.linear2 = nn.Linear(config.intermediate_size, config.hidden_size) + self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + if isinstance(config.hidden_act, str): + self.activation = ACT2FN[config.hidden_act] + else: + self.activation = config.hidden_act + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = True, + ): + residual = hidden_states + if output_attentions: + hidden_states, attention_opt_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + ) + + else: + hidden_states = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + ) + hidden_states = residual + self.dropout1(hidden_states) + hidden_states = self.norm1(hidden_states) + residual = hidden_states + + hidden_states = self.linear1(hidden_states) + hidden_states = self.activation(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.linear2(hidden_states) + hidden_states = residual + self.dropout2(hidden_states) + hidden_states = self.norm2(hidden_states) + + if output_attentions: + return hidden_states, attention_opt_weights + else: + return hidden_states + + +class ErnieMEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layers = nn.ModuleList([ErnieMEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + + def forward( + self, + input_embeds: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: + hidden_states = () if output_hidden_states else None + attentions = () if output_attentions else None + + output = input_embeds + if output_hidden_states: + hidden_states = hidden_states + (output,) + for i, layer in enumerate(self.layers): + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + output, opt_attn_weights = layer( + hidden_states=output, + attention_mask=attention_mask, + head_mask=layer_head_mask, + past_key_value=past_key_value, + ) + + if output_hidden_states: + hidden_states = hidden_states + (output,) + if output_attentions: + attentions = attentions + (opt_attn_weights,) + + last_hidden_state = output + if not return_dict: + return tuple(v for v in [last_hidden_state, hidden_states, attentions] if v is not None) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=last_hidden_state, hidden_states=hidden_states, attentions=attentions + ) + + +class ErnieMPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class ErnieMPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = ErnieMConfig + base_model_prefix = "ernie_m" + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +ERNIE_M_START_DOCSTRING = r""" + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use + it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`ErnieMConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +ERNIE_M_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`ErnieMTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert *input_ids* indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare ErnieM Model transformer outputting raw hidden-states without any specific head on top.", + ERNIE_M_START_DOCSTRING, +) +class ErnieMModel(ErnieMPreTrainedModel): + def __init__(self, config, add_pooling_layer=True): + super(ErnieMModel, self).__init__(config) + self.initializer_range = config.initializer_range + self.embeddings = ErnieMEmbeddings(config) + self.encoder = ErnieMEncoder(config) + self.pooler = ErnieMPooler(config) if add_pooling_layer else None + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layers[layer].self_attn.prune_heads(heads) + + @add_start_docstrings_to_model_forward(ERNIE_M_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPastAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[tensor] = None, + position_ids: Optional[tensor] = None, + attention_mask: Optional[tensor] = None, + head_mask: Optional[tensor] = None, + inputs_embeds: Optional[tensor] = None, + past_key_values: Optional[Tuple[Tuple[tensor]]] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.FloatTensor], BaseModelOutputWithPoolingAndCrossAttentions]: + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time.") + + # init the default bool value + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + + # Adapted from paddlenlp.transformers.ernie_m.ErnieMModel + if attention_mask is None: + attention_mask = (input_ids == self.config.pad_token_id).to(torch.float32) + attention_mask *= torch.finfo(attention_mask.dtype).min + if past_key_values is not None: + batch_size = past_key_values[0][0].shape[0] + past_mask = torch.zeros([batch_size, 1, 1, past_key_values_length], dtype=attention_mask.dtype) + attention_mask = torch.concat([past_mask, attention_mask], dim=-1) + # For 2D attention_mask from tokenizer + elif attention_mask.ndim == 2: + attention_mask = attention_mask.to(torch.float32) + attention_mask = 1.0 - attention_mask + attention_mask *= torch.finfo(attention_mask.dtype).min + + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + past_key_values=past_key_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + sequence_output = encoder_outputs[0] + pooler_output = self.pooler(sequence_output) if self.pooler is not None else None + return (sequence_output, pooler_output) + encoder_outputs[1:] + + sequence_output = encoder_outputs["last_hidden_state"] + pooler_output = self.pooler(sequence_output) if self.pooler is not None else None + hidden_states = None if not output_hidden_states else encoder_outputs["hidden_states"] + attentions = None if not output_attentions else encoder_outputs["attentions"] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooler_output, + hidden_states=hidden_states, + attentions=attentions, + ) + + +@add_start_docstrings( + """ErnieM Model transformer with a sequence classification/regression head on top (a linear layer on top of + the pooled output) e.g. for GLUE tasks.""", + ERNIE_M_START_DOCSTRING, +) +class ErnieMForSequenceClassification(ErnieMPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + + self.ernie_m = ErnieMModel(config) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(ERNIE_M_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=SequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.Tensor]] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = True, + labels: Optional[torch.Tensor] = None, + ) -> Union[Tuple[torch.FloatTensor], SequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.ernie_m( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ErnieM Model with a multiple choice classification head on top (a linear layer on top of + the pooled output and a softmax) e.g. for RocStories/SWAG tasks.""", + ERNIE_M_START_DOCSTRING, +) +class ErnieMForMultipleChoice(ErnieMPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.ernie_m = ErnieMModel(config) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, 1) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(ERNIE_M_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MultipleChoiceModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = True, + ) -> Union[Tuple[torch.FloatTensor], MultipleChoiceModelOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., + num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See + `input_ids` above) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] + + input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) + + outputs = self.ernie_m( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + reshaped_logits = logits.view(-1, num_choices) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + + if not return_dict: + output = (reshaped_logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return MultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ErnieM Model with a token classification head on top (a linear layer on top of + the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks.""", + ERNIE_M_START_DOCSTRING, +) +class ErnieMForTokenClassification(ErnieMPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.ernie_m = ErnieMModel(config, add_pooling_layer=False) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(ERNIE_M_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.Tensor]] = None, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = True, + labels: Optional[torch.Tensor] = None, + ) -> Union[Tuple[torch.FloatTensor], TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.ernie_m( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ErnieM Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear + layers on top of the hidden-states output to compute `span start logits` and `span end logits`).""", + ERNIE_M_START_DOCSTRING, +) +class ErnieMForQuestionAnswering(ErnieMPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.ernie_m = ErnieMModel(config, add_pooling_layer=False) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(ERNIE_M_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=QuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + start_positions: Optional[torch.Tensor] = None, + end_positions: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = True, + ) -> Union[Tuple[torch.FloatTensor], QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.ernie_m( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ErnieMForInformationExtraction is a Ernie-M Model with two linear layer on top of the hidden-states output to + compute `start_prob` and `end_prob`, designed for Universal Information Extraction.""", + ERNIE_M_START_DOCSTRING, +) +class ErnieMForInformationExtraction(ErnieMPreTrainedModel): + def __init__(self, config): + super(ErnieMForInformationExtraction, self).__init__(config) + self.ernie_m = ErnieMModel(config) + self.linear_start = nn.Linear(config.hidden_size, 1) + self.linear_end = nn.Linear(config.hidden_size, 1) + self.sigmoid = nn.Sigmoid() + self.post_init() + + @add_start_docstrings_to_model_forward(ERNIE_M_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + start_positions: Optional[torch.Tensor] = None, + end_positions: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = True, + ) -> Union[Tuple[torch.FloatTensor], QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for position (index) for computing the start_positions loss. Position outside of the sequence are + not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) for computing the end_positions loss. Position outside of the sequence are not + taken into account for computing the loss. + """ + + result = self.ernie_m( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + if return_dict: + sequence_output = result.last_hidden_state + elif not return_dict: + sequence_output = result[0] + + start_logits = self.linear_start(sequence_output) + start_logits = start_logits.squeeze(-1) + end_logits = self.linear_end(sequence_output) + end_logits = end_logits.squeeze(-1) + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = BCEWithLogitsLoss() + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + return tuple( + i + for i in [total_loss, start_logits, end_logits, result.hidden_states, result.attentions] + if i is not None + ) + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=result.hidden_states, + attentions=result.attentions, + ) + + +__all__ = [ + "ErnieMForMultipleChoice", + "ErnieMForQuestionAnswering", + "ErnieMForSequenceClassification", + "ErnieMForTokenClassification", + "ErnieMModel", + "ErnieMPreTrainedModel", + "ErnieMForInformationExtraction", +] diff --git a/docs/transformers/build/lib/transformers/models/deprecated/ernie_m/tokenization_ernie_m.py b/docs/transformers/build/lib/transformers/models/deprecated/ernie_m/tokenization_ernie_m.py new file mode 100644 index 0000000000000000000000000000000000000000..44bc197a4f7c463c71eb64dd941dacce51148675 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/ernie_m/tokenization_ernie_m.py @@ -0,0 +1,410 @@ +# coding=utf-8 +# Copyright 2023 Xuan Ouyang, Shuohuan Wang, Chao Pang, Yu Sun, Hao Tian, Hua Wu, Haifeng Wang and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for Ernie-M.""" + +import io +import os +import unicodedata +from typing import Any, Dict, List, Optional, Tuple + +import sentencepiece as spm + +from ....tokenization_utils import PreTrainedTokenizer +from ....utils import logging +from ....utils.import_utils import requires + + +logger = logging.get_logger(__name__) + +SPIECE_UNDERLINE = "▁" + +VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "sentencepiece_model_ckpt": "sentencepiece.bpe.model"} + +RESOURCE_FILES_NAMES = { + "sentencepiece_model_file": "sentencepiece.bpe.model", + "vocab_file": "vocab.txt", +} + + +# Adapted from paddlenlp.transformers.ernie_m.tokenizer.ErnieMTokenizer +@requires(backends=("sentencepiece",)) +class ErnieMTokenizer(PreTrainedTokenizer): + r""" + Constructs a Ernie-M tokenizer. It uses the `sentencepiece` tools to cut the words to sub-words. + + Args: + sentencepiece_model_file (`str`): + The file path of sentencepiece model. + vocab_file (`str`, *optional*): + The file path of the vocabulary. + do_lower_case (`str`, *optional*, defaults to `True`): + Whether or not to lowercase the input when tokenizing. + unk_token (`str`, *optional*, defaults to `"[UNK]"`): + A special token representing the `unknown (out-of-vocabulary)` token. An unknown token is set to be + `unk_token` inorder to be converted to an ID. + sep_token (`str`, *optional*, defaults to `"[SEP]"`): + A special token separating two different sentences in the same input. + pad_token (`str`, *optional*, defaults to `"[PAD]"`): + A special token used to make arrays of tokens the same size for batching purposes. + cls_token (`str`, *optional*, defaults to `"[CLS]"`): + A special token used for sequence classification. It is the last token of the sequence when built with + special tokens. + mask_token (`str`, *optional*, defaults to `"[MASK]"`): + A special token representing a masked token. This is the token used in the masked language modeling task + which the model tries to predict the original unmasked ones. + """ + + # Ernie-M model doesn't have token_type embedding. + model_input_names: List[str] = ["input_ids"] + + vocab_files_names = VOCAB_FILES_NAMES + resource_files_names = RESOURCE_FILES_NAMES + + def __init__( + self, + sentencepiece_model_ckpt, + vocab_file=None, + do_lower_case=False, + encoding="utf8", + unk_token="[UNK]", + sep_token="[SEP]", + pad_token="[PAD]", + cls_token="[CLS]", + mask_token="[MASK]", + sp_model_kwargs: Optional[Dict[str, Any]] = None, + **kwargs, + ) -> None: + # Mask token behave like a normal word, i.e. include the space before it and + # is included in the raw text, there should be a match in a non-normalized sentence. + + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + + self.do_lower_case = do_lower_case + self.sentencepiece_model_ckpt = sentencepiece_model_ckpt + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(sentencepiece_model_ckpt) + + # to mimic paddlenlp.transformers.ernie_m.tokenizer.ErnieMTokenizer functioning + if vocab_file is not None: + self.vocab = self.load_vocab(filepath=vocab_file) + else: + self.vocab = {self.sp_model.id_to_piece(id): id for id in range(self.sp_model.get_piece_size())} + self.reverse_vocab = {v: k for k, v in self.vocab.items()} + + super().__init__( + do_lower_case=do_lower_case, + unk_token=unk_token, + sep_token=sep_token, + pad_token=pad_token, + cls_token=cls_token, + mask_token=mask_token, + vocab_file=vocab_file, + encoding=encoding, + sp_model_kwargs=self.sp_model_kwargs, + **kwargs, + ) + + def get_offset_mapping(self, text): + if text is None: + return None + + split_tokens = self.tokenize(text) + normalized_text, char_mapping = "", [] + + for i, ch in enumerate(text): + if ch in self.SP_CHAR_MAPPING: + ch = self.SP_CHAR_MAPPING.get(ch) + else: + ch = unicodedata.normalize("NFKC", ch) + if self.is_whitespace(ch): + continue + normalized_text += ch + char_mapping.extend([i] * len(ch)) + + text, token_mapping, offset = normalized_text, [], 0 + + if self.do_lower_case: + text = text.lower() + + for token in split_tokens: + if token[:1] == "▁": + token = token[1:] + start = text[offset:].index(token) + offset + end = start + len(token) + + token_mapping.append((char_mapping[start], char_mapping[end - 1] + 1)) + offset = end + return token_mapping + + @property + def vocab_size(self): + return len(self.vocab) + + def get_vocab(self): + return dict(self.vocab, **self.added_tokens_encoder) + + def __getstate__(self): + state = self.__dict__.copy() + state["sp_model"] = None + return state + + def __setstate__(self, d): + self.__dict__ = d + + # for backward compatibility + if not hasattr(self, "sp_model_kwargs"): + self.sp_model_kwargs = {} + + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(self.sentencepiece_model_ckpt) + + def clean_text(self, text): + """Performs invalid character removal and whitespace cleanup on text.""" + return "".join((self.SP_CHAR_MAPPING.get(c, c) for c in text)) + + def _tokenize(self, text, enable_sampling=False, nbest_size=64, alpha=0.1): + """Tokenize a string.""" + + if self.sp_model_kwargs.get("enable_sampling") is True: + enable_sampling = True + if self.sp_model_kwargs.get("alpha") is not None: + alpha = self.sp_model_kwargs.get("alpha") + if self.sp_model_kwargs.get("nbest_size") is not None: + nbest_size = self.sp_model_kwargs.get("nbest_size") + + if not enable_sampling: + pieces = self.sp_model.EncodeAsPieces(text) + else: + pieces = self.sp_model.SampleEncodeAsPieces(text, nbest_size, alpha) + new_pieces = [] + for pi, piece in enumerate(pieces): + if piece == SPIECE_UNDERLINE: + if not pieces[pi + 1].startswith(SPIECE_UNDERLINE) and pi != 0: + new_pieces.append(SPIECE_UNDERLINE) + continue + else: + continue + lst_i = 0 + for i, chunk in enumerate(piece): + if chunk == SPIECE_UNDERLINE: + continue + if self.is_ch_char(chunk) or self.is_punct(chunk): + if i > lst_i and piece[lst_i:i] != SPIECE_UNDERLINE: + new_pieces.append(piece[lst_i:i]) + new_pieces.append(chunk) + lst_i = i + 1 + elif chunk.isdigit() and i > 0 and not piece[i - 1].isdigit(): + if i > lst_i and piece[lst_i:i] != SPIECE_UNDERLINE: + new_pieces.append(piece[lst_i:i]) + lst_i = i + elif not chunk.isdigit() and i > 0 and piece[i - 1].isdigit(): + if i > lst_i and piece[lst_i:i] != SPIECE_UNDERLINE: + new_pieces.append(piece[lst_i:i]) + lst_i = i + if len(piece) > lst_i: + new_pieces.append(piece[lst_i:]) + return new_pieces + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (strings for sub-words) in a single string.""" + out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip() + return out_string + + def convert_ids_to_string(self, ids): + """ + Converts a sequence of tokens (strings for sub-words) in a single string. + """ + tokens = self.convert_ids_to_tokens(ids) + out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip() + return out_string + + # to mimic paddlenlp.transformers.ernie_m.tokenizer.ErnieMTokenizer functioning + def _convert_token_to_id(self, token): + return self.vocab.get(token, self.vocab.get(self.unk_token)) + + # to mimic paddlenlp.transformers.ernie_m.tokenizer.ErnieMTokenizer functioning + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.reverse_vocab.get(index, self.unk_token) + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + r""" + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. An ErnieM sequence has the following format: + + - single sequence: `[CLS] X [SEP]` + - pair of sequences: `[CLS] A [SEP] [SEP] B [SEP]` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + Returns: + `List[int]`: List of input_id with the appropriate special tokens. + """ + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + _cls = [self.cls_token_id] + _sep = [self.sep_token_id] + return _cls + token_ids_0 + _sep + _sep + token_ids_1 + _sep + + def build_offset_mapping_with_special_tokens(self, offset_mapping_0, offset_mapping_1=None): + r""" + Build offset map from a pair of offset map by concatenating and adding offsets of special tokens. An Ernie-M + offset_mapping has the following format: + + - single sequence: `(0,0) X (0,0)` + - pair of sequences: `(0,0) A (0,0) (0,0) B (0,0)` + + Args: + offset_mapping_ids_0 (`List[tuple]`): + List of char offsets to which the special tokens will be added. + offset_mapping_ids_1 (`List[tuple]`, *optional*): + Optional second list of wordpiece offsets for offset mapping pairs. + Returns: + `List[tuple]`: List of wordpiece offsets with the appropriate offsets of special tokens. + """ + if offset_mapping_1 is None: + return [(0, 0)] + offset_mapping_0 + [(0, 0)] + + return [(0, 0)] + offset_mapping_0 + [(0, 0), (0, 0)] + offset_mapping_1 + [(0, 0)] + + def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False): + r""" + Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `encode` method. + + Args: + token_ids_0 (`List[int]`): + List of ids of the first sequence. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`str`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + Returns: + `List[int]`: + The list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + + if already_has_special_tokens: + if token_ids_1 is not None: + raise ValueError( + "You should not supply a second sequence if the provided sequence of " + "ids is already formatted with special tokens for the model." + ) + return [1 if x in [self.sep_token_id, self.cls_token_id] else 0 for x in token_ids_0] + + if token_ids_1 is not None: + return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1] + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create the token type IDs corresponding to the sequences passed. [What are token type + IDs?](../glossary#token-type-ids) Should be overridden in a subclass if the model has a special way of + building: those. + + Args: + token_ids_0 (`List[int]`): + The first tokenized sequence. + token_ids_1 (`List[int]`, *optional*): + The second tokenized sequence. + Returns: + `List[int]`: The token type ids. + """ + # called when `add_special_tokens` is True, so align with `build_inputs_with_special_tokens` method + if token_ids_1 is None: + # [CLS] X [SEP] + return (len(token_ids_0) + 2) * [0] + + # [CLS] A [SEP] [SEP] B [SEP] + return [0] * (len(token_ids_0) + 1) + [1] * (len(token_ids_1) + 3) + + def is_ch_char(self, char): + """ + is_ch_char + """ + if "\u4e00" <= char <= "\u9fff": + return True + return False + + def is_alpha(self, char): + """ + is_alpha + """ + if ("a" <= char <= "z") or ("A" <= char <= "Z"): + return True + return False + + def is_punct(self, char): + """ + is_punct + """ + if char in ",;:.?!~,;:。?!《》【】": + return True + return False + + def is_whitespace(self, char): + """ + is whitespace + """ + if char == " " or char == "\t" or char == "\n" or char == "\r": + return True + if len(char) == 1: + cat = unicodedata.category(char) + if cat == "Zs": + return True + return False + + def load_vocab(self, filepath): + token_to_idx = {} + with io.open(filepath, "r", encoding="utf-8") as f: + for index, line in enumerate(f): + token = line.rstrip("\n") + token_to_idx[token] = int(index) + + return token_to_idx + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + index = 0 + if os.path.isdir(save_directory): + vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + else: + vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory + with open(vocab_file, "w", encoding="utf-8") as writer: + for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning( + f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive." + " Please check that the vocabulary is not corrupted!" + ) + index = token_index + writer.write(token + "\n") + index += 1 + + tokenizer_model_file = os.path.join(save_directory, "sentencepiece.bpe.model") + with open(tokenizer_model_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + + return (vocab_file,) + + +__all__ = ["ErnieMTokenizer"] diff --git a/docs/transformers/build/lib/transformers/models/deprecated/gptsan_japanese/__init__.py b/docs/transformers/build/lib/transformers/models/deprecated/gptsan_japanese/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3c23b58f35012f1ddc00e2275c484810287de6f8 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/gptsan_japanese/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ....utils import _LazyModule +from ....utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_gptsan_japanese import * + from .modeling_gptsan_japanese import * + from .tokenization_gptsan_japanese import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/build/lib/transformers/models/deprecated/gptsan_japanese/configuration_gptsan_japanese.py b/docs/transformers/build/lib/transformers/models/deprecated/gptsan_japanese/configuration_gptsan_japanese.py new file mode 100644 index 0000000000000000000000000000000000000000..cd565810095955c1d7a6e199db93d2ef1a499173 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/gptsan_japanese/configuration_gptsan_japanese.py @@ -0,0 +1,157 @@ +# coding=utf-8 +# Copyright 2023, HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""GPTSAN-japanese model configuration""" + +from ....configuration_utils import PretrainedConfig +from ....utils import logging + + +logger = logging.get_logger(__name__) + + +class GPTSanJapaneseConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`GPTSanJapaneseModel`]. It is used to instantiate + a GPTSANJapanese model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the GPTSANJapanese + [Tanrei/GPTSAN-japanese](https://huggingface.co/Tanrei/GPTSAN-japanese) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Arguments: + vocab_size (`int`, *optional*, defaults to 36000): + Vocabulary size of the GPTSANJapanese model. Defines the number of different tokens that can be represented + by the `inputs_ids` passed when calling [`GPTSanJapaneseModel`]. + max_position_embeddings (`int`, *optional*, defaults to 1280): + The maximum sequence length that this model might ever be used with. Defaults set this to 1280. + d_model (`int`, *optional*, defaults to 1024): + Size of the encoder layers and the pooler layer. + d_ff (`int`, *optional*, defaults to 8192): + Size of the intermediate feed forward layer in each `SwitchTransformersBlock`. + d_ext (`int`, *optional*, defaults to 4096): + Size of the intermediate feed forward layer in each Extra-layers. + d_spout (`int`, *optional*, defaults to 128): + Size of the `spout` vector. + num_switch_layers (`int`, *optional*, defaults to 10): + Number of layers in the Switch Transformer layer. + num_ext_layers (`int`, *optional*, defaults to 0): + Number of layers in the Extra-layers. + num_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + num_experts (`int`, *optional*, defaults to 16): + Number of experts for each SwitchTransformer layer. + expert_capacity (`int`, *optional*, defaults to 128): + Number of tokens that can be stored in each expert. If set to 1, the model will behave like a regular + Transformer. + dropout_rate (`float`, *optional*, defaults to 0.0): + The ratio for all dropout layers. + layer_norm_eps (`float`, *optional*, defaults to 1e-5): + The epsilon used by the layer normalization layers. + router_bias (`bool`, *optional*, defaults to `False`): + Whether to add a bias to the router. + router_jitter_noise (`float`, *optional*, defaults to 0.0): + Amount of noise to add to the router. Set it to 0.0 during prediction or set small value (usually 1e-2) + during training. + router_dtype (`str`, *optional*, default to `"float32"`): + The `dtype` used for the routers. It is preferable to keep the `dtype` to `"float32"` as specified in the + *selective precision* discussion in [the paper](https://arxiv.org/abs/2101.03961). + router_ignore_padding_tokens (`bool`, *optional*, defaults to `False`): + Whether to ignore padding tokens when routing. + output_hidden_states (`bool`, *optional*, default to `False`): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. + initializer_factor (`float`, *optional*, defaults to 0.002): + A factor for initializing all weight matrices. + output_router_logits (`bool`, *optional*, default to `False`): + Whether or not to return the router logits of all experts. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models) + """ + + model_type = "gptsan-japanese" + keys_to_ignore_at_inference = [ + "past_key_values", + ] + attribute_map = { + "hidden_size": "d_model", + "num_attention_heads": "num_heads", + "num_hidden_layers": "num_layers", + } + + def __init__( + self, + vocab_size=36000, + max_position_embeddings=1280, + d_model=1024, + d_ff=8192, + d_ext=4096, + d_spout=128, + num_switch_layers=10, + num_ext_layers=0, + num_heads=16, + num_experts=16, + expert_capacity=128, + dropout_rate=0.0, + layer_norm_epsilon=1e-5, + router_bias=False, + router_jitter_noise=0.0, + router_dtype="float32", + router_ignore_padding_tokens=False, + output_hidden_states=False, + output_attentions=False, + initializer_factor=0.002, + output_router_logits=False, + use_cache=True, + separator_token_id=35998, + pad_token_id=35995, + eos_token_id=35999, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.d_model = d_model + self.d_ff = d_ff + self.d_ext = d_ext + self.d_spout = d_spout + self.num_switch_layers = num_switch_layers + self.num_ext_layers = num_ext_layers + self.num_layers = num_switch_layers + num_ext_layers + self.num_heads = num_heads + self.num_experts = num_experts + self.expert_capacity = expert_capacity + self.dropout_rate = dropout_rate + self.layer_norm_epsilon = layer_norm_epsilon + self.router_bias = router_bias + self.router_jitter_noise = router_jitter_noise + self.router_dtype = router_dtype + self.router_ignore_padding_tokens = router_ignore_padding_tokens + self.output_hidden_states = output_hidden_states + self.output_attentions = output_attentions + self.initializer_factor = initializer_factor + self.output_router_logits = output_router_logits + self.use_cache = use_cache + + super().__init__( + separator_token_id=separator_token_id, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + **kwargs, + ) + + +__all__ = ["GPTSanJapaneseConfig"] diff --git a/docs/transformers/build/lib/transformers/models/deprecated/gptsan_japanese/convert_gptsan_tf_checkpoint_to_pytorch.py b/docs/transformers/build/lib/transformers/models/deprecated/gptsan_japanese/convert_gptsan_tf_checkpoint_to_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..a84d000d44390fe6ae821fb1cdfba968d40a2b93 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/gptsan_japanese/convert_gptsan_tf_checkpoint_to_pytorch.py @@ -0,0 +1,181 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Convert GPTSANJapanese checkpoints from the original repository to pytorch model.""" + +import argparse +import json +import os +from collections import OrderedDict + +import numpy as np +import tensorflow as tf +import torch + + +def convert_tf_gptsan_to_pt(args): + parameter_file = os.path.join(args.tf_model_dir, "parameters.json") + params = json.loads(open(parameter_file).read()) + if not params: + raise ValueError( + f"It seems that the json file at {parameter_file} is empty. Make sure you have a correct json file." + ) + if not args.output.endswith(".pt"): + args.output = args.output + ".pt" + new_state = OrderedDict() + with tf.device("/CPU:0"): + reader = tf.train.load_checkpoint(args.tf_model_dir) + shapes = reader.get_variable_to_shape_map() + for key_name in shapes.keys(): + vnp = reader.get_tensor(key_name).astype(np.float16) + if key_name.endswith("/adam_m") or key_name.endswith("/adam_v"): + continue + if key_name.startswith("pasts/"): + if key_name.startswith("pasts/mlp"): + player = int(key_name[9]) + elif key_name.startswith("pasts/out"): + player = 8 + name = "model.sqout.%d.weight" % (player * 2) # enter to nn.Sequencial with Tanh, so 2 at a time + state = vnp.transpose([1, 0]).copy() # Mesh-Tensorflow is a diagonal matrix + new_state[name] = torch.tensor(state) + elif key_name.startswith("model/moe"): + player = int(key_name[9:].split("/")[0]) + if key_name.endswith("/switch_gating/kernel"): + name = "model.blocks.%d.feed_forward.mlp.router.classifier.weight" % player + state = vnp.transpose([1, 0]).copy() # Mesh-Tensorflow is a diagonal matrix + new_state[name] = torch.tensor(state) + elif key_name.endswith("/softmlp/kernel"): + name = "model.blocks.%d.feed_forward.soft_bypass_mlp.weight" % player + state = vnp.transpose([1, 0]).copy() # Mesh-Tensorflow is a diagonal matrix + new_state[name] = torch.tensor(state) + elif key_name.endswith("/wo/kernel") or key_name.endswith("/wi/kernel"): + nlayer = key_name[-9:-7] + for i in range(16): + name = "model.blocks.%d.feed_forward.mlp.experts.expert_%d.%s.weight" % (player, i, nlayer) + state = ( + vnp[i].transpose([1, 0]).copy() + ) # In Mesh-Tensorflow, it is one array, so it is divided + new_state[name] = torch.tensor(state) + elif key_name.startswith("model/mlp"): + player = int(key_name[9:].split("/")[0]) + if key_name.endswith("/p1/kernel"): + name = "model.blocks.%d.feed_forward.mlp.wi.weight" % player + state = vnp.transpose([1, 0]).copy() # Mesh-Tensorflow is a diagonal matrix + new_state[name] = torch.tensor(state) + elif key_name.endswith("/p1/bias"): + name = "model.blocks.%d.feed_forward.mlp.wi.bias" % player + state = vnp.copy() # same because it is one dimensional + new_state[name] = torch.tensor(state) + elif key_name.endswith("/p2/kernel"): + name = "model.blocks.%d.feed_forward.mlp.wo.weight" % player + state = vnp.transpose([1, 0]).copy() # Mesh-Tensorflow is a diagonal matrix + new_state[name] = torch.tensor(state) + elif key_name.endswith("/p2/bias"): + name = "model.blocks.%d.feed_forward.mlp.wo.bias" % player + state = vnp.copy() # same because it is one dimensional + new_state[name] = torch.tensor(state) + elif key_name.startswith("model/ln"): + player = int(key_name[8:].split("/")[0]) + if key_name.endswith("/b"): + name = "model.blocks.%d.feed_forward.norm.bias" % player + state = vnp.copy() # same because it is one dimensional + new_state[name] = torch.tensor(state) + elif key_name.endswith("/g"): + name = "model.blocks.%d.feed_forward.norm.weight" % player + state = vnp.copy() # same because it is one dimensional + new_state[name] = torch.tensor(state) + elif key_name.startswith("model/att"): + player = int(key_name[9:].split("/")[0]) + if key_name.endswith("/qkv/kernel"): + state = vnp.copy() # Compute same dimension as Mesh-tensorflow using einsum + state_q = state[:, 0, :, :] + state_k = state[:, 1, :, :] + state_v = state[:, 2, :, :] + state_q = ( + state_q.reshape([state_q.shape[0], state_q.shape[1] * state_q.shape[2]]) + .transpose([1, 0]) + .copy() + ) # Mesh-Tensorflow is a diagonal matrix + state_k = ( + state_k.reshape([state_k.shape[0], state_k.shape[1] * state_k.shape[2]]) + .transpose([1, 0]) + .copy() + ) # Mesh-Tensorflow is a diagonal matrix + state_v = ( + state_v.reshape([state_v.shape[0], state_v.shape[1] * state_v.shape[2]]) + .transpose([1, 0]) + .copy() + ) # Mesh-Tensorflow is a diagonal matrix + name = "model.blocks.%d.self_attn.self_attn.q_proj.weight" % player + new_state[name] = torch.tensor(state_q) + name = "model.blocks.%d.self_attn.self_attn.k_proj.weight" % player + new_state[name] = torch.tensor(state_k) + name = "model.blocks.%d.self_attn.self_attn.v_proj.weight" % player + new_state[name] = torch.tensor(state_v) + elif key_name.endswith("/o/kernel"): + name = "model.blocks.%d.self_attn.self_attn.out_proj.weight" % player + state = ( + vnp.reshape([vnp.shape[0] * vnp.shape[1], vnp.shape[2]]).transpose([1, 0]).copy() + ) # Mesh-Tensorflow is a diagonal matrix + new_state[name] = torch.tensor(state) + elif key_name.startswith("model/an"): + player = int(key_name[8:].split("/")[0]) + if key_name.endswith("/b"): + name = "model.blocks.%d.self_attn.norm.bias" % player + state = vnp.copy() # same because it is one dimensional + new_state[name] = torch.tensor(state) + elif key_name.endswith("/g"): + name = "model.blocks.%d.self_attn.norm.weight" % player + state = vnp.copy() # same because it is one dimensional + new_state[name] = torch.tensor(state) + elif ( + key_name.startswith("model/wte") + or key_name.startswith("model/wpe") + or key_name.startswith("model/ete") + ): + nlayer = {"wte": "embed_tokens", "wpe": "position_embeddings", "ete": "extra_position_embeddings"}[ + key_name[-3:] + ] + name = "model.%s.weight" % nlayer + state = vnp.copy() # same in embedded + new_state[name] = torch.tensor(state) + if key_name.startswith("model/wte"): + name = "lm_head.weight" + state = vnp.copy() # same in embedded + new_state[name] = torch.tensor(state) + elif key_name.startswith("model/wob"): + name = "final_logits_bias" + state = vnp.copy() # same in embedded + state = state.reshape((1, -1)) + new_state[name] = torch.tensor(state) + elif key_name == "model/dense/kernel": + name = "model.last_project.weight" + state = vnp.transpose([1, 0]).copy() # Mesh-Tensorflow is a diagonal matrix + new_state[name] = torch.tensor(state) + elif key_name == "model/dense_1/bias": + name = "model.last_project.bias" + state = vnp.copy() # same because it is one dimensional + new_state[name] = torch.tensor(state) + torch.save(new_state, args.output) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="model converter.", formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument("--tf_model_dir", metavar="PATH", type=str, required=True, help="import model") + parser.add_argument("--output", metavar="PATH", type=str, required=True, help="output model") + args = parser.parse_args() + convert_tf_gptsan_to_pt(args) diff --git a/docs/transformers/build/lib/transformers/models/deprecated/gptsan_japanese/modeling_gptsan_japanese.py b/docs/transformers/build/lib/transformers/models/deprecated/gptsan_japanese/modeling_gptsan_japanese.py new file mode 100644 index 0000000000000000000000000000000000000000..a35ea4a31199887a32e6912bfb2d3d6036635b38 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/gptsan_japanese/modeling_gptsan_japanese.py @@ -0,0 +1,1337 @@ +# coding=utf-8 +# Copyright 2023 Toshiyuki Sakamoto(tanreinama) and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch GPTSANJapanese model.""" + +import copy +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ....activations import ACT2FN +from ....modeling_outputs import MoECausalLMOutputWithPast, MoEModelOutputWithPastAndCrossAttentions +from ....modeling_utils import PreTrainedModel +from ....utils import ( + DUMMY_INPUTS, + DUMMY_MASK, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_torch_fx_proxy, + logging, +) +from .configuration_gptsan_japanese import GPTSanJapaneseConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "GPTSanJapaneseConfig" +_CHECKPOINT_FOR_DOC = "Tanrei/GPTSAN-japanese" + +#################################################### +# This dict contains ids and associated url +# for the pretrained weights provided with the models +#################################################### + + +def router_z_loss_func(router_logits: torch.Tensor) -> float: + r""" + Compute the router z-loss implemented in PyTorch. + + The router z-loss was introduced in [Designing Effective Sparse Expert Models](https://arxiv.org/abs/2202.08906). + It encourages router logits to remain small in an effort to improve stability. + + Args: + router_logits (`float`): + Input logits of shape [batch_size, sequence_length, num_experts] + + Returns: + Scalar router z-loss. + """ + num_groups, tokens_per_group, _ = router_logits.shape + log_z = torch.logsumexp(router_logits, dim=-1) + z_loss = log_z**2 + return torch.sum(z_loss) / (num_groups * tokens_per_group) + + +def load_balancing_loss_func(router_probs: torch.Tensor, expert_indices: torch.Tensor) -> float: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + router_probs (`torch.Tensor`): + Probability assigned to each expert per token. Shape: [batch_size, seqeunce_length, num_experts]. + expert_indices (`torch.Tensor`): + Indices tensor of shape [batch_size, seqeunce_length] identifying the selected expert for a given token. + + Returns: + The auxiliary loss. + """ + num_experts = router_probs.shape[-1] + + # cast the expert indices to int64, otherwise one-hot encoding will fail + if expert_indices.dtype != torch.int64: + expert_indices = expert_indices.to(torch.int64) + + if len(expert_indices.shape) == 2: + expert_indices = expert_indices.unsqueeze(2) + + expert_mask = torch.nn.functional.one_hot(expert_indices, num_experts) + + # For a given token, determine if it was routed to a given expert. + expert_mask = torch.max(expert_mask, axis=-2).values + + # cast to float32 otherwise mean will fail + expert_mask = expert_mask.to(torch.float32) + tokens_per_group_and_expert = torch.mean(expert_mask, axis=-2) + + router_prob_per_group_and_expert = torch.mean(router_probs, axis=-2) + return torch.mean(tokens_per_group_and_expert * router_prob_per_group_and_expert) * (num_experts**2) + + +class GPTSanJapaneseDenseActDense(nn.Module): + """ + FFN Layer for Switch Transformer and Extra layers + + GPTSAN can mix Switch Transformer layers and normal Transformer layers This class is used as Expert in Switch + Transformer layers and as FFN in regular Transformer layers. RELU is used in the Switch Transformer layer, and + Swish is used in the normal Transformer layer, so there is a choice of which is used in the argument. + + """ + + def __init__(self, config: GPTSanJapaneseConfig, ext_layer=False): + super().__init__() + d_inter = config.d_ext if ext_layer else config.d_ff + self.wi = nn.Linear(config.d_model, d_inter, bias=ext_layer) + self.wo = nn.Linear(d_inter, config.d_model, bias=ext_layer) + self.dropout = nn.Identity() if ext_layer else nn.Dropout(config.dropout_rate) + self.act = ACT2FN["swish" if ext_layer else "relu"] + + def forward(self, hidden_states): + r""" + Args: + hidden_states (`torch.Tensor`) : + [num_groups, tokens_per_group, hidden_dim] inputs to send to experts. + Returns: + torch.Tensor[num_groups, tokens_per_group, hidden_dim] + + """ + hidden_states = self.wi(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.wo(hidden_states) + return hidden_states + + +class GPTSanJapaneseTop1Router(nn.Module): + """ + Router using tokens choose top-1 experts assignment. + + This router uses the same mechanism as in Switch Transformer (https://arxiv.org/abs/2101.03961) and V-MoE + (https://arxiv.org/abs/2106.05974): tokens choose their top experts. Items are sorted by router_probs and then + routed to their choice of expert until the expert's expert_capacity is reached. **There is no guarantee that each + token is processed by an expert**, or that each expert receives at least one token. + + """ + + def __init__(self, config: GPTSanJapaneseConfig): + super().__init__() + self.num_experts = config.num_experts + self.expert_capacity = config.expert_capacity + self.classifier = nn.Linear(config.hidden_size, self.num_experts, bias=config.router_bias) + self.jitter_noise = config.router_jitter_noise + self.ignore_padding_tokens = config.router_ignore_padding_tokens + self.dtype = getattr(torch, config.router_dtype) + + def _compute_router_probabilities(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Computes router probabilities from input hidden states. + + Args: + hidden_states (`torch.Tensor`): + (batch_size, sequence_length, hidden_dim) from which router probabilities are computed. + Returns: + router_probabilities (`torch.Tensor`): + Tensor of shape (batch_size, sequence_length, num_experts) corresponding to the probabilities for each + token and expert. Used for routing tokens to experts. + router_logits (`torch.Tensor`): + Logits tensor of shape (batch_size, sequence_length, num_experts) corresponding to raw router logits. + This is used later for computing router z-loss. + """ + # float32 is used to ensure stability. See the discussion of "selective precision" in + # https://arxiv.org/abs/2101.03961. + # We also store the previous dtype to cast back the output to the previous dtype + self.input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(self.dtype) + + if self.training and self.jitter_noise > 0: + # Multiply the token inputs by the uniform distribution - adding some noise + hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise) + + # Shape: [num_groups, tokens_per_group, num_experts] + self._cast_classifier() + router_logits = self.classifier(hidden_states) + + # Apply Softmax and cast back to the original `dtype` + router_probabilities = nn.functional.softmax(router_logits, dim=-1, dtype=self.dtype).to(self.input_dtype) + return router_probabilities, router_logits + + def _cast_classifier(self): + r""" + `bitsandbytes` `Linear8bitLt` layers does not support manual casting Therefore we need to check if they are an + instance of the `Linear8bitLt` class by checking special attributes. + """ + if not (hasattr(self.classifier, "SCB") or hasattr(self.classifier, "CB")): + self.classifier = self.classifier.to(self.dtype) + + def forward(self, hidden_states: torch.Tensor) -> Tuple: + r""" + Generic forward function for every Router class. Each Router expects to have the same input hidden states + (`hidden_states`) corresponding to the hidden states for each token, the `expert_capacity` corresponding to the + number of tokens the Router will send to each expert, some Routers can send up to few tokens to each expert. + + Each Router works as the following: it expects the hidden states for each token, gets the `router_probs` and + `router_logits` from the `router_weights`. This will assign for each token, the raw probability to be assigned + to an expert. Then each Router class will have to define its own `_compute_routing_instructions`. + + Args: + hidden_states (`torch.Tensor`) : + [num_groups, tokens_per_group, hidden_dim] inputs to send to experts. + Returns: + Tuple[`torch.Tensor`, `torch.Tensor`, `torch.Tensor`] Tuple containing the expert index, the router probs + and the router logits. The router probabilities and logits are required to compute the loss. + """ + router_probs, router_logits = self._compute_router_probabilities(hidden_states) + + expert_index = torch.argmax(router_probs, dim=-1) + expert_index = torch.nn.functional.one_hot(expert_index, num_classes=self.num_experts) + + # Mask tokens outside expert capacity. Sum over each sequence + token_priority = torch.cumsum(expert_index, dim=-2) + # mask if the token routed to to the expert will overflow + expert_capacity_mask = token_priority <= self.expert_capacity + expert_index = expert_index * expert_capacity_mask + + router_probs = torch.max(router_probs, dim=-1).values.unsqueeze(-1) + return expert_index, router_probs, router_logits + + +class GPTSanJapaneseSparseMLP(nn.Module): + r""" + Implementation of the Switch Transformers Sparse MLP module. + """ + + def __init__(self, config: GPTSanJapaneseConfig, expert_class: nn.Module = GPTSanJapaneseDenseActDense): + super().__init__() + # Step 1: Get the correct router according to its class + self.router = GPTSanJapaneseTop1Router(config) + + # Step 2: Get the experts + self.experts = nn.ModuleDict() + for idx in range(config.num_experts): + self.experts[f"expert_{idx}"] = expert_class(config) + + def forward(self, hidden_states): + r""" + Hold on, this will be slightly tricky to understand In the correct order, a MoE layer does the following: + + 1- Gets the `router_mask` from the router. The shape of the mask is `(batch_size, sequence_length, num_expert)` + and corresponds to the argmax of the `router_probs`. The probabilities are needed in the computation of the + hidden states : they are broadcasted to the hidden states values (can be interpreted as a scaling factor). + + 2- Dispatch the tokens to its associated experts. We do a classic for loop over the experts and assign for each + expert the corresponding hidden states. + + """ + # Step 1: Get the router_mask from the router as wel as the probabilities + router_mask, router_probs, router_logits = self.router(hidden_states) + expert_index = torch.argmax(router_mask, dim=-1) + + # The routers introduced might not always map all the tokens, to a router, which means that some hidden states + # can be unchanged from one layer to another. That is why the hidden states are cloned before updating only the seleced ones. + + next_states = hidden_states.clone() + for idx, expert in enumerate(self.experts.values()): + token_indices = router_mask[:, :, idx].bool() + next_states[token_indices] = expert(hidden_states[token_indices]).to(next_states.dtype) + + hidden_states = router_probs * next_states + return hidden_states, (router_logits, expert_index) + + +class GPTSanJapaneseLayerSparseFF(nn.Module): + r""" + Switch Transformers Feed Forward layer module. This is a wrapper around the Mixture of Experts module. + + Parameters: + config : ([`GPTSanJapaneseConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. + """ + + def __init__(self, config: GPTSanJapaneseConfig): + super().__init__() + self.mlp = GPTSanJapaneseSparseMLP(config) + self.soft_bypass_mlp = nn.Linear(config.d_model, config.d_model, bias=False) + self.norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + + def forward(self, hidden_states, output_router_logits): + r""" + Args: + hidden_states (`torch.Tensor`) : + [num_groups, tokens_per_group, hidden_dim] inputs to send to experts. + output_router_logits (`bool`) : + output experts router output. + Returns: + torch.Tensor[num_groups, tokens_per_group, hidden_dim] + + """ + forwarded_states, router_tuple = self.mlp(hidden_states) + forwarded_states += torch.tanh(self.soft_bypass_mlp(hidden_states)) + output = hidden_states + self.norm(forwarded_states) + + if output_router_logits and router_tuple is not None: + return output, router_tuple + else: + return output + + +class GPTSanJapaneseLayerDenseFF(nn.Module): + r""" + Extra Transformers Feed Forward layer module. + + Parameters: + config : ([`GPTSanJapaneseConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. + """ + + def __init__(self, config: GPTSanJapaneseConfig): + super().__init__() + # Check if it is a sparse layer, if not then it is a dense layer + self.mlp = GPTSanJapaneseDenseActDense(config, ext_layer=True) + self.norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + + def forward(self, hidden_states): + r""" + Args: + hidden_states (`torch.Tensor`) : + [num_groups, tokens_per_group, hidden_dim] inputs to send to experts. + Returns: + torch.Tensor[num_groups, tokens_per_group, hidden_dim] + + """ + forwarded_states = self.mlp(hidden_states) + output = hidden_states + self.norm(forwarded_states) + return output + + +class GPTSanJapaneseAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + is_causal: bool = False, + config: Optional[GPTSanJapaneseConfig] = None, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + self.config = config + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + self.is_causal = is_causal + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.reshape(*proj_shape) + value_states = value_states.reshape(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value + + +class GPTSanJapaneseLayerSelfAttention(nn.Module): + """ + Self Attention and Normalization Unit + """ + + def __init__(self, config, has_relative_attention_bias=False): + super().__init__() + self.self_attn = GPTSanJapaneseAttention( + embed_dim=config.d_model, + num_heads=config.num_heads, + is_decoder=True, + bias=has_relative_attention_bias, + ) + self.norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: + r""" + Self-attention and normalize block. + + Args: + hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + if the model is configured as a decoder. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up + decoding. If `past_key_values` are used, the user can optionally input only the last + `decoder_input_ids` (those that don't have their past key value states given to this model) of shape + `(batch_size, 1)` instead of all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used + in the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + head_mask (`numpy.ndarray` of shape `({0})`, `optional): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + Returns: + Tuple[torch.Tensor[num_groups, tokens_per_group, hidden_dim],...] + """ + # Self Attention + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + # add present self-attn cache to positions 1,2 of present_key_value tuple + atten_out = self.self_attn( + hidden_states=hidden_states, + past_key_value=self_attn_past_key_value, + attention_mask=(1 - attention_mask) * torch.finfo(hidden_states.dtype).min, + layer_head_mask=head_mask, + output_attentions=output_attentions, + ) + if output_attentions: + attn_weights = (atten_out[1],) + else: + attn_weights = () + + attention_output = atten_out[0] + + hidden = hidden_states + self.norm(attention_output) + + if use_cache: + outputs = (hidden, atten_out[2]) # hidden, present, (attentions) + else: + outputs = (hidden,) # hidden, (attentions) + + return outputs + attn_weights + + +class GPTSanJapaneseBlock(nn.Module): + """ + Self Attention and FFN Unit + """ + + def __init__(self, config, ext_layer=False): + super().__init__() + self.self_attn = GPTSanJapaneseLayerSelfAttention(config) + self.feed_forward = GPTSanJapaneseLayerDenseFF(config) if ext_layer else GPTSanJapaneseLayerSparseFF(config) + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + output_router_tuple: Optional[bool] = False, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: + r""" + GPTSAN transformer block. + + Args: + hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + if the model is configured as a decoder. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up + decoding. If `past_key_values` are used, the user can optionally input only the last + `decoder_input_ids` (those that don't have their past key value states given to this model) of shape + `(batch_size, 1)` instead of all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used + in the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + head_mask (`numpy.ndarray` of shape `({0})`, `optional): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + output_attentions (`bool`) : + output attention probabirities. + output_router_tuple: + output experts router logits and expert id. + Returns: + Tuple[torch.Tensor[num_groups, tokens_per_group, hidden_dim],...] + """ + atten_out = self.self_attn( + hidden_states=hidden_states, + past_key_value=past_key_value, + attention_mask=attention_mask, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + attention_output = atten_out[0] + + if isinstance(self.feed_forward, GPTSanJapaneseLayerSparseFF): + sparse_out = self.feed_forward(attention_output, output_router_tuple) + if output_router_tuple: + hidden, router_tuple = sparse_out + else: + hidden = sparse_out + else: + hidden = self.feed_forward(attention_output) + + outputs = (hidden,) + atten_out[1:] + + if isinstance(self.feed_forward, GPTSanJapaneseLayerSparseFF) and output_router_tuple: + outputs += (router_tuple,) + + return outputs + + +class GPTSanJapanesePreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = GPTSanJapaneseConfig + base_model_prefix = "gptsan_japanese" + supports_gradient_checkpointing = False + _no_split_modules = ["GPTSanJapaneseBlock"] + _skip_keys_device_placement = "past_key_values" + + @property + def dummy_inputs(self): + input_ids = torch.tensor(DUMMY_INPUTS) + input_mask = torch.tensor(DUMMY_MASK) + dummy_inputs = { + "input_ids": input_ids, + "attention_mask": input_mask, + } + return dummy_inputs + + def _init_weights(self, module): + """Initialize the weights""" + factor = self.config.initializer_factor # Used for testing weights initialization + if isinstance(module, nn.LayerNorm): + module.weight.data.fill_(factor * 1.0) + module.bias.data.zero_() + elif isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module, "bias") and module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=factor * 1.0) + elif isinstance(module, GPTSanJapaneseModel): + # Mesh TensorFlow embeddings initialization + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624 + module.embed_tokens.weight.data.normal_(mean=0.0, std=factor * 1.0) + module.position_embeddings.weight.data.normal_(mean=0.0, std=factor * 1.0) + if hasattr(module, "extra_position_embeddings") and module.extra_position_embeddings is not None: + module.extra_position_embeddings.weight.data.normal_(mean=0.0, std=factor * 1.0) + elif isinstance(module, (GPTSanJapaneseModel, GPTSanJapaneseForConditionalGeneration)): + # Mesh TensorFlow embeddings initialization + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624 + module.final_logits_bias.data.normal_(mean=0.0, std=factor * 1.0) + if hasattr(module, "lm_head") and not self.config.tie_word_embeddings: + module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0) + elif isinstance(module, GPTSanJapaneseDenseActDense): + # Mesh TensorFlow FF initialization + # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56 + # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89 + module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.wi, "bias") and module.wi.bias is not None: + module.wi.bias.data.zero_() + module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + if hasattr(module.wo, "bias") and module.wo.bias is not None: + module.wo.bias.data.zero_() + elif isinstance(module, GPTSanJapaneseAttention): + # Multi-headed attention + d_model = self.config.d_model + key_value_proj_dim = self.config.d_model + n_heads = self.config.num_heads + module.k_proj.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) + module.v_proj.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) + module.q_proj.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) + module.out_proj.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) + elif isinstance(module, GPTSanJapaneseSparseMLP): + # Mesh TensorFlow attention initialization to avoid scaling before softmax + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136 + d_model = self.config.d_model + key_value_proj_dim = self.config.d_model + n_heads = self.config.num_heads + module.router.classifier.weight.data.normal_(mean=0.0, std=factor * 1) + for idx in range(self.config.num_experts): + module.experts[f"expert_{idx}"].wi.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) + module.experts[f"expert_{idx}"].wo.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) + + def _shift_right(self, input_ids): + decoder_start_token_id = self.config.decoder_start_token_id + pad_token_id = self.config.pad_token_id + + if decoder_start_token_id is None: + raise ValueError( + "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. " + "See T5 docs for more information." + ) + + # shift inputs to the right + if is_torch_fx_proxy(input_ids): + # Item assignment is not supported natively for proxies. + shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id) + shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1) + else: + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() + shifted_input_ids[..., 0] = decoder_start_token_id + + if pad_token_id is None: + raise ValueError("self.model.config.pad_token_id has to be defined.") + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +GPTSAN_JAPANESE_START_DOCSTRING = r""" + + The [GPTSAN-japanese](https://github.com/tanreinama/GPTSAN) model was proposed in General-purpose Swich transformer + based Japanese language model + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`GPTSanJapaneseConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +GPTSAN_JAPANESE_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. GPTSAN-japanese is a model that generates sentence + continuations or predicts tokens at mask positions. Special tokens required for inputs to the model are + automatically appended. + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + An input that masks the Prefix part in the Prefix-LM input. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **prefix** input, + - 0 for tokens that are **not-prefix** input. + spout (`torch.Tensor` of shape `(batch_size, config.d_spout)`): + This vector is transformed through an 8-layer FFN and can be used instead of `past_key_values`. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded + representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be + input (see `past_key_values`). This is useful if you want more control over how to convert + `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. + Router logits of the decoder model, useful to compute the auxiliary loss for Mixture of Experts models. +""" + + +@add_start_docstrings( + "The bare GPTSAN-japanese Model transformer outputting raw hidden-states without any specific head on top.", + GPTSAN_JAPANESE_START_DOCSTRING, +) +class GPTSanJapaneseModel(GPTSanJapanesePreTrainedModel): + def __init__(self, config: GPTSanJapaneseConfig): + super().__init__(config) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.d_model) + self.config = copy.deepcopy(config) + self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model) + self.last_project = nn.Linear(config.d_model, config.d_model, bias=True) + self.act = ACT2FN["swish"] + + self.blocks = torch.nn.ModuleList([]) + for _ in range(config.num_switch_layers): + self.blocks.append(GPTSanJapaneseBlock(config)) + for _ in range(config.num_ext_layers): + self.blocks.append(GPTSanJapaneseBlock(config, ext_layer=True)) + + if config.num_ext_layers > 0: + self.extra_position_embeddings = nn.Embedding(config.max_position_embeddings, config.d_model) + + if config.d_spout: + spouts = [] + for _ in range(8): + spouts.append(nn.Linear(config.d_spout, config.d_spout, bias=False)) + spouts.append(nn.Tanh()) + spouts.append(nn.Linear(config.d_spout, config.num_layers * 2 * config.d_model, bias=False)) + self.spout = nn.Sequential(*spouts) + + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, new_embeddings): + self.embed_tokens = new_embeddings + + @add_start_docstrings_to_model_forward(GPTSAN_JAPANESE_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.FloatTensor] = None, + spout: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + head_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + num_precontext: Optional[torch.LongTensor] = None, + ) -> Union[MoEModelOutputWithPastAndCrossAttentions, Tuple[torch.FloatTensor]]: + r""" + num_precontext (`torch.LongTensor` of shape `(batch_size,1)`): + length of `hybrid` input tokens in the input. Tokens up to this length refer to both front and back like + BERT, tokens after that refer only to front like GPT. see also: + https://github.com/tanreinama/GPTSAN/blob/main/report/model.md + + Returns: + `MoEModelOutputWithPastAndCrossAttentions` or `tuple` if `return_dict` returns + MoEModelOutputWithPastAndCrossAttentions insted of tuple + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + device = self.position_embeddings.weight.device + if input_ids is None: + input_ids = torch.zeros([1, 1]).int().to(device) # dummy for input_ids was None + if inputs_embeds is not None: + raise NotImplementedError( + "GPTSanJapaneseModel does not use `inputs_embeds`. Make sure to pass in `input_ids` instead." + ) + num_pasts_contexts = 0 + num_batch = input_ids.shape[0] + pasts_or_spout_value = None + if past_key_values is not None: + num_pasts_contexts = past_key_values[0][0].shape[2] + elif self.config.d_spout and spout is not None: + # `spout` is a special input vector specific to GPTSAN + # This controls the output by projecting embedded information such as the class of sentences during learning. + # It should passed instead of the first past_key_value. + # See the original GPTSAN repository for details + num_pasts_contexts += 1 + + # If there is an attention_mask, increase first one for spout + if self.config.d_spout and spout is not None and attention_mask is not None: + attention_mask_with_spout = torch.ones(num_batch, attention_mask.shape[1] + 1, device=device) + attention_mask_with_spout[:, 1:] -= 1 - attention_mask # 1st token should be spout + attention_mask = attention_mask_with_spout # update attention_mask + + if num_precontext is not None: + # `num_precontext` is the number of tokens that refer to each other in prefix-lm + # created per batch, so dimension of num_precontext should be [batch, 1] + if not ( + len(num_precontext.shape) == 2 and num_precontext.shape[1] == 1 + ): # num_precontext Should be [batch,1] + raise ValueError("num_precontext should be [batch, 1] size.") + num_precontext = torch.reshape(num_precontext, [-1]) + else: + num_precontext = torch.zeros([num_batch]).int().to(device) + + num_input_contexts = input_ids.shape[1] + num_output_contexts = num_input_contexts + num_pasts_contexts + + hidden_states = self.embed_tokens(input_ids) + + if past_key_values is not None: + pasts_or_spout_value = past_key_values + elif self.config.d_spout and spout is not None: + # Make vector from `spout` of GPTSAN to the same shape as past_key_values + pasts_or_spout_value = self.spout(spout) # projecting `spout` vector + pasts_or_spout_value = torch.reshape( + pasts_or_spout_value, + [ + num_batch, + self.config.num_layers, + 2, + self.config.num_heads, + num_pasts_contexts, + self.config.d_model // self.config.num_heads, + ], + ) + pasts_or_spout_value = torch.split(pasts_or_spout_value, [1] * self.config.num_layers, dim=1) + # make same shape as past_key_values + pasts_or_spout_value = tuple( + tuple([b.squeeze(1) for b in torch.split(a.squeeze(1), [1, 1], dim=1)]) for a in pasts_or_spout_value + ) + else: + pasts_or_spout_value = [None] * self.config.num_layers + + # Token position considering spout and pasts + token_position = torch.arange(num_input_contexts).to(device) + num_pasts_contexts + + if attention_mask is None: + attention_mask = torch.ones(num_batch, num_input_contexts, device=device) + + # positions for get position_embeddings + gather_position = ( + ( + torch.zeros((num_batch, self.config.d_model, num_input_contexts)).to(device) + + token_position.unsqueeze(0) + ) + .transpose(1, 2) + .long() + ) + # When padding with padding_side="left", zeros line up on the left side of attention_mask, so position_embeddings is shifted accordingly + gather_position -= (1 - attention_mask).argmin(dim=-1).unsqueeze(1).unsqueeze(2) + gather_position = torch.clip(gather_position, num_pasts_contexts, self.config.max_position_embeddings - 1) + + # attention_mask is applied per batch + for i in range(num_batch): + hidden_states[i] += torch.gather(self.position_embeddings.weight, dim=0, index=gather_position[i]) + + # Create a mask to be used when making the prefix Input length of Prefix-LM variable + causal_mask = ( + torch.tril(torch.ones((num_output_contexts, num_output_contexts), dtype=torch.uint8)) + .view(1, 1, num_output_contexts, num_output_contexts) + .to(device) + ) + prefix_lm_mask = causal_mask[:, :, -num_input_contexts:, :] + if token_type_ids is not None: + token_type_ids = token_type_ids.unsqueeze(1).unsqueeze(2) + prefix_lm_mask = ((prefix_lm_mask + token_type_ids) > 0).float() + # Marge prefix_lm_mask and attention_mask + extended_attention_mask = prefix_lm_mask * attention_mask.unsqueeze(1).unsqueeze(2) + + # Prepare head mask if needed + if head_mask is not None: + head_mask = self.get_head_mask( + head_mask, self.config.num_switch_layers + self.config.num_ext_layers + ) # n_layer x batch x n_heads x N x N + + # outputs + present_key_value_states = () if self.config.use_cache or use_cache else None + all_hidden_states = () if self.config.output_hidden_states or output_hidden_states else None + all_attentions = () if self.config.output_attentions or output_attentions else None + all_router_probs = () if self.config.output_router_logits or output_router_logits else None + + for layer, past in enumerate(pasts_or_spout_value): + if layer == self.config.num_switch_layers: + if self.config.num_ext_layers > 0: + # extra_position_embeddings are extra position embeddings that are only created when extending the model with code from the original GPTSAN repository. Not used in the default model. + # However, it is created when you create an additional layer and partially train only that location. + # Therefore, convert_gptsan_tf_checkpoint_to_pytorch.py is used when converting and loading models created in the original GPTSAN repository. + for i in range(num_batch): + hidden_states[i] += torch.gather( + self.extra_position_embeddings.weight, dim=0, index=gather_position[i] + ) + + output_router_tuple = ( + self.config.output_router_logits or output_router_logits + ) and layer < self.config.num_switch_layers + block_output = self.blocks[layer]( + hidden_states=hidden_states, + past_key_value=past, + attention_mask=extended_attention_mask, + head_mask=head_mask, + use_cache=self.config.use_cache or use_cache, + output_attentions=self.config.output_attentions or output_attentions, + output_router_tuple=output_router_tuple, + ) + + outpos = 0 + hidden_states = block_output[outpos] + if self.config.output_hidden_states or output_hidden_states: + all_hidden_states += (hidden_states,) + if self.config.use_cache or use_cache: + outpos += 1 + present = block_output[outpos] + present_key_value_states += (present,) + if self.config.output_attentions or output_attentions: + outpos += 1 + attention_probs = block_output[outpos] + all_attentions += (attention_probs,) + if output_router_tuple: + outpos += 1 + router_tuple = block_output[outpos] + all_router_probs.append(router_tuple[0]) + + hidden_states = self.last_project(hidden_states) + hidden_states = self.act(hidden_states) + + if self.config.output_hidden_states or output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + present_key_value_states, + all_hidden_states, + all_attentions, + all_router_probs, + ] + if v is not None + ) + + return MoEModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=present_key_value_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + router_probs=all_router_probs, + ) + + +@add_start_docstrings( + "The bare GPTSAN-japanese Model with a language modeling head.", + GPTSAN_JAPANESE_START_DOCSTRING, +) +class GPTSanJapaneseForConditionalGeneration(GPTSanJapanesePreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config: GPTSanJapaneseConfig): + super().__init__(config) + self.model = GPTSanJapaneseModel(config) + self.register_buffer("final_logits_bias", torch.zeros([1, config.vocab_size])) + self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) + if not self.config.torchscript: + self.lm_head.weight = self.model.embed_tokens.weight + + @add_start_docstrings_to_model_forward(GPTSAN_JAPANESE_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.FloatTensor] = None, + spout: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + head_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + labels: Optional[torch.LongTensor] = None, + ) -> Union[Tuple[torch.FloatTensor], MoECausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification loss. Indices should be in `[-100, 0, ..., + config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for + labels in `[0, ..., config.vocab_size]` + + Returns: + `MoECausalLMOutputWithPast` or `tuple` if `return_dict` returns MoECausalLMOutputWithPast insted of tuple + + Example: + + Text Generation with regular LM Model + ```python + >>> from transformers import AutoModel, AutoTokenizer, trainer_utils + + >>> device = "cuda" + >>> model = AutoModel.from_pretrained("Tanrei/GPTSAN-japanese").to(device) + >>> tokenizer = AutoTokenizer.from_pretrained("Tanrei/GPTSAN-japanese") + >>> x_token = tokenizer("織田信長は、", return_tensors="pt") + >>> trainer_utils.set_seed(30) + >>> input_ids = x_token.input_ids.to(device) + >>> gen_token = model.generate(input_ids, max_new_tokens=50) + >>> tokenizer.decode(gen_token[0]) + "織田信長は、政治・軍事の中枢まで掌握した政治家であり、日本史上類を見ない驚異的な軍事侵攻を続け..." + ``` + + Text Generation with Prefix-LM Model + ```python + >>> from transformers import AutoModel, AutoTokenizer, trainer_utils + + >>> device = "cuda" + >>> model = AutoModel.from_pretrained("Tanrei/GPTSAN-japanese").to(device) + >>> tokenizer = AutoTokenizer.from_pretrained("Tanrei/GPTSAN-japanese") + >>> x_token = tokenizer("", prefix_text="織田信長は、", return_tensors="pt") + >>> trainer_utils.set_seed(30) + >>> input_ids = x_token.input_ids.to(device) + >>> token_type_ids = x_token.token_type_ids.to(device) + >>> gen_token = model.generate(input_ids, token_type_ids=token_type_ids, max_new_tokens=50) + >>> tokenizer.decode(gen_token[0]) + "織田信長は、政治・外交で数々の戦果を上げるが、1568年からは、いわゆる本能寺の変で細川晴元に暗殺される..." + ``` + + Simultaneously Text Generation And Masked Language Model + ```python + >>> from transformers import AutoModel, AutoTokenizer, trainer_utils + + >>> device = "cuda" + >>> model = AutoModel.from_pretrained("Tanrei/GPTSAN-japanese").to(device) + >>> tokenizer = AutoTokenizer.from_pretrained("Tanrei/GPTSAN-japanese") + >>> masked_sentence = "武田信玄は、<|inputmask|>時代ファンならぜひ押さえ<|inputmask|>きたい名将の一人。" + >>> x_token = tokenizer("", prefix_text=masked_sentence, return_tensors="pt") + >>> trainer_utils.set_seed(30) + >>> input_ids = x_token.input_ids.to(device) + >>> token_type_ids = x_token.token_type_ids.to(device) + >>> out_lm_token = model.generate(input_ids, token_type_ids=token_type_ids, max_new_tokens=50) + >>> out_mlm_token = model(input_ids, token_type_ids=token_type_ids).logits.argmax(axis=-1) + >>> tokenizer.decode(out_mlm_token[0]) + "武田信玄は、戦国時代ファンならぜひ押さえておきたい名将の一人。" + + >>> tokenizer.decode(out_lm_token[0][input_ids.shape[1] :]) + "武田氏の三代に渡った武田家のひとり\n甲斐市に住む、日本史上最大の戦国大名。..." + ```""" + SEG_TOKEN = self.config.separator_token_id + use_cache = use_cache or self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + model_return_dict = True + num_precontext = None + if input_ids is not None: + num_batch = input_ids.shape[0] + num_precontext = torch.zeros([num_batch]).int().to(input_ids.device) + where_separators = torch.where(input_ids == SEG_TOKEN) + num_precontext[where_separators[0]] += where_separators[1] + num_precontext = num_precontext.unsqueeze(1) + + outputs = self.model( + input_ids, + attention_mask, + token_type_ids, + spout, + past_key_values, + head_mask, + use_cache, + inputs_embeds, + decoder_inputs_embeds, + output_attentions, + output_hidden_states, + model_return_dict, + output_router_logits, + num_precontext, + ) + + lm_logits = self.lm_head(outputs[0]) + if lm_logits.shape[-1] == self.final_logits_bias.shape[-1]: + lm_logits = lm_logits + self.final_logits_bias + + loss = None + z_loss = None + router_probs = None + aux_loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(lm_logits.device) + + loss_fct = nn.CrossEntropyLoss(ignore_index=-100) + + if output_router_logits: + # Compute the router loss (z_loss + auxiliary loss) for each router in the encoder and decoder + router_logits, expert_indexes = self._unpack_router_logits(outputs.router_probs) + z_loss = router_z_loss_func(router_logits) + router_probs = nn.Softmax(dim=-1)(router_logits) + aux_loss = load_balancing_loss_func(router_probs, expert_indexes) + + loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) + + if not return_dict: + return tuple( + v + for v in [ + loss, + lm_logits, + outputs.past_key_values, + outputs.hidden_states, + outputs.router_probs, + z_loss, + aux_loss, + ] + if v is not None + ) + + return MoECausalLMOutputWithPast( + loss=loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_probs, + z_loss=z_loss, + aux_loss=aux_loss, + ) + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor, + attention_mask: torch.FloatTensor, + token_type_ids: Optional[torch.FloatTensor] = None, + spout: Optional[Union[List, torch.FloatTensor]] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + **kwargs, + ): + if isinstance(spout, list): + spout = torch.tensor(spout).float() + if input_ids is not None: + spout = spout.to(input_ids.device) + if past_key_values is not None: + return { + "input_ids": input_ids[:, -1:] if input_ids is not None else None, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids[:, -1:] if token_type_ids is not None else None, + "spout": spout, + "past_key_values": past_key_values, + } + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + "spout": spout, + "past_key_values": None, + } + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return self._shift_right(labels) + + def resize_token_embeddings( + self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True + ) -> nn.Embedding: + new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing) + self._resize_final_logits_bias(new_embeddings.weight.shape[0]) + return new_embeddings + + def _resize_final_logits_bias(self, new_num_tokens: int) -> None: + old_num_tokens = self.final_logits_bias.shape[-1] + if new_num_tokens <= old_num_tokens: + new_bias = self.final_logits_bias[:, :new_num_tokens] + else: + extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device) + new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1) + self.register_buffer("final_logits_bias", new_bias) + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, new_embeddings): + self.model.set_input_embeddings(new_embeddings) + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_output_embeddings(self): + return self.lm_head + + def _unpack_router_logits(self, router_outputs): + total_router_logits = [] + total_expert_indexes = [] + for router_output in router_outputs: + if len(router_output[0].shape) > 1: + router_logits, expert_indexes = router_output + total_router_logits.append(router_logits) + total_expert_indexes.append(expert_indexes) + return torch.cat(total_router_logits, dim=1), torch.cat(total_expert_indexes, dim=1) + + +__all__ = ["GPTSanJapaneseForConditionalGeneration", "GPTSanJapaneseModel", "GPTSanJapanesePreTrainedModel"] diff --git a/docs/transformers/build/lib/transformers/models/deprecated/gptsan_japanese/tokenization_gptsan_japanese.py b/docs/transformers/build/lib/transformers/models/deprecated/gptsan_japanese/tokenization_gptsan_japanese.py new file mode 100644 index 0000000000000000000000000000000000000000..c93ea87278d767fb12b48554ffd6b7f611f7e034 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/gptsan_japanese/tokenization_gptsan_japanese.py @@ -0,0 +1,518 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for GPTSANJapanese.""" + +import collections +import json +import os +import re +import sys +from typing import List, Optional, Tuple, Union + +import numpy as np + +from ....tokenization_utils import PreTrainedTokenizer +from ....tokenization_utils_base import ( + BatchEncoding, + PreTokenizedInput, + PreTokenizedInputPair, + TextInput, + TextInputPair, + TruncationStrategy, +) +from ....utils import PaddingStrategy, logging + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "emoji_file": "emoji.json"} + + +def load_vocab_and_emoji(vocab_file, emoji_file): + """Loads a vocabulary file and emoji file into a dictionary.""" + with open(emoji_file, "r", encoding="utf-8") as f: + emoji = json.loads(f.read()) + + vocab = collections.OrderedDict() + raw_vocab = collections.OrderedDict() + ids_to_tokens = collections.OrderedDict() + with open(vocab_file, "r", encoding="utf-8") as f: + token = f.readlines() + token = [[t.rstrip("\n")] if (t == ",\n" or "," not in t) else t.rstrip("\n").split(",") for t in token] + for idx, b in enumerate(token): + ids_to_tokens[idx] = b + raw_vocab[",".join(b)] = idx + for wd in b: + vocab[wd] = idx + + return vocab, raw_vocab, ids_to_tokens, emoji + + +class GPTSanJapaneseTokenizer(PreTrainedTokenizer): + """ + This tokenizer is based on GPTNeoXJapaneseTokenizer and has the following modifications + - Decoding byte0~byte255 tokens correctly + - Added bagofword token handling + - Return token_type_ids for Prefix-LM model + The bagofword token represents a repetition of the previous token and is converted to 3 consecutive tokens when + decoding In addition, the original Japanese special Sub-Word-Encoding has been released in this repository + (https://github.com/tanreinama/Japanese-BPEEncoder_V2). The token_type_ids is a mask indicating the prefix input + position of the Prefix-LM model. To specify a prefix position, specify a prefix input for prefix_text, or specify a + sentence of the prefix part and the part after it as a text pair of batch input. + + Example: + + ```python + >>> from transformers import GPTSanJapaneseTokenizer + + >>> tokenizer = GPTSanJapaneseTokenizer.from_pretrained("Tanrei/GPTSAN-japanese") + >>> # You can confirm both 慶応 and 慶應 are encoded to 17750 + >>> tokenizer("吾輩は猫である🐯。実は慶応(慶應)大学出身")["input_ids"] + [35993, 35998, 34347, 31459, 30647, 31448, 25, 30659, 35729, 35676, 32417, 30647, 17750, 35589, 17750, 35590, 321, 1281] + + >>> # Both 慶応 and 慶應 are decoded to 慶応 + >>> tokenizer.decode(tokenizer("吾輩は猫である🐯。実は慶応(慶應)大学出身")["input_ids"]) + '吾輩は猫である🐯。実は慶応(慶応)大学出身' + ``` + + Example for Prefix-LM: + + ```python + >>> from transformers import GPTSanJapaneseTokenizer + + >>> tokenizer = GPTSanJapaneseTokenizer.from_pretrained("Tanrei/GPTSAN-japanese") + >>> tokenizer("実は慶応(慶應)大学出身", prefix_text="吾輩は猫である🐯。")["input_ids"] + [35993, 34347, 31459, 30647, 31448, 25, 30659, 35729, 35676, 35998, 32417, 30647, 17750, 35589, 17750, 35590, 321, 1281] + + >>> # Mask for Prefix-LM inputs + >>> tokenizer("実は慶応(慶應)大学出身", prefix_text="吾輩は猫である🐯。")["token_type_ids"] + [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0] + ``` + + Example for batch encode: + + ```python + >>> from transformers import GPTSanJapaneseTokenizer + + >>> tokenizer = GPTSanJapaneseTokenizer.from_pretrained("Tanrei/GPTSAN-japanese") + >>> tokenizer([["武田信玄", "は、"], ["織田信長", "の配下の、"]], padding=True)["input_ids"] + [[35993, 35998, 8640, 25948, 35993, 35998, 30647, 35675, 35999, 35999], [35993, 35998, 10382, 9868, 35993, 35998, 30646, 9459, 30646, 35675]] + + >>> # Mask for Prefix-LM inputs + >>> tokenizer([["武田信玄", "は、"], ["織田信長", "の配下の、"]], padding=True)["token_type_ids"] + [[1, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0, 0, 0]] + + >>> # Mask for padding + >>> tokenizer([["武田信玄", "は、"], ["織田信長", "の配下の、"]], padding=True)["attention_mask"] + [[1, 1, 1, 1, 1, 1, 1, 1, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]] + ``` + + Args: + vocab_file (`str`): + File containing the vocabulary. + emoji_file (`str`): + File containing the emoji. + unk_token (`str`, *optional*, defaults to `"<|nottoken|>"`): + The token used for unknown charactor + pad_token (`str`, *optional*, defaults to `"<|separator|>"`): + The token used for padding + bos_token (`str`, *optional*, defaults to `"<|startoftext|>"`): + The beginning of sequence token. + eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`): + The end of sequence token. + sep_token (`str`, *optional*, defaults to `"<|segmenter|>"`): + A special token to separate token to prefix part and general input part. + do_clean_text (`bool`, *optional*, defaults to `False`): + Whether or not to clean text for URL, EMAIL, TEL, Japanese DATE and Japanese PRICE. + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask", "token_type_ids"] + + def __init__( + self, + vocab_file, + emoji_file, + unk_token="<|nottoken|>", + pad_token="<|separator|>", + bos_token="<|startoftext|>", + eos_token="<|endoftext|>", + sep_token="<|segmenter|>", + do_clean_text=False, + **kwargs, + ): + if not os.path.isfile(vocab_file): + raise ValueError( + f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained" + " model use `tokenizer = GPTSanJapaneseTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`" + ) + if not os.path.isfile(emoji_file): + raise ValueError( + f"Can't find a emoji file at path '{emoji_file}'. To load the emoji information from a Google" + " pretrained model use `tokenizer = GPTSanJapaneseTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`" + ) + self.do_clean_text = do_clean_text + self.vocab, self.raw_vocab, self.ids_to_tokens, self.emoji = load_vocab_and_emoji(vocab_file, emoji_file) + self.subword_tokenizer = SubWordJapaneseTokenizer( + vocab=self.vocab, ids_to_tokens=self.ids_to_tokens, emoji=self.emoji + ) + + super().__init__( + unk_token=unk_token, + pad_token=pad_token, + bos_token=bos_token, + eos_token=eos_token, + sep_token=sep_token, + do_clean_text=do_clean_text, + **kwargs, + ) + + @property + def vocab_size(self): + # self.vocab contains support for character fluctuation unique to Japanese, and has a large number of vocab + return len(self.raw_vocab) + + def get_vocab(self): + return dict(self.raw_vocab, **self.added_tokens_encoder) + + def _tokenize(self, text): + return self.subword_tokenizer.tokenize(text, clean=self.do_clean_text) + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.vocab.get(token, self.vocab.get(self.unk_token)) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.subword_tokenizer.convert_id_to_token(index) + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + words = [] + byte_tokens = [] + for word in tokens: + if word[:6] == "<|byte" and word[-2:] == "|>": + byte_tokens.append(int(word[6:-2])) + else: + if len(byte_tokens) > 0: + words.append(bytearray(byte_tokens).decode("utf-8", errors="replace")) + byte_tokens = [] + if word[:7] == "<|emoji" and word[-2:] == "|>": + words.append(self.emoji["emoji_inv"][word]) + elif word == "": + words.append(" ") + elif word == "
": + words.append("\n") + elif word == "": + words.append("\t") + elif word == "": + words.append("▀") + elif word == "": + words.append("ǀ") + elif word == "": + words.append("‖") + elif word == "<|bagoftoken|>": + if len(words) > 0: + words.append(words[-1]) + words.append(words[-1]) + words.append(words[-1]) + elif word.startswith("<|") and word.endswith("|>"): + words.append("") + else: + words.append(word) + if len(byte_tokens) > 0: + words.append(bytearray(byte_tokens).decode("utf-8", errors="replace")) + text = "".join(words) + return text + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + index = 0 + if os.path.isdir(save_directory): + vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + emoji_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["emoji_file"] + ) + else: + vocab_file = ( + (filename_prefix + "-" if filename_prefix else "") + save_directory + VOCAB_FILES_NAMES["vocab_file"] + ) + emoji_file = ( + (filename_prefix + "-" if filename_prefix else "") + save_directory + VOCAB_FILES_NAMES["emoji_file"] + ) + with open(vocab_file, "w", encoding="utf-8") as writer: + for token_index, token in self.ids_to_tokens.items(): + if index != token_index: + logger.warning( + f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive." + " Please check that the vocabulary is not corrupted!" + ) + index = token_index + writer.write(",".join(token) + "\n") + index += 1 + with open(emoji_file, "w", encoding="utf-8") as writer: + json.dump(self.emoji, writer) + return vocab_file, emoji_file + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + # docstyle-ignore + """ + The tokenizer returns token_type_ids as separators between the Prefix part and the rest. + token_type_ids is 1 for the Prefix part and 0 for the rest of the token. + + Example: + ```python + >>> from transformers import GPTSanJapaneseTokenizer + + >>> tokenizer = GPTSanJapaneseTokenizer.from_pretrained("Tanrei/GPTSAN-japanese") + >>> x_token = tokenizer("アイウエ") + >>> # input_ids: | SOT | SEG | ア | イ | ウ | エ | + >>> # token_type_ids: | 1 | 0 | 0 | 0 | 0 | 0 | + + >>> x_token = tokenizer("", prefix_text="アイウエ") + >>> # input_ids: | SOT | ア | イ | ウ | エ | SEG | + >>> # token_type_ids: | 1 | 1 | 1 | 1 | 1 | 0 | + + >>> x_token = tokenizer("ウエ", prefix_text="アイ") + >>> # input_ids: | SOT | ア | イ | SEG | ウ | エ | + >>> # token_type_ids: | 1 | 1 | 1 | 0 | 0 | 0 | + ```""" + prefix_len = 0 + if self.sep_token in self.vocab: + segid = self.vocab[self.sep_token] + if segid in token_ids_0: + prefix_len = token_ids_0.index(segid) + if token_ids_1 is None: + total_len = len(token_ids_0) + else: + total_len = len(token_ids_0 + token_ids_1) + return prefix_len * [1] + (total_len - prefix_len) * [0] + + def prepare_for_tokenization(self, text, prefix_text=None, add_sep_token=None, **kwargs): + # GPTSAN inserts extra SEP tokens in Prefix-LM in addition to SOT for text generation. + # SOT at the beginning of the text, and SEP at the separator between the Prefix part and the rest. + if add_sep_token is None: + add_sep_token = self.sep_token not in text # If insert un-prefix position explicitly + prepared = self.bos_token if self.bos_token in self.vocab else "" + prepared += prefix_text if prefix_text is not None else "" + if add_sep_token: + prepared += self.sep_token if self.sep_token in self.vocab else "" + prepared += text + return (prepared, kwargs) + + def _batch_encode_plus( + self, + batch_text_or_text_pairs: Union[ + List[TextInput], List[TextInputPair], List[PreTokenizedInput], List[PreTokenizedInputPair] + ], + add_special_tokens: bool = True, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + max_length: Optional[int] = None, + stride: int = 0, + is_split_into_words: bool = False, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[str] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + # This tokenizer converts input text pairs into Prefix input and subsequent input + if isinstance(batch_text_or_text_pairs[0], tuple) or isinstance(tuple(batch_text_or_text_pairs[0]), list): + # As a single text with an explicit un-prefix position + batch_prefix_texts = [] + for pref, txt in batch_text_or_text_pairs: + batch_prefix_texts.append(pref + self.sep_token + txt) + batch_text_or_text_pairs = batch_prefix_texts + + return super()._batch_encode_plus( + batch_text_or_text_pairs, + add_special_tokens, + padding_strategy, + truncation_strategy, + max_length, + stride, + is_split_into_words, + pad_to_multiple_of, + return_tensors, + return_token_type_ids, + return_attention_mask, + return_overflowing_tokens, + return_special_tokens_mask, + return_offsets_mapping, + return_length, + verbose, + **kwargs, + ) + + +class SubWordJapaneseTokenizer: + """ + This tokenizer is based on GPTNeoXJapaneseTokenizer and has the following modifications + - Decoding byte0~byte255 tokens correctly + - Added bagofword token handling + + https://github.com/tanreinama/Japanese-BPEEncoder_V2 This tokenizer class is under MIT Lisence according to the + original repository. + + MIT License + + Copyright (c) 2020 tanreinama + + Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated + documentation files (the "Software"), to deal in the Software without restriction, including without limitation the + rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to + permit persons to whom the Software is furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all copies or substantial portions of + the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO + THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. + """ + + def __init__(self, vocab, ids_to_tokens, emoji): + self.vocab = vocab # same as swe + self.ids_to_tokens = ids_to_tokens # same as bpe + self.emoji = emoji + self.maxlen = np.max([len(w) for w in self.vocab.keys()]) + self.content_repatter1 = re.compile(r"(https?|ftp)(:\/\/[-_\.!~*\'()a-zA-Z0-9;\/?:\@&=\+$,%#]+)") + self.content_repatter2 = re.compile(r"[A-Za-z0-9\._+]*@[\-_0-9A-Za-z]+(\.[A-Za-z]+)*") + self.content_repatter3 = re.compile(r"[\(]{0,1}[0-9]{2,4}[\)\-\(]{0,1}[0-9]{2,4}[\)\-]{0,1}[0-9]{3,4}") + self.content_repatter4 = re.compile( + r"([12]\d{3}[/\-年])*(0?[1-9]|1[0-2])[/\-月]((0?[1-9]|[12][0-9]|3[01])日?)*(\d{1,2}|:|\d{1,2}時|\d{1,2}分|\(日\)|\(月\)|\(火\)|\(水\)|\(木\)|\(金\)|\(土\)|㈰|㈪|㈫|㈬|㈭|㈮|㈯)*" + ) + self.content_repatter5 = re.compile( + r"(明治|大正|昭和|平成|令和|㍾|㍽|㍼|㍻|\u32ff)\d{1,2}年(0?[1-9]|1[0-2])月(0?[1-9]|[12][0-9]|3[01])日(\d{1,2}|:|\d{1,2}時|\d{1,2}分|\(日\)|\(月\)|\(火\)|\(水\)|\(木\)|\(金\)|\(土\)|㈰|㈪|㈫|㈬|㈭|㈮|㈯)*" + ) + # The original version of this regex displays catastrophic backtracking behaviour. We avoid this using + # possessive quantifiers in Py >= 3.11. In versions below this, we avoid the vulnerability using a slightly + # different regex that should generally have the same behaviour in most non-pathological cases. + if sys.version_info >= (3, 11): + self.content_repatter6 = re.compile( + r"(?:\d,\d{3}|[\d億])*+" + r"(?:\d,\d{3}|[\d万])*+" + r"(?:\d,\d{3}|[\d千])*+" + r"(?:千円|万円|千万円|円|千ドル|万ドル|千万ドル|ドル|千ユーロ|万ユーロ|千万ユーロ|ユーロ)+" + r"(?:\(税込\)|\(税抜\)|\+tax)*" + ) + else: + self.content_repatter6 = re.compile( + r"(?:\d,\d{3}|[\d億万千])*" + r"(?:千円|万円|千万円|円|千ドル|万ドル|千万ドル|ドル|千ユーロ|万ユーロ|千万ユーロ|ユーロ)+" + r"(?:\(税込\)|\(税抜\)|\+tax)*" + ) + keisen = "─━│┃┄┅┆┇┈┉┊┋┌┍┎┏┐┑┒┓└┕┖┗┘┙┚┛├┝┞┟┠┡┢┣┤┥┦┧┨┩┪┫┬┭┮┯┰┱┲┳┴┵┶┷┸┹┺┻┼┽┾┿╀╁╂╃╄╅╆╇╈╉╊╋╌╍╎╏═║╒╓╔╕╖╗╘╙╚╛╜╝╞╟╠╡╢╣╤╥╦╧╨╩╪╫╬╭╮╯╰╱╲╳╴╵╶╷╸╹╺╻╼╽╾╿" + blocks = "▀▁▂▃▄▅▆▇█▉▊▋▌▍▎▏▐░▒▓▔▕▖▗▘▙▚▛▜▝▞▟" + self.content_trans1 = str.maketrans(dict.fromkeys(keisen + blocks, "")) + + def __len__(self): + return len(self.ids_to_tokens) + + def clean_text(self, content): + content = self.content_repatter1.sub("", content) + content = self.content_repatter2.sub("", content) + content = self.content_repatter3.sub("", content) + content = self.content_repatter4.sub("", content) + content = self.content_repatter5.sub("", content) + content = self.content_repatter6.sub("", content) + content = content.translate(self.content_trans1) + while "" in content: + content = content.replace("", "") + return content + + def tokenize(self, text, clean=False): + text = text.replace(" ", "") + text = text.replace(" ", "") + text = text.replace("\r\n", "
") + text = text.replace("\n", "
") + text = text.replace("\r", "
") + text = text.replace("\t", "") + text = text.replace("—", "ー") + text = text.replace("−", "ー") + for k, v in self.emoji["emoji"].items(): + if k in text: + text = text.replace(k, v) + if clean: + text = self.clean_text(text) + + def check_simbol(x): + e = x.encode() + if len(x) == 1 and len(e) == 2: + c = (int(e[0]) << 8) + int(e[1]) + if ( + (c >= 0xC2A1 and c <= 0xC2BF) + or (c >= 0xC780 and c <= 0xC783) + or (c >= 0xCAB9 and c <= 0xCBBF) + or (c >= 0xCC80 and c <= 0xCDA2) + ): + return True + return False + + def checku2e(x): + e = x.encode() + if len(x) == 1 and len(e) == 3: + c = (int(e[0]) << 16) + (int(e[1]) << 8) + int(e[2]) + if c >= 0xE28080 and c <= 0xE2B07F: + return True + return False + + pos = 0 + result = [] + while pos < len(text): + end = min(len(text), pos + self.maxlen + 1) if text[pos] == "<" else pos + 3 + candidates = [] # (token_id, token, pos) + for e in range(end, pos, -1): + wd = text[pos:e] + if wd in self.vocab: + if wd[0] == "<" and len(wd) > 2: + candidates = [(self.vocab[wd], wd, e)] + break + else: + candidates.append((self.vocab[wd], wd, e)) + if len(candidates) > 0: + # the smallest token_id is adopted + _, wd, e = sorted(candidates, key=lambda x: x[0])[0] + result.append(wd) + pos = e + else: + end = pos + 1 + wd = text[pos:end] + if check_simbol(wd): + result.append("") + elif checku2e(wd): + result.append("") + else: + for i in wd.encode("utf-8"): + result.append("<|byte%d|>" % i) + pos = end + return result + + def convert_id_to_token(self, index): + return self.ids_to_tokens[index][0] + + +__all__ = ["GPTSanJapaneseTokenizer"] diff --git a/docs/transformers/build/lib/transformers/models/deprecated/graphormer/__init__.py b/docs/transformers/build/lib/transformers/models/deprecated/graphormer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3a4b3eb1be2b4e69dcad1540abdd91f412de0ca2 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/graphormer/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ....utils import _LazyModule +from ....utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_graphormer import * + from .modeling_graphormer import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/build/lib/transformers/models/deprecated/graphormer/algos_graphormer.pyx b/docs/transformers/build/lib/transformers/models/deprecated/graphormer/algos_graphormer.pyx new file mode 100644 index 0000000000000000000000000000000000000000..a0fafbdee53b55efb9596036817b03be0d006992 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/graphormer/algos_graphormer.pyx @@ -0,0 +1,107 @@ +# Copyright (c) Microsoft Corporation and HuggingFace +# Licensed under the MIT License. + +import cython + +cimport numpy +from cython.parallel cimport parallel, prange + +import numpy as np + + +# Reduce this number if matrices are too big for large graphs +UNREACHABLE_NODE_DISTANCE = 510 + +def floyd_warshall(adjacency_matrix): + """ + Applies the Floyd-Warshall algorithm to the adjacency matrix, to compute the + shortest paths distance between all nodes, up to UNREACHABLE_NODE_DISTANCE. + """ + (nrows, ncols) = adjacency_matrix.shape + assert nrows == ncols + cdef unsigned int n = nrows + + adj_mat_copy = adjacency_matrix.astype(np.int32, order='C', casting='safe', copy=True) + assert adj_mat_copy.flags['C_CONTIGUOUS'] + cdef numpy.ndarray[numpy.int32_t, ndim=2, mode='c'] M = adj_mat_copy + cdef numpy.ndarray[numpy.int32_t, ndim=2, mode='c'] path = -1 * np.ones([n, n], dtype=np.int32) + + cdef unsigned int i, j, k + cdef numpy.int32_t M_ij, M_ik, cost_ikkj + cdef numpy.int32_t* M_ptr = &M[0,0] + cdef numpy.int32_t* M_i_ptr + cdef numpy.int32_t* M_k_ptr + + # set unreachable nodes distance to UNREACHABLE_NODE_DISTANCE + for i in range(n): + for j in range(n): + if i == j: + M[i][j] = 0 + elif M[i][j] == 0: + M[i][j] = UNREACHABLE_NODE_DISTANCE + + # floyed algo + for k in range(n): + M_k_ptr = M_ptr + n*k + for i in range(n): + M_i_ptr = M_ptr + n*i + M_ik = M_i_ptr[k] + for j in range(n): + cost_ikkj = M_ik + M_k_ptr[j] + M_ij = M_i_ptr[j] + if M_ij > cost_ikkj: + M_i_ptr[j] = cost_ikkj + path[i][j] = k + + # set unreachable path to UNREACHABLE_NODE_DISTANCE + for i in range(n): + for j in range(n): + if M[i][j] >= UNREACHABLE_NODE_DISTANCE: + path[i][j] = UNREACHABLE_NODE_DISTANCE + M[i][j] = UNREACHABLE_NODE_DISTANCE + + return M, path + + +def get_all_edges(path, i, j): + """ + Recursive function to compute all possible paths between two nodes from the graph adjacency matrix. + """ + cdef int k = path[i][j] + if k == -1: + return [] + else: + return get_all_edges(path, i, k) + [k] + get_all_edges(path, k, j) + + +def gen_edge_input(max_dist, path, edge_feat): + """ + Generates the full edge feature and adjacency matrix. + Shape: num_nodes * num_nodes * max_distance_between_nodes * num_edge_features + Dim 1 is the input node, dim 2 the output node of the edge, dim 3 the depth of the edge, dim 4 the feature + """ + (nrows, ncols) = path.shape + assert nrows == ncols + cdef unsigned int n = nrows + cdef unsigned int max_dist_copy = max_dist + + path_copy = path.astype(long, order='C', casting='safe', copy=True) + edge_feat_copy = edge_feat.astype(long, order='C', casting='safe', copy=True) + assert path_copy.flags['C_CONTIGUOUS'] + assert edge_feat_copy.flags['C_CONTIGUOUS'] + + cdef numpy.ndarray[numpy.int32_t, ndim=4, mode='c'] edge_fea_all = -1 * np.ones([n, n, max_dist_copy, edge_feat.shape[-1]], dtype=np.int32) + cdef unsigned int i, j, k, num_path, cur + + for i in range(n): + for j in range(n): + if i == j: + continue + if path_copy[i][j] == UNREACHABLE_NODE_DISTANCE: + continue + path = [i] + get_all_edges(path_copy, i, j) + [j] + num_path = len(path) - 1 + for k in range(num_path): + edge_fea_all[i, j, k, :] = edge_feat_copy[path[k], path[k+1], :] + + return edge_fea_all diff --git a/docs/transformers/build/lib/transformers/models/deprecated/graphormer/collating_graphormer.py b/docs/transformers/build/lib/transformers/models/deprecated/graphormer/collating_graphormer.py new file mode 100644 index 0000000000000000000000000000000000000000..1c2342913d63ffa120118574be4b1bd30af09157 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/graphormer/collating_graphormer.py @@ -0,0 +1,134 @@ +# Copyright (c) Microsoft Corporation and HuggingFace +# Licensed under the MIT License. + +from typing import Any, Dict, List, Mapping + +import numpy as np +import torch + +from ....utils import is_cython_available, requires_backends + + +if is_cython_available(): + import pyximport + + pyximport.install(setup_args={"include_dirs": np.get_include()}) + from . import algos_graphormer # noqa E402 + + +def convert_to_single_emb(x, offset: int = 512): + feature_num = x.shape[1] if len(x.shape) > 1 else 1 + feature_offset = 1 + np.arange(0, feature_num * offset, offset, dtype=np.int64) + x = x + feature_offset + return x + + +def preprocess_item(item, keep_features=True): + requires_backends(preprocess_item, ["cython"]) + + if keep_features and "edge_attr" in item.keys(): # edge_attr + edge_attr = np.asarray(item["edge_attr"], dtype=np.int64) + else: + edge_attr = np.ones((len(item["edge_index"][0]), 1), dtype=np.int64) # same embedding for all + + if keep_features and "node_feat" in item.keys(): # input_nodes + node_feature = np.asarray(item["node_feat"], dtype=np.int64) + else: + node_feature = np.ones((item["num_nodes"], 1), dtype=np.int64) # same embedding for all + + edge_index = np.asarray(item["edge_index"], dtype=np.int64) + + input_nodes = convert_to_single_emb(node_feature) + 1 + num_nodes = item["num_nodes"] + + if len(edge_attr.shape) == 1: + edge_attr = edge_attr[:, None] + attn_edge_type = np.zeros([num_nodes, num_nodes, edge_attr.shape[-1]], dtype=np.int64) + attn_edge_type[edge_index[0], edge_index[1]] = convert_to_single_emb(edge_attr) + 1 + + # node adj matrix [num_nodes, num_nodes] bool + adj = np.zeros([num_nodes, num_nodes], dtype=bool) + adj[edge_index[0], edge_index[1]] = True + + shortest_path_result, path = algos_graphormer.floyd_warshall(adj) + max_dist = np.amax(shortest_path_result) + + input_edges = algos_graphormer.gen_edge_input(max_dist, path, attn_edge_type) + attn_bias = np.zeros([num_nodes + 1, num_nodes + 1], dtype=np.single) # with graph token + + # combine + item["input_nodes"] = input_nodes + 1 # we shift all indices by one for padding + item["attn_bias"] = attn_bias + item["attn_edge_type"] = attn_edge_type + item["spatial_pos"] = shortest_path_result.astype(np.int64) + 1 # we shift all indices by one for padding + item["in_degree"] = np.sum(adj, axis=1).reshape(-1) + 1 # we shift all indices by one for padding + item["out_degree"] = item["in_degree"] # for undirected graph + item["input_edges"] = input_edges + 1 # we shift all indices by one for padding + if "labels" not in item: + item["labels"] = item["y"] + + return item + + +class GraphormerDataCollator: + def __init__(self, spatial_pos_max=20, on_the_fly_processing=False): + if not is_cython_available(): + raise ImportError("Graphormer preprocessing needs Cython (pyximport)") + + self.spatial_pos_max = spatial_pos_max + self.on_the_fly_processing = on_the_fly_processing + + def __call__(self, features: List[dict]) -> Dict[str, Any]: + if self.on_the_fly_processing: + features = [preprocess_item(i) for i in features] + + if not isinstance(features[0], Mapping): + features = [vars(f) for f in features] + batch = {} + + max_node_num = max(len(i["input_nodes"]) for i in features) + node_feat_size = len(features[0]["input_nodes"][0]) + edge_feat_size = len(features[0]["attn_edge_type"][0][0]) + max_dist = max(len(i["input_edges"][0][0]) for i in features) + edge_input_size = len(features[0]["input_edges"][0][0][0]) + batch_size = len(features) + + batch["attn_bias"] = torch.zeros(batch_size, max_node_num + 1, max_node_num + 1, dtype=torch.float) + batch["attn_edge_type"] = torch.zeros(batch_size, max_node_num, max_node_num, edge_feat_size, dtype=torch.long) + batch["spatial_pos"] = torch.zeros(batch_size, max_node_num, max_node_num, dtype=torch.long) + batch["in_degree"] = torch.zeros(batch_size, max_node_num, dtype=torch.long) + batch["input_nodes"] = torch.zeros(batch_size, max_node_num, node_feat_size, dtype=torch.long) + batch["input_edges"] = torch.zeros( + batch_size, max_node_num, max_node_num, max_dist, edge_input_size, dtype=torch.long + ) + + for ix, f in enumerate(features): + for k in ["attn_bias", "attn_edge_type", "spatial_pos", "in_degree", "input_nodes", "input_edges"]: + f[k] = torch.tensor(f[k]) + + if len(f["attn_bias"][1:, 1:][f["spatial_pos"] >= self.spatial_pos_max]) > 0: + f["attn_bias"][1:, 1:][f["spatial_pos"] >= self.spatial_pos_max] = float("-inf") + + batch["attn_bias"][ix, : f["attn_bias"].shape[0], : f["attn_bias"].shape[1]] = f["attn_bias"] + batch["attn_edge_type"][ix, : f["attn_edge_type"].shape[0], : f["attn_edge_type"].shape[1], :] = f[ + "attn_edge_type" + ] + batch["spatial_pos"][ix, : f["spatial_pos"].shape[0], : f["spatial_pos"].shape[1]] = f["spatial_pos"] + batch["in_degree"][ix, : f["in_degree"].shape[0]] = f["in_degree"] + batch["input_nodes"][ix, : f["input_nodes"].shape[0], :] = f["input_nodes"] + batch["input_edges"][ + ix, : f["input_edges"].shape[0], : f["input_edges"].shape[1], : f["input_edges"].shape[2], : + ] = f["input_edges"] + + batch["out_degree"] = batch["in_degree"] + + sample = features[0]["labels"] + if len(sample) == 1: # one task + if isinstance(sample[0], float): # regression + batch["labels"] = torch.from_numpy(np.concatenate([i["labels"] for i in features])) + else: # binary classification + batch["labels"] = torch.from_numpy(np.concatenate([i["labels"] for i in features])) + else: # multi task classification, left to float to keep the NaNs + batch["labels"] = torch.from_numpy(np.stack([i["labels"] for i in features], axis=0)) + + return batch diff --git a/docs/transformers/build/lib/transformers/models/deprecated/graphormer/configuration_graphormer.py b/docs/transformers/build/lib/transformers/models/deprecated/graphormer/configuration_graphormer.py new file mode 100644 index 0000000000000000000000000000000000000000..1ecde152e4e522ab7ca5c7d1c5d2bc5b4fdda029 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/graphormer/configuration_graphormer.py @@ -0,0 +1,220 @@ +# coding=utf-8 +# Copyright 2022 Microsoft, clefourrier and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Graphormer model configuration""" + +from typing import Optional + +from ....configuration_utils import PretrainedConfig +from ....utils import logging + + +logger = logging.get_logger(__name__) + + +class GraphormerConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`~GraphormerModel`]. It is used to instantiate an + Graphormer model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the Graphormer + [graphormer-base-pcqm4mv1](https://huggingface.co/graphormer-base-pcqm4mv1) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + num_classes (`int`, *optional*, defaults to 1): + Number of target classes or labels, set to n for binary classification of n tasks. + num_atoms (`int`, *optional*, defaults to 512*9): + Number of node types in the graphs. + num_edges (`int`, *optional*, defaults to 512*3): + Number of edges types in the graph. + num_in_degree (`int`, *optional*, defaults to 512): + Number of in degrees types in the input graphs. + num_out_degree (`int`, *optional*, defaults to 512): + Number of out degrees types in the input graphs. + num_edge_dis (`int`, *optional*, defaults to 128): + Number of edge dis in the input graphs. + multi_hop_max_dist (`int`, *optional*, defaults to 20): + Maximum distance of multi hop edges between two nodes. + spatial_pos_max (`int`, *optional*, defaults to 1024): + Maximum distance between nodes in the graph attention bias matrices, used during preprocessing and + collation. + edge_type (`str`, *optional*, defaults to multihop): + Type of edge relation chosen. + max_nodes (`int`, *optional*, defaults to 512): + Maximum number of nodes which can be parsed for the input graphs. + share_input_output_embed (`bool`, *optional*, defaults to `False`): + Shares the embedding layer between encoder and decoder - careful, True is not implemented. + num_layers (`int`, *optional*, defaults to 12): + Number of layers. + embedding_dim (`int`, *optional*, defaults to 768): + Dimension of the embedding layer in encoder. + ffn_embedding_dim (`int`, *optional*, defaults to 768): + Dimension of the "intermediate" (often named feed-forward) layer in encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads in the encoder. + self_attention (`bool`, *optional*, defaults to `True`): + Model is self attentive (False not implemented). + activation_function (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for the attention weights. + activation_dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for the activation of the linear transformer layer. + layerdrop (`float`, *optional*, defaults to 0.0): + The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) + for more details. + bias (`bool`, *optional*, defaults to `True`): + Uses bias in the attention module - unsupported at the moment. + embed_scale(`float`, *optional*, defaults to None): + Scaling factor for the node embeddings. + num_trans_layers_to_freeze (`int`, *optional*, defaults to 0): + Number of transformer layers to freeze. + encoder_normalize_before (`bool`, *optional*, defaults to `False`): + Normalize features before encoding the graph. + pre_layernorm (`bool`, *optional*, defaults to `False`): + Apply layernorm before self attention and the feed forward network. Without this, post layernorm will be + used. + apply_graphormer_init (`bool`, *optional*, defaults to `False`): + Apply a custom graphormer initialisation to the model before training. + freeze_embeddings (`bool`, *optional*, defaults to `False`): + Freeze the embedding layer, or train it along the model. + encoder_normalize_before (`bool`, *optional*, defaults to `False`): + Apply the layer norm before each encoder block. + q_noise (`float`, *optional*, defaults to 0.0): + Amount of quantization noise (see "Training with Quantization Noise for Extreme Model Compression"). (For + more detail, see fairseq's documentation on quant_noise). + qn_block_size (`int`, *optional*, defaults to 8): + Size of the blocks for subsequent quantization with iPQ (see q_noise). + kdim (`int`, *optional*, defaults to None): + Dimension of the key in the attention, if different from the other values. + vdim (`int`, *optional*, defaults to None): + Dimension of the value in the attention, if different from the other values. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + traceable (`bool`, *optional*, defaults to `False`): + Changes return value of the encoder's inner_state to stacked tensors. + + Example: + ```python + >>> from transformers import GraphormerForGraphClassification, GraphormerConfig + + >>> # Initializing a Graphormer graphormer-base-pcqm4mv2 style configuration + >>> configuration = GraphormerConfig() + + >>> # Initializing a model from the graphormer-base-pcqm4mv1 style configuration + >>> model = GraphormerForGraphClassification(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "graphormer" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + num_classes: int = 1, + num_atoms: int = 512 * 9, + num_edges: int = 512 * 3, + num_in_degree: int = 512, + num_out_degree: int = 512, + num_spatial: int = 512, + num_edge_dis: int = 128, + multi_hop_max_dist: int = 5, # sometimes is 20 + spatial_pos_max: int = 1024, + edge_type: str = "multi_hop", + max_nodes: int = 512, + share_input_output_embed: bool = False, + num_hidden_layers: int = 12, + embedding_dim: int = 768, + ffn_embedding_dim: int = 768, + num_attention_heads: int = 32, + dropout: float = 0.1, + attention_dropout: float = 0.1, + activation_dropout: float = 0.1, + layerdrop: float = 0.0, + encoder_normalize_before: bool = False, + pre_layernorm: bool = False, + apply_graphormer_init: bool = False, + activation_fn: str = "gelu", + embed_scale: Optional[float] = None, + freeze_embeddings: bool = False, + num_trans_layers_to_freeze: int = 0, + traceable: bool = False, + q_noise: float = 0.0, + qn_block_size: int = 8, + kdim: Optional[int] = None, + vdim: Optional[int] = None, + bias: bool = True, + self_attention: bool = True, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + **kwargs, + ): + self.num_classes = num_classes + self.num_atoms = num_atoms + self.num_in_degree = num_in_degree + self.num_out_degree = num_out_degree + self.num_edges = num_edges + self.num_spatial = num_spatial + self.num_edge_dis = num_edge_dis + self.edge_type = edge_type + self.multi_hop_max_dist = multi_hop_max_dist + self.spatial_pos_max = spatial_pos_max + self.max_nodes = max_nodes + self.num_hidden_layers = num_hidden_layers + self.embedding_dim = embedding_dim + self.hidden_size = embedding_dim + self.ffn_embedding_dim = ffn_embedding_dim + self.num_attention_heads = num_attention_heads + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.layerdrop = layerdrop + self.encoder_normalize_before = encoder_normalize_before + self.pre_layernorm = pre_layernorm + self.apply_graphormer_init = apply_graphormer_init + self.activation_fn = activation_fn + self.embed_scale = embed_scale + self.freeze_embeddings = freeze_embeddings + self.num_trans_layers_to_freeze = num_trans_layers_to_freeze + self.share_input_output_embed = share_input_output_embed + self.traceable = traceable + self.q_noise = q_noise + self.qn_block_size = qn_block_size + + # These parameters are here for future extensions + # atm, the model only supports self attention + self.kdim = kdim + self.vdim = vdim + self.self_attention = self_attention + self.bias = bias + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + **kwargs, + ) + + +__all__ = ["GraphormerConfig"] diff --git a/docs/transformers/build/lib/transformers/models/deprecated/graphormer/modeling_graphormer.py b/docs/transformers/build/lib/transformers/models/deprecated/graphormer/modeling_graphormer.py new file mode 100644 index 0000000000000000000000000000000000000000..1253d1365e8528da0ec4543d39d9a8f66f2255f7 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/graphormer/modeling_graphormer.py @@ -0,0 +1,911 @@ +# coding=utf-8 +# Copyright 2022 Microsoft, clefourrier The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Graphormer model.""" + +import math +from typing import Iterable, Iterator, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ....activations import ACT2FN +from ....modeling_outputs import ( + BaseModelOutputWithNoAttention, + SequenceClassifierOutput, +) +from ....modeling_utils import PreTrainedModel +from ....utils import logging +from .configuration_graphormer import GraphormerConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "graphormer-base-pcqm4mv1" +_CONFIG_FOR_DOC = "GraphormerConfig" + + +def quant_noise(module: nn.Module, p: float, block_size: int): + """ + From: + https://github.com/facebookresearch/fairseq/blob/dd0079bde7f678b0cd0715cbd0ae68d661b7226d/fairseq/modules/quant_noise.py + + Wraps modules and applies quantization noise to the weights for subsequent quantization with Iterative Product + Quantization as described in "Training with Quantization Noise for Extreme Model Compression" + + Args: + - module: nn.Module + - p: amount of Quantization Noise + - block_size: size of the blocks for subsequent quantization with iPQ + + Remarks: + - Module weights must have the right sizes wrt the block size + - Only Linear, Embedding and Conv2d modules are supported for the moment + - For more detail on how to quantize by blocks with convolutional weights, see "And the Bit Goes Down: + Revisiting the Quantization of Neural Networks" + - We implement the simplest form of noise here as stated in the paper which consists in randomly dropping + blocks + """ + + # if no quantization noise, don't register hook + if p <= 0: + return module + + # supported modules + if not isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d)): + raise NotImplementedError("Module unsupported for quant_noise.") + + # test whether module.weight has the right sizes wrt block_size + is_conv = module.weight.ndim == 4 + + # 2D matrix + if not is_conv: + if module.weight.size(1) % block_size != 0: + raise AssertionError("Input features must be a multiple of block sizes") + + # 4D matrix + else: + # 1x1 convolutions + if module.kernel_size == (1, 1): + if module.in_channels % block_size != 0: + raise AssertionError("Input channels must be a multiple of block sizes") + # regular convolutions + else: + k = module.kernel_size[0] * module.kernel_size[1] + if k % block_size != 0: + raise AssertionError("Kernel size must be a multiple of block size") + + def _forward_pre_hook(mod, input): + # no noise for evaluation + if mod.training: + if not is_conv: + # gather weight and sizes + weight = mod.weight + in_features = weight.size(1) + out_features = weight.size(0) + + # split weight matrix into blocks and randomly drop selected blocks + mask = torch.zeros(in_features // block_size * out_features, device=weight.device) + mask.bernoulli_(p) + mask = mask.repeat_interleave(block_size, -1).view(-1, in_features) + + else: + # gather weight and sizes + weight = mod.weight + in_channels = mod.in_channels + out_channels = mod.out_channels + + # split weight matrix into blocks and randomly drop selected blocks + if mod.kernel_size == (1, 1): + mask = torch.zeros( + int(in_channels // block_size * out_channels), + device=weight.device, + ) + mask.bernoulli_(p) + mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels) + else: + mask = torch.zeros(weight.size(0), weight.size(1), device=weight.device) + mask.bernoulli_(p) + mask = mask.unsqueeze(2).unsqueeze(3).repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1]) + + # scale weights and apply mask + mask = mask.to(torch.bool) # x.bool() is not currently supported in TorchScript + s = 1 / (1 - p) + mod.weight.data = s * weight.masked_fill(mask, 0) + + module.register_forward_pre_hook(_forward_pre_hook) + return module + + +class LayerDropModuleList(nn.ModuleList): + """ + From: + https://github.com/facebookresearch/fairseq/blob/dd0079bde7f678b0cd0715cbd0ae68d661b7226d/fairseq/modules/layer_drop.py + A LayerDrop implementation based on [`torch.nn.ModuleList`]. LayerDrop as described in + https://arxiv.org/abs/1909.11556. + + We refresh the choice of which layers to drop every time we iterate over the LayerDropModuleList instance. During + evaluation we always iterate over all layers. + + Usage: + + ```python + layers = LayerDropList(p=0.5, modules=[layer1, layer2, layer3]) + for layer in layers: # this might iterate over layers 1 and 3 + x = layer(x) + for layer in layers: # this might iterate over all layers + x = layer(x) + for layer in layers: # this might not iterate over any layers + x = layer(x) + ``` + + Args: + p (float): probability of dropping out each layer + modules (iterable, optional): an iterable of modules to add + """ + + def __init__(self, p: float, modules: Optional[Iterable[nn.Module]] = None): + super().__init__(modules) + self.p = p + + def __iter__(self) -> Iterator[nn.Module]: + dropout_probs = torch.empty(len(self)).uniform_() + for i, m in enumerate(super().__iter__()): + if not self.training or (dropout_probs[i] > self.p): + yield m + + +class GraphormerGraphNodeFeature(nn.Module): + """ + Compute node features for each node in the graph. + """ + + def __init__(self, config: GraphormerConfig): + super().__init__() + self.num_heads = config.num_attention_heads + self.num_atoms = config.num_atoms + + self.atom_encoder = nn.Embedding(config.num_atoms + 1, config.hidden_size, padding_idx=config.pad_token_id) + self.in_degree_encoder = nn.Embedding( + config.num_in_degree, config.hidden_size, padding_idx=config.pad_token_id + ) + self.out_degree_encoder = nn.Embedding( + config.num_out_degree, config.hidden_size, padding_idx=config.pad_token_id + ) + + self.graph_token = nn.Embedding(1, config.hidden_size) + + def forward( + self, + input_nodes: torch.LongTensor, + in_degree: torch.LongTensor, + out_degree: torch.LongTensor, + ) -> torch.Tensor: + n_graph, n_node = input_nodes.size()[:2] + + node_feature = ( # node feature + graph token + self.atom_encoder(input_nodes).sum(dim=-2) # [n_graph, n_node, n_hidden] + + self.in_degree_encoder(in_degree) + + self.out_degree_encoder(out_degree) + ) + + graph_token_feature = self.graph_token.weight.unsqueeze(0).repeat(n_graph, 1, 1) + + graph_node_feature = torch.cat([graph_token_feature, node_feature], dim=1) + + return graph_node_feature + + +class GraphormerGraphAttnBias(nn.Module): + """ + Compute attention bias for each head. + """ + + def __init__(self, config: GraphormerConfig): + super().__init__() + self.num_heads = config.num_attention_heads + self.multi_hop_max_dist = config.multi_hop_max_dist + + # We do not change edge feature embedding learning, as edge embeddings are represented as a combination of the original features + # + shortest path + self.edge_encoder = nn.Embedding(config.num_edges + 1, config.num_attention_heads, padding_idx=0) + + self.edge_type = config.edge_type + if self.edge_type == "multi_hop": + self.edge_dis_encoder = nn.Embedding( + config.num_edge_dis * config.num_attention_heads * config.num_attention_heads, + 1, + ) + + self.spatial_pos_encoder = nn.Embedding(config.num_spatial, config.num_attention_heads, padding_idx=0) + + self.graph_token_virtual_distance = nn.Embedding(1, config.num_attention_heads) + + def forward( + self, + input_nodes: torch.LongTensor, + attn_bias: torch.Tensor, + spatial_pos: torch.LongTensor, + input_edges: torch.LongTensor, + attn_edge_type: torch.LongTensor, + ) -> torch.Tensor: + n_graph, n_node = input_nodes.size()[:2] + graph_attn_bias = attn_bias.clone() + graph_attn_bias = graph_attn_bias.unsqueeze(1).repeat( + 1, self.num_heads, 1, 1 + ) # [n_graph, n_head, n_node+1, n_node+1] + + # spatial pos + # [n_graph, n_node, n_node, n_head] -> [n_graph, n_head, n_node, n_node] + spatial_pos_bias = self.spatial_pos_encoder(spatial_pos).permute(0, 3, 1, 2) + graph_attn_bias[:, :, 1:, 1:] = graph_attn_bias[:, :, 1:, 1:] + spatial_pos_bias + + # reset spatial pos here + t = self.graph_token_virtual_distance.weight.view(1, self.num_heads, 1) + graph_attn_bias[:, :, 1:, 0] = graph_attn_bias[:, :, 1:, 0] + t + graph_attn_bias[:, :, 0, :] = graph_attn_bias[:, :, 0, :] + t + + # edge feature + if self.edge_type == "multi_hop": + spatial_pos_ = spatial_pos.clone() + + spatial_pos_[spatial_pos_ == 0] = 1 # set pad to 1 + # set 1 to 1, input_nodes > 1 to input_nodes - 1 + spatial_pos_ = torch.where(spatial_pos_ > 1, spatial_pos_ - 1, spatial_pos_) + if self.multi_hop_max_dist > 0: + spatial_pos_ = spatial_pos_.clamp(0, self.multi_hop_max_dist) + input_edges = input_edges[:, :, :, : self.multi_hop_max_dist, :] + # [n_graph, n_node, n_node, max_dist, n_head] + + input_edges = self.edge_encoder(input_edges).mean(-2) + max_dist = input_edges.size(-2) + edge_input_flat = input_edges.permute(3, 0, 1, 2, 4).reshape(max_dist, -1, self.num_heads) + edge_input_flat = torch.bmm( + edge_input_flat, + self.edge_dis_encoder.weight.reshape(-1, self.num_heads, self.num_heads)[:max_dist, :, :], + ) + input_edges = edge_input_flat.reshape(max_dist, n_graph, n_node, n_node, self.num_heads).permute( + 1, 2, 3, 0, 4 + ) + input_edges = (input_edges.sum(-2) / (spatial_pos_.float().unsqueeze(-1))).permute(0, 3, 1, 2) + else: + # [n_graph, n_node, n_node, n_head] -> [n_graph, n_head, n_node, n_node] + input_edges = self.edge_encoder(attn_edge_type).mean(-2).permute(0, 3, 1, 2) + + graph_attn_bias[:, :, 1:, 1:] = graph_attn_bias[:, :, 1:, 1:] + input_edges + graph_attn_bias = graph_attn_bias + attn_bias.unsqueeze(1) # reset + + return graph_attn_bias + + +class GraphormerMultiheadAttention(nn.Module): + """Multi-headed attention. + + See "Attention Is All You Need" for more details. + """ + + def __init__(self, config: GraphormerConfig): + super().__init__() + self.embedding_dim = config.embedding_dim + self.kdim = config.kdim if config.kdim is not None else config.embedding_dim + self.vdim = config.vdim if config.vdim is not None else config.embedding_dim + self.qkv_same_dim = self.kdim == config.embedding_dim and self.vdim == config.embedding_dim + + self.num_heads = config.num_attention_heads + self.attention_dropout_module = torch.nn.Dropout(p=config.attention_dropout, inplace=False) + + self.head_dim = config.embedding_dim // config.num_attention_heads + if not (self.head_dim * config.num_attention_heads == self.embedding_dim): + raise AssertionError("The embedding_dim must be divisible by num_heads.") + self.scaling = self.head_dim**-0.5 + + self.self_attention = True # config.self_attention + if not (self.self_attention): + raise NotImplementedError("The Graphormer model only supports self attention for now.") + if self.self_attention and not self.qkv_same_dim: + raise AssertionError("Self-attention requires query, key and value to be of the same size.") + + self.k_proj = quant_noise( + nn.Linear(self.kdim, config.embedding_dim, bias=config.bias), + config.q_noise, + config.qn_block_size, + ) + self.v_proj = quant_noise( + nn.Linear(self.vdim, config.embedding_dim, bias=config.bias), + config.q_noise, + config.qn_block_size, + ) + self.q_proj = quant_noise( + nn.Linear(config.embedding_dim, config.embedding_dim, bias=config.bias), + config.q_noise, + config.qn_block_size, + ) + + self.out_proj = quant_noise( + nn.Linear(config.embedding_dim, config.embedding_dim, bias=config.bias), + config.q_noise, + config.qn_block_size, + ) + + self.onnx_trace = False + + def reset_parameters(self): + if self.qkv_same_dim: + # Empirically observed the convergence to be much better with + # the scaled initialization + nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) + else: + nn.init.xavier_uniform_(self.k_proj.weight) + nn.init.xavier_uniform_(self.v_proj.weight) + nn.init.xavier_uniform_(self.q_proj.weight) + + nn.init.xavier_uniform_(self.out_proj.weight) + if self.out_proj.bias is not None: + nn.init.constant_(self.out_proj.bias, 0.0) + + def forward( + self, + query: torch.LongTensor, + key: Optional[torch.Tensor], + value: Optional[torch.Tensor], + attn_bias: Optional[torch.Tensor], + key_padding_mask: Optional[torch.Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[torch.Tensor] = None, + before_softmax: bool = False, + need_head_weights: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Args: + key_padding_mask (Bytetorch.Tensor, optional): mask to exclude + keys that are pads, of shape `(batch, src_len)`, where padding elements are indicated by 1s. + need_weights (bool, optional): return the attention weights, + averaged over heads (default: False). + attn_mask (Bytetorch.Tensor, optional): typically used to + implement causal attention, where the mask prevents the attention from looking forward in time + (default: None). + before_softmax (bool, optional): return the raw attention + weights and values before the attention softmax. + need_head_weights (bool, optional): return the attention + weights for each head. Implies *need_weights*. Default: return the average attention weights over all + heads. + """ + if need_head_weights: + need_weights = True + + tgt_len, bsz, embedding_dim = query.size() + src_len = tgt_len + if not (embedding_dim == self.embedding_dim): + raise AssertionError( + f"The query embedding dimension {embedding_dim} is not equal to the expected embedding_dim" + f" {self.embedding_dim}." + ) + if not (list(query.size()) == [tgt_len, bsz, embedding_dim]): + raise AssertionError("Query size incorrect in Graphormer, compared to model dimensions.") + + if key is not None: + src_len, key_bsz, _ = key.size() + if not torch.jit.is_scripting(): + if (key_bsz != bsz) or (value is None) or not (src_len, bsz == value.shape[:2]): + raise AssertionError( + "The batch shape does not match the key or value shapes provided to the attention." + ) + + q = self.q_proj(query) + k = self.k_proj(query) + v = self.v_proj(query) + + q *= self.scaling + + q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) + if k is not None: + k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) + if v is not None: + v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) + + if (k is None) or not (k.size(1) == src_len): + raise AssertionError("The shape of the key generated in the attention is incorrect") + + # This is part of a workaround to get around fork/join parallelism + # not supporting Optional types. + if key_padding_mask is not None and key_padding_mask.dim() == 0: + key_padding_mask = None + + if key_padding_mask is not None: + if key_padding_mask.size(0) != bsz or key_padding_mask.size(1) != src_len: + raise AssertionError( + "The shape of the generated padding mask for the key does not match expected dimensions." + ) + attn_weights = torch.bmm(q, k.transpose(1, 2)) + attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) + + if list(attn_weights.size()) != [bsz * self.num_heads, tgt_len, src_len]: + raise AssertionError("The attention weights generated do not match the expected dimensions.") + + if attn_bias is not None: + attn_weights += attn_bias.view(bsz * self.num_heads, tgt_len, src_len) + + if attn_mask is not None: + attn_mask = attn_mask.unsqueeze(0) + attn_weights += attn_mask + + if key_padding_mask is not None: + # don't attend to padding symbols + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf") + ) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if before_softmax: + return attn_weights, v + + attn_weights_float = torch.nn.functional.softmax(attn_weights, dim=-1) + attn_weights = attn_weights_float.type_as(attn_weights) + attn_probs = self.attention_dropout_module(attn_weights) + + if v is None: + raise AssertionError("No value generated") + attn = torch.bmm(attn_probs, v) + if list(attn.size()) != [bsz * self.num_heads, tgt_len, self.head_dim]: + raise AssertionError("The attention generated do not match the expected dimensions.") + + attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embedding_dim) + attn: torch.Tensor = self.out_proj(attn) + + attn_weights = None + if need_weights: + attn_weights = attn_weights_float.contiguous().view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0) + if not need_head_weights: + # average attention weights over heads + attn_weights = attn_weights.mean(dim=0) + + return attn, attn_weights + + def apply_sparse_mask(self, attn_weights: torch.Tensor, tgt_len: int, src_len: int, bsz: int) -> torch.Tensor: + return attn_weights + + +class GraphormerGraphEncoderLayer(nn.Module): + def __init__(self, config: GraphormerConfig) -> None: + super().__init__() + + # Initialize parameters + self.embedding_dim = config.embedding_dim + self.num_attention_heads = config.num_attention_heads + self.q_noise = config.q_noise + self.qn_block_size = config.qn_block_size + self.pre_layernorm = config.pre_layernorm + + self.dropout_module = torch.nn.Dropout(p=config.dropout, inplace=False) + + self.activation_dropout_module = torch.nn.Dropout(p=config.activation_dropout, inplace=False) + + # Initialize blocks + self.activation_fn = ACT2FN[config.activation_fn] + self.self_attn = GraphormerMultiheadAttention(config) + + # layer norm associated with the self attention layer + self.self_attn_layer_norm = nn.LayerNorm(self.embedding_dim) + + self.fc1 = self.build_fc( + self.embedding_dim, + config.ffn_embedding_dim, + q_noise=config.q_noise, + qn_block_size=config.qn_block_size, + ) + self.fc2 = self.build_fc( + config.ffn_embedding_dim, + self.embedding_dim, + q_noise=config.q_noise, + qn_block_size=config.qn_block_size, + ) + + # layer norm associated with the position wise feed-forward NN + self.final_layer_norm = nn.LayerNorm(self.embedding_dim) + + def build_fc( + self, input_dim: int, output_dim: int, q_noise: float, qn_block_size: int + ) -> Union[nn.Module, nn.Linear, nn.Embedding, nn.Conv2d]: + return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size) + + def forward( + self, + input_nodes: torch.Tensor, + self_attn_bias: Optional[torch.Tensor] = None, + self_attn_mask: Optional[torch.Tensor] = None, + self_attn_padding_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + nn.LayerNorm is applied either before or after the self-attention/ffn modules similar to the original + Transformer implementation. + """ + residual = input_nodes + if self.pre_layernorm: + input_nodes = self.self_attn_layer_norm(input_nodes) + + input_nodes, attn = self.self_attn( + query=input_nodes, + key=input_nodes, + value=input_nodes, + attn_bias=self_attn_bias, + key_padding_mask=self_attn_padding_mask, + need_weights=False, + attn_mask=self_attn_mask, + ) + input_nodes = self.dropout_module(input_nodes) + input_nodes = residual + input_nodes + if not self.pre_layernorm: + input_nodes = self.self_attn_layer_norm(input_nodes) + + residual = input_nodes + if self.pre_layernorm: + input_nodes = self.final_layer_norm(input_nodes) + input_nodes = self.activation_fn(self.fc1(input_nodes)) + input_nodes = self.activation_dropout_module(input_nodes) + input_nodes = self.fc2(input_nodes) + input_nodes = self.dropout_module(input_nodes) + input_nodes = residual + input_nodes + if not self.pre_layernorm: + input_nodes = self.final_layer_norm(input_nodes) + + return input_nodes, attn + + +class GraphormerGraphEncoder(nn.Module): + def __init__(self, config: GraphormerConfig): + super().__init__() + + self.dropout_module = torch.nn.Dropout(p=config.dropout, inplace=False) + self.layerdrop = config.layerdrop + self.embedding_dim = config.embedding_dim + self.apply_graphormer_init = config.apply_graphormer_init + self.traceable = config.traceable + + self.graph_node_feature = GraphormerGraphNodeFeature(config) + self.graph_attn_bias = GraphormerGraphAttnBias(config) + + self.embed_scale = config.embed_scale + + if config.q_noise > 0: + self.quant_noise = quant_noise( + nn.Linear(self.embedding_dim, self.embedding_dim, bias=False), + config.q_noise, + config.qn_block_size, + ) + else: + self.quant_noise = None + + if config.encoder_normalize_before: + self.emb_layer_norm = nn.LayerNorm(self.embedding_dim) + else: + self.emb_layer_norm = None + + if config.pre_layernorm: + self.final_layer_norm = nn.LayerNorm(self.embedding_dim) + + if self.layerdrop > 0.0: + self.layers = LayerDropModuleList(p=self.layerdrop) + else: + self.layers = nn.ModuleList([]) + self.layers.extend([GraphormerGraphEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + + # Apply initialization of model params after building the model + if config.freeze_embeddings: + raise NotImplementedError("Freezing embeddings is not implemented yet.") + + for layer in range(config.num_trans_layers_to_freeze): + m = self.layers[layer] + if m is not None: + for p in m.parameters(): + p.requires_grad = False + + def forward( + self, + input_nodes: torch.LongTensor, + input_edges: torch.LongTensor, + attn_bias: torch.Tensor, + in_degree: torch.LongTensor, + out_degree: torch.LongTensor, + spatial_pos: torch.LongTensor, + attn_edge_type: torch.LongTensor, + perturb=None, + last_state_only: bool = False, + token_embeddings: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, + ) -> Tuple[Union[torch.Tensor, List[torch.LongTensor]], torch.Tensor]: + # compute padding mask. This is needed for multi-head attention + data_x = input_nodes + n_graph, n_node = data_x.size()[:2] + padding_mask = (data_x[:, :, 0]).eq(0) + padding_mask_cls = torch.zeros(n_graph, 1, device=padding_mask.device, dtype=padding_mask.dtype) + padding_mask = torch.cat((padding_mask_cls, padding_mask), dim=1) + + attn_bias = self.graph_attn_bias(input_nodes, attn_bias, spatial_pos, input_edges, attn_edge_type) + + if token_embeddings is not None: + input_nodes = token_embeddings + else: + input_nodes = self.graph_node_feature(input_nodes, in_degree, out_degree) + + if perturb is not None: + input_nodes[:, 1:, :] += perturb + + if self.embed_scale is not None: + input_nodes = input_nodes * self.embed_scale + + if self.quant_noise is not None: + input_nodes = self.quant_noise(input_nodes) + + if self.emb_layer_norm is not None: + input_nodes = self.emb_layer_norm(input_nodes) + + input_nodes = self.dropout_module(input_nodes) + + input_nodes = input_nodes.transpose(0, 1) + + inner_states = [] + if not last_state_only: + inner_states.append(input_nodes) + + for layer in self.layers: + input_nodes, _ = layer( + input_nodes, + self_attn_padding_mask=padding_mask, + self_attn_mask=attn_mask, + self_attn_bias=attn_bias, + ) + if not last_state_only: + inner_states.append(input_nodes) + + graph_rep = input_nodes[0, :, :] + + if last_state_only: + inner_states = [input_nodes] + + if self.traceable: + return torch.stack(inner_states), graph_rep + else: + return inner_states, graph_rep + + +class GraphormerDecoderHead(nn.Module): + def __init__(self, embedding_dim: int, num_classes: int): + super().__init__() + """num_classes should be 1 for regression, or the number of classes for classification""" + self.lm_output_learned_bias = nn.Parameter(torch.zeros(1)) + self.classifier = nn.Linear(embedding_dim, num_classes, bias=False) + self.num_classes = num_classes + + def forward(self, input_nodes: torch.Tensor, **unused) -> torch.Tensor: + input_nodes = self.classifier(input_nodes) + input_nodes = input_nodes + self.lm_output_learned_bias + return input_nodes + + +class GraphormerPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = GraphormerConfig + base_model_prefix = "graphormer" + main_input_name_nodes = "input_nodes" + main_input_name_edges = "input_edges" + + def normal_(self, data: torch.Tensor): + # with FSDP, module params will be on CUDA, so we cast them back to CPU + # so that the RNG is consistent with and without FSDP + data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device)) + + def init_graphormer_params(self, module: Union[nn.Linear, nn.Embedding, GraphormerMultiheadAttention]): + """ + Initialize the weights specific to the Graphormer Model. + """ + if isinstance(module, nn.Linear): + self.normal_(module.weight.data) + if module.bias is not None: + module.bias.data.zero_() + if isinstance(module, nn.Embedding): + self.normal_(module.weight.data) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + if isinstance(module, GraphormerMultiheadAttention): + self.normal_(module.q_proj.weight.data) + self.normal_(module.k_proj.weight.data) + self.normal_(module.v_proj.weight.data) + + def _init_weights( + self, + module: Union[ + nn.Linear, nn.Conv2d, nn.Embedding, nn.LayerNorm, GraphormerMultiheadAttention, GraphormerGraphEncoder + ], + ): + """ + Initialize the weights + """ + if isinstance(module, (nn.Linear, nn.Conv2d)): + # We might be missing part of the Linear init, dependant on the layer num + module.weight.data.normal_(mean=0.0, std=0.02) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=0.02) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, GraphormerMultiheadAttention): + module.q_proj.weight.data.normal_(mean=0.0, std=0.02) + module.k_proj.weight.data.normal_(mean=0.0, std=0.02) + module.v_proj.weight.data.normal_(mean=0.0, std=0.02) + module.reset_parameters() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, GraphormerGraphEncoder): + if module.apply_graphormer_init: + module.apply(self.init_graphormer_params) + + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +class GraphormerModel(GraphormerPreTrainedModel): + """The Graphormer model is a graph-encoder model. + + It goes from a graph to its representation. If you want to use the model for a downstream classification task, use + GraphormerForGraphClassification instead. For any other downstream task, feel free to add a new class, or combine + this model with a downstream model of your choice, following the example in GraphormerForGraphClassification. + """ + + def __init__(self, config: GraphormerConfig): + super().__init__(config) + self.max_nodes = config.max_nodes + + self.graph_encoder = GraphormerGraphEncoder(config) + + self.share_input_output_embed = config.share_input_output_embed + self.lm_output_learned_bias = None + + # Remove head is set to true during fine-tuning + self.load_softmax = not getattr(config, "remove_head", False) + + self.lm_head_transform_weight = nn.Linear(config.embedding_dim, config.embedding_dim) + self.activation_fn = ACT2FN[config.activation_fn] + self.layer_norm = nn.LayerNorm(config.embedding_dim) + + self.post_init() + + def reset_output_layer_parameters(self): + self.lm_output_learned_bias = nn.Parameter(torch.zeros(1)) + + def forward( + self, + input_nodes: torch.LongTensor, + input_edges: torch.LongTensor, + attn_bias: torch.Tensor, + in_degree: torch.LongTensor, + out_degree: torch.LongTensor, + spatial_pos: torch.LongTensor, + attn_edge_type: torch.LongTensor, + perturb: Optional[torch.FloatTensor] = None, + masked_tokens: None = None, + return_dict: Optional[bool] = None, + **unused, + ) -> Union[Tuple[torch.LongTensor], BaseModelOutputWithNoAttention]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + inner_states, graph_rep = self.graph_encoder( + input_nodes, input_edges, attn_bias, in_degree, out_degree, spatial_pos, attn_edge_type, perturb=perturb + ) + + # last inner state, then revert Batch and Graph len + input_nodes = inner_states[-1].transpose(0, 1) + + # project masked tokens only + if masked_tokens is not None: + raise NotImplementedError + + input_nodes = self.layer_norm(self.activation_fn(self.lm_head_transform_weight(input_nodes))) + + # project back to size of vocabulary + if self.share_input_output_embed and hasattr(self.graph_encoder.embed_tokens, "weight"): + input_nodes = torch.nn.functional.linear(input_nodes, self.graph_encoder.embed_tokens.weight) + + if not return_dict: + return tuple(x for x in [input_nodes, inner_states] if x is not None) + return BaseModelOutputWithNoAttention(last_hidden_state=input_nodes, hidden_states=inner_states) + + def max_nodes(self): + """Maximum output length supported by the encoder.""" + return self.max_nodes + + +class GraphormerForGraphClassification(GraphormerPreTrainedModel): + """ + This model can be used for graph-level classification or regression tasks. + + It can be trained on + - regression (by setting config.num_classes to 1); there should be one float-type label per graph + - one task classification (by setting config.num_classes to the number of classes); there should be one integer + label per graph + - binary multi-task classification (by setting config.num_classes to the number of labels); there should be a list + of integer labels for each graph. + """ + + def __init__(self, config: GraphormerConfig): + super().__init__(config) + self.encoder = GraphormerModel(config) + self.embedding_dim = config.embedding_dim + self.num_classes = config.num_classes + self.classifier = GraphormerDecoderHead(self.embedding_dim, self.num_classes) + self.is_encoder_decoder = True + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_nodes: torch.LongTensor, + input_edges: torch.LongTensor, + attn_bias: torch.Tensor, + in_degree: torch.LongTensor, + out_degree: torch.LongTensor, + spatial_pos: torch.LongTensor, + attn_edge_type: torch.LongTensor, + labels: Optional[torch.LongTensor] = None, + return_dict: Optional[bool] = None, + **unused, + ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_outputs = self.encoder( + input_nodes, + input_edges, + attn_bias, + in_degree, + out_degree, + spatial_pos, + attn_edge_type, + return_dict=True, + ) + outputs, hidden_states = encoder_outputs["last_hidden_state"], encoder_outputs["hidden_states"] + + head_outputs = self.classifier(outputs) + logits = head_outputs[:, 0, :].contiguous() + + loss = None + if labels is not None: + mask = ~torch.isnan(labels) + + if self.num_classes == 1: # regression + loss_fct = MSELoss() + loss = loss_fct(logits[mask].squeeze(), labels[mask].squeeze().float()) + elif self.num_classes > 1 and len(labels.shape) == 1: # One task classification + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits[mask].view(-1, self.num_classes), labels[mask].view(-1)) + else: # Binary multi-task classification + loss_fct = BCEWithLogitsLoss(reduction="sum") + loss = loss_fct(logits[mask], labels[mask]) + + if not return_dict: + return tuple(x for x in [loss, logits, hidden_states] if x is not None) + return SequenceClassifierOutput(loss=loss, logits=logits, hidden_states=hidden_states, attentions=None) + + +__all__ = ["GraphormerForGraphClassification", "GraphormerModel", "GraphormerPreTrainedModel"] diff --git a/docs/transformers/build/lib/transformers/models/deprecated/jukebox/__init__.py b/docs/transformers/build/lib/transformers/models/deprecated/jukebox/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..826bdbddc1f182de43a795ab0d78ad4009507c14 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/jukebox/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ....utils import _LazyModule +from ....utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_jukebox import * + from .modeling_jukebox import * + from .tokenization_jukebox import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/build/lib/transformers/models/deprecated/jukebox/configuration_jukebox.py b/docs/transformers/build/lib/transformers/models/deprecated/jukebox/configuration_jukebox.py new file mode 100644 index 0000000000000000000000000000000000000000..d10cbc2d82cfbd310cff2b84b54de4d693cea1c7 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/jukebox/configuration_jukebox.py @@ -0,0 +1,613 @@ +# coding=utf-8 +# Copyright 2022 The OpenAI Team Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Jukebox configuration""" + +import os +from typing import List, Union + +from ....configuration_utils import PretrainedConfig +from ....utils import logging + + +logger = logging.get_logger(__name__) + + +_LARGE_ATTENTION = [ + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "cross_attention", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "cross_attention", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "cross_attention", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "cross_attention", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "cross_attention", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "cross_attention", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "block_attn", + "transpose_block_attn", + "prev_block_attn", + "cross_attention", +] +_RawColumnPreviousRowAttention = ["block_attn", "transpose_block_attn", "prev_block_attn"] +_FullDenseAttention = ["dense_attention"] +_PrimePrimeDenseAttention = ["prime_attn", "prime_attn", "dense_attn"] + + +def full_dense_attention(layer): + return _FullDenseAttention[0] + + +def raw_column_previous_row_attention(layer): + return _RawColumnPreviousRowAttention[layer % 3] + + +def large_separated_enc_dec_w_lyrics(layer): + return _LARGE_ATTENTION[layer % 79] + + +def enc_dec_with_lyrics(layer): + if layer % 16 == 15: + return _PrimePrimeDenseAttention[layer % 3] + return _RawColumnPreviousRowAttention[layer % 3] + + +ATTENTION_PATTERNS = { + "full_dense_attention": full_dense_attention, + "raw_column_previous_row_attention": raw_column_previous_row_attention, # Alternate row, column and previous row attn + "large_separated_enc_dec_w_lyrics": large_separated_enc_dec_w_lyrics, # Used by large separated_enc_dec model with lyrics + "enc_dec_with_lyrics": enc_dec_with_lyrics, # Used by encoder_decoder model with lyrics +} + + +class JukeboxPriorConfig(PretrainedConfig): + """ + This is the configuration class to store the configuration of a [`JukeboxPrior`]. It is used to instantiate a + `JukeboxPrior` according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the top level prior from the + [openai/jukebox-1b-lyrics](https://huggingface.co/openai/jukebox + -1b-lyrics) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + + Args: + act_fn (`str`, *optional*, defaults to `"quick_gelu"`): + Activation function. + alignment_head (`int`, *optional*, defaults to 2): + Head that is responsible of the alignment between lyrics and music. Only used to compute the lyric to audio + alignment + alignment_layer (`int`, *optional*, defaults to 68): + Index of the layer that is responsible of the alignment between lyrics and music. Only used to compute the + lyric to audio alignment + attention_multiplier (`float`, *optional*, defaults to 0.25): + Multiplier coefficient used to define the hidden dimension of the attention layers. 0.25 means that + 0.25*width of the model will be used. + attention_pattern (`str`, *optional*, defaults to `"enc_dec_with_lyrics"`): + Which attention pattern to use for the decoder/ + attn_dropout (`int`, *optional*, defaults to 0): + Dropout probability for the post-attention layer dropout in the decoder. + attn_res_scale (`bool`, *optional*, defaults to `False`): + Whether or not to scale the residuals in the attention conditioner block. + blocks (`int`, *optional*, defaults to 64): + Number of blocks used in the `block_attn`. A sequence of length seq_len is factored as `[blocks, seq_len // + blocks]` in the `JukeboxAttention` layer. + conv_res_scale (`int`, *optional*): + Whether or not to scale the residuals in the conditioner block. Since the top level prior does not have a + conditioner, the default value is to None and should not be modified. + num_layers (`int`, *optional*, defaults to 72): + Number of layers of the transformer architecture. + emb_dropout (`int`, *optional*, defaults to 0): + Embedding dropout used in the lyric decoder. + encoder_config (`JukeboxPriorConfig`, *optional*) : + Configuration of the encoder which models the prior on the lyrics. + encoder_loss_fraction (`float`, *optional*, defaults to 0.4): + Multiplication factor used in front of the lyric encoder loss. + hidden_size (`int`, *optional*, defaults to 2048): + Hidden dimension of the attention layers. + init_scale (`float`, *optional*, defaults to 0.2): + Initialization scales for the prior modules. + is_encoder_decoder (`bool`, *optional*, defaults to `True`): + Whether or not the prior is an encoder-decoder model. In case it is not, and `nb_relevant_lyric_tokens` is + greater than 0, the `encoder` args should be specified for the lyric encoding. + mask (`bool`, *optional*, defaults to `False`): + Whether or not to mask the previous positions in the attention. + max_duration (`int`, *optional*, defaults to 600): + Maximum supported duration of the generated song in seconds. + max_nb_genres (`int`, *optional*, defaults to 1): + Maximum number of genres that can be used to condition the model. + merged_decoder (`bool`, *optional*, defaults to `True`): + Whether or not the decoder and the encoder inputs are merged. This is used for the separated + encoder-decoder architecture + metadata_conditioning (`bool`, *optional*, defaults to `True)`: + Whether or not to condition on the artist and genre metadata. + metadata_dims (`List[int]`, *optional*, defaults to `[604, 7898]`): + Number of genres and the number of artists that were used to train the embedding layers of the prior + models. + min_duration (`int`, *optional*, defaults to 0): + Minimum duration of the generated audio on which the model was trained. + mlp_multiplier (`float`, *optional*, defaults to 1.0): + Multiplier coefficient used to define the hidden dimension of the MLP layers. 0.25 means that 0.25*width of + the model will be used. + music_vocab_size (`int`, *optional*, defaults to 2048): + Number of different music tokens. Should be similar to the `JukeboxVQVAEConfig.nb_discrete_codes`. + n_ctx (`int`, *optional*, defaults to 6144): + Number of context tokens for each prior. The context tokens are the music tokens that are attended to when + generating music tokens. + n_heads (`int`, *optional*, defaults to 2): + Number of attention heads. + nb_relevant_lyric_tokens (`int`, *optional*, defaults to 384): + Number of lyric tokens that are used when sampling a single window of length `n_ctx` + res_conv_depth (`int`, *optional*, defaults to 3): + Depth of the `JukeboxDecoderConvBock` used to upsample the previously sampled audio in the + `JukeboxMusicTokenConditioner`. + res_conv_width (`int`, *optional*, defaults to 128): + Width of the `JukeboxDecoderConvBock` used to upsample the previously sampled audio in the + `JukeboxMusicTokenConditioner`. + res_convolution_multiplier (`int`, *optional*, defaults to 1): + Multiplier used to scale the `hidden_dim` of the `JukeboxResConv1DBlock`. + res_dilation_cycle (`int`, *optional*): + Dilation cycle used to define the `JukeboxMusicTokenConditioner`. Usually similar to the ones used in the + corresponding level of the VQVAE. The first prior does not use it as it is not conditioned on upper level + tokens. + res_dilation_growth_rate (`int`, *optional*, defaults to 1): + Dilation grow rate used between each convolutionnal block of the `JukeboxMusicTokenConditioner` + res_downs_t (`List[int]`, *optional*, defaults to `[3, 2, 2]`): + Downsampling rates used in the audio conditioning network + res_strides_t (`List[int]`, *optional*, defaults to `[2, 2, 2]`): + Striding used in the audio conditioning network + resid_dropout (`int`, *optional*, defaults to 0): + Residual dropout used in the attention pattern. + sampling_rate (`int`, *optional*, defaults to 44100): + Sampling rate used for training. + spread (`int`, *optional*): + Spread used in the `summary_spread_attention` pattern + timing_dims (`int`, *optional*, defaults to 64): + Dimension of the timing embedding. + zero_out (`bool`, *optional*, defaults to `False`): + Whether or not to zero out convolution weights when initializing. + """ + + model_type = "jukebox_prior" + attribute_map = { + "max_position_embeddings": "n_positions", + "num_attention_heads": "n_head", + } + + def __init__( + self, + act_fn="quick_gelu", + level=0, + alignment_head=2, + alignment_layer=68, + attention_multiplier=0.25, + attention_pattern="enc_dec_with_lyrics", + attn_dropout=0, + attn_res_scale=False, + blocks=64, + conv_res_scale=None, + num_layers=72, + emb_dropout=0, + encoder_config=None, + encoder_loss_fraction=0.4, + hidden_size=2048, + init_scale=0.2, + is_encoder_decoder=True, + lyric_vocab_size=80, + mask=False, + max_duration=600, + max_nb_genres=1, + merged_decoder=True, + metadata_conditioning=True, + metadata_dims=[604, 7898], + min_duration=0, + mlp_multiplier=1.0, + music_vocab_size=2048, + n_ctx=6144, + n_heads=2, + nb_relevant_lyric_tokens=384, + res_conv_depth=3, + res_conv_width=128, + res_convolution_multiplier=1, + res_dilation_cycle=None, + res_dilation_growth_rate=1, + res_downs_t=[3, 2, 2], + res_strides_t=[2, 2, 2], + resid_dropout=0, + sampling_rate=44100, + spread=None, + timing_dims=64, + zero_out=False, + **kwargs, + ): + self.act_fn = act_fn + self.alignment_head = alignment_head + self.alignment_layer = alignment_layer + self.attention_multiplier = attention_multiplier + self.attention_pattern = attention_pattern + self.attn_dropout = attn_dropout + self.attn_res_scale = attn_res_scale + self.blocks = blocks + self.conv_res_scale = conv_res_scale + self.num_layers = num_layers + self.emb_dropout = emb_dropout + self.music_vocab_size = music_vocab_size + if encoder_config is not None: + self.encoder_config = JukeboxPriorConfig(**encoder_config) + else: + self.encoder_config = None + self.encoder_loss_fraction = encoder_loss_fraction + self.init_scale = init_scale + self.is_encoder_decoder = is_encoder_decoder + self.lyric_vocab_size = lyric_vocab_size + self.level = level + self.mask = mask + self.max_duration = max_duration + self.max_nb_genres = max_nb_genres + self.merged_decoder = merged_decoder + self.metadata_conditioning = metadata_conditioning + self.metadata_dims = metadata_dims + self.min_duration = min_duration + self.mlp_multiplier = mlp_multiplier + self.n_ctx = n_ctx + self.n_heads = n_heads + self.nb_relevant_lyric_tokens = nb_relevant_lyric_tokens + self.res_conv_depth = res_conv_depth + self.res_conv_width = res_conv_width + self.res_convolution_multiplier = res_convolution_multiplier + self.res_dilation_cycle = res_dilation_cycle + self.res_dilation_growth_rate = res_dilation_growth_rate + self.res_downs_t = res_downs_t + self.res_strides_t = res_strides_t + self.resid_dropout = resid_dropout + self.sampling_rate = sampling_rate + self.spread = spread + self.timing_dims = timing_dims + self.hidden_size = hidden_size + self.zero_out = zero_out + + @classmethod + def from_pretrained( + cls, pretrained_model_name_or_path: Union[str, os.PathLike], level=0, **kwargs + ) -> "PretrainedConfig": + cls._set_token_in_kwargs(kwargs) + + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + + # get the prior config dict if we are loading from JukeboxConfig + if config_dict.get("model_type") == "jukebox": + config_dict = config_dict[f"prior_{level}"] + + if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + ) + + return cls.from_dict(config_dict, **kwargs) + + +class JukeboxVQVAEConfig(PretrainedConfig): + """ + This is the configuration class to store the configuration of a [`JukeboxVQVAE`]. It is used to instantiate a + `JukeboxVQVAE` according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the VQVAE from + [openai/jukebox-1b-lyrics](https://huggingface.co/openai/jukebox-1b-lyrics) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + act_fn (`str`, *optional*, defaults to `"relu"`): + Activation function of the model. + nb_discrete_codes (`int`, *optional*, defaults to 2048): + Number of codes of the VQVAE. + commit (`float`, *optional*, defaults to 0.02): + Commit loss multiplier. + conv_input_shape (`int`, *optional*, defaults to 1): + Number of audio channels. + conv_res_scale (`bool`, *optional*, defaults to `False`): + Whether or not to scale the residuals of the `JukeboxResConv1DBlock`. + embed_dim (`int`, *optional*, defaults to 64): + Embedding dimension of the codebook vectors. + hop_fraction (`List[int]`, *optional*, defaults to `[0.125, 0.5, 0.5]`): + Fraction of non-intersecting window used when continuing the sampling process. + levels (`int`, *optional*, defaults to 3): + Number of hierarchical levels that used in the VQVAE. + lmu (`float`, *optional*, defaults to 0.99): + Used in the codebook update, exponential moving average coefficient. For more detail refer to Appendix A.1 + of the original [VQVAE paper](https://arxiv.org/pdf/1711.00937v2.pdf) + multipliers (`List[int]`, *optional*, defaults to `[2, 1, 1]`): + Depth and width multipliers used for each level. Used on the `res_conv_width` and `res_conv_depth` + res_conv_depth (`int`, *optional*, defaults to 4): + Depth of the encoder and decoder block. If no `multipliers` are used, this is the same for each level. + res_conv_width (`int`, *optional*, defaults to 32): + Width of the encoder and decoder block. If no `multipliers` are used, this is the same for each level. + res_convolution_multiplier (`int`, *optional*, defaults to 1): + Scaling factor of the hidden dimension used in the `JukeboxResConv1DBlock`. + res_dilation_cycle (`int`, *optional*): + Dilation cycle value used in the `JukeboxResnet`. If an int is used, each new Conv1 block will have a depth + reduced by a power of `res_dilation_cycle`. + res_dilation_growth_rate (`int`, *optional*, defaults to 3): + Resnet dilation growth rate used in the VQVAE (dilation_growth_rate ** depth) + res_downs_t (`List[int]`, *optional*, defaults to `[3, 2, 2]`): + Downsampling rate for each level of the hierarchical VQ-VAE. + res_strides_t (`List[int]`, *optional*, defaults to `[2, 2, 2]`): + Stride used for each level of the hierarchical VQ-VAE. + sample_length (`int`, *optional*, defaults to 1058304): + Provides the max input shape of the VQVAE. Is used to compute the input shape of each level. + init_scale (`float`, *optional*, defaults to 0.2): + Initialization scale. + zero_out (`bool`, *optional*, defaults to `False`): + Whether or not to zero out convolution weights when initializing. + """ + + model_type = "jukebox_vqvae" + + def __init__( + self, + act_fn="relu", + nb_discrete_codes=2048, + commit=0.02, + conv_input_shape=1, + conv_res_scale=False, + embed_dim=64, + hop_fraction=[0.125, 0.5, 0.5], + levels=3, + lmu=0.99, + multipliers=[2, 1, 1], + res_conv_depth=4, + res_conv_width=32, + res_convolution_multiplier=1, + res_dilation_cycle=None, + res_dilation_growth_rate=3, + res_downs_t=[3, 2, 2], + res_strides_t=[2, 2, 2], + sample_length=1058304, + init_scale=0.2, + zero_out=False, + **kwargs, + ): + self.hop_fraction = hop_fraction + self.conv_input_shape = conv_input_shape + self.sample_length = sample_length + + # VQVAE parameters (all used) + self.levels = levels + self.embed_dim = embed_dim + self.nb_discrete_codes = nb_discrete_codes + self.res_conv_width = res_conv_width + self.res_conv_depth = res_conv_depth + self.res_convolution_multiplier = res_convolution_multiplier + self.res_dilation_growth_rate = res_dilation_growth_rate + self.res_dilation_cycle = res_dilation_cycle + self.multipliers = multipliers + self.res_downs_t = res_downs_t + self.res_strides_t = res_strides_t + self.lmu = lmu + self.commit = commit + self.conv_res_scale = conv_res_scale + self.act_fn = act_fn + self.init_scale = init_scale + self.zero_out = zero_out + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": + cls._set_token_in_kwargs(kwargs) + + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + + # get the text config dict if we are loading from CLIPConfig + if config_dict.get("model_type") == "jukebox": + config_dict = config_dict["vqvae_config"] + + if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + ) + + return cls.from_dict(config_dict, **kwargs) + + +class JukeboxConfig(PretrainedConfig): + """ + This is the configuration class to store the configuration of a [`JukeboxModel`]. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. Instantiating a configuration with the defaults will + yield a similar configuration to that of + [openai/jukebox-1b-lyrics](https://huggingface.co/openai/jukebox-1b-lyrics) architecture. + + + The downsampling and stride are used to determine downsampling of the input sequence. For example, downsampling = + (5,3), and strides = (2, 2) will downsample the audio by 2^5 = 32 to get the first level of codes, and 2**8 = 256 + to get the second level codes. This is mostly true for training the top level prior and the upsamplers. + + Args: + vqvae_config (`JukeboxVQVAEConfig`, *optional*): + Configuration for the `JukeboxVQVAE` model. + prior_config_list (`List[JukeboxPriorConfig]`, *optional*): + List of the configs for each of the `JukeboxPrior` of the model. The original architecture uses 3 priors. + nb_priors (`int`, *optional*, defaults to 3): + Number of prior models that will sequentially sample tokens. Each prior is conditional auto regressive + (decoder) model, apart from the top prior, which can include a lyric encoder. The available models were + trained using a top prior and 2 upsampler priors. + sampling_rate (`int`, *optional*, defaults to 44100): + Sampling rate of the raw audio. + timing_dims (`int`, *optional*, defaults to 64): + Dimensions of the JukeboxRangeEmbedding layer which is equivalent to traditional positional embedding + layer. The timing embedding layer converts the absolute and relative position in the currently sampled + audio to a tensor of length `timing_dims` that will be added to the music tokens. + min_duration (`int`, *optional*, defaults to 0): + Minimum duration of the audios to generate + max_duration (`float`, *optional*, defaults to 600.0): + Maximum duration of the audios to generate + max_nb_genres (`int`, *optional*, defaults to 5): + Maximum number of genres that can be used to condition a single sample. + metadata_conditioning (`bool`, *optional*, defaults to `True`): + Whether or not to use metadata conditioning, corresponding to the artist, the genre and the min/maximum + duration. + + Example: + + ```python + >>> from transformers import JukeboxModel, JukeboxConfig + + >>> # Initializing a Jukebox configuration + >>> configuration = JukeboxConfig() + + >>> # Initializing a model from the configuration + >>> model = JukeboxModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "jukebox" + + def __init__( + self, + vqvae_config=None, + prior_config_list=None, + nb_priors=3, + sampling_rate=44100, + timing_dims=64, + min_duration=0, + max_duration=600.0, + max_nb_genres=5, + metadata_conditioning=True, + **kwargs, + ): + if vqvae_config is None: + vqvae_config = {} + logger.info("vqvae_config is None. initializing the JukeboxVQVAE with default values.") + + self.vqvae_config = JukeboxVQVAEConfig(**vqvae_config) + if prior_config_list is not None: + self.prior_configs = [JukeboxPriorConfig(**prior_config) for prior_config in prior_config_list] + else: + self.prior_configs = [] + for prior_idx in range(nb_priors): + prior_config = kwargs.pop(f"prior_{prior_idx}", None) + if prior_config is None: + prior_config = {} + logger.info( + f"prior_{prior_idx}'s config is None. Initializing the JukeboxPriorConfig list with default" + " values." + ) + self.prior_configs.append(JukeboxPriorConfig(**prior_config)) + + self.hop_fraction = self.vqvae_config.hop_fraction + + self.nb_priors = nb_priors + + # Metadata conditioning + self.max_nb_genres = max_nb_genres + self.sampling_rate = sampling_rate + self.timing_dims = timing_dims + self.min_duration = min_duration + self.max_duration = max_duration + self.metadata_conditioning = metadata_conditioning + + super().__init__(**kwargs) + + @classmethod + def from_configs(cls, prior_configs: List[JukeboxPriorConfig], vqvae_config: JukeboxVQVAEConfig, **kwargs): + r""" + Instantiate a [`JukeboxConfig`] (or a derived class) from clip text model configuration and clip vision model + configuration. + + Returns: + [`JukeboxConfig`]: An instance of a configuration object + """ + prior_config_list = [config.to_dict() for config in prior_configs] + return cls(prior_config_list=prior_config_list, vqvae_config_dict=vqvae_config.to_dict(), **kwargs) + + def to_dict(self): + # Override the default to_dict to apply to_dict to the list of prior configs. + result = super().to_dict() + result["prior_config_list"] = [config.to_dict() for config in result.pop("prior_configs")] + return result + + +__all__ = ["JukeboxConfig", "JukeboxPriorConfig", "JukeboxVQVAEConfig"] diff --git a/docs/transformers/build/lib/transformers/models/deprecated/jukebox/convert_jukebox.py b/docs/transformers/build/lib/transformers/models/deprecated/jukebox/convert_jukebox.py new file mode 100644 index 0000000000000000000000000000000000000000..aac3b2efe733bd9f0c4eefb2d5442e15427c9347 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/jukebox/convert_jukebox.py @@ -0,0 +1,279 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert Jukebox checkpoints""" + +import argparse +import json +import os +from pathlib import Path + +import requests +import torch + +from transformers import JukeboxConfig, JukeboxModel +from transformers.utils import logging + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + +PREFIX = "https://openaipublic.azureedge.net/jukebox/models/" +MODEL_MAPPING = { + "jukebox-1b-lyrics": [ + "5b/vqvae.pth.tar", + "5b/prior_level_0.pth.tar", + "5b/prior_level_1.pth.tar", + "1b_lyrics/prior_level_2.pth.tar", + ], + "jukebox-5b-lyrics": [ + "5b/vqvae.pth.tar", + "5b/prior_level_0.pth.tar", + "5b/prior_level_1.pth.tar", + "5b_lyrics/prior_level_2.pth.tar", + ], +} + + +def replace_key(key): + if key.endswith(".model.1.bias") and len(key.split(".")) > 10: + key = key.replace(".model.1.bias", ".conv1d_1.bias") + elif key.endswith(".model.1.weight") and len(key.split(".")) > 10: + key = key.replace(".model.1.weight", ".conv1d_1.weight") + elif key.endswith(".model.3.bias") and len(key.split(".")) > 10: + key = key.replace(".model.3.bias", ".conv1d_2.bias") + elif key.endswith(".model.3.weight") and len(key.split(".")) > 10: + key = key.replace(".model.3.weight", ".conv1d_2.weight") + + if "conditioner_blocks.0." in key: + key = key.replace("conditioner_blocks.0", "conditioner_blocks") + + if "prime_prior" in key: + key = key.replace("prime_prior", "encoder") + + if ".emb." in key and "total" not in key and "absolute" not in key and "relative" not in key: + key = key.replace(".emb.", ".") + + if key.endswith("k"): # replace vqvae.X.k with vqvae.X.codebook + return key.replace(".k", ".codebook") + if "y_emb." in key: + return key.replace("y_emb.", "metadata_embedding.") + + if "x_emb.emb." in key: + key = key.replace("0.x_emb.emb", "embed_tokens") + + if "prime_state_ln" in key: + return key.replace("prime_state_ln", "encoder.final_layer_norm") + if ".ln" in key: + return key.replace(".ln", ".layer_norm") + if "_ln" in key: + return key.replace("_ln", "_layer_norm") + + if "prime_state_proj" in key: + return key.replace("prime_state_proj", "encoder.proj_in") + if "prime_x_out" in key: + return key.replace("prime_x_out", "encoder.lm_head") + if "prior.x_out" in key: + return key.replace("x_out", "fc_proj_out") + if "x_emb" in key: + return key.replace("x_emb", "embed_tokens") + + return key + + +def fix_jukebox_keys(state_dict, model_state_dict, key_prefix, mapping): + new_dict = {} + import re + + re_encoder_block_conv_in = re.compile(r"encoders.(\d*).level_blocks.(\d*).model.(\d*).(\d).(bias|weight)") + re_encoder_block_resnet = re.compile( + r"encoders.(\d*).level_blocks.(\d*).model.(\d*).(\d).model.(\d*).model.(\d*).(bias|weight)" + ) + re_encoder_block_proj_out = re.compile(r"encoders.(\d*).level_blocks.(\d*).model.(\d*).(bias|weight)") + + re_decoder_block_conv_out = re.compile(r"decoders.(\d*).level_blocks.(\d*).model.(\d*).(\d).(bias|weight)") + re_decoder_block_resnet = re.compile( + r"decoders.(\d*).level_blocks.(\d*).model.(\d*).(\d).model.(\d*).model.(\d*).(bias|weight)" + ) + re_decoder_block_proj_in = re.compile(r"decoders.(\d*).level_blocks.(\d*).model.(\d*).(bias|weight)") + + re_prior_cond_conv_out = re.compile(r"conditioner_blocks.(\d*).cond.model.(\d*).(\d).(bias|weight)") + re_prior_cond_resnet = re.compile( + r"conditioner_blocks.(\d*).cond.model.(\d*).(\d).model.(\d*).model.(\d*).(bias|weight)" + ) + re_prior_cond_proj_in = re.compile(r"conditioner_blocks.(\d*).cond.model.(\d*).(bias|weight)") + + for original_key, value in state_dict.items(): + # rename vqvae.encoder keys + if re_encoder_block_conv_in.fullmatch(original_key): + regex_match = re_encoder_block_conv_in.match(original_key) + groups = regex_match.groups() + block_index = int(groups[2]) * 2 + int(groups[3]) + re_new_key = f"encoders.{groups[0]}.level_blocks.{groups[1]}.downsample_block.{block_index}.{groups[-1]}" + key = re_encoder_block_conv_in.sub(re_new_key, original_key) + + elif re_encoder_block_resnet.fullmatch(original_key): + regex_match = re_encoder_block_resnet.match(original_key) + groups = regex_match.groups() + block_index = int(groups[2]) * 2 + int(groups[3]) + conv_index = {"1": 1, "3": 2}[groups[-2]] + prefix = f"encoders.{groups[0]}.level_blocks.{groups[1]}.downsample_block.{block_index}." + resnet_block = f"resnet_block.{groups[-3]}.conv1d_{conv_index}.{groups[-1]}" + re_new_key = prefix + resnet_block + key = re_encoder_block_resnet.sub(re_new_key, original_key) + + elif re_encoder_block_proj_out.fullmatch(original_key): + regex_match = re_encoder_block_proj_out.match(original_key) + groups = regex_match.groups() + re_new_key = f"encoders.{groups[0]}.level_blocks.{groups[1]}.proj_out.{groups[-1]}" + key = re_encoder_block_proj_out.sub(re_new_key, original_key) + + # rename vqvae.decoder keys + elif re_decoder_block_conv_out.fullmatch(original_key): + regex_match = re_decoder_block_conv_out.match(original_key) + groups = regex_match.groups() + block_index = int(groups[2]) * 2 + int(groups[3]) - 2 + re_new_key = f"decoders.{groups[0]}.level_blocks.{groups[1]}.upsample_block.{block_index}.{groups[-1]}" + key = re_decoder_block_conv_out.sub(re_new_key, original_key) + + elif re_decoder_block_resnet.fullmatch(original_key): + regex_match = re_decoder_block_resnet.match(original_key) + groups = regex_match.groups() + block_index = int(groups[2]) * 2 + int(groups[3]) - 2 + conv_index = {"1": 1, "3": 2}[groups[-2]] + prefix = f"decoders.{groups[0]}.level_blocks.{groups[1]}.upsample_block.{block_index}." + resnet_block = f"resnet_block.{groups[-3]}.conv1d_{conv_index}.{groups[-1]}" + re_new_key = prefix + resnet_block + key = re_decoder_block_resnet.sub(re_new_key, original_key) + + elif re_decoder_block_proj_in.fullmatch(original_key): + regex_match = re_decoder_block_proj_in.match(original_key) + groups = regex_match.groups() + re_new_key = f"decoders.{groups[0]}.level_blocks.{groups[1]}.proj_in.{groups[-1]}" + key = re_decoder_block_proj_in.sub(re_new_key, original_key) + + # rename prior cond.model to upsampler.upsample_block and resnet + elif re_prior_cond_conv_out.fullmatch(original_key): + regex_match = re_prior_cond_conv_out.match(original_key) + groups = regex_match.groups() + block_index = int(groups[1]) * 2 + int(groups[2]) - 2 + re_new_key = f"conditioner_blocks.upsampler.upsample_block.{block_index}.{groups[-1]}" + key = re_prior_cond_conv_out.sub(re_new_key, original_key) + + elif re_prior_cond_resnet.fullmatch(original_key): + regex_match = re_prior_cond_resnet.match(original_key) + groups = regex_match.groups() + block_index = int(groups[1]) * 2 + int(groups[2]) - 2 + conv_index = {"1": 1, "3": 2}[groups[-2]] + prefix = f"conditioner_blocks.upsampler.upsample_block.{block_index}." + resnet_block = f"resnet_block.{groups[-3]}.conv1d_{conv_index}.{groups[-1]}" + re_new_key = prefix + resnet_block + key = re_prior_cond_resnet.sub(re_new_key, original_key) + + elif re_prior_cond_proj_in.fullmatch(original_key): + regex_match = re_prior_cond_proj_in.match(original_key) + groups = regex_match.groups() + re_new_key = f"conditioner_blocks.upsampler.proj_in.{groups[-1]}" + key = re_prior_cond_proj_in.sub(re_new_key, original_key) + + # keep original key + else: + key = original_key + + key = replace_key(key) + + if f"{key_prefix}.{key}" not in model_state_dict or key is None: + print(f"failed converting {original_key} to {key}, does not match") + + # handle missmatched shape + elif value.shape != model_state_dict[f"{key_prefix}.{key}"].shape: + val = model_state_dict[f"{key_prefix}.{key}"] + print(f"{original_key}-> {key} : \nshape {val.shape} and {value.shape}, do not match") + key = original_key + + mapping[key] = original_key + new_dict[key] = value + + return new_dict + + +@torch.no_grad() +def convert_openai_checkpoint(model_name=None, pytorch_dump_folder_path=None): + """ + Copy/paste/tweak model's weights to our Jukebox structure. + """ + for file in MODEL_MAPPING[model_name]: + if not os.path.isfile(f"{pytorch_dump_folder_path}/{file.split('/')[-1]}"): + r = requests.get(f"{PREFIX}{file}", allow_redirects=True) + os.makedirs(f"{pytorch_dump_folder_path}/", exist_ok=True) + open(f"{pytorch_dump_folder_path}/{file.split('/')[-1]}", "wb").write(r.content) + + model_to_convert = MODEL_MAPPING[model_name.split("/")[-1]] + + config = JukeboxConfig.from_pretrained(model_name) + model = JukeboxModel(config) + + weight_dict = [] + mapping = {} + for i, dict_name in enumerate(model_to_convert): + old_dic = torch.load(f"{pytorch_dump_folder_path}/{dict_name.split('/')[-1]}", weights_only=True)["model"] + + new_dic = {} + for k in old_dic.keys(): + if k.endswith(".b"): + new_dic[k.replace("b", "bias")] = old_dic[k] + elif k.endswith(".w"): + new_dic[k.replace("w", "weight")] = old_dic[k] + elif "level_2" not in dict_name and "cond.model." in k: + new_dic[k.replace(".blocks.", ".model.")] = old_dic[k] + else: + new_dic[k] = old_dic[k] + + key_prefix = "vqvae" if i == 0 else f"priors.{3 - i}" + new_dic = fix_jukebox_keys(new_dic, model.state_dict(), key_prefix, mapping) + weight_dict.append(new_dic) + + vqvae_state_dict = weight_dict.pop(0) + model.vqvae.load_state_dict(vqvae_state_dict) + for i in range(len(weight_dict)): + model.priors[i].load_state_dict(weight_dict[2 - i]) + + Path(pytorch_dump_folder_path).mkdir(exist_ok=True) + with open(f"{pytorch_dump_folder_path}/mapping.json", "w") as txtfile: + json.dump(mapping, txtfile) + + print(f"Saving model {model_name} to {pytorch_dump_folder_path}") + model.save_pretrained(pytorch_dump_folder_path) + + return weight_dict + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--model_name", + default="jukebox-5b-lyrics", + type=str, + help="Name of the model you'd like to convert.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", + default="jukebox-5b-lyrics-converted", + type=str, + help="Path to the output PyTorch model directory.", + ) + args = parser.parse_args() + convert_openai_checkpoint(args.model_name, args.pytorch_dump_folder_path) diff --git a/docs/transformers/build/lib/transformers/models/deprecated/jukebox/modeling_jukebox.py b/docs/transformers/build/lib/transformers/models/deprecated/jukebox/modeling_jukebox.py new file mode 100644 index 0000000000000000000000000000000000000000..566148ceda36b3581c653c2c29ec7265d2582135 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/jukebox/modeling_jukebox.py @@ -0,0 +1,2670 @@ +# coding=utf-8 +# Copyright 2022 The OpenAI Team Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Jukebox model.""" + +import math +import os +from typing import List, Optional, Tuple + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn import LayerNorm as FusedLayerNorm + +from ....activations import ACT2FN +from ....modeling_utils import PreTrainedModel +from ....utils import add_start_docstrings, logging +from ....utils.logging import tqdm +from .configuration_jukebox import ATTENTION_PATTERNS, JukeboxConfig, JukeboxPriorConfig, JukeboxVQVAEConfig + + +logger = logging.get_logger(__name__) + + +def filter_logits(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")): + """ + Filter a distribution of logits using top-k and/or nucleus (top-p) filtering + + Args: + logits (`torch.Tensor`): + logits distribution shape (vocabulary size) + top_k (`int`, *optional*, defaults to 0): + When `top_k >0` keep only top key tokens with highest probability (top-k filtering). + top_p (`int`, *optional*, defaults to 0): + When `top_p>0.0` keep the top tokens with cumulative probability >= `top_p` (nucleus filtering). + """ + logits = logits.clone() + top_k = min(top_k, logits.size(-1)) # Safety check + + if top_k > 0: + # Remove all tokens with a probability less than the last token of the top-k + indices_to_remove = logits < torch.topk(logits, top_k, dim=-1)[0][..., -1:] + logits[indices_to_remove] = filter_value + + if top_p > 0.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + + # Remove tokens with cumulative probability above the threshold + sorted_indices_to_remove = cumulative_probs > top_p + # Shift the indices to the right to keep also the first token above the threshold + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + # indices_to_remove = sorted_indices[sorted_indices_to_remove] + indices_to_remove = torch.zeros_like(logits, dtype=torch.bool).scatter_( + dim=-1, index=sorted_indices, src=sorted_indices_to_remove + ) + logits[indices_to_remove] = filter_value + return logits + + +def get_relevant_lyric_tokens(full_tokens, max_n_lyric_tokens, total_length, offset, duration): + """ + Extract only the relevant tokens based on the character position. A total of `max_n_lyric_tokens` tokens will be + returned. If the provided token sequence is smaller, it will be padded, otherwise, only characters ranging from the + midpoint - `max_n_lyric_tokens//2` to the midpoint + `max_n_lyric_tokens//2` will be returned. This *focuses* on + the most relevant tokens (in time) for the sequence. + + Args: + full_tokens (`List[int]`): + List containing the token ids of the entire lyrics. + total_length (`int`): + Total expected length of the music (not all of it is generated, see duration), in samples. + offset (`int`): + Starting sample in the music. If the offset is greater than 0, the lyrics will be shifted take that into + account + duration (`int`): + Expected duration of the generated music, in samples. The duration has to be smaller than the total length, + which represent the overall length of the signal, + """ + full_tokens = full_tokens[0] + if len(full_tokens) < max_n_lyric_tokens: + tokens = torch.cat( + [torch.zeros(max_n_lyric_tokens - len(full_tokens), dtype=torch.long).to(full_tokens.device), full_tokens] + ) + indices = [-1] * (max_n_lyric_tokens - len(full_tokens)) + list(range(0, len(full_tokens))) + else: + midpoint = int(len(full_tokens) * (offset + duration / 2.0) / total_length) + midpoint = min(max(midpoint, max_n_lyric_tokens // 2), len(full_tokens) - max_n_lyric_tokens // 2) + tokens = full_tokens[midpoint - max_n_lyric_tokens // 2 : midpoint + max_n_lyric_tokens // 2] + indices = list(range(midpoint - max_n_lyric_tokens // 2, midpoint + max_n_lyric_tokens // 2)) + return tokens.unsqueeze(dim=0), indices + + +# Break total_length into hops/windows of size n_ctx separated by hop_length +def get_starts(total_length, n_ctx, hop_length): + starts = [] + for start in range(0, total_length - n_ctx + hop_length, hop_length): + if start + n_ctx >= total_length: + # Last hop could be smaller, we make it n_ctx to maximise context + start = total_length - n_ctx + starts.append(start) + return starts + + +def get_alignment(music_tokens, labels, prior, config): + level = prior.levels - 1 # Top level used + n_ctx = prior.n_ctx + tokens = music_tokens[level] + batch_size, total_length = tokens.shape[0], tokens.shape[1] + if total_length < n_ctx: + padding_length = n_ctx - total_length + tokens = torch.cat( + [tokens, torch.zeros(batch_size, n_ctx - total_length, dtype=tokens.dtype, device=tokens.device)], dim=1 + ) + total_length = tokens.shape[1] + else: + padding_length = 0 + + hop_length = int(config.hop_fraction[-level - 1] * prior.n_ctx) + alignment_head, alignment_layer = config.prior_alignment_head[0], config.prior_alignment_layer[0] + attn_layers = {alignment_layer} + alignment_hops = {} + indices_hops = {} + for start in tqdm(get_starts(total_length, n_ctx, hop_length), desc="Computing lyric to music alignment "): + end = start + n_ctx + # set metadata offset, sample_length and lyrics tokens + metadata, indices_hop = prior.get_metadata(labels, start, config.sample_length, get_indices=True, offset=0) + tokens_bs = torch.chunk(tokens, batch_size, dim=0) + metadata_bs = torch.chunk(metadata, batch_size, dim=0) + w_hops = [] + for tokens_i, metadata_i in zip(tokens_bs, metadata_bs): + w_hop = prior.forward_tokens(tokens_i[:, start:end], [], metadata_i, get_attn_weights=attn_layers) + w_hops.append(w_hop[0][:, alignment_head]) + del w_hop + weights = torch.cat(w_hops, dim=0) + del w_hops + alignment_hop = weights.to(device="cpu", dtype=torch.float).numpy() + del weights + + # alignment_hop has shape (bs, n_ctx, nb_relevant_lyric_tokens) + # indices_hop is a list of len=bs, each entry of len hps.nb_relevant_lyric_tokens + indices_hops[start] = indices_hop + alignment_hops[start] = alignment_hop + + # Combine attn for each hop into attn for full range + # Use indices to place them into correct place for corresponding source tokens + alignments = [] + for item in range(batch_size): + # Note each item has different length lyrics + full_tokens = labels[0, 3:] + alignment = np.zeros((total_length, len(full_tokens) + 1)) + for start in reversed(get_starts(total_length, n_ctx, hop_length)): + end = start + n_ctx + alignment_hop = alignment_hops[start][item] + indices = indices_hops[start][item] + alignment[start:end, indices] = alignment_hop + alignment = alignment[: total_length - padding_length, :-1] # remove token padding, and last lyric index + alignments.append(alignment) + return alignments + + +def save_temp_audio(fname, lvl, metas, aud): + aud = torch.clamp(aud, -1, 1).cpu().numpy() + for i in list(range(aud.shape[0])): + if metas is not None: + artists, genres, lyrics = list(metas)[i].values() + path = f"{fname}/lvl_{lvl}-{artists}-{genres}-{lyrics[:5]}-{i}" + np.save(path, aud[i]) + else: + np.save(f"{fname}/lvl_{lvl}-sample-{i}", aud[i]) + + +def get_mask(mask, query_length, key_value_length, blocks, spread, device, sample, sample_t): + # returns a mask of shape 1 x 1 x query_length x key_value_length or None if masking is not needed. + if mask is None or query_length == 1: + return None + offset = sample_t - query_length if sample else max(key_value_length - query_length, 0) + if mask == "autoregressive": + # Masked dense + mask = torch.ones(query_length, key_value_length, device=device).tril(offset) + elif mask == "summary": + # Masked summary + mask = torch.ones(query_length, query_length, device=device).tril() + mask = torch.ones(query_length, query_length, device=device).tril() + mask = mask.view(query_length, blocks, query_length // blocks)[:, :-1, -key_value_length // blocks :] + mask = ( + torch.nn.functional.pad( + mask, + (0, 0, 1, 0), + value=1, + ) + .contiguous() + .view(query_length, key_value_length) + ) + elif mask == "prime": + mask = torch.ones(query_length, key_value_length, device=device).tril(offset) + return mask.view(1, 1, query_length, key_value_length) + + +class JukeboxConv1D(nn.Module): + def __init__(self, input_width, output_width): + super().__init__() + self.input_width = input_width + self.output_width = output_width + weight = torch.empty(input_width, output_width) + bias = torch.zeros(output_width) + self.weight = nn.Parameter(weight) + self.bias = nn.Parameter(bias) + + def forward(self, hidden_states): + size_out = (*hidden_states.size()[:-1], self.output_width) + hidden_states = torch.addmm( + self.bias.type_as(hidden_states), + hidden_states.view(-1, hidden_states.size(-1)), + self.weight.type_as(hidden_states), + ) + hidden_states = hidden_states.view(*size_out) + return hidden_states + + +class JukeboxResConv1DBlock(nn.Module): + def __init__(self, config, conv_width, depth=1, res_scale=1.0): + super().__init__() + hidden_dim = config.res_convolution_multiplier * conv_width + dilation = config.res_dilation_growth_rate**depth + padding = dilation + + self.res_scale = res_scale + self.activation = nn.ReLU() + self.conv1d_1 = nn.Conv1d(conv_width, hidden_dim, 3, 1, padding, dilation) + self.conv1d_2 = nn.Conv1d(hidden_dim, conv_width, 1, 1, 0) + + def forward(self, hidden_states): + residuals = hidden_states + hidden_states = self.activation(hidden_states) + hidden_states = self.conv1d_1(hidden_states) + hidden_states = self.activation(hidden_states) + hidden_states = self.conv1d_2(hidden_states) + return residuals + self.res_scale * hidden_states + + +class JukeboxResnet1D(nn.Module): + def __init__(self, config, conv_width, n_depth, reverse_dilation=False): + super().__init__() + self.dilation_cycle = config.res_dilation_cycle + res_scale = 1.0 if not config.conv_res_scale else 1.0 / math.sqrt(n_depth) + + blocks = [] + for depth in range(n_depth): + block_depth = depth if self.dilation_cycle is None else depth % self.dilation_cycle + blocks.append(JukeboxResConv1DBlock(config, conv_width, block_depth, res_scale)) + + if reverse_dilation: + blocks = blocks[::-1] + self.resnet_block = nn.ModuleList(blocks) + + def forward(self, hidden_states): + for block in self.resnet_block: + hidden_states = block(hidden_states) + return hidden_states + + +class JukeboxEncoderConvBlock(nn.Module): + def __init__(self, config, embed_dim, hidden_dim, depth, down_t, stride_t): + super().__init__() + blocks = [] + filter_t = stride_t * 2 + pad_t = stride_t // 2 + if down_t > 0: + for i in range(down_t): + blocks.append(nn.Conv1d(embed_dim if i == 0 else hidden_dim, hidden_dim, filter_t, stride_t, pad_t)) + blocks.append(JukeboxResnet1D(config, hidden_dim, depth)) + self.proj_out = nn.Conv1d(hidden_dim, config.embed_dim, 3, 1, 1) + self.downsample_block = nn.ModuleList(blocks) + + def forward(self, hidden_states): + for block in self.downsample_block: + hidden_states = block(hidden_states) + hidden_states = self.proj_out(hidden_states) + return hidden_states + + +class JukeboxEncoder(nn.Module): + def __init__(self, config, width, depth, levels, downs_t, strides_t): + super().__init__() + self.levels = levels + self.level_blocks = nn.ModuleList() + + iterator = zip(list(range(self.levels)), downs_t, strides_t) + for i, down_t, stride_t in iterator: + self.level_blocks.append( + JukeboxEncoderConvBlock( + config, config.conv_input_shape if i == 0 else config.embed_dim, width, depth, down_t, stride_t + ) + ) + + def forward(self, hidden_states): + all_hidden_states = [] + + # 64, 32, ... + for level in range(self.levels): + level_block = self.level_blocks[level] + hidden_states = level_block(hidden_states) + all_hidden_states.append(hidden_states) + + return all_hidden_states + + +class JukeboxDecoderConvBock(nn.Module): + def __init__(self, config, embed_dim, hidden_dim, depth, down_t, stride_t, reverse_dilation=True): + self.embed_dim = embed_dim + self.hidden_dim = hidden_dim + super().__init__() + blocks = [] + if down_t > 0: + filter_t = stride_t * 2 + pad_t = stride_t // 2 + self.proj_in = nn.Conv1d(embed_dim, hidden_dim, 3, 1, 1) + for i in range(down_t): + blocks.append(JukeboxResnet1D(config, hidden_dim, depth, reverse_dilation)) + blocks.append( + nn.ConvTranspose1d( + hidden_dim, hidden_dim if i < down_t - 1 else embed_dim, filter_t, stride_t, pad_t + ) + ) + self.upsample_block = nn.ModuleList(blocks) + + def forward(self, hidden_states): + hidden_states = self.proj_in(hidden_states) + for block in self.upsample_block: + hidden_states = block(hidden_states) + return hidden_states + + +class JukeboxDecoder(nn.Module): + def __init__(self, config, hidden_dim, depth, levels, downs_t, strides_t): + super().__init__() + self.levels = levels + self.level_blocks = nn.ModuleList() + for level, down_t, stride_t in zip(list(range(self.levels)), downs_t, strides_t): + self.level_blocks.append( + JukeboxDecoderConvBock(config, config.embed_dim, hidden_dim, depth, down_t, stride_t) + ) + + self.out = nn.Conv1d(config.embed_dim, config.conv_input_shape, 3, 1, 1) + + def forward(self, hidden_states, all_levels=True): + hidden_state = hidden_states[-1] + + # 32, 64 ... + for level in reversed(range(self.levels)): + level_block = self.level_blocks[level] + hidden_state = level_block(hidden_state) + + if level != 0 and all_levels: + hidden_state = hidden_state + hidden_states[level - 1] + + hidden_state = self.out(hidden_state) + return hidden_state + + +class JukeboxBottleneckBlock(nn.Module): + def __init__(self, config: JukeboxVQVAEConfig): + super().__init__() + self.nb_discrete_codes = config.nb_discrete_codes + self.codebook_width = config.embed_dim + self.mu = config.lmu + self.threshold = 1.0 + self.init = False + self.codebook_sum = None + self.codebook_elem = None + self.register_buffer("codebook", torch.zeros(self.nb_discrete_codes, self.codebook_width)) + + def _tile(self, hidden_states): + dim, embed_width = hidden_states.shape + if dim < self.nb_discrete_codes: + n_repeats = (self.nb_discrete_codes + dim - 1) // dim + std = 0.01 / np.sqrt(embed_width) + hidden_states = hidden_states.repeat(n_repeats, 1) + hidden_states = hidden_states + torch.randn_like(hidden_states) * std + return hidden_states + + def init_codebook(self, hidden_states): + nb_discrete_codes = self.nb_discrete_codes + self.init = True + codes = self._tile(hidden_states) + self.codebook = codes[torch.randperm(codes.shape[0])][:nb_discrete_codes] + self.codebook_sum = self.codebook + self.codebook_elem = torch.ones(nb_discrete_codes, device=self.codebook.device) + + def update_codebook(self, hidden_states, latent_states): + mu, codebook_width, nb_discrete_codes = self.mu, self.codebook_width, self.nb_discrete_codes + with torch.no_grad(): + # Calculate new centres + # nb_discrete_codes, batch_size * seq_length + latent_states_onehot = torch.zeros(nb_discrete_codes, hidden_states.shape[0], device=hidden_states.device) + latent_states_onehot.scatter_(0, latent_states.view(1, hidden_states.shape[0]), 1) + + _codebook_sum = torch.matmul(latent_states_onehot, hidden_states) + _codebook_elem = latent_states_onehot.sum(dim=-1) # nb_discrete_codes + codes = self._tile(hidden_states) + _random_codebook = codes[torch.randperm(codes.shape[0])][:nb_discrete_codes] + + # Update centres + old_codebook = self.codebook + self.codebook_sum = mu * self.codebook_sum + (1.0 - mu) * _codebook_sum + self.codebook_elem = mu * self.codebook_elem + (1.0 - mu) * _codebook_elem # nb_discrete_codes + usage = (self.codebook_elem.view(nb_discrete_codes, 1) >= self.threshold).float() + + norm_code = self.codebook_sum.view(nb_discrete_codes, codebook_width) / self.codebook_elem.view( + nb_discrete_codes, 1 + ) + self.codebook = usage * (norm_code) + (1 - usage) * _random_codebook + _codebook_prob = _codebook_elem / torch.sum(_codebook_elem) # prob of each bin + entropy = -torch.sum(_codebook_prob * torch.log(_codebook_prob + 1e-8)) # entropy ie how diverse + used_curr = (_codebook_elem >= self.threshold).sum() + usage = torch.sum(usage) + dk = torch.linalg.norm(self.codebook - old_codebook) / np.sqrt(np.prod(old_codebook.shape)) + return {"entropy": entropy, "used_curr": used_curr, "usage": usage, "dk": dk} + + def preprocess(self, hidden_states): + hidden_states = hidden_states.permute(0, 2, 1).contiguous() + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + + if hidden_states.shape[-1] == self.codebook_width: + prenorm = torch.linalg.norm(hidden_states - torch.mean(hidden_states)) / np.sqrt( + np.prod(hidden_states.shape) + ) + elif hidden_states.shape[-1] == 2 * self.codebook_width: + x1, x2 = hidden_states[..., : self.codebook_width], hidden_states[..., self.codebook_width :] + prenorm = (torch.linalg.norm(x1 - torch.mean(x1)) / np.sqrt(np.prod(x1.shape))) + ( + torch.linalg.norm(x2 - torch.mean(x2)) / np.sqrt(np.prod(x2.shape)) + ) + + # Normalise + hidden_states = x1 + x2 + + return hidden_states, prenorm + + def postprocess(self, latent_states, dequantised_states, x_shape): + batch_size, time = x_shape + dequantised_states = dequantised_states.view(batch_size, time, -1).permute(0, 2, 1).contiguous() + latent_states = latent_states.view(batch_size, time) + return latent_states, dequantised_states + + def quantise(self, latent_states): + # Calculate latent code latent_states + codebook_weights = self.codebook.t() + distance = ( + torch.sum(latent_states**2, dim=-1, keepdim=True) + - 2 * torch.matmul(latent_states, codebook_weights) + + torch.sum(codebook_weights**2, dim=0, keepdim=True) + ) # (batch_size * latent_states , codebook_weights) + min_distance, music_tokens = torch.min(distance, dim=-1) + fit = torch.mean(min_distance) + return music_tokens, fit + + def dequantise(self, music_tokens): + dequantised_states = F.embedding(music_tokens, self.codebook) + return dequantised_states + + def encode(self, latent_states): + samples, _, seq_len = latent_states.shape + + # Preprocess. + latent_states, _ = self.preprocess(latent_states) + + # Quantise + music_tokens, _ = self.quantise(latent_states) + + # Postprocess. + music_tokens = music_tokens.view(samples, seq_len) + return music_tokens + + def decode(self, music_tokens): + samples, seq_len = music_tokens.shape + + # Dequantise + dequantised_states = self.dequantise(music_tokens) + + # Postprocess + dequantised_states = ( + dequantised_states.view(samples, seq_len, self.codebook_width).permute(0, 2, 1).contiguous() + ) + return dequantised_states + + def forward(self, hidden_states, update_codebook=True): + samples, _, seq_len = hidden_states.shape + + # Preprocess + hidden_states, prenorm = self.preprocess(hidden_states) + + # Init codebook if not inited + if update_codebook and not self.init: + self.init_codebook(hidden_states) + + # Quantise and dequantise through bottleneck + music_tokens, fit = self.quantise(hidden_states) + dequantised_states = self.dequantise(music_tokens) + + # Update embeddings + if update_codebook: + update_metrics = self.update_codebook(hidden_states, music_tokens) + else: + update_metrics = {} + + # Loss + commit_loss = torch.linalg.norm(dequantised_states.detach() - hidden_states) ** 2 / np.prod( + hidden_states.shape + ) + + # Passthrough + dequantised_states = hidden_states + (dequantised_states - hidden_states).detach() + + # Postprocess + music_tokens, dequantised_states = self.postprocess(music_tokens, dequantised_states, (samples, seq_len)) + return music_tokens, dequantised_states, commit_loss, dict(fit=fit, pn=prenorm, **update_metrics) + + +class JukeboxBottleneck(nn.Module): + def __init__(self, config, levels): + super().__init__() + self.levels = levels + self.level_blocks = nn.ModuleList() + for level in range(self.levels): + self.level_blocks.append(JukeboxBottleneckBlock(config)) + + def encode(self, raw_audio): + music_tokens = [ + level_block.encode(hidden_states) for (level_block, hidden_states) in zip(self.level_blocks, raw_audio) + ] + return music_tokens + + def decode(self, music_tokens, start_level=0, end_level=None): + if end_level is None: + end_level = self.levels + quantised_audio = [ + level_block.decode(z) for (level_block, z) in zip(self.level_blocks[start_level:end_level], music_tokens) + ] + return quantised_audio + + def forward(self, input_audio): + music_tokens, quantised_states, commit_losses, metrics = [], [], [], [] + for level in range(self.levels): + level_block = self.level_blocks[-level - 1] + hidden_states = input_audio[level] + sampled_tokens, quantised_state, commit_loss, metric = level_block( + hidden_states, update_codebook=self.training + ) + music_tokens.append(sampled_tokens) + if not self.training: + # Be extra paranoid and make sure the encoder weights can't + # change from straight-through estimator + quantised_state = quantised_state.detach() + quantised_states.append(quantised_state) + commit_losses.append(commit_loss) + if self.training: + metrics.append(metric) + return music_tokens, quantised_states, commit_losses, metrics + + +JUKEBOX_START_DOCSTRING = r""" + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config (`JukeboxConfig`): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + """The Hierarchical VQ-VAE model used in Jukebox. This model follows the Hierarchical VQVAE paper from [Will Williams, Sam +Ringer, Tom Ash, John Hughes, David MacLeod, Jamie Dougherty](https://arxiv.org/abs/2002.08111). + + """, + JUKEBOX_START_DOCSTRING, +) +class JukeboxVQVAE(PreTrainedModel): + config_class = JukeboxVQVAEConfig + base_model_prefix = "vqvae" + + def _init_weights(self, module): + if isinstance(module, nn.Embedding): # embed_tokens + module.weight.data.normal_(mean=0.0, std=0.02 * self.config.init_scale) + elif isinstance(module, JukeboxConv1D): + if self.config.zero_out: + module.weight.data.zero_() + else: + module.weight.data.normal_(mean=0.0, std=0.02 * self.config.init_scale) + elif isinstance(module, JukeboxResConv1DBlock) and self.config.zero_out: + module.conv1d_2.weight.data.zero_() + module.conv1d_2.bias.data.zero_() + if isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + def __init__(self, config: JukeboxVQVAEConfig): + super().__init__(config) + downs_t = config.res_downs_t + strides_t = config.res_strides_t + if not config.sample_length: + downsamples = [stride**down for stride, down in zip(strides_t, downs_t)] + top_raw_to_tokens = np.prod(downsamples) + config.sample_length = ( + config.sample_length_in_seconds * config.sampling_rate // top_raw_to_tokens + ) * top_raw_to_tokens + config.sample_length = config.sample_length.astype(int) + + self.nb_discrete_codes = config.nb_discrete_codes + self.commit = config.commit + self.sample_length = config.sample_length + + self.downsamples = [stride**down for stride, down in zip(strides_t, downs_t)] + self.hop_lengths = np.cumprod(self.downsamples) + self.levels = levels = config.levels + self.music_tokens_shapes = [ + (int(self.sample_length // self.hop_lengths[-level - 1])) for level in range(levels) + ] + + self.multipliers = config.multipliers if config.multipliers is not None else [1] * levels + + self.encoders = nn.ModuleList() + self.decoders = nn.ModuleList() + for level in range(levels): + width = config.res_conv_width * self.multipliers[level] + depth = config.res_conv_depth * self.multipliers[level] + self.encoders.append( + JukeboxEncoder(config, width, depth, level + 1, downs_t[: level + 1], strides_t[: level + 1]) + ) + self.decoders.append( + JukeboxDecoder(config, width, depth, level + 1, downs_t[: level + 1], strides_t[: level + 1]) + ) + + self.bottleneck = JukeboxBottleneck(config, levels) + + def _decode(self, music_tokens, start_level=0, end_level=None): + # Decode + if end_level is None: + end_level = self.levels + latent_states = self.bottleneck.decode(music_tokens, start_level=start_level, end_level=end_level) + # Use only lowest level + decoder, dequantised_state = self.decoders[start_level], latent_states[0:1] + dequantised_state = decoder(dequantised_state, all_levels=False) + dequantised_state = dequantised_state.permute(0, 2, 1) + return dequantised_state + + def decode(self, music_tokens, start_level=0, end_level=None, bs_chunks=1) -> torch.Tensor: + """ + Transforms the input `music_tokens` to their `raw_audio` representation. + + Args: + music_tokens (`torch.LongTensor`): + Tensor of music tokens which will be decoded to raw audio by using the codebook. Each music token + should be an index to a corresponding `code` vector in the codebook. + start_level (`int`, *optional*): + Level at which the decoding process will start. Default to 0. + end_level (`int`, *optional*): + Level at which the decoding process will start. Default to None. + bs_chunks (int, *optional*): + Number of chunks to process at the same time. + """ + token_chunks = [torch.chunk(token, bs_chunks, dim=0) for token in music_tokens] + dequantised_states = [] + for i in range(bs_chunks): + music_tokens_i = [chunks[i] for chunks in token_chunks] + dequantised_state = self._decode(music_tokens_i, start_level=start_level, end_level=end_level) + dequantised_states.append(dequantised_state) + return torch.cat(dequantised_states, dim=0) + + def _encode(self, raw_audio, start_level=0, end_level=None): + # Encode + if end_level is None: + end_level = self.levels + input_audio = raw_audio.permute(0, 2, 1).float() + latent_states = [] + for level in range(self.levels): + encoder = self.encoders[level] + latent_state = encoder(input_audio) + latent_states.append(latent_state[-1]) + music_tokens = self.bottleneck.encode(latent_states) + return music_tokens[start_level:end_level] + + def encode(self, input_audio, start_level=0, end_level=None, bs_chunks=1): + """ + Transforms the `input_audio` to a discrete representation made out of `music_tokens`. + + Args: + input_audio (`torch.Tensor`): + Raw audio which will be encoded to its discrete representation using the codebook. The closest `code` + form the codebook will be computed for each sequence of samples. + start_level (`int`, *optional*, defaults to 0): + Level at which the encoding process will start. Default to 0. + end_level (`int`, *optional*): + Level at which the encoding process will start. Default to None. + bs_chunks (int, *optional*, defaults to 1): + Number of chunks of raw audio to process at the same time. + """ + audio_chunks = torch.chunk(input_audio, bs_chunks, dim=0) + music_tokens_list = [] + for chunk_i in audio_chunks: + music_tokens_i = self._encode(chunk_i, start_level=start_level, end_level=end_level) + music_tokens_list.append(music_tokens_i) + music_tokens = [torch.cat(music_tokens_level, dim=0) for music_tokens_level in zip(*music_tokens_list)] + return music_tokens + + def sample(self, n_samples): + music_tokens = [ + torch.randint(0, self.nb_discrete_codes, size=(n_samples, *music_tokens_shape), device="cpu") + for music_tokens_shape in self.music_tokens_shapes + ] + return self.decode(music_tokens) + + def forward(self, raw_audio: torch.FloatTensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass of the VQ-VAE, encodes the `raw_audio` to latent states, which are then decoded for each level. + The commit loss, which ensure that the encoder's computed embeddings are close to the codebook vectors, is + computed. + + Args: + raw_audio (`torch.FloatTensor`): + Audio input which will be encoded and decoded. + + Returns: + `Tuple[torch.Tensor, torch.Tensor]` + + + Example: + ```python + >>> from transformers import JukeboxVQVAE, set_seed + >>> import torch + + >>> model = JukeboxVQVAE.from_pretrained("openai/jukebox-1b-lyrics").eval() + >>> set_seed(0) + >>> zs = [torch.randint(100, (4, 1))] + >>> model.decode(zs).shape + torch.Size([4, 8, 1]) + ``` + """ + + # Encode/Decode + input_audio = raw_audio.permute(0, 2, 1).float() + latent_states = [] + for level in range(self.levels): + encoder = self.encoders[level] + latent_state = encoder(input_audio) + latent_states.append(latent_state[-1]) + + _, music_tokens, commit_losses, _ = self.bottleneck(latent_states) + dequantised_states = [] + for level in range(self.levels): + decoder = self.decoders[level] + dequantised_state = decoder(music_tokens[level : level + 1], all_levels=False) + dequantised_states.append(dequantised_state.permute(0, 2, 1)) + + commit_loss = sum(commit_losses) + loss = self.commit * commit_loss + + return dequantised_states, loss + + +class JukeboxMLP(nn.Module): + def __init__(self, config): + # a single channel is always used in original code + super().__init__() + embed_dim = config.hidden_size + hidden_dim = int(config.mlp_multiplier * embed_dim) + + self.c_fc = JukeboxConv1D(embed_dim, hidden_dim) + self.c_proj = JukeboxConv1D(hidden_dim, embed_dim) + self.act = ACT2FN[config.act_fn] + self.dropout = nn.Dropout(config.resid_dropout) + + def forward(self, hidden_states): + hidden_states = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.c_proj(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +class JukeboxLayerNorm(FusedLayerNorm): + def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): + super().__init__(normalized_shape, eps=eps, elementwise_affine=elementwise_affine) + self.width = np.prod(normalized_shape) + self.max_numel = 65535 * self.width + + def forward(self, input): + if input.numel() > self.max_numel: + return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps).type_as(input) + else: + return super().forward(input).type_as(input) + + +class JukeboxAttention(nn.Module): + def __init__(self, config, n_ctx, attn_func="dense_attn"): + super().__init__() + self.embed_dim = config.hidden_size + self.n_heads = config.n_heads + self.dropout = config.attn_dropout + hidden_dim = int(config.attention_multiplier * self.embed_dim) + + self.head_dim = hidden_dim // config.n_heads + self.n_ctx = n_ctx + self.hidden_dim = hidden_dim + self.scale = self.head_dim**-0.25 + self.mask = config.mask + + if attn_func == "cross_attention": + self.c_attn = JukeboxConv1D(self.embed_dim, hidden_dim) + self.c_enc_kv = JukeboxConv1D(self.embed_dim, hidden_dim * 2) + else: + self.c_attn = JukeboxConv1D(self.embed_dim, hidden_dim * 3) + + self.c_proj = JukeboxConv1D(hidden_dim, self.embed_dim) + self.attn_dropout = nn.Dropout(config.attn_dropout) + self.resid_dropout = nn.Dropout(config.resid_dropout) + + # Sequence of length seq_len is factored as [blocks, seq_len // blocks] + self.attn_func = attn_func + if attn_func == "cross_attention": + self.qkv = self.decode_qkv + elif attn_func == "prime_attn": + self.qkv = self.prime_qkv + else: + self.qkv = self.factored_qkv + + ATTENTION_MAP = { + "dense_attn": (self.dense_attn, "autoregressive"), + "block_attn": (self.block_attn, "autoregressive"), + "transpose_block_attn": (self.transpose_block_attn, "autoregressive"), + "prev_block_attn": (self.prev_block_attn, None), + "summary_attn": (self.summary_attn, "summary"), + "summary_spread_attn": (self.summary_spread_attn, "summary"), + "cross_attention": (self.dense_attn, None), + "prime_attn": (self.prime_attn, "prime"), + } + self.attn, self.attn_mask = ATTENTION_MAP[attn_func] + + self.blocks = config.blocks + self.spread = config.spread + if self.blocks is not None: + self.block_ctx = self.n_ctx // self.blocks + + self.sample_t = 0 + self.cache = {} + self.encoder_len = config.nb_relevant_lyric_tokens # length of the encoder input ids + self.record_attn = False + + def _attn(self, query_states, key_states, value_states, sample): + scale = self.scale + if self.training: + attention_weight = torch.matmul(query_states * scale, key_states * scale) + else: + attention_weight = torch.matmul(query_states, key_states) + attention_weight.mul_(scale * scale) + attn_weight_type = attention_weight.dtype + attention_weight = attention_weight.float() + if self.mask: + # Generate appropriate mask to mask out all positions before current + # Might take up lot of memory for dense, so can cache it + mask = get_mask( + self.attn_mask, + query_states.size(-2), + key_states.size(-1), + self.blocks, + self.spread, + attention_weight.device, + sample, + self.sample_t, + ) + if mask is not None: + attention_weight = attention_weight * mask + -1e9 * (1 - mask) + attention_prob = F.softmax(attention_weight, dim=-1).type(attn_weight_type) + if self.record_attn: + self.attention_prob = attention_prob + if self.attn_func == "prime_attn": + # only keep music queries and lyrics keys/values + self.attention_prob = self.attention_prob[:, :, self.encoder_len :, : self.encoder_len] + attention_prob = self.attn_dropout(attention_prob) + context_states = torch.matmul(attention_prob, value_states) + return context_states + + def merge_heads(self, hidden_states): + hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous() + new_hidden_states_shape = (*hidden_states.size()[:-2], hidden_states.size(-2) * hidden_states.size(-1)) + return hidden_states.view(*new_hidden_states_shape) # in Tensorflow implem: fct merge_states + + def split_heads(self, hidden_states, is_key=False): + new_hidden_states_shape = ( + *hidden_states.size()[:-1], + self.n_heads, + hidden_states.size(-1) // self.n_heads, + ) + hidden_states = hidden_states.view(*new_hidden_states_shape) # in Tensorflow implem: fct split_states + if is_key: + return hidden_states.permute(0, 2, 3, 1) + else: + return hidden_states.permute(0, 2, 1, 3) + + def dense_attn(self, query, key, value, sample): + query = self.split_heads(query) + key = self.split_heads(key, is_key=True) + value = self.split_heads(value) + context_states = self._attn(query, key, value, sample) + context_states = self.merge_heads(context_states) + return context_states + + def block_attn(self, query, key, value, sample): + block_ctx = self.block_ctx + batch_size, seq_len, embed_dim = value.shape # For sample, query_len= 1, key_len = value_len = sample_t + if sample: + return self.dense_attn(query, key, value, sample).view(batch_size, 1, embed_dim) + else: + query_length = query.shape[1] + query = query.view(batch_size * query_length // block_ctx, block_ctx, embed_dim) + if query_length < seq_len: + seq_len = query_length + key = key[:, -seq_len:].contiguous() + value = value[:, -seq_len:].contiguous() + key = key.view(batch_size * seq_len // block_ctx, block_ctx, embed_dim) + value = value.view(batch_size * seq_len // block_ctx, block_ctx, embed_dim) + return self.dense_attn(query, key, value, sample).view(batch_size, seq_len, embed_dim) + + def transpose_block_attn(self, query, key, value, sample): + block_ctx = self.block_ctx + batch_size, seq_len, embed_dim = value.shape # For sample, query_len= 1, key_len = value_len = sample_t + if sample: + block_len = (seq_len - 1) % block_ctx + key = key[:, block_len::block_ctx, :] + value = value[:, block_len::block_ctx, :] + return self.dense_attn(query, key, value, sample).view(batch_size, 1, embed_dim) + else: + query_length = query.shape[1] + query = query.view(batch_size, query_length // block_ctx, block_ctx, embed_dim) + query = query.transpose(1, 2).contiguous() + query = query.view(batch_size * block_ctx, query_length // block_ctx, embed_dim) + + key = key.view(batch_size, seq_len // block_ctx, block_ctx, embed_dim) + key = key.transpose(1, 2).contiguous() + key = key.view(batch_size * block_ctx, seq_len // block_ctx, embed_dim) + + value = value.view(batch_size, seq_len // block_ctx, block_ctx, embed_dim) + value = value.transpose(1, 2).contiguous() + value = value.view(batch_size * block_ctx, seq_len // block_ctx, embed_dim) + + block_attn = self.dense_attn(query, key, value, sample) + block_attn = block_attn.view(batch_size, block_ctx, query_length // block_ctx, embed_dim) + block_attn = block_attn.transpose(1, 2).contiguous() + block_attn = block_attn.view(batch_size, query_length, embed_dim) + + return block_attn + + def prev_block_attn(self, query, key, value, sample): + block_ctx = self.block_ctx + batch_size, seq_len, embed_dim = value.shape # For sample, query_len= 1, key_len = value_len = sample_t + if sample: + block = (seq_len - 1) // block_ctx + prev_l = (block - 1) * block_ctx + if block > 0: + key = key[:, prev_l : prev_l + block_ctx, :] + value = value[:, prev_l : prev_l + block_ctx, :] + else: + key = torch.zeros(batch_size, block_ctx, embed_dim, device=query.device, dtype=query.dtype) + value = torch.zeros(batch_size, block_ctx, embed_dim, device=query.device, dtype=query.dtype) + return self.dense_attn(query, key, value, sample).view(batch_size, 1, embed_dim) + else: + query_length = query.shape[1] + query = query.view(batch_size * query_length // block_ctx, block_ctx, embed_dim) + + key = key.view(batch_size, seq_len // block_ctx, block_ctx, embed_dim)[:, :-1, :, :] + key = torch.nn.functional.pad(key, (0, 0, 0, 0, 1, 0)) + key = key.view(batch_size * seq_len // block_ctx, block_ctx, embed_dim) + + value = value.view(batch_size, seq_len // block_ctx, block_ctx, embed_dim)[:, :-1, :, :] + value = torch.nn.functional.pad(value, (0, 0, 0, 0, 1, 0)) + value = value.view(batch_size * seq_len // block_ctx, block_ctx, embed_dim) + + if query_length < seq_len: + nb_query_blocks = query_length // block_ctx + nb_key_blocks = seq_len // block_ctx + seq_len = query_length + key = key.view(batch_size, nb_key_blocks, block_ctx, embed_dim)[:, -nb_query_blocks:] + key = key.contiguous().view(batch_size * nb_query_blocks, block_ctx, embed_dim) + + value = value.view(batch_size, nb_key_blocks, block_ctx, embed_dim)[:, -nb_query_blocks:] + value = value.contiguous().view(batch_size * nb_query_blocks, block_ctx, embed_dim) + + return self.dense_attn(query, key, value, sample).view(batch_size, seq_len, embed_dim) + + def summary_attn(self, query, key, value, sample): + blocks = self.blocks + block_ctx = self.block_ctx + batch_size, seq_len, embed_dim = value.shape # For sample, query_len= 1, key_len = value_len = sample_t + if sample: + key = key[:, block_ctx - 1 : blocks * block_ctx - 1 : block_ctx, :] + key = torch.nn.functional.pad(key, (0, 0, 1, 0)) + + value = value[:, block_ctx - 1 : blocks * block_ctx - 1 : block_ctx, :] + value = torch.nn.functional.pad(value, (0, 0, 1, 0)) + return self.dense_attn(query, key, value, sample).view(batch_size, 1, embed_dim) + else: + key = key.view(batch_size, blocks, seq_len // blocks, embed_dim)[:, :-1, -1, :] + key = torch.nn.functional.pad(key, (0, 0, 1, 0)) # batch_size, blocks, embed_dim + + value = value.view(batch_size, blocks, seq_len // blocks, embed_dim)[:, :-1, -1, :] + value = torch.nn.functional.pad(value, (0, 0, 1, 0)) # batch_size, blocks, embed_dim + return self.dense_attn(query, key, value, sample).view(batch_size, seq_len, embed_dim) + + def summary_spread_attn(self, query, key, value, sample): + blocks = self.blocks + spread = self.spread + + batch_size, seq_len, embed_dim = value.shape # For sample, query_len= 1, key_len = value_len = sample_t + if sample: + raise NotImplementedError + else: + key = key.view(batch_size, blocks, seq_len // blocks, embed_dim)[:, :-1, -spread:, :] + key = torch.nn.functional.pad(key, (0, 0, 0, 0, 1, 0)).contiguous() + key = key.view(batch_size, blocks * spread, embed_dim) + + value = value.view(batch_size, blocks, seq_len // blocks, embed_dim)[:, :-1, -spread:, :] + value = torch.nn.functional.pad(value, (0, 0, 0, 0, 1, 0)).contiguous() + value = value.view(batch_size, blocks * spread, embed_dim) + + return self.dense_attn(query, key, value, sample).view(batch_size, seq_len, embed_dim) + + def prime_attn(self, query, key, value, sample): + encoder_len = self._encoder_len + key = key[:, :encoder_len] + value = value[:, :encoder_len] + return self.dense_attn(query, key, value, sample) + + def factored_qkv(self, hidden_states, last_encoder_hidden_states=None, sample=False): + curr_ctx = hidden_states.shape[1] + if last_encoder_hidden_states is not None: + raise TypeError("last_encoder_hidden_states should be None") + + query, key, value = hidden_states.chunk(3, dim=2) + if sample: + self.sample_t += curr_ctx + key, value = self._append_cache(key, value) + l_cache = self._suff_cache_len() + if self._cache_len() > l_cache: + self._slice_cache(-l_cache) + if curr_ctx > 1: + if self.attn_func != "dense_attn": + query = self._pad_to_block_ctx(query, query=True) + key = self._pad_to_block_ctx(key) + value = self._pad_to_block_ctx(value) + sample = False + else: + key = self.cache["key"] + value = self.cache["value"] + return query, key, value, sample + + def prime_qkv(self, hidden_states, last_encoder_hidden_states=None, sample=False): + curr_ctx = hidden_states.shape[1] + if last_encoder_hidden_states is not None: + raise TypeError("last_encoder_hidden_states should be None") + query, key, value = hidden_states.chunk(3, dim=2) + if sample: + if self._cache_len() < self._encoder_len: + self._append_cache(key, value) + if self._cache_len() > self._encoder_len: + self._slice_cache(0, self._encoder_len) + key, value = self.cache["key"], self.cache["value"] + self.sample_t += curr_ctx + return query, key, value, sample + + def decode_qkv(self, hidden_states, last_encoder_hidden_states=None, sample=False): + curr_ctx = hidden_states.shape[1] + query = hidden_states + if sample: + if self.sample_t == 0: + self.cache["key"], self.cache["value"] = self.c_enc_kv( + last_encoder_hidden_states.type_as(hidden_states) + ).chunk(2, dim=2) + key, value = self.cache["key"], self.cache["value"] + self.sample_t += curr_ctx + else: + key, value = self.c_enc_kv(last_encoder_hidden_states.type_as(hidden_states)).chunk(2, dim=2) + return query, key, value, sample + + def forward(self, hidden_states, last_encoder_hidden_states=None, sample=False): + curr_ctx = hidden_states.shape[1] + hidden_states = self.c_attn(hidden_states) + query, key, value, sample = self.qkv( + hidden_states, last_encoder_hidden_states=last_encoder_hidden_states, sample=sample + ) + attention_scores = self.attn(query, key, value, sample) + if attention_scores.shape[1] != curr_ctx: + offset = self._offset(curr_ctx) + attention_scores = attention_scores[:, offset : offset + curr_ctx, :].contiguous() + attention_scores = self.c_proj(attention_scores) + return self.resid_dropout(attention_scores) + + @property + def _encoder_len(self): + encoder_len = self.encoder_len + encoder_blocks = (encoder_len // self.blocks) + 1 + return encoder_blocks * self.blocks + + def _offset(self, curr_ctx): + if self.attn_func == "dense_attn": + return 0 + return (self.sample_t - curr_ctx) % self.block_ctx + + def _pad_to_block_ctx(self, hidden_states, query=False): + seq_len = hidden_states.shape[1] + offset = self._offset(seq_len) if query else 0 + n_blocks = (seq_len + offset + self.block_ctx - 1) // self.block_ctx + pad = n_blocks * self.block_ctx - seq_len - offset + if pad == 0 and offset == 0: + return hidden_states + else: + return F.pad(hidden_states, (0, 0, offset, pad)) + + def _cache_len(self): + return 0 if "key" not in self.cache else self.cache["key"].shape[1] + + def _suff_cache_len(self): + """ + Precondition: + key and value are appended with the current context and self.sample_t reflects the 1-indexed sample + location in the context. + """ + previous_block_length = (self.sample_t - 1) % self.block_ctx + 1 + self.block_ctx + REQUIRED_CACHE_LEN = { + "dense_attn": self.sample_t, + "block_attn": (self.sample_t - 1) % self.block_ctx + 1, + "transpose_block_attn": self.sample_t, + "prev_block_attn": self.sample_t if self.sample_t <= self.block_ctx else previous_block_length, + "cross_attn": self.encoder_len, + "prime_attn": min(self.sample_t, self._encoder_len), + } + + return REQUIRED_CACHE_LEN[self.attn_func] + + def _slice_cache(self, start, end=None): + self.cache["key"] = self.cache["key"][:, start:end] + self.cache["value"] = self.cache["value"][:, start:end] + + def _append_cache(self, key, value): + if "key" not in self.cache: + self.cache["key"] = key + self.cache["value"] = value + else: + old_key, old_value = key, value + key = torch.cat([self.cache["key"], old_key], dim=1) + value = torch.cat([self.cache["value"], old_value], dim=1) + del self.cache["key"] + del self.cache["value"] + del old_key + del old_value + self.cache["key"] = key + self.cache["value"] = value + return self.cache["key"], self.cache["value"] + + def del_cache(self): + self.sample_t = 0 + if "key" in self.cache: + del self.cache["key"] + if "value" in self.cache: + del self.cache["value"] + self.cache = {} + + +class JukeboxBlock(nn.Module): + def __init__(self, config, n_ctx, attn_func="dense_attn"): + super().__init__() + self.width = config.hidden_size + self.attn = JukeboxAttention(config, n_ctx, attn_func=attn_func) + + self.layer_norm_0 = JukeboxLayerNorm(config.hidden_size) + self.mlp = JukeboxMLP(config) + self.layer_norm_1 = JukeboxLayerNorm(config.hidden_size) + self.res_scale = 1.0 / config.num_layers if config.attn_res_scale else 1.0 + self.attn_func = attn_func + + def forward(self, hidden_states, last_encoder_hidden_states, sample=False): + residuals = hidden_states + hidden_states = self.layer_norm_0(hidden_states) + hidden_states = self.attn(hidden_states, last_encoder_hidden_states, sample) + + output_states = self.layer_norm_1(residuals + hidden_states) + output_states = self.mlp(output_states) + if self.res_scale == 1.0: + output = residuals + hidden_states + output_states + else: + output = residuals + self.res_scale * (hidden_states + output_states) + return output + + +class JukeboxLayerStack(nn.Module): + def __init__(self, config, n_ctx): + super().__init__() + self.n_ctx = n_ctx + self.width = config.hidden_size + self.num_layers = config.num_layers + self.blocks = config.blocks + self.attention_pattern = config.attention_pattern + if self.blocks is not None: + self.block_ctx = n_ctx // self.blocks + self.encoder_len = config.nb_relevant_lyric_tokens + self.n_heads = config.n_heads + + # Orders of attn_func + attention_pattern = ATTENTION_PATTERNS[self.attention_pattern] + self._attn_mods = nn.ModuleList() + for depth in range(self.num_layers): + self._attn_mods.append(JukeboxBlock(config, n_ctx, attn_func=attention_pattern(depth))) + + self.saved_attn_weights = [] + + def set_record_attn(self, record_attn): + """ + Makes forward prop dump self-attention softmaxes to self.saved_attn_weights. + + Args: + record_attn (`Union[bool,set]`): + Either a set of layer indices indicating which layers to store, or a boolean value indicating Whether + to dump all. + """ + + def _should_record_attn(layer_idx): + if isinstance(record_attn, bool): + return record_attn + return layer_idx in record_attn + + for i, layer in enumerate(self._attn_mods): + layer.attn.record_attn = _should_record_attn(i) + + if not record_attn: + self.saved_attn_weights = [] + + def forward(self, hidden_states, last_encoder_hidden_states=None, sample=False): + # Blocks + for i, attn_layer in enumerate(self._attn_mods): + if attn_layer.attn_func == "cross_attention": # attend to the lyrics + hidden_states = attn_layer( + hidden_states, last_encoder_hidden_states=last_encoder_hidden_states, sample=sample + ) + else: + hidden_states = attn_layer(hidden_states, last_encoder_hidden_states=None, sample=sample) + if attn_layer.attn.record_attn: + self.saved_attn_weights.append(attn_layer.attn.c_attn.weight) + return hidden_states + + def del_cache(self): + for attn_layer in self._attn_mods: + attn_layer.attn.del_cache() + + +class JukeboxPositionalEmbedding(nn.Module): + def __init__(self, embed_dim, width): + super().__init__() + self.pos_emb = nn.Parameter(torch.empty((embed_dim, width))) + + def forward(self): + pos_emb = self.pos_emb + return pos_emb + + +class JukeboxConditionalAutoregressive(nn.Module): + def __init__( + self, + config, + n_ctx=None, + embed_dim=None, + audio_conditioning=False, + metadata_conditioning=False, + is_encoder=False, + ): + """ + Autoregressive model on either lyric tokens or music tokens, or both. The attention pattern should be properly + set fro each configuration. + + Args: + config (`JukeboxPriorConfig`): + Model configuration class with all the parameters of the model. Initializing with a config file does + not load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. + n_ctx (`int`, *optional*): + Number of tokens or lyrics tokens provided in a single pass. + embed_dim (`int`, *optional*): + Either equals to the dimension of the codebook, or the sum of n_vocab (lyrics) and codebook dimension, + if the model combines lyrics and music tokens, or simply n_vocab if the model is a seperate encoder + audio_conditioning (`bool`, *optional*, defaults to `False`): + Whether or not the prior supports conditionning on audio. + metadata_conditioning (`bool`, *optional*, defaults to `False`): + Whether or not the prior supports conditionning on artitst, genres, lyrics and timing. + is_encoder (`bool`, *optional*, defaults to `False`): + Whether the model is an encoder only model. + """ + + super().__init__() + self.width = config.hidden_size + self.num_layers = config.num_layers + self.n_ctx = n_ctx if n_ctx is not None else config.n_ctx + self.embed_dim = embed_dim if embed_dim is not None else config.music_vocab_size + self.embed_tokens = nn.Embedding(self.embed_dim, config.hidden_size) + self.embed_tokens_dropout = nn.Dropout(config.emb_dropout) + self.metadata_conditioning = metadata_conditioning + self.audio_conditioning = audio_conditioning + if not metadata_conditioning: + self.start_token = nn.Parameter(torch.empty((1, config.hidden_size))) + self.pos_emb = JukeboxPositionalEmbedding(self.n_ctx, config.hidden_size) + self.pos_emb_dropout = nn.Dropout(config.emb_dropout) + + self.transformer = JukeboxLayerStack(config, n_ctx=self.n_ctx) + self.is_encoder = is_encoder + self.encoder_len = config.nb_relevant_lyric_tokens + + if config.merged_decoder: + # Merged piped model uses this setup + self.add_cond_after_transformer = False + self.share_embed_tokens_fc_proj_out = False + else: + self.add_cond_after_transformer = True + self.share_embed_tokens_fc_proj_out = True + + if not is_encoder: + self.fc_proj_out = nn.Linear(config.hidden_size, self.embed_dim, bias=False) + if self.share_embed_tokens_fc_proj_out: + self.fc_proj_out.weight = self.embed_tokens.weight + self.loss = torch.nn.CrossEntropyLoss() + + def forward( + self, + tokens, + audio_conditioning=None, + metadata_conditioning=None, + last_encoder_hidden_states=None, + get_preds=False, + get_acts=False, + get_sep_loss=False, + ): + """ + Args: + tokens (`torch.tensor`): + Can represent music tokens, lyrics tokens or both, depending on the configuration. + """ + # Preprocess. + batch_size = tokens.shape[0] + with torch.no_grad(): + tokens = tokens.view(batch_size, -1).long() + + if not self.audio_conditioning: + audio_conditioning = torch.zeros( + (batch_size, 1, self.width), + device=tokens.device, + dtype=self.transformer._attn_mods[0].mlp.c_fc.weight.dtype, + ) + + target = tokens # Target + hidden_states = self.embed_tokens(tokens) + # Shift by 1, and fill in start token + hidden_states = torch.cat((hidden_states[:, -1:], hidden_states[:, :-1]), dim=1) + if self.metadata_conditioning: + hidden_states[:, 0] = metadata_conditioning.view(batch_size, self.width) + else: + hidden_states[:, 0] = self.start_token + + hidden_states = ( + self.embed_tokens_dropout(hidden_states) + self.pos_emb_dropout(self.pos_emb()) + audio_conditioning + ) # Pos emb and dropout + + hidden_states = self.transformer( + hidden_states, last_encoder_hidden_states=last_encoder_hidden_states + ) # Transformer + if self.add_cond_after_transformer: # Piped doesnt add x_cond + hidden_states = hidden_states + audio_conditioning + + activations = hidden_states + if self.is_encoder: + return hidden_states + + hidden_states = self.fc_proj_out(hidden_states) # Predictions + loss_fn = nn.CrossEntropyLoss() + if get_sep_loss: + lyric_hidden_states = hidden_states[:, : self.encoder_len].reshape(-1, self.embed_dim) + token_hidden_states = hidden_states[:, self.encoder_len :].reshape(-1, self.embed_dim) + + lyric_loss = loss_fn(lyric_hidden_states, target[:, : self.encoder_len].reshape(-1)) / np.log(2.0) + music_token_loss = loss_fn(token_hidden_states, target[:, self.encoder_len :].reshape(-1)) / np.log(2.0) + + loss = (lyric_loss, music_token_loss) # Note order! Lyric is first + else: + loss = loss_fn(hidden_states.view(-1, self.embed_dim), target.view(-1)) / np.log(2.0) # Loss + + if get_preds: + return loss, hidden_states + elif get_acts: + return loss, activations + else: + return loss, None + + def get_emb(self, sample_t, n_samples, tokens, audio_conditioning, metadata_conditioning): + if sample_t == 0: + hidden_states = torch.empty(n_samples, 1, self.width, dtype=self.embed_tokens.weight.dtype).to( + self.embed_tokens.weight.device + ) + if self.metadata_conditioning: + hidden_states[:, 0] = metadata_conditioning.view(n_samples, self.width) + else: + hidden_states[:, 0] = self.start_token + else: + hidden_states = self.embed_tokens(tokens) + if audio_conditioning.shape == (n_samples, self.n_ctx, self.width): + cond = audio_conditioning[:, sample_t : sample_t + 1, :] + else: + cond = audio_conditioning + # Pos emb, dropout is identity at eval time + hidden_states = hidden_states + self.pos_emb()[sample_t : sample_t + 1] + cond + return hidden_states, cond + + def sample( + self, + n_samples, + audio_conditioning=None, + metadata_conditioning=None, + last_encoder_hidden_states=None, + temp=1.0, + top_k=0, + top_p=0.0, + get_preds=False, + sample_tokens=None, + ): + if sample_tokens is None: + sample_tokens = self.n_ctx + + if not self.audio_conditioning: + audio_conditioning = torch.zeros( + (n_samples, 1, self.width), dtype=self.transformer._attn_mods[0].mlp.c_fc.weight.dtype + ).to(self.fc_proj_out.device) + + with torch.no_grad(): + sampled_tokens = [] + tokens = None + if get_preds: + preds = [] + + iter = tqdm(range(0, sample_tokens), leave=False) + for sample_t in iter: + iter.set_description(f"Ancestral sampling {sample_tokens} music tokens", refresh=True) + hidden_states, cond = self.get_emb( + sample_t, n_samples, tokens, audio_conditioning, metadata_conditioning + ) + + hidden_states = self.transformer( + hidden_states, last_encoder_hidden_states=last_encoder_hidden_states, sample=True + ) + if self.add_cond_after_transformer: + hidden_states = hidden_states + cond + hidden_states = self.fc_proj_out(hidden_states) # Predictions + if get_preds: + preds.append(hidden_states.clone()) + # Adjust logits + hidden_states = hidden_states / temp + hidden_states = filter_logits(hidden_states, top_k=top_k, top_p=top_p) + # Sample and replace hidden_states + tokens = torch.distributions.Categorical(logits=hidden_states).sample() + sampled_tokens.append(tokens.clone()) + + del tokens + self.transformer.del_cache() + + tokens = torch.cat(sampled_tokens, dim=1) + if get_preds: + preds = torch.cat(preds, dim=1) + if get_preds: + return tokens, preds + else: + return tokens + + def split_chunks(self, length, chunk_size): + n_passes = (length + chunk_size - 1) // chunk_size + chunk_sizes = [*[chunk_size] * (n_passes - 1), (length - 1) % chunk_size + 1] + return chunk_sizes + + def primed_sample( + self, + n_samples, + lyric_and_music_tokens, + audio_conditioning=None, + metadata_conditioning=None, + last_encoder_hidden_states=None, + temp=1.0, + top_k=0, + top_p=0.0, + get_preds=False, + chunk_size=None, + sample_tokens=None, + ): + if sample_tokens is None: + sample_tokens = self.n_ctx + # Preprocess. + batch_size = lyric_and_music_tokens.shape[0] + with torch.no_grad(): + lyric_and_music_tokens = lyric_and_music_tokens.view(batch_size, -1).long() + + sampled_audio = torch.split(lyric_and_music_tokens, 1, dim=1) + sampled_audio = list(sampled_audio) + + if not self.audio_conditioning: + audio_conditioning = torch.zeros( + (n_samples, 1, self.width), dtype=self.transformer._attn_mods[0].mlp.c_fc.weight.dtype + ).to(lyric_and_music_tokens.device) + + with torch.no_grad(): + if get_preds: + preds = [] + + # Fill up key/value cache for past context by runing forward pass. + # We do so in chunks instead of doing the whole past in one forward pass to reduce max memory usage. + if chunk_size is None: + chunk_size = len(sampled_audio) + chunk_sizes = self.split_chunks(len(sampled_audio), chunk_size) + x_primes = [] + start = 0 + token = None + + for current_chunk_size in tqdm(chunk_sizes, desc="Preparing past key value", leave=False): + sampled_audio_prime, conds_prime = [], [] + for sample_t in range(start, start + current_chunk_size): + x_prime, cond_prime = self.get_emb( + sample_t, n_samples, token, audio_conditioning, metadata_conditioning + ) + token = sampled_audio[sample_t] + sampled_audio_prime.append(x_prime) + conds_prime.append(cond_prime) + start = start + current_chunk_size + x_prime, cond_prime = torch.cat(sampled_audio_prime, dim=1), torch.cat(conds_prime, dim=1) + del sampled_audio_prime + del conds_prime + if not get_preds: + del cond_prime + x_prime = self.transformer(x_prime, last_encoder_hidden_states=last_encoder_hidden_states, sample=True) + + if get_preds: + if self.add_cond_after_transformer: + x_prime = x_prime + cond_prime + del cond_prime + x_primes.append(x_prime) + else: + del x_prime + + if get_preds: + x_prime = torch.cat(x_primes, dim=1) + x_prime = self.fc_proj_out(x_prime) # Predictions + preds.append(x_prime) + + # the input of the encoder and decoder can be merged into (lyrics, music tokens) + input_tokens = sampled_audio[-1] + + itererator = tqdm( + range(len(sampled_audio), sample_tokens), + desc=f"Sampling {len(range(len(sampled_audio), sample_tokens))} music tokens", + leave=False, + ) + for sample_t in itererator: + hidden_states, cond = self.get_emb( + sample_t, n_samples, input_tokens, audio_conditioning, metadata_conditioning + ) + + hidden_states = self.transformer( + hidden_states, last_encoder_hidden_states=last_encoder_hidden_states, sample=True + ) + if self.add_cond_after_transformer: + hidden_states = hidden_states + cond + hidden_states = self.fc_proj_out(hidden_states) # Predictions + if get_preds: + preds.append(hidden_states) + # Adjust logits + hidden_states = hidden_states / temp + hidden_states = filter_logits(hidden_states, top_k=top_k, top_p=top_p) + # only music tokens are sampled + music_tokens = torch.distributions.Categorical(logits=hidden_states).sample() + sampled_audio.append(music_tokens.clone()) + input_tokens = music_tokens + + del input_tokens, music_tokens + self.transformer.del_cache() + + music_tokens = torch.cat(sampled_audio, dim=1) + if get_preds: + preds = torch.cat(preds, dim=1) + if get_preds: + return music_tokens, preds + else: + return music_tokens + + +class JukeboxMusicTokenConditioner(nn.Module): + """ + The `JukeboxMusicTokenConditioner` takes music tokens as an input (coresponding to the codes of the VQVAE's + codebook) and upsamples it using a single layer of decoder convolution block (the same is used in the VQVAE). + """ + + def __init__(self, config, level): + super().__init__() + self.embed_tokens = nn.Embedding(config.music_vocab_size, config.hidden_size) + config.embed_dim = config.music_vocab_size # setting correct argument for the `JukeboxDecoder` + + self.upsampler = JukeboxDecoderConvBock( + config, + config.hidden_size, + config.res_conv_width, + config.res_conv_depth, + config.res_downs_t[level], + config.res_strides_t[level], + reverse_dilation=False, + ) + self.layer_norm = JukeboxLayerNorm(config.hidden_size) + + def forward(self, music_tokens, raw_audio_conditionning=None): + """ + Args: + music_tokens (`torch.LongTensor`): + Music tokens form the uper level in range(nb_discrete_codes) + raw_audio_conditionning (`torch.LongTensor`, *optional*): + Audio used when primed sampling, raw audio information that conditions the generation + """ + if raw_audio_conditionning is None: + raw_audio_conditionning = 0.0 + # Embed music_tokens + music_tokens = music_tokens.long() + hidden_states = self.embed_tokens(music_tokens) + hidden_states = hidden_states + raw_audio_conditionning + + # Run conditioner + hidden_states = hidden_states.permute(0, 2, 1) + hidden_states = self.upsampler(hidden_states) + hidden_states = hidden_states.permute(0, 2, 1) + hidden_states = self.layer_norm(hidden_states) + return hidden_states + + +class JukeboxRangeEmbedding(nn.Module): + """ + The `JukeboxRangeEmbedding` interpolate the given [pos_start, pos_end] to obtain an equivalent of time positional + embedding of length `n_ctx`. + + Binning process : For each pos in position tensor, find its bin [start,end) mapped to [0,1,...,bins-1] [start,end) + -> [0,1) -> [0, bins) -> floor -> [0,...,bins-1] NOTE: Open ended interval on right, so start <= pos < end, not <= + end + """ + + def __init__(self, n_time, embed_dim, range, out_width, clamp=False): + super().__init__() + self.n_time = n_time + self.embed_dim = embed_dim + self.emb = nn.Embedding(embed_dim, out_width) + self.pos_min, self.pos_max = range + self.clamp = clamp + + def forward(self, pos_start, pos_end=None): + # Check if [pos_start,pos_end] in [pos_min, pos_max) + if not len(pos_start.shape) == 2: + raise TypeError(f"Expected shape with 2 dims, got {pos_start.shape}") + if not (self.pos_min <= pos_start).all() and (pos_start < self.pos_max).all(): + raise TypeError(f"Range is [{self.pos_min},{self.pos_max}), got {pos_start}") + + pos_start = pos_start.float() + if pos_end is not None: + if self.clamp: + pos_end = pos_end.clamp(self.pos_min, self.pos_max) + + pos_end = pos_end.float() + # Interpolate so that [pos_start, ..., pos_end] <-> position tensor of length n_ctx + n_time = self.n_time + if n_time != 1: + interpolation = ( + torch.arange(0, n_time, dtype=torch.float, device=pos_start.device).view(1, n_time) / n_time + ) + position = pos_start + (pos_end - pos_start) * interpolation + else: + position = pos_start + + # Bin each value to bins_ + # [0,1) -> [0,1..,embed_dim) -> [0,1...,embed_dim-1 + normalised_position = (position - self.pos_min) / (self.pos_max - self.pos_min) + bins_ = (self.embed_dim * normalised_position).floor().long().detach() + return self.emb(bins_) + + +class JukeboxLabelConditioner(nn.Module): + def __init__(self, config, include_time_signal): + super().__init__() + + embed_dim = config.hidden_size + timing_dims = config.timing_dims + sampling_rate = config.sampling_rate + nb_genres, nb_artists = config.metadata_dims + music_tokens_shape = config.n_ctx + + self.max_nb_genres = config.max_nb_genres + self.bow_genre_emb = nn.Embedding(nb_genres, embed_dim) + self.artist_emb = nn.Embedding(nb_artists, embed_dim) + self.include_time_signal = include_time_signal + if self.include_time_signal: + total_length_range = (config.min_duration * sampling_rate, config.max_duration * sampling_rate) + absolute_pos_range = (0.0, config.max_duration * sampling_rate) + relative_pos_range = (0.0, 1.0) + self.total_length_emb = JukeboxRangeEmbedding(1, timing_dims, total_length_range, embed_dim) + self.absolute_pos_emb = JukeboxRangeEmbedding( + music_tokens_shape, timing_dims, absolute_pos_range, embed_dim + ) + self.relative_pos_emb = JukeboxRangeEmbedding( + music_tokens_shape, timing_dims, relative_pos_range, embed_dim, clamp=True + ) + + def forward(self, metadata): + total_length = metadata[:, 0:1] + offset = metadata[:, 1:2] + length = metadata[:, 2:3] + artist = metadata[:, 3:4] + genre = metadata[:, 4:] + + # Start embedding of length 1 + artist_emb = self.artist_emb(artist) + # Empty genre slots are denoted by -1. We mask these out. + mask = (genre >= 0).float().unsqueeze(2) + genre_emb = (self.bow_genre_emb(genre.clamp(0)) * mask).sum(dim=1, keepdim=True) + start_emb = genre_emb + artist_emb + + # Pos embedding of length n_ctx + if self.include_time_signal: + start, end = offset, offset + length + total_length = total_length.float() + start = start.float() + end = end.float() + pos_emb = ( + self.total_length_emb(total_length) + + self.absolute_pos_emb(start, end) + + self.relative_pos_emb(start / total_length, end / total_length) + ) + else: + pos_emb = None + return start_emb, pos_emb + + +class JukeboxPrior(PreTrainedModel): + """ + The JukeboxPrior class, which is a wrapper around the various conditioning and the transformer. JukeboxPrior can be + seen as language models trained on music. They model the next `music token` prediction task. If a (lyric) `encoderù + is defined, it also models the `next character` prediction on the lyrics. Can be conditionned on timing, artist, + genre, lyrics and codes from lower-levels Priors. + + Args: + config (`JukeboxPriorConfig`): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. + level (`int`, *optional*): + Current level of the Prior. Should be in range `[0,nb_priors]`. + nb_priors (`int`, *optional*, defaults to 3): + Total number of priors. + vqvae_encoder (`Callable`, *optional*): + Encoding method of the VQVAE encoder used in the forward pass of the model. Passing functions instead of + the vqvae module to avoid getting the parameters. + vqvae_decoder (`Callable`, *optional*): + Decoding method of the VQVAE decoder used in the forward pass of the model. Passing functions instead of + the vqvae module to avoid getting the parameters. + """ + + config_class = JukeboxPriorConfig + + def _init_weights(self, module): + init_scale = self.config.init_scale + + if isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=0.02 * init_scale) + elif isinstance(module, JukeboxConv1D): + if self.config.zero_out: + module.weight.data.zero_() + else: + module.weight.data.normal_(mean=0.0, std=0.02 * init_scale) + elif isinstance(module, JukeboxPositionalEmbedding): + module.pos_emb.data.normal_(mean=0.0, std=0.01 * init_scale) + elif isinstance(module, JukeboxRangeEmbedding): + module.emb.weight.data.normal_(mean=0.0, std=0.01 * init_scale) + elif isinstance(module, JukeboxConditionalAutoregressive) and hasattr(module, "lm_head"): + module.lm_head.weight.data.normal_(mean=0.0, std=0.02 * init_scale) + elif isinstance(module, JukeboxConditionalAutoregressive) and hasattr(module, "start_token"): + module.start_token.data.normal_(mean=0.0, std=0.01 * init_scale) + elif isinstance(module, JukeboxResConv1DBlock) and self.config.zero_out: + module.conv1d_2.weigth.data.zero_() + module.conv1d_2.bias.data.zero_() + if isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + def __init__(self, config: JukeboxPriorConfig, level=None, nb_priors=3, vqvae_encoder=None, vqvae_decoder=None): + super().__init__(config) + # Passing functions instead of the vqvae module to avoid getting params, only used in the + # forward loop + self.vqvae_encoder = vqvae_encoder + self.vqvae_decoder = vqvae_decoder + + self.levels = nb_priors + self.level = level if level is not None else config.level + + self.base_model_prefix = f"priors.{self.level}" + + self.n_ctx = config.n_ctx + + self.lyric_conditioning = config.nb_relevant_lyric_tokens > 0 + self.nb_relevant_lyric_tokens = config.nb_relevant_lyric_tokens + self.encoder_loss_fraction = config.encoder_loss_fraction + + # Audio conditioning : conditioning on music tokens (either from audio or from previous levels or both) + self.audio_conditioning = self.level != 0 + self.cond_level = self.level - 1 + if self.audio_conditioning: + self.conditioner_blocks = JukeboxMusicTokenConditioner(config, self.level) + + # metadata conditioning : contioning on timing, genres, and artist + self.metadata_conditioning = config.metadata_conditioning + if self.metadata_conditioning: + self.metadata_embedding = JukeboxLabelConditioner(config, include_time_signal=not self.audio_conditioning) + + # define encoder-decoder or encoder and decoder + self.is_encoder_decoder = config.is_encoder_decoder + if config.is_encoder_decoder: + # encoder-decoder transformer + self.input_shapes = [config.nb_relevant_lyric_tokens, config.n_ctx] + self.embed_dim_shift = [0, config.lyric_vocab_size] + self.width = config.hidden_size + + self.nb_relevant_lyric_tokens = config.nb_relevant_lyric_tokens + + self.prior = JukeboxConditionalAutoregressive( + config, + n_ctx=config.nb_relevant_lyric_tokens + config.n_ctx, + embed_dim=config.lyric_vocab_size + config.music_vocab_size, + audio_conditioning=(self.audio_conditioning or self.metadata_conditioning), + metadata_conditioning=True, + ) + + else: + # Separate encoder-decoder transformer + encoder_config = config.encoder_config + + if self.nb_relevant_lyric_tokens != 0 and self.lyric_conditioning: + self.lyric_acts_width = encoder_config.hidden_size + self.encoder_width = config.hidden_size + self.encoder_dim = config.lyric_vocab_size + self.encoder = JukeboxConditionalAutoregressive( + encoder_config, + n_ctx=self.nb_relevant_lyric_tokens, + embed_dim=self.encoder_dim, + audio_conditioning=False, + metadata_conditioning=False, + is_encoder=True, + ) + self.encoder.proj_in = JukeboxConv1D(encoder_config.hidden_size, config.hidden_size) + self.encoder.final_layer_norm = JukeboxLayerNorm(config.hidden_size) + self.encoder.lm_head = nn.Linear(config.hidden_size, config.lyric_vocab_size, bias=False) + else: + self.nb_relevant_lyric_tokens = 0 + + # decoder model on the tokens + self.prior = JukeboxConditionalAutoregressive( + config, + audio_conditioning=(self.audio_conditioning or self.metadata_conditioning), + metadata_conditioning=self.metadata_conditioning, + ) + + self.next_token_prediction_loss_dims = config.n_ctx + self.total_loss_dims = self.nb_relevant_lyric_tokens + self.next_token_prediction_loss_dims + + self.downsamples = [stride**down for stride, down in zip(config.res_strides_t, config.res_downs_t)] + self.cond_downsample = self.downsamples[self.level] if self.level != 0 else None + self.raw_to_tokens = np.prod(self.downsamples[: nb_priors - self.level]) + self.sample_length = self.n_ctx * self.raw_to_tokens + + logger.info( + f"Level:{self.level}, Cond downsample:{self.cond_downsample}, Raw to tokens:{self.raw_to_tokens}, Sample" + f" length:{self.sample_length}" + ) + + def get_metadata(self, labels, start, total_length, offset, get_indices=False): + metadata = labels.clone() + metadata[:, 0] = total_length + # Set sample_length to match this level + metadata[:, 2] = int(self.sample_length) + + # Set offset + metadata[:, 1:2] = int(offset * self.raw_to_tokens) + int(start * self.raw_to_tokens) + # here since metadata has the full token_list, we just need to selected the ones that are relevant + + # Set lyric tokens + metadata, indices = self.set_metadata_lyric_tokens(metadata) + if get_indices: + return metadata, indices + else: + return metadata + + def set_metadata_lyric_tokens(self, labels): + """ + Processes the full labels to only retrieve the relevant lyric tokens and keep the metadata conditioning tokens. + """ + if self.nb_relevant_lyric_tokens > 0: + tokens_list = torch.zeros( + (labels.shape[0], self.nb_relevant_lyric_tokens), dtype=torch.long, device=labels.device + ) + indices_list = [] # whats the index of each current character in original array + for idx in range(labels.shape[0]): + full_tokens = labels.clone()[:, 4 + self.metadata_embedding.max_nb_genres :] + total_length, offset, duration = labels[idx, 0], labels[idx, 1], labels[idx, 2] + tokens, indices = get_relevant_lyric_tokens( + full_tokens, self.nb_relevant_lyric_tokens, total_length, offset, duration + ) + tokens_list[idx, :] = tokens + indices_list.append(indices) + + return ( + torch.cat((labels[:, : 4 + self.metadata_embedding.max_nb_genres], tokens_list), dim=-1), + indices_list, + ) + else: + return labels, None + + def get_music_tokens_conds(self, music_tokens, start, end): + """ + Extracts current level's conditioning music tokens. + """ + if self.level != 0: + music_tokens_cond = music_tokens[self.level - 1] + music_tokens = music_tokens_cond[:, start // self.cond_downsample : end // self.cond_downsample] + missing_cond_len = self.n_ctx // self.cond_downsample - music_tokens_cond[-1].shape[-1] + if missing_cond_len > 0: + init_cond = torch.zeros(1, missing_cond_len).to(music_tokens_cond.device) + music_tokens_cond = torch.cat((music_tokens_cond, init_cond), dim=-1).long() + music_tokens_conds = [music_tokens_cond] + else: + music_tokens_conds = None + return music_tokens_conds + + def prior_preprocess(self, tokens, conds): + """ + Shifts the input tokens to account for the dictionary merge. The embed_dim_shift give by how much the music + tokens should be shifted by. It is equal to `lyric_vocab_size`. + """ + batch_size = tokens[0].shape[0] + for i in range(len(tokens)): + tokens[i] = (tokens[i] + int(self.embed_dim_shift[i])).view(batch_size, -1) + + for i in range(len(conds)): + if conds[i] is None: + conds[i] = torch.zeros( + (batch_size, self.input_shapes[i], self.width), dtype=tokens[0].dtype, device=tokens[0].device + ) + + return torch.cat(tokens, dim=1), torch.cat(conds, dim=1) + + def prior_postprocess(self, tokens): + """ + Shifts back the input tokens if the model uses an encoder decoder architecture. As the embedding layer is + shared, `prior_embed_dim_shift` shifts the music token ids by `lyric_vocab_size`. Only returns the music + tokens. + """ + batch_size = tokens.shape[0] + dims = (self.input_shapes[0], tokens.shape[1] - self.input_shapes[0]) + tokens = list(torch.split(tokens, dims, dim=1)) + + # Some of the input tokens might be shifted to take into account the voccabulary fusion + for i in range(len(tokens)): + bins_shift = int(self.embed_dim_shift[i]) + tokens[i] = (tokens[i] - bins_shift).view(batch_size, -1) + tokens[i] = torch.clamp(tokens[i], min=0) + # If not masking loss, model may have generated lyric/midi tokens which are now shifted <0 by bin_shift + return tokens[-1] + + def embed_tokens(self, music_tokens_conds): + """ + Embeds the upper level music tokens and upsamples them to provide as audio conditioning. + """ + music_tokens_conds = music_tokens_conds[: self.cond_level + 1] + audio_conditioning = None + for music_tokens_cond, conditioner_block in reversed(list(zip(music_tokens_conds, [self.conditioner_blocks]))): + audio_conditioning = conditioner_block(music_tokens_cond, audio_conditioning) + return audio_conditioning + + def encode(self, hidden_states, start_level=None, end_level=None, bs_chunks=1): + """ + Encodes the hidden states (raw audio) using the VQVAE's encoder. Returns latent_states. + """ + if start_level is None: + start_level = self.level + if end_level is None: + end_level = self.levels + # Get latents + with torch.no_grad(): + latent_states = self.vqvae_encoder( + hidden_states, start_level=start_level, end_level=end_level, bs_chunks=bs_chunks + ) + return latent_states + + def decode(self, music_tokens, start_level=None, end_level=None, bs_chunks=1): + """ + Usamples the sequence of codebook vectors to a raw audio. + """ + if start_level is None: + start_level = self.level + if end_level is None: + end_level = self.levels + with torch.no_grad(): + output = self.vqvae_decoder( + music_tokens, start_level=start_level, end_level=end_level, bs_chunks=bs_chunks + ) + return output + + def get_cond(self, music_tokens_conds, metadata): + """ + Converts the input tokens to input_embeddings. Splits the lyrics form the rest of the metadata. Lyric tokens + can be None. + """ + if metadata is not None: + n_labels = metadata.shape[1] - self.nb_relevant_lyric_tokens + metadata, lyric_tokens = metadata[:, :n_labels], metadata[:, n_labels:] + else: + metadata, lyric_tokens = None, None + metadata_conditioning, metadata_pos = ( + self.metadata_embedding(metadata) if self.metadata_conditioning else (None, None) + ) + audio_conditioning = self.embed_tokens(music_tokens_conds) if self.audio_conditioning else metadata_pos + return audio_conditioning, metadata_conditioning, lyric_tokens + + def sample( + self, + n_samples, + music_tokens=None, + music_tokens_conds=None, + metadata=None, + temp=1.0, + top_k=0, + top_p=0.0, + chunk_size=None, + sample_tokens=None, + ): + """ + Ancestral/Prime sampling a window of tokens using the provided conditioning and metadatas. + + Args: + n_samples (`int`): + Number of samples to generate. + music_tokens (`List[torch.LongTensor]`, *optional*): + Previously gemerated tokens at the current level. Used as context for the generation. + music_tokens_conds (`List[torch.FloatTensor]`, *optional*): + Upper-level music tokens generated by the previous prior model. Is `None` if the generation is not + conditionned on the upper-level tokens. + metadata (`List[torch.LongTensor]`, *optional*): + List containing the metatdata tensor with the artist, genre and the lyric tokens. + temp (`float`, *optional*, defaults to 1.0): + Sampling temperature. + top_k (`int`, *optional*, defaults to 0): + Top k probabilities used for filtering. + top_p (`float`, *optional*, defaults to 0.0): + Top p probabilities used for filtering. + chunk_size (`int`, *optional*): + Size of the chunks used to prepare the cache of the transformer. + sample_tokens (`int`, *optional*): + Number of tokens to sample. + + """ + no_past_context = music_tokens is None or music_tokens.shape[1] == 0 + name = {True: "Ancestral", False: "Primed"}[no_past_context] + logger.info(f"{name} sampling {n_samples} samples with temp={temp}, top_k={top_k}, top_p={top_p}") + + with torch.no_grad(): + # Currently audio_conditioning only uses immediately above layer + audio_conditioning, metadata_conditioning, lyric_tokens = self.get_cond(music_tokens_conds, metadata) + if self.is_encoder_decoder: + if no_past_context: # the prime_sample function will be used with music_tokens set to None + lyric_and_music_tokens, audio_conditioning = self.prior_preprocess( + [lyric_tokens], [None, audio_conditioning] + ) + else: + lyric_and_music_tokens, audio_conditioning = self.prior_preprocess( + [lyric_tokens, music_tokens], [None, audio_conditioning] + ) + if sample_tokens is not None: + sample_tokens += self.nb_relevant_lyric_tokens + music_tokens = self.prior.primed_sample( + n_samples, + lyric_and_music_tokens, + audio_conditioning, + metadata_conditioning, + temp=temp, + top_k=top_k, + top_p=top_p, + chunk_size=chunk_size, + sample_tokens=sample_tokens, + ) + music_tokens = self.prior_postprocess(music_tokens) + else: + last_encoder_hidden_states = self.get_encoder_states(lyric_tokens, sample=True) + if no_past_context: + music_tokens = self.prior.sample( + n_samples, + audio_conditioning, + metadata_conditioning, + last_encoder_hidden_states, + temp=temp, + top_k=top_k, + top_p=top_p, + sample_tokens=sample_tokens, + ) + else: + music_tokens = self.prior.primed_sample( + n_samples, + music_tokens, + audio_conditioning, + metadata_conditioning, + last_encoder_hidden_states, + temp=temp, + top_k=top_k, + top_p=top_p, + chunk_size=chunk_size, + sample_tokens=sample_tokens, + ) + return music_tokens + + def get_encoder_states(self, lyric_tokens, sample=False): + """ + Retrieve the last hidden_states of the lyric encoder that will be attended to by the decoder. Forwards through + the lyric encoder. + """ + if self.nb_relevant_lyric_tokens != 0 and self.lyric_conditioning: + if sample: + self.encoder = self.encoder.to(lyric_tokens.device) + lyric_acts = self.encoder(lyric_tokens, None, None, None) + lyric_acts = self.encoder.proj_in(lyric_acts) + last_encoder_hidden_states = self.encoder.final_layer_norm(lyric_acts) + else: + last_encoder_hidden_states = None + return last_encoder_hidden_states + + def get_encoder_loss(self, last_encoder_hidden_states, target_lyrics): + """ + Computes the loss for the lyric encoder: next lyric token prediction. + """ + if self.lyric_conditioning: + last_encoder_hidden_states = self.encoder.lm_head(last_encoder_hidden_states) + encoder_loss = nn.functional.cross_entropy( + last_encoder_hidden_states.view(-1, self.encoder_dim), target_lyrics.view(-1) + ) / np.log(2.0) + else: + encoder_loss = torch.tensor(0.0, device=last_encoder_hidden_states.device) + return encoder_loss + + def forward_tokens( + self, music_tokens, music_tokens_conds=[], metadata=None, get_preds=False, get_attn_weights=False + ): + """ + Applies a forward pass using the conditioning tokens. Different from the classic forward as it does not use the + vqvae's encoding layers. + """ + if get_attn_weights: + self.prior.transformer.set_record_attn(get_attn_weights) + audio_conditioning, metadata_conditioning, lyric_tokens = self.get_cond(music_tokens_conds, metadata) + + if self.is_encoder_decoder: # the preprocess returns the full tokens (Lyrics and Music tokens), shifted + tokens, audio_conditioning = self.prior_preprocess( + [lyric_tokens, music_tokens], [None, audio_conditioning] + ) + (encoder_loss, next_token_prediction_loss), preds = self.prior( + tokens, audio_conditioning, metadata_conditioning, get_sep_loss=True, get_preds=get_preds + ) + else: + last_encoder_hidden_states = self.get_encoder_states(lyric_tokens) + encoder_loss = self.get_encoder_loss(last_encoder_hidden_states, lyric_tokens) + next_token_prediction_loss, preds = self.prior( + music_tokens, + audio_conditioning, + metadata_conditioning, + last_encoder_hidden_states, + get_preds=get_preds, + ) + loss = self.encoder_loss_fraction * encoder_loss * self.nb_relevant_lyric_tokens / self.total_loss_dims + loss += next_token_prediction_loss * self.next_token_prediction_loss_dims / self.total_loss_dims + + metrics = { + "bpd": next_token_prediction_loss.detach().clone(), + "encoder_loss": encoder_loss.detach().clone(), + "next_token_prediction_loss": next_token_prediction_loss.detach().clone(), + } + if get_preds: + metrics["preds"] = preds.detach().clone() + if get_attn_weights: + saved_attn_weights = self.prior.transformer.saved_attn_weights + self.prior.transformer.set_record_attn(False) + return saved_attn_weights + else: + return loss, metrics + + def forward( + self, + hidden_states: torch.Tensor, + metadata: Optional[List[torch.LongTensor]], + decode: Optional[bool] = False, + get_preds: Optional[bool] = False, + ) -> List[torch.Tensor]: + """ + Encode the hidden states using the `vqvae` encoder, and then predicts the next token in the `forward_tokens` + function. The loss is the sum of the `encoder` loss and the `decoder` loss. + + Args: + hidden_states (`torch.Tensor`): + Hidden states which should be raw audio + metadata (`List[torch.LongTensor]`, *optional*): + List containing the metadata conditioning tensorwith the lyric and the metadata tokens. + decode (`bool`, *optional*, defaults to `False`): + Whether or not to decode the encoded to tokens. + get_preds (`bool`, *optional*, defaults to `False`): + Whether or not to return the actual predicitons of the model. + """ + batch_size = hidden_states.shape[0] + music_tokens, *music_tokens_conds = self.encode(hidden_states, bs_chunks=batch_size) + loss, metrics = self.forward_tokens( + music_tokens=music_tokens, + music_tokens_conds=music_tokens_conds, + metadata=metadata, + get_preds=get_preds, + ) + if decode: + dequantised_states = self.decode([music_tokens, *music_tokens_conds]) + else: + dequantised_states = None + return dequantised_states, loss, metrics + + +class JukeboxPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = JukeboxConfig + base_model_prefix = "jukebox" + supports_gradient_checkpointing = False + + def _init_weights(self, module): + if isinstance(module, JukeboxPrior) or isinstance(module, JukeboxVQVAE): + module.apply(module._init_weights) + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + +JUKEBOX_SAMPLING_INPUT_DOCSTRING = r""" + labels (`List[torch.LongTensor]` of length `n_sample`, and shape `(self.levels, self.config.max_nb_genre + lyric_sequence_length)` : + List of metadata such as `artist_id`, `genre_id` and the full list of lyric tokens which are used to + condition the generation. + sampling_kwargs (`Dict[Any]`): + Various additional sampling arguments that are used by the `_sample` function. A detail list of the + arguments can bee seen in the [`_sample`] function documentation. +""" + + +@add_start_docstrings( + """The bare JUKEBOX Model used for music generation. 4 sampling techniques are supported : `primed_sample`, `upsample`, + `continue_sample` and `ancestral_sample`. It does not have a `forward` method as the training is not end to end. If + you want to fine-tune the model, it is recommended to use the `JukeboxPrior` class and train each prior + individually. + """, + JUKEBOX_START_DOCSTRING, +) +class JukeboxModel(JukeboxPreTrainedModel): + _no_split_modules = ["JukeboxBlock"] + + def __init__(self, config): + super().__init__(config) + vqvae_config = config.vqvae_config + self.vqvae = JukeboxVQVAE(vqvae_config) + self.set_shared_params(config) + self.priors = nn.ModuleList( + [JukeboxPrior(config.prior_configs[level], level) for level in range(config.nb_priors)] + ) + + def set_shared_params(self, model_config): + """ + Initialises the parameters that are shared. This has to be done here because the list of `JukeboxPriorConfig` + is nest, and is thus unreachable in the `from_dict` function + """ + for config in model_config.prior_configs: + config.sampling_rate = model_config.sampling_rate + config.timing_dims = model_config.timing_dims + config.min_duration = model_config.min_duration + config.max_duration = model_config.max_duration + config.max_nb_genres = model_config.max_nb_genres + config.metadata_conditioning = model_config.metadata_conditioning + + def decode(self, music_tokens, start_level=0, end_level=None, bs_chunks=1): + return self.vqvae.decode(music_tokens, start_level, end_level, bs_chunks) + + def encode(self, input_audio, start_level=0, end_level=None, bs_chunks=1): + return self.vqvae.encode(input_audio, start_level, end_level, bs_chunks) + + def split_batch(self, obj, n_samples, split_size): + n_passes = (n_samples + split_size - 1) // split_size + if isinstance(obj, torch.Tensor): + return torch.split(obj, split_size, dim=0) + elif isinstance(obj, list): + return list(zip(*[torch.split(item, split_size, dim=0) for item in obj])) + elif obj is None: + return [None] * n_passes + else: + raise TypeError("Unknown input type") + + # Sample a partial window of length= self.priors[level].n_ctx: + iterator = get_starts(total_length, self.priors[level].n_ctx, hop_length) + for start in iterator: + music_tokens = self.sample_single_window( + music_tokens, labels, offset, sampling_kwargs, level, start, max_batch_size + ) + + else: + music_tokens = self.sample_partial_window( + music_tokens, labels, offset, sampling_kwargs, level, total_length, max_batch_size + ) + return music_tokens + + @torch.no_grad() + def _sample( + self, + music_tokens, + labels, + sample_levels, + metas=None, + chunk_size=32, + sampling_temperature=0.98, + lower_batch_size=16, + max_batch_size=16, + sample_length_in_seconds=24, + compute_alignments=False, + sample_tokens=None, + offset=0, + save_results=True, + sample_length=None, + ) -> List[torch.LongTensor]: + """ + Core sampling function used to generate music tokens. Iterates over the provided list of levels, while saving + the generated raw audio at each step. + + Args: + music_tokens (`List[torch.LongTensor]`): + A sequence of music tokens of length `self.levels` which will be used as context to continue the + sampling process. Should have `self.levels` tensors, each corresponding to the generation at a certain + level. + labels (`List[torch.LongTensor]`): + List of length `n_sample`, and shape `(self.levels, 4 + self.config.max_nb_genre + + lyric_sequence_length)` metadata such as `artist_id`, `genre_id` and the full list of lyric tokens + which are used to condition the generation. + sample_levels (`List[int]`): + List of the desired levels at which the sampling will be done. A level is equivalent to the index of + the prior in the list of priors + metas (`List[Any]`, *optional*): + Metadatas used to generate the `labels` + chunk_size (`int`, *optional*, defaults to 32): + Size of a chunk of audio, used to fill up the memory in chuncks to prevent OOM erros. Bigger chunks + means faster memory filling but more consumption. + sampling_temperature (`float`, *optional*, defaults to 0.98): + Temperature used to ajust the randomness of the sampling. + lower_batch_size (`int`, *optional*, defaults to 16): + Maximum batch size for the lower level priors + max_batch_size (`int`, *optional*, defaults to 16): + Maximum batch size for the top level priors + sample_length_in_seconds (`int`, *optional*, defaults to 24): + Desired length of the generation in seconds + compute_alignments (`bool`, *optional*, defaults to `False`): + Whether or not to compute the alignment between the lyrics and the audio using the top_prior + sample_tokens (`int`, *optional*): + Precise number of tokens that should be sampled at each level. This is mostly useful for running dummy + experiments + offset (`int`, *optional*, defaults to 0): + Audio offset used as conditioning, corresponds to the starting sample in the music. If the offset is + greater than 0, the lyrics will be shifted take that intoaccount + save_results (`bool`, *optional*, defaults to `True`): + Whether or not to save the intermediate results. If `True`, will generate a folder named with the start + time. + sample_length (`int`, *optional*): + Desired length of the generation in samples. + + Returns: torch.Tensor + + Example: + + ```python + >>> from transformers import AutoTokenizer, JukeboxModel, set_seed + >>> import torch + + >>> metas = dict(artist="Zac Brown Band", genres="Country", lyrics="I met a traveller from an antique land") + >>> tokenizer = AutoTokenizer.from_pretrained("openai/jukebox-1b-lyrics") + >>> model = JukeboxModel.from_pretrained("openai/jukebox-1b-lyrics", min_duration=0).eval() + + >>> labels = tokenizer(**metas)["input_ids"] + >>> set_seed(0) + >>> zs = [torch.zeros(1, 0, dtype=torch.long) for _ in range(3)] + >>> zs = model._sample(zs, labels, [0], sample_length=40 * model.priors[0].raw_to_tokens, save_results=False) + >>> zs[0] + tensor([[1853, 1369, 1150, 1869, 1379, 1789, 519, 710, 1306, 1100, 1229, 519, + 353, 1306, 1379, 1053, 519, 653, 1631, 1467, 1229, 1229, 10, 1647, + 1254, 1229, 1306, 1528, 1789, 216, 1631, 1434, 653, 475, 1150, 1528, + 1804, 541, 1804, 1434]]) + ``` + """ + + top_prior = self.priors[0] + if sample_length is not None: + total_length = sample_length + else: + total_length = ( + int(sample_length_in_seconds * self.config.sampling_rate) // top_prior.raw_to_tokens + ) * top_prior.raw_to_tokens + + if sample_levels is None: + sample_levels = range(len(self.priors)) + + # total length of the signal, might be bit different from the actual generated length + self.total_length = total_length + for level in sample_levels: + sampling_kwargs = { + "temp": 0.99 if level == len(self.priors) - 1 else sampling_temperature, + "chunk_size": chunk_size, + "sample_tokens": sample_tokens, + } + # Set correct total_length, hop_length, labels and sampling_kwargs for level + + total_token_to_sample = total_length // self.priors[level].raw_to_tokens + hop_length = int(self.config.hop_fraction[level] * self.priors[level].n_ctx) + max_batch_size = lower_batch_size if level != sample_levels else max_batch_size + music_tokens = self.sample_level( + music_tokens, + labels[level], + offset, + sampling_kwargs, + level, + total_token_to_sample, + hop_length, + max_batch_size, + ) + + if save_results: + self.vqvae.to(music_tokens[level].device) + # Decode sample + with torch.no_grad(): + start_level = len(self.priors) - level - 1 # vqvae levels are reversed + raw_audio = self.vqvae.decode( + music_tokens[: level + 1], start_level=start_level, bs_chunks=music_tokens[level].shape[0] + ) + logdir = f"jukebox/level_{level}" + if not os.path.exists(logdir): + os.makedirs(logdir) + save_temp_audio(logdir, level, metas=metas, aud=raw_audio.float()) + if compute_alignments and self.priors[0] is not None and self.priors[0].nb_relevant_lyric_tokens > 0: + with torch.no_grad(): + alignments = get_alignment(music_tokens, labels[0], self.priors[0], self.config) + torch.save({"alignments": alignments}, f"{logdir}/lyric_alignments.pt") + + return music_tokens + + @add_start_docstrings( + """ + Generates music tokens based on the provided `labels. Will start at the desired prior level and automatically + upsample the sequence. If you want to create the audio, you should call `model.decode(tokens)`, which will use + the VQ-VAE decoder to convert the music tokens to raw audio. + + Args: + labels (`List[torch.LongTensor]`) : + List of length `n_sample`, and shape `(self.levels, 4 + self.config.max_nb_genre + + lyric_sequence_length)` metadata such as `artist_id`, `genre_id` and the full list of lyric tokens + which are used to condition the generation. + n_samples (`int`, *optional*, default to 1) : + Number of samples to be generated in parallel. + """, + ) + def ancestral_sample(self, labels, n_samples=1, **sampling_kwargs) -> List[torch.LongTensor]: + """ + Example: + + ```python + >>> from transformers import AutoTokenizer, JukeboxModel, set_seed + + >>> model = JukeboxModel.from_pretrained("openai/jukebox-1b-lyrics", min_duration=0).eval() + >>> tokenizer = AutoTokenizer.from_pretrained("openai/jukebox-1b-lyrics") + + >>> lyrics = "Hey, are you awake? Can you talk to me?" + >>> artist = "Zac Brown Band" + >>> genre = "Country" + >>> metas = tokenizer(artist=artist, genres=genre, lyrics=lyrics) + >>> set_seed(0) + >>> music_tokens = model.ancestral_sample(metas.input_ids, sample_length=400) + + >>> with torch.no_grad(): + ... model.decode(music_tokens)[:, :10].squeeze(-1) + tensor([[-0.0219, -0.0679, -0.1050, -0.1203, -0.1271, -0.0936, -0.0396, -0.0405, + -0.0818, -0.0697]]) + ``` + """ + + sample_levels = sampling_kwargs.pop("sample_levels", list(range(len(self.priors)))) + music_tokens = [ + torch.zeros(n_samples, 0, dtype=torch.long, device=labels[0].device) for _ in range(len(self.priors)) + ] + music_tokens = self._sample(music_tokens, labels, sample_levels, **sampling_kwargs) + return music_tokens + + @add_start_docstrings( + """Generates a continuation of the previously generated tokens. + + Args: + music_tokens (`List[torch.LongTensor]` of length `self.levels` ) : + A sequence of music tokens which will be used as context to continue the sampling process. Should have + `self.levels` tensors, each corresponding to the generation at a certain level. + """, + JUKEBOX_SAMPLING_INPUT_DOCSTRING, + ) + def continue_sample(self, music_tokens, labels, **sampling_kwargs) -> List[torch.LongTensor]: + sample_levels = sampling_kwargs.pop("sample_levels", list(range(len(self.priors)))) + music_tokens = self._sample(music_tokens, labels, sample_levels, **sampling_kwargs) + return music_tokens + + @add_start_docstrings( + """Upsamples a sequence of music tokens using the prior at level `level`. + + Args: + music_tokens (`List[torch.LongTensor]` of length `self.levels` ) : + A sequence of music tokens which will be used as context to continue the sampling process. Should have + `self.levels` tensors, each corresponding to the generation at a certain level. + """, + JUKEBOX_SAMPLING_INPUT_DOCSTRING, + ) + def upsample(self, music_tokens, labels, **sampling_kwargs) -> List[torch.LongTensor]: + sample_levels = sampling_kwargs.pop("sample_levels", list(range(len(self.priors) - 1))) + music_tokens = self._sample(music_tokens, labels, sample_levels, **sampling_kwargs) + return music_tokens + + @add_start_docstrings( + """Generate a raw audio conditioned on the provided `raw_audio` which is used as conditioning at each of the + generation levels. The audio is encoded to music tokens using the 3 levels of the VQ-VAE. These tokens are + used: as conditioning for each level, which means that no ancestral sampling is required. + + Args: + raw_audio (`List[torch.Tensor]` of length `n_samples` ) : + A list of raw audio that will be used as conditioning information for each samples that will be + generated. + """, + JUKEBOX_SAMPLING_INPUT_DOCSTRING, + ) + def primed_sample(self, raw_audio, labels, **sampling_kwargs) -> List[torch.LongTensor]: + sample_levels = sampling_kwargs.pop("sample_levels", list(range(len(self.priors)))) + self.vqvae.to(raw_audio.device).float() + with torch.no_grad(): + music_tokens = self.vqvae.encode( + raw_audio, start_level=0, end_level=len(self.priors), bs_chunks=raw_audio.shape[0] + ) + music_tokens = self._sample(music_tokens, labels, sample_levels, **sampling_kwargs) + return music_tokens + + +__all__ = ["JukeboxModel", "JukeboxPreTrainedModel", "JukeboxVQVAE", "JukeboxPrior"] diff --git a/docs/transformers/build/lib/transformers/models/deprecated/jukebox/tokenization_jukebox.py b/docs/transformers/build/lib/transformers/models/deprecated/jukebox/tokenization_jukebox.py new file mode 100644 index 0000000000000000000000000000000000000000..e08ab179a807d9af7050e92838ced80c7d0d1742 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/jukebox/tokenization_jukebox.py @@ -0,0 +1,407 @@ +# coding=utf-8 +# Copyright 2022 The Open AI Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for OpenAI Jukebox.""" + +import json +import os +import re +import unicodedata +from json.encoder import INFINITY +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import regex + +from ....tokenization_utils import AddedToken, PreTrainedTokenizer +from ....tokenization_utils_base import BatchEncoding +from ....utils import TensorType, is_flax_available, is_tf_available, is_torch_available, logging +from ....utils.generic import _is_jax, _is_numpy + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = { + "artists_file": "artists.json", + "lyrics_file": "lyrics.json", + "genres_file": "genres.json", +} + + +class JukeboxTokenizer(PreTrainedTokenizer): + """ + Constructs a Jukebox tokenizer. Jukebox can be conditioned on 3 different inputs : + - Artists, unique ids are associated to each artist from the provided dictionary. + - Genres, unique ids are associated to each genre from the provided dictionary. + - Lyrics, character based tokenization. Must be initialized with the list of characters that are inside the + vocabulary. + + This tokenizer does not require training. It should be able to process a different number of inputs: + as the conditioning of the model can be done on the three different queries. If None is provided, defaults values will be used.: + + Depending on the number of genres on which the model should be conditioned (`n_genres`). + ```python + >>> from transformers import JukeboxTokenizer + + >>> tokenizer = JukeboxTokenizer.from_pretrained("openai/jukebox-1b-lyrics") + >>> tokenizer("Alan Jackson", "Country Rock", "old town road")["input_ids"] + [tensor([[ 0, 0, 0, 6785, 546, 41, 38, 30, 76, 46, 41, 49, + 40, 76, 44, 41, 27, 30]]), tensor([[ 0, 0, 0, 145, 0]]), tensor([[ 0, 0, 0, 145, 0]])] + ``` + + You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer or when you + call it on some text, but since the model was not pretrained this way, it might yield a decrease in performance. + + + + If nothing is provided, the genres and the artist will either be selected randomly or set to None + + + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to: + this superclass for more information regarding those methods. + + However the code does not allow that and only supports composing from various genres. + + Args: + artists_file (`str`): + Path to the vocabulary file which contains a mapping between artists and ids. The default file supports + both "v2" and "v3" + genres_file (`str`): + Path to the vocabulary file which contain a mapping between genres and ids. + lyrics_file (`str`): + Path to the vocabulary file which contains the accepted characters for the lyrics tokenization. + version (`List[str]`, `optional`, default to `["v3", "v2", "v2"]`) : + List of the tokenizer versions. The `5b-lyrics`'s top level prior model was trained using `v3` instead of + `v2`. + n_genres (`int`, `optional`, defaults to 1): + Maximum number of genres to use for composition. + max_n_lyric_tokens (`int`, `optional`, defaults to 512): + Maximum number of lyric tokens to keep. + unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + artists_file, + genres_file, + lyrics_file, + version=["v3", "v2", "v2"], + max_n_lyric_tokens=512, + n_genres=5, + unk_token="<|endoftext|>", + **kwargs, + ): + unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token + self.version = version + self.max_n_lyric_tokens = max_n_lyric_tokens + self.n_genres = n_genres + self._added_tokens_decoder = {0: unk_token} + + with open(artists_file, encoding="utf-8") as vocab_handle: + self.artists_encoder = json.load(vocab_handle) + + with open(genres_file, encoding="utf-8") as vocab_handle: + self.genres_encoder = json.load(vocab_handle) + + with open(lyrics_file, encoding="utf-8") as vocab_handle: + self.lyrics_encoder = json.load(vocab_handle) + + oov = r"[^A-Za-z0-9.,:;!?\-'\"()\[\] \t\n]+" + # In v2, we had a n_vocab=80 and in v3 we missed + and so n_vocab=79 of characters. + if len(self.lyrics_encoder) == 79: + oov = oov.replace(r"\-'", r"\-+'") + + self.out_of_vocab = regex.compile(oov) + self.artists_decoder = {v: k for k, v in self.artists_encoder.items()} + self.genres_decoder = {v: k for k, v in self.genres_encoder.items()} + self.lyrics_decoder = {v: k for k, v in self.lyrics_encoder.items()} + super().__init__( + unk_token=unk_token, + n_genres=n_genres, + version=version, + max_n_lyric_tokens=max_n_lyric_tokens, + **kwargs, + ) + + @property + def vocab_size(self): + return len(self.artists_encoder) + len(self.genres_encoder) + len(self.lyrics_encoder) + + def get_vocab(self): + return { + "artists_encoder": self.artists_encoder, + "genres_encoder": self.genres_encoder, + "lyrics_encoder": self.lyrics_encoder, + } + + def _convert_token_to_id(self, list_artists, list_genres, list_lyrics): + """Converts the artist, genre and lyrics tokens to their index using the vocabulary. + The total_length, offset and duration have to be provided in order to select relevant lyrics and add padding to + the lyrics token sequence. + """ + artists_id = [self.artists_encoder.get(artist, 0) for artist in list_artists] + for genres in range(len(list_genres)): + list_genres[genres] = [self.genres_encoder.get(genre, 0) for genre in list_genres[genres]] + list_genres[genres] = list_genres[genres] + [-1] * (self.n_genres - len(list_genres[genres])) + + lyric_ids = [[self.lyrics_encoder.get(character, 0) for character in list_lyrics[0]], [], []] + return artists_id, list_genres, lyric_ids + + def _tokenize(self, lyrics): + """ + Converts a string into a sequence of tokens (string), using the tokenizer. Split in words for word-based + vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces). + + Do NOT take care of added tokens. Only the lyrics are split into character for the character-based vocabulary. + """ + # only lyrics are not tokenized, but character based is easily handled + return list(lyrics) + + def tokenize(self, artist, genre, lyrics, **kwargs): + """ + Converts three strings in a 3 sequence of tokens using the tokenizer + """ + artist, genre, lyrics = self.prepare_for_tokenization(artist, genre, lyrics) + lyrics = self._tokenize(lyrics) + return artist, genre, lyrics + + def prepare_for_tokenization( + self, artists: str, genres: str, lyrics: str, is_split_into_words: bool = False + ) -> Tuple[str, str, str, Dict[str, Any]]: + """ + Performs any necessary transformations before tokenization. + + Args: + artist (`str`): + The artist name to prepare. This will mostly lower the string + genres (`str`): + The genre name to prepare. This will mostly lower the string. + lyrics (`str`): + The lyrics to prepare. + is_split_into_words (`bool`, *optional*, defaults to `False`): + Whether or not the input is already pre-tokenized (e.g., split into words). If set to `True`, the + tokenizer assumes the input is already split into words (for instance, by splitting it on whitespace) + which it will tokenize. This is useful for NER or token classification. + """ + for idx in range(len(self.version)): + if self.version[idx] == "v3": + artists[idx] = artists[idx].lower() + genres[idx] = [genres[idx].lower()] + else: + artists[idx] = self._normalize(artists[idx]) + ".v2" + genres[idx] = [ + self._normalize(genre) + ".v2" for genre in genres[idx].split("_") + ] # split is for the full dictionary with combined genres + + if self.version[0] == "v2": + self.out_of_vocab = regex.compile(r"[^A-Za-z0-9.,:;!?\-'\"()\[\] \t\n]+") + vocab = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789.,:;!?-+'\"()[] \t\n" + self.vocab = {vocab[index]: index + 1 for index in range(len(vocab))} + self.vocab[""] = 0 + self.n_vocab = len(vocab) + 1 + self.lyrics_encoder = self.vocab + self.lyrics_decoder = {v: k for k, v in self.vocab.items()} + self.lyrics_decoder[0] = "" + else: + self.out_of_vocab = regex.compile(r"[^A-Za-z0-9.,:;!?\-+'\"()\[\] \t\n]+") + + lyrics = self._run_strip_accents(lyrics) + lyrics = lyrics.replace("\\", "\n") + lyrics = self.out_of_vocab.sub("", lyrics), [], [] + return artists, genres, lyrics + + def _run_strip_accents(self, text): + """Strips accents from a piece of text.""" + text = unicodedata.normalize("NFD", text) + output = [] + for char in text: + cat = unicodedata.category(char) + if cat == "Mn": + continue + output.append(char) + return "".join(output) + + def _normalize(self, text: str) -> str: + """ + Normalizes the input text. This process is for the genres and the artist + + Args: + text (`str`): + Artist or Genre string to normalize + """ + + accepted = ( + [chr(i) for i in range(ord("a"), ord("z") + 1)] + + [chr(i) for i in range(ord("A"), ord("Z") + 1)] + + [chr(i) for i in range(ord("0"), ord("9") + 1)] + + ["."] + ) + accepted = frozenset(accepted) + pattern = re.compile(r"_+") + text = "".join([c if c in accepted else "_" for c in text.lower()]) + text = pattern.sub("_", text).strip("_") + return text + + def convert_lyric_tokens_to_string(self, lyrics: List[str]) -> str: + return " ".join(lyrics) + + def convert_to_tensors( + self, inputs, tensor_type: Optional[Union[str, TensorType]] = None, prepend_batch_axis: bool = False + ): + """ + Convert the inner content to tensors. + + Args: + tensor_type (`str` or [`~utils.TensorType`], *optional*): + The type of tensors to use. If `str`, should be one of the values of the enum [`~utils.TensorType`]. If + unset, no modification is done. + prepend_batch_axis (`int`, *optional*, defaults to `False`): + Whether or not to add the batch dimension during the conversion. + """ + # Convert to TensorType + if not isinstance(tensor_type, TensorType): + tensor_type = TensorType(tensor_type) + + # Get a function reference for the correct framework + if tensor_type == TensorType.TENSORFLOW: + if not is_tf_available(): + raise ImportError( + "Unable to convert output to TensorFlow tensors format, TensorFlow is not installed." + ) + import tensorflow as tf + + as_tensor = tf.constant + is_tensor = tf.is_tensor + elif tensor_type == TensorType.PYTORCH: + if not is_torch_available(): + raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.") + import torch + + as_tensor = torch.tensor + is_tensor = torch.is_tensor + elif tensor_type == TensorType.JAX: + if not is_flax_available(): + raise ImportError("Unable to convert output to JAX tensors format, JAX is not installed.") + import jax.numpy as jnp # noqa: F811 + + as_tensor = jnp.array + is_tensor = _is_jax + else: + as_tensor = np.asarray + is_tensor = _is_numpy + + # Do the tensor conversion in batch + + try: + if prepend_batch_axis: + inputs = [inputs] + + if not is_tensor(inputs): + inputs = as_tensor(inputs) + except: # noqa E722 + raise ValueError( + "Unable to create tensor, you should probably activate truncation and/or padding " + "with 'padding=True' 'truncation=True' to have batched tensors with the same length." + ) + + return inputs + + def __call__(self, artist, genres, lyrics="", return_tensors="pt") -> BatchEncoding: + """Convert the raw string to a list of token ids + + Args: + artist (`str`): + Name of the artist. + genres (`str`): + List of genres that will be mixed to condition the audio + lyrics (`str`, *optional*, defaults to `""`): + Lyrics used to condition the generation + """ + input_ids = [0, 0, 0] + artist = [artist] * len(self.version) + genres = [genres] * len(self.version) + + artists_tokens, genres_tokens, lyrics_tokens = self.tokenize(artist, genres, lyrics) + artists_id, genres_ids, full_tokens = self._convert_token_to_id(artists_tokens, genres_tokens, lyrics_tokens) + + attention_masks = [-INFINITY] * len(full_tokens[-1]) + input_ids = [ + self.convert_to_tensors( + [input_ids + [artists_id[i]] + genres_ids[i] + full_tokens[i]], tensor_type=return_tensors + ) + for i in range(len(self.version)) + ] + return BatchEncoding({"input_ids": input_ids, "attention_masks": attention_masks}) + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + """ + Saves the tokenizer's vocabulary dictionary to the provided save_directory. + + Args: + save_directory (`str`): + A path to the directory where to saved. It will be created if it doesn't exist. + + filename_prefix (`Optional[str]`, *optional*): + A prefix to add to the names of the files saved by the tokenizer. + + """ + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + + artists_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["artists_file"] + ) + with open(artists_file, "w", encoding="utf-8") as f: + f.write(json.dumps(self.artists_encoder, ensure_ascii=False)) + + genres_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["genres_file"] + ) + with open(genres_file, "w", encoding="utf-8") as f: + f.write(json.dumps(self.genres_encoder, ensure_ascii=False)) + + lyrics_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["lyrics_file"] + ) + with open(lyrics_file, "w", encoding="utf-8") as f: + f.write(json.dumps(self.lyrics_encoder, ensure_ascii=False)) + + return (artists_file, genres_file, lyrics_file) + + def _convert_id_to_token(self, artists_index, genres_index, lyric_index): + """ + Converts an index (integer) in a token (str) using the vocab. + + Args: + artists_index (`int`): + Index of the artist in its corresponding dictionary. + genres_index (`Union[List[int], int]`): + Index of the genre in its corresponding dictionary. + lyric_index (`List[int]`): + List of character indices, which each correspond to a character. + """ + artist = self.artists_decoder.get(artists_index) + genres = [self.genres_decoder.get(genre) for genre in genres_index] + lyrics = [self.lyrics_decoder.get(character) for character in lyric_index] + return artist, genres, lyrics + + +__all__ = ["JukeboxTokenizer"] diff --git a/docs/transformers/build/lib/transformers/models/deprecated/mctct/__init__.py b/docs/transformers/build/lib/transformers/models/deprecated/mctct/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..53ec5ed37c13614266d04cde838ff7360946a451 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/mctct/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ....utils import _LazyModule +from ....utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_mctct import * + from .feature_extraction_mctct import * + from .modeling_mctct import * + from .processing_mctct import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/build/lib/transformers/models/deprecated/mctct/configuration_mctct.py b/docs/transformers/build/lib/transformers/models/deprecated/mctct/configuration_mctct.py new file mode 100644 index 0000000000000000000000000000000000000000..9cba190a0f460e3fed5a3ebbc773e9ab31283c1a --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/mctct/configuration_mctct.py @@ -0,0 +1,184 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""M-CTC-T model configuration""" + +from ....configuration_utils import PretrainedConfig +from ....utils import logging + + +logger = logging.get_logger(__name__) + + +class MCTCTConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`MCTCTModel`]. It is used to instantiate an + M-CTC-T model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the M-CTC-T + [speechbrain/m-ctc-t-large](https://huggingface.co/speechbrain/m-ctc-t-large) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 8065): + Vocabulary size of the M-CTC-T model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`MCTCTModel`]. + hidden_size (`int`, *optional*, defaults to 1536): + Dimension of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 36): + Number of hidden layers in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 6144): + Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 4): + Number of attention heads for each attention layer in the Transformer encoder. + attention_head_dim (`int`, *optional*, defaults to 384): + Dimensions of each attention head for each attention layer in the Transformer encoder. + max_position_embeddings (`int`, *optional*, defaults to 920): + The maximum sequence length that this model might ever be used with (after log-mel spectrogram extraction). + layer_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the layer normalization layers. + layerdrop (`float`, *optional*, defaults to 0.3): + The probability of dropping an encoder layer during training. The default 0.3 value is used in the original + implementation. + hidden_act (`str` or `function`, *optional*, defaults to `"relu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + hidden_dropout_prob (`float`, *optional*, defaults to 0.3): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.3): + The dropout ratio for the attention probabilities. + pad_token_id (`int`, *optional*, defaults to 1): + The tokenizer index of the pad token. + bos_token_id (`int`, *optional*, defaults to 0): + The tokenizer index of the bos token. + eos_token_id (`int`, *optional*, defaults to 2): + The tokenizer index of the eos token. + conv_glu_dim (`int`, *optional*, defaults to 1): + The dimension of the output of the `Conv1dSubsampler` layer in which GLU is applied on. Though the original + Flashlight code uses the value of 2, here it's adapted to 1 due to transposition differences. + conv_dropout (`int`, *optional*, defaults to 0.3): + The probability of randomly dropping the `Conv1dSubsampler` layer during training. + num_conv_layers (`int`, *optional*, defaults to 1): + Number of convolution layers before applying transformer encoder layers. + conv_kernel (`Sequence[int]`, *optional*, defaults to `(7,)`): + The kernel size of the 1D convolution applied before transformer layers. `len(conv_kernel)` must be equal + to `num_conv_layers`. + conv_stride (`Sequence[int]`, *optional*, defaults to `(3,)`): + The stride length of the 1D convolution applied before transformer layers. `len(conv_stride)` must be equal + to `num_conv_layers`. + input_feat_per_channel (`int`, *optional*, defaults to 80): + Feature dimensions of the channels of the input to the Conv1D layer. + input_channels (`int`, *optional*, defaults to 1): + Number of input channels of the input to the Conv1D layer. + conv_channels (`List[int]`, *optional*): + Channel sizes of intermediate Conv1D layers. + ctc_loss_reduction (`str`, *optional*, defaults to `"sum"`): + Specifies the reduction to apply to the output of `torch.nn.CTCLoss`. Only relevant when training an + instance of [`MCTCTForCTC`]. + ctc_zero_infinity (`bool`, *optional*, defaults to `False`): + Whether to zero infinite losses and the associated gradients of `torch.nn.CTCLoss`. Infinite losses mainly + occur when the inputs are too short to be aligned to the targets. Only relevant when training an instance + of [`MCTCTForCTC`]. + + Example: + + ```python + >>> from transformers import MCTCTConfig, MCTCTModel + + >>> # Initializing a M-CTC-T mctct-large style configuration + >>> configuration = MCTCTConfig() + + >>> # Initializing a model (with random weights) from the mctct-large style configuration + >>> model = MCTCTModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "mctct" + + def __init__( + self, + vocab_size=8065, + hidden_size=1536, + num_hidden_layers=36, + intermediate_size=6144, + num_attention_heads=4, + attention_head_dim=384, + max_position_embeddings=920, + layer_norm_eps=1e-5, + layerdrop=0.3, + hidden_act="relu", + initializer_range=0.02, + hidden_dropout_prob=0.3, + attention_probs_dropout_prob=0.3, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + conv_glu_dim=1, + conv_dropout=0.3, + num_conv_layers=1, + conv_kernel=(7,), + conv_stride=(3,), + input_feat_per_channel=80, + input_channels=1, + conv_channels=None, + ctc_loss_reduction="sum", + ctc_zero_infinity=False, + **kwargs, + ): + super().__init__(**kwargs, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id) + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.intermediate_size = intermediate_size + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + self.max_position_embeddings = max_position_embeddings + self.layer_norm_eps = layer_norm_eps + self.layerdrop = layerdrop + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.conv_glu_dim = conv_glu_dim + self.conv_dropout = conv_dropout + self.num_conv_layers = num_conv_layers + self.input_feat_per_channel = input_feat_per_channel + self.input_channels = input_channels + self.conv_channels = conv_channels + self.ctc_loss_reduction = ctc_loss_reduction + self.ctc_zero_infinity = ctc_zero_infinity + + # prevents config testing fail with exporting to json + self.conv_kernel = list(conv_kernel) + self.conv_stride = list(conv_stride) + + if len(self.conv_kernel) != self.num_conv_layers: + raise ValueError( + "Configuration for convolutional module is incorrect. " + "It is required that `len(config.conv_kernel)` == `config.num_conv_layers` " + f"but is `len(config.conv_kernel) = {len(self.conv_kernel)}`, " + f"`config.num_conv_layers = {self.num_conv_layers}`." + ) + + +__all__ = ["MCTCTConfig"] diff --git a/docs/transformers/build/lib/transformers/models/deprecated/mctct/feature_extraction_mctct.py b/docs/transformers/build/lib/transformers/models/deprecated/mctct/feature_extraction_mctct.py new file mode 100644 index 0000000000000000000000000000000000000000..f210031f3097e084c288fa58773d4271301e7c29 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/mctct/feature_extraction_mctct.py @@ -0,0 +1,291 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Feature extractor class for M-CTC-T +""" + +from typing import List, Optional, Union + +import numpy as np + +from ....audio_utils import mel_filter_bank, optimal_fft_length, spectrogram, window_function +from ....feature_extraction_sequence_utils import SequenceFeatureExtractor +from ....feature_extraction_utils import BatchFeature +from ....file_utils import PaddingStrategy, TensorType +from ....utils import logging + + +logger = logging.get_logger(__name__) + + +class MCTCTFeatureExtractor(SequenceFeatureExtractor): + r""" + Constructs a M-CTC-T feature extractor. + + This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains + most of the main methods. Users should refer to this superclass for more information regarding those methods. This + code has been adapted from Flashlight's C++ code. For more information about the implementation, one can refer to + this [notebook](https://colab.research.google.com/drive/1GLtINkkhzms-IsdcGy_-tVCkv0qNF-Gt#scrollTo=pMCRGMmUC_an) + that takes the user step-by-step in the implementation. + + Args: + feature_size (`int`, defaults to 80): + The feature dimension of the extracted features. This is the number of mel_frequency + sampling_rate (`int`, defaults to 16000): + The sampling rate at which the audio files should be digitalized expressed in hertz (Hz). + padding_value (`float`, defaults to 0.0): + The value that is used to fill the padding values. + hop_length (`int`, defaults to 10): + Number of audio samples between windows. Otherwise referred to as "shift" in many papers. + win_length (`int`, defaults to 25): + Number of ms per window + win_function (`str`, defaults to `"hamming_window"`): + Name for the window function used for windowing, must be accessible via `torch.{win_function}` + frame_signal_scale (`float`, defaults to 32768.0): + Constant multiplied in creating the frames before applying DFT. + preemphasis_coeff (`float`, defaults to 0.97): + Constant multiplied in applying Pre-emphasis before DFT. + mel_floor (`float` defaults to 1.0): + Minimum value of mel frequency banks. + normalize_means (`bool`, *optional*, defaults to `True`): + Whether or not to zero-mean normalize the extracted features. + normalize_vars (`bool`, *optional*, defaults to `True`): + Whether or not to unit-variance normalize the extracted features. + """ + + model_input_names = ["input_features", "attention_mask"] + + def __init__( + self, + feature_size=80, + sampling_rate=16000, + padding_value=0.0, + hop_length=10, + win_length=25, + win_function="hamming_window", + frame_signal_scale=32768.0, + preemphasis_coeff=0.97, + mel_floor=1.0, + normalize_means=True, + normalize_vars=True, + return_attention_mask=False, + **kwargs, + ): + super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs) + + self.feature_size = feature_size + self.sampling_rate = sampling_rate + self.padding_value = padding_value + self.hop_length = hop_length + self.win_length = win_length + self.frame_signal_scale = frame_signal_scale + self.preemphasis_coeff = preemphasis_coeff + self.mel_floor = mel_floor + self.normalize_means = normalize_means + self.normalize_vars = normalize_vars + self.win_function = win_function + self.return_attention_mask = return_attention_mask + + self.sample_size = win_length * sampling_rate // 1000 + self.sample_stride = hop_length * sampling_rate // 1000 + + self.n_fft = optimal_fft_length(self.sample_size) + self.n_freqs = (self.n_fft // 2) + 1 + + def _extract_mfsc_features(self, one_waveform: np.array) -> np.ndarray: + """ + Extracts MFSC Features for one waveform vector (unbatched). Adapted from Flashlight's C++ MFSC code. + """ + if self.win_function == "hamming_window": + window = window_function(window_length=self.sample_size, name=self.win_function, periodic=False) + else: + window = window_function(window_length=self.sample_size, name=self.win_function) + + fbanks = mel_filter_bank( + num_frequency_bins=self.n_freqs, + num_mel_filters=self.feature_size, + min_frequency=0.0, + max_frequency=self.sampling_rate / 2.0, + sampling_rate=self.sampling_rate, + ) + + msfc_features = spectrogram( + one_waveform * self.frame_signal_scale, + window=window, + frame_length=self.sample_size, + hop_length=self.sample_stride, + fft_length=self.n_fft, + center=False, + preemphasis=self.preemphasis_coeff, + mel_filters=fbanks, + mel_floor=self.mel_floor, + log_mel="log", + ) + return msfc_features.T + + def _normalize_one(self, x, input_length, padding_value): + # make sure we normalize float32 arrays + if self.normalize_means: + mean = x[:input_length].mean(axis=0) + x = np.subtract(x, mean) + if self.normalize_vars: + std = x[:input_length].std(axis=0) + x = np.divide(x, std) + + if input_length < x.shape[0]: + x[input_length:] = padding_value + + # make sure array is in float32 + x = x.astype(np.float32) + + return x + + def normalize( + self, input_features: List[np.ndarray], attention_mask: Optional[np.ndarray] = None + ) -> List[np.ndarray]: + lengths = attention_mask.sum(-1) if attention_mask is not None else [x.shape[0] for x in input_features] + return [self._normalize_one(x, n, self.padding_value) for x, n in zip(input_features, lengths)] + + def __call__( + self, + raw_speech: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]], + padding: Union[bool, str, PaddingStrategy] = False, + max_length: Optional[int] = None, + truncation: bool = False, + pad_to_multiple_of: Optional[int] = None, + return_attention_mask: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + sampling_rate: Optional[int] = None, + **kwargs, + ) -> BatchFeature: + """ + Main method to featurize and prepare for the model one or several sequence(s). sequences. It returns the + log-mel spectrogram of the input audio, as implemented in the original Flashlight MFSC feature extraction code. + + Args: + raw_speech (`torch.Tensor`, `np.ndarray`, `List[float]`, `List[torch.Tensor]`, `List[np.ndarray]`, `List[List[float]]`): + The sequence or batch of sequences to be padded. Each sequence can be a tensor, a numpy array, a list + of float values, a list of tensors, a list of numpy arrays or a list of list of float values. Must be + mono channel audio, not stereo, i.e. single float per timestep. + padding (`bool`, `str` or [`~file_utils.PaddingStrategy`], *optional*, defaults to `False`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding + index) among: + + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + max_length (`int`, *optional*): + Maximum length of the returned list and optionally padding length (see above). + truncation (`bool`): + Activates truncation to cut input sequences longer than *max_length* to *max_length*. + pad_to_multiple_of (`int`, *optional*): + If set will pad the sequence to a multiple of the provided value. + + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability + `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. + return_attention_mask (`bool`, *optional*): + Whether to return the attention mask. If left to the default, will return the attention mask according + to the specific feature_extractor's default. + + [What are attention masks?](../glossary#attention-mask) + + return_tensors (`str` or [`~file_utils.TensorType`], *optional*): + If set, will return tensors instead of list of python integers. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return Numpy `np.ndarray` objects. + sampling_rate (`int`, *optional*): + The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass + `sampling_rate` at the forward call to prevent silent errors. + padding_value (`float`, defaults to 0.0): + """ + + if sampling_rate is not None: + if sampling_rate != self.sampling_rate: + raise ValueError( + f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of" + f" {self.sampling_rate}. Please make sure that the provided `raw_speech` input was sampled with" + f" {self.sampling_rate} and not {sampling_rate}." + ) + else: + logger.warning( + "It is strongly recommended to pass the ``sampling_rate`` argument to this function. " + "Failing to do so can result in silent errors that might be hard to debug." + ) + + is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1 + if is_batched_numpy and len(raw_speech.shape) > 2: + raise ValueError(f"Only mono-channel audio is supported for input to {self}") + is_batched = is_batched_numpy or ( + isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list))) + ) + + if is_batched: + raw_speech = [np.asarray(speech, dtype=np.float32) for speech in raw_speech] + elif not is_batched and not isinstance(raw_speech, np.ndarray): + raw_speech = np.asarray(raw_speech, dtype=np.float32) + elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(np.float64): + raw_speech = raw_speech.astype(np.float32) + + # always return batch + if not is_batched: + raw_speech = [raw_speech] + + # extract fbank features + features = [self._extract_mfsc_features(one_waveform) for one_waveform in raw_speech] + + # convert into correct format for padding + encoded_inputs = BatchFeature({"input_features": features}) + + padded_inputs = self.pad( + encoded_inputs, + padding=padding, + max_length=max_length, + truncation=truncation, + pad_to_multiple_of=pad_to_multiple_of, + return_attention_mask=True, + **kwargs, + ) + # make sure list is in array format + input_features = padded_inputs.get("input_features") + if isinstance(input_features[0], list): + padded_inputs["input_features"] = [np.asarray(feature, dtype=np.float32) for feature in input_features] + + attention_mask = padded_inputs.get("attention_mask") + if attention_mask is not None: + padded_inputs["attention_mask"] = [np.asarray(array, dtype=np.int32) for array in attention_mask] + + if self.normalize_means or self.normalize_vars: + attention_mask = ( + np.array(attention_mask, dtype=np.int32) + if self._get_padding_strategies(padding, max_length=max_length) is not PaddingStrategy.DO_NOT_PAD + and padding + else None + ) + padded_inputs["input_features"] = self.normalize( + padded_inputs["input_features"], attention_mask=attention_mask + ) + + if return_tensors is not None: + padded_inputs = padded_inputs.convert_to_tensors(return_tensors) + + return padded_inputs + + +__all__ = ["MCTCTFeatureExtractor"] diff --git a/docs/transformers/build/lib/transformers/models/deprecated/mctct/modeling_mctct.py b/docs/transformers/build/lib/transformers/models/deprecated/mctct/modeling_mctct.py new file mode 100644 index 0000000000000000000000000000000000000000..2dd074b28c53d6d501ec680013afc517118bc227 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/mctct/modeling_mctct.py @@ -0,0 +1,791 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch M-CTC-T model.""" + +import math +from typing import Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn + +from ....activations import ACT2FN +from ....file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward +from ....integrations.deepspeed import is_deepspeed_zero3_enabled +from ....integrations.fsdp import is_fsdp_managed_module +from ....modeling_attn_mask_utils import _prepare_4d_attention_mask +from ....modeling_outputs import BaseModelOutput, CausalLMOutput +from ....modeling_utils import ( + PreTrainedModel, + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, +) +from ....utils import logging +from .configuration_mctct import MCTCTConfig + + +logger = logging.get_logger(__name__) + +_HIDDEN_STATES_START_POSITION = 1 + +_CONFIG_FOR_DOC = "MCTCTConfig" + +# Base docstring +_CHECKPOINT_FOR_DOC = "speechbrain/m-ctc-t-large" +_EXPECTED_OUTPUT_SHAPE = [1, 195, 1536] + +# CTC docstring +_CTC_EXPECTED_OUTPUT = '"Mr. Quilter is the apostle of the middle classes, and we\'re glad to welcome his gospel."' +_CTC_EXPECTED_LOSS = 1885.65 + + +class MCTCTConv1dSubsampler(nn.Module): + """ + Convolutional subsampler: a stack of 1D convolution (along temporal dimension) followed by non-linear activation + via gated linear units (https://arxiv.org/abs/1911.08460) + """ + + def __init__(self, config): + super().__init__() + self.config = config + self.glu_dim = config.conv_glu_dim + + self.dropout = nn.Dropout(config.conv_dropout) + + self.num_layers = config.num_conv_layers + self.in_channels = config.input_feat_per_channel * config.input_channels + + if self.num_layers > 1: + if config.conv_channels is None: + raise ValueError( + "Need to specify `conv_channels` configuration in `MCTCTConfig` to use multiple convolution" + " layers." + ) + + self.mid_channels = config.conv_channels + else: + self.mid_channels = None + + self.out_channels = config.hidden_size * 2 # considering GLU halving + self.kernel_size = config.conv_kernel + self.stride = config.conv_stride + + # NOTE: MCTCT by construction only uses one convolution kernel. I've made this flexible to allow for + # multiple layers of convolutions, but not sure if this model definition should just restrict it + # to one layer. This becomes especially relevant when considering the padding like line 1 of forward(). + self.conv_layers = nn.ModuleList( + nn.Conv1d( + self.in_channels if i == 0 else self.mid_channels[i], + self.mid_channels[i] if i < self.num_layers - 1 else self.out_channels, + kernel_size=k, + stride=self.stride[i], + padding="valid", + ) + for i, k in enumerate(self.kernel_size) + ) + + def forward(self, input_features): + # NOTE: in reference to the NOTE in __init__, right now it just calculates padding as if + # there will be just one conv layer. + padding = sum([size // 2 for size in self.kernel_size]) # (7, 7) -> (3, 3) + + input_features = torch.nn.functional.pad(input_features, (0, 0, padding, padding), "constant", 0) + hidden_states = input_features.transpose(1, 2).contiguous() # -> Batch x Frame x Time + for conv in self.conv_layers: + hidden_states = conv(hidden_states) + hidden_states = nn.functional.glu(hidden_states, dim=self.glu_dim) + hidden_states = self.dropout(hidden_states) + + hidden_states = hidden_states.transpose(1, 2).contiguous() # -> Batch x Time x Frame + return hidden_states + + +class MCTCTEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + # self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.LayerNorm = MCTCTLayerNorm() + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + self.register_buffer( + "token_type_ids", + torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device), + persistent=False, + ) + + def forward( + self, input_features=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 + ): + input_shape = input_features.size() if input_features is not None else inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] + + # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs + # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves + # issue #5664 + if token_type_ids is None: + if hasattr(self, "token_type_ids"): + buffered_token_type_ids = self.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_features) + + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class MCTCTSelfAttention(nn.Module): + def __init__(self, config): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = config.attention_head_dim + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=False) + self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=False) + self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=False) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) + + self.is_decoder = config.is_decoder + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def reshape_fortran(self, x, shape): + if len(x.shape) > 0: + x = x.permute(*reversed(range(len(x.shape)))) + return x.reshape(*reversed(shape)).permute(*reversed(range(len(shape)))) + + def relative_position_embedding_rotate(self, scores): + # NOTE: should re-evaluate whether this re-implementation was truly necessary + # or the reason why my complete re-haul worked was due to some other part + # of the code. Adding this and the reshape fortrain code seems very undesirable. + scores = scores.permute(0, 2, 3, 1) # e.g. [10, 1839, 14, 4] + + batch, hidden_state, seq_len, heads = scores.shape + + # e.g. [10, 1853, 14, 4] + scores = torch.cat((scores, torch.zeros((batch, seq_len, seq_len, heads), device=scores.device)), dim=1) + + # e.g. [10, 25942, 1, 4] + scores = self.reshape_fortran(scores, [batch, (hidden_state + seq_len) * seq_len, 1, heads]) + + # e.g. [10, 25928, 1, 4] + scores = scores[:, : (seq_len + hidden_state - 1) * seq_len] + + # e.g. [10, 1852, 14, 4] + scores = self.reshape_fortran(scores, [batch, hidden_state + seq_len - 1, seq_len, heads]) + + halfpoint = hidden_state // 2 + scores = scores[:, halfpoint : halfpoint + seq_len].transpose(1, 2) # e.g. [10, 14, 14, 4] + + return scores.permute(0, 3, 1, 2) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + output_attentions=False, + ): + mixed_query_layer = self.query(hidden_states) + mixed_query_layer = mixed_query_layer / math.sqrt(self.attention_head_size) + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + # relative key position embeddings + positional_embedding = self.distance_embedding.weight + relative_position_scores = torch.einsum("lh, bche -> bcle", positional_embedding, query_layer.transpose(2, 3)) + + relative_position_scores = self.relative_position_embedding_rotate(relative_position_scores) + attention_scores = attention_scores + relative_position_scores + + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in MCTCTModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).flatten(start_dim=-2) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +class MCTCTLayerNorm(nn.Module): + def __init__(self): + super().__init__() + self.singleton_weight = nn.Parameter(torch.ones(1)) + self.singleton_bias = nn.Parameter(torch.zeros(1)) + + def forward(self, hidden_states): + return (hidden_states * self.singleton_weight) + self.singleton_bias + + +class MCTCTSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.dense = nn.Linear(config.hidden_size, config.hidden_size, bias=False) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class MCTCTAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.self = MCTCTSelfAttention(config) + self.output = MCTCTSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + output_attentions=False, + ): + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + + return outputs + + +class MCTCTIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class MCTCTOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class MCTCTLayer(nn.Module): + def __init__(self, config: MCTCTConfig): + super().__init__() + + self.seq_len_dim = 1 + self.chunk_size_feed_forward = config.chunk_size_feed_forward + + self.intermediate = MCTCTIntermediate(config) + self.attention = MCTCTAttention(config) + self.is_decoder = config.is_decoder + self.output = MCTCTOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + output_attentions=False, + ): + self_attention_outputs = self.attention( + hidden_states, attention_mask, head_mask, output_attentions=output_attentions + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + ) + + outputs = (layer_output,) + outputs + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class MCTCTPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = MCTCTConfig + base_model_prefix = "mctct" + main_input_name = "input_features" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights""" + std = self.config.initializer_range + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, MCTCTLayerNorm): + module.singleton_weight.data.fill_(1.0) + module.singleton_bias.data.zero_() + if isinstance(module, (nn.Linear, nn.Conv1d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + + def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): + """ + Computes the output length of the convolutional layers + """ + dilation = 1 + for _, kernel_sz, stride in zip( + range(self.config.num_conv_layers), self.config.conv_kernel, self.config.conv_stride + ): + padding = kernel_sz // 2 + input_lengths = input_lengths + 2 * padding - dilation * (kernel_sz - 1) - 1 + input_lengths = torch.div(input_lengths, stride, rounding_mode="trunc") + 1 + + return input_lengths + + def _get_feature_vector_attention_mask(self, feature_vector_length, attention_mask): + # generate creates 3D attention mask, because of the shape of input_features + # convert it to 2D if thats the case + if len(attention_mask.shape) > 2: + attention_mask = attention_mask[:, :, -1] + + # subsampled_lengths = attention_mask.sum(-1) + subsampled_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)) + bsz = attention_mask.size()[0] + attention_mask = torch.zeros( + (bsz, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device + ) + + # these two operations makes sure that all values + # before the output lengths indices are attended to + attention_mask[(torch.arange(bsz, device=attention_mask.device), subsampled_lengths - 1)] = 1 + attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).long() + return attention_mask + + +MCTCT_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use + it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`MCTCTConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +MCTCT_INPUTS_DOCSTRING = r""" + Args: + input_features (`torch.LongTensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`Wav2Vec2CTCTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. +""" + + +class MCTCTEncoder(MCTCTPreTrainedModel): + def __init__(self, config: MCTCTConfig): + super().__init__(config) + self.hidden_dropout_prob = config.hidden_dropout_prob + + self.layer_norm = MCTCTLayerNorm() + self.conv = MCTCTConv1dSubsampler(config) + self.layers = nn.ModuleList([MCTCTLayer(config) for _ in range(config.num_hidden_layers)]) + + self.gradient_checkpointing = False + + def forward( + self, + input_features: torch.Tensor, + attention_mask: torch.Tensor, + head_mask: torch.Tensor, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ) -> Union[Tuple, BaseModelOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + input_features = self.layer_norm(input_features) + + inputs_embeds = self.conv(input_features) + + # subsample attention mask if necessary + if attention_mask is not None: + attention_mask = self._get_feature_vector_attention_mask(inputs_embeds.shape[1], attention_mask) + + hidden_states = nn.functional.dropout(inputs_embeds, p=self.hidden_dropout_prob, training=self.training) + + # expand attention_mask + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + if head_mask.size()[0] != len(self.layers): + raise ValueError( + f"The head_mask should be specified for {len(self.layers)} layers, " + f"but it is for {head_mask.size()[0]}." + ) + + synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = torch.rand([]) + + skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False + if not skip_the_layer or synced_gpus: + # under fsdp or deepspeed zero3 all gpus must run in sync + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + attention_mask, + (head_mask[idx] if head_mask is not None else None), + output_attentions, + ) + else: + layer_outputs = encoder_layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if skip_the_layer: + layer_outputs = (None, None) + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +@add_start_docstrings( + "The bare M-CTC-T Model transformer outputting raw hidden-states without any specific head on top.", + MCTCT_START_DOCSTRING, +) +class MCTCTModel(MCTCTPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.config = config + + self.encoder = MCTCTEncoder(config) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(MCTCT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutput, + config_class=_CONFIG_FOR_DOC, + modality="audio", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward( + self, + input_features: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_features is None: + raise ValueError("You have to specify input_features.") + + encoder_outputs = self.encoder( + input_features, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + + if not return_dict: + return (sequence_output,) + encoder_outputs[1:] + + return BaseModelOutput( + last_hidden_state=sequence_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + """MCTCT Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""", + MCTCT_START_DOCSTRING, +) +class MCTCTForCTC(MCTCTPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.mctct = MCTCTModel(config) + + if config.vocab_size is None: + raise ValueError( + f"You are trying to instantiate {self.__class__} with a configuration that " + "does not define the vocabulary size of the language model head. Please " + "instantiate the model as follows: `MCTCTForCTC.from_pretrained(..., vocab_size=vocab_size)`. " + "or define `vocab_size` of your model's configuration." + ) + output_hidden_size = config.hidden_size + + self.ctc_head = nn.Linear(output_hidden_size, config.vocab_size) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(MCTCT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=CausalLMOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_CTC_EXPECTED_OUTPUT, + expected_loss=_CTC_EXPECTED_LOSS, + ) + def forward( + self, + input_features: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*): + Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to + the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. + All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., + config.vocab_size - 1]`. + """ + if labels is not None and labels.max() >= self.config.vocab_size: + raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}") + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + outputs = self.mctct( + input_features, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + + logits = self.ctc_head(hidden_states) + + loss = None + if labels is not None: + # retrieve loss input_lengths from attention_mask + attention_mask = ( + attention_mask + if attention_mask is not None + else torch.ones(input_features.shape[:-1], dtype=torch.long) + ) + input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long) + # assuming that padded tokens are filled with -100 + # when not being attended to + labels_mask = labels >= 0 + target_lengths = labels_mask.sum(-1) + flattened_targets = labels.masked_select(labels_mask) + + # ctc_loss doesn't support fp16 + log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1) + + with torch.backends.cudnn.flags(enabled=False): + loss = nn.functional.ctc_loss( + log_probs, + flattened_targets, + input_lengths, + target_lengths, + blank=self.config.pad_token_id, + reduction=self.config.ctc_loss_reduction, + zero_infinity=self.config.ctc_zero_infinity, + ) + + if not return_dict: + output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutput( + loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions + ) + + +__all__ = ["MCTCTForCTC", "MCTCTModel", "MCTCTPreTrainedModel"] diff --git a/docs/transformers/build/lib/transformers/models/deprecated/mctct/processing_mctct.py b/docs/transformers/build/lib/transformers/models/deprecated/mctct/processing_mctct.py new file mode 100644 index 0000000000000000000000000000000000000000..f953c5895a0d2e1cf3a343fbdefc99861da2dadd --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/mctct/processing_mctct.py @@ -0,0 +1,146 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Speech processor class for M-CTC-T +""" + +import warnings +from contextlib import contextmanager + +from ....processing_utils import ProcessorMixin + + +class MCTCTProcessor(ProcessorMixin): + r""" + Constructs a MCTCT processor which wraps a MCTCT feature extractor and a MCTCT tokenizer into a single processor. + + [`MCTCTProcessor`] offers all the functionalities of [`MCTCTFeatureExtractor`] and [`AutoTokenizer`]. See the + [`~MCTCTProcessor.__call__`] and [`~MCTCTProcessor.decode`] for more information. + + Args: + feature_extractor (`MCTCTFeatureExtractor`): + An instance of [`MCTCTFeatureExtractor`]. The feature extractor is a required input. + tokenizer (`AutoTokenizer`): + An instance of [`AutoTokenizer`]. The tokenizer is a required input. + """ + + feature_extractor_class = "MCTCTFeatureExtractor" + tokenizer_class = "AutoTokenizer" + + def __init__(self, feature_extractor, tokenizer): + super().__init__(feature_extractor, tokenizer) + self.current_processor = self.feature_extractor + self._in_target_context_manager = False + + def __call__(self, *args, **kwargs): + """ + When used in normal mode, this method forwards all its arguments to MCTCTFeatureExtractor's + [`~MCTCTFeatureExtractor.__call__`] and returns its output. If used in the context + [`~MCTCTProcessor.as_target_processor`] this method forwards all its arguments to AutoTokenizer's + [`~AutoTokenizer.__call__`]. Please refer to the docstring of the above two methods for more information. + """ + # For backward compatibility + if self._in_target_context_manager: + return self.current_processor(*args, **kwargs) + + if "raw_speech" in kwargs: + warnings.warn("Using `raw_speech` as a keyword argument is deprecated. Use `audio` instead.") + audio = kwargs.pop("raw_speech") + else: + audio = kwargs.pop("audio", None) + sampling_rate = kwargs.pop("sampling_rate", None) + text = kwargs.pop("text", None) + if len(args) > 0: + audio = args[0] + args = args[1:] + + if audio is None and text is None: + raise ValueError("You need to specify either an `audio` or `text` input to process.") + + if audio is not None: + inputs = self.feature_extractor(audio, *args, sampling_rate=sampling_rate, **kwargs) + if text is not None: + encodings = self.tokenizer(text, **kwargs) + + if text is None: + return inputs + elif audio is None: + return encodings + else: + inputs["labels"] = encodings["input_ids"] + return inputs + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to AutoTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please refer + to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + def pad(self, *args, **kwargs): + """ + When used in normal mode, this method forwards all its arguments to MCTCTFeatureExtractor's + [`~MCTCTFeatureExtractor.pad`] and returns its output. If used in the context + [`~MCTCTProcessor.as_target_processor`] this method forwards all its arguments to PreTrainedTokenizer's + [`~PreTrainedTokenizer.pad`]. Please refer to the docstring of the above two methods for more information. + """ + # For backward compatibility + if self._in_target_context_manager: + return self.current_processor.pad(*args, **kwargs) + + input_features = kwargs.pop("input_features", None) + labels = kwargs.pop("labels", None) + if len(args) > 0: + input_features = args[0] + args = args[1:] + + if input_features is not None: + input_features = self.feature_extractor.pad(input_features, *args, **kwargs) + if labels is not None: + labels = self.tokenizer.pad(labels, **kwargs) + + if labels is None: + return input_features + elif input_features is None: + return labels + else: + input_features["labels"] = labels["input_ids"] + return input_features + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to AutoTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to the + docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + @contextmanager + def as_target_processor(self): + """ + Temporarily sets the tokenizer for processing the input. Useful for encoding the labels when fine-tuning MCTCT. + """ + warnings.warn( + "`as_target_processor` is deprecated and will be removed in v5 of Transformers. You can process your " + "labels by using the argument `text` of the regular `__call__` method (either in the same call as " + "your audio inputs, or in a separate call." + ) + self._in_target_context_manager = True + self.current_processor = self.tokenizer + yield + self.current_processor = self.feature_extractor + self._in_target_context_manager = False + + +__all__ = ["MCTCTProcessor"] diff --git a/docs/transformers/build/lib/transformers/models/deprecated/mega/__init__.py b/docs/transformers/build/lib/transformers/models/deprecated/mega/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cff2c19505f9306c28edb5bdabb66f668363f5fa --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/mega/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ....utils import _LazyModule +from ....utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_mega import * + from .modeling_mega import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/build/lib/transformers/models/deprecated/mega/configuration_mega.py b/docs/transformers/build/lib/transformers/models/deprecated/mega/configuration_mega.py new file mode 100644 index 0000000000000000000000000000000000000000..3b9d53d52079f0dde5237b50ce36730685a3767e --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/mega/configuration_mega.py @@ -0,0 +1,243 @@ +# coding=utf-8 +# Copyright 2023 The Mega Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MEGA configuration""" + +from collections import OrderedDict +from typing import Mapping + +from ....configuration_utils import PretrainedConfig +from ....onnx import OnnxConfig +from ....utils import logging + + +logger = logging.get_logger(__name__) + + +class MegaConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`MegaModel`]. It is used to instantiate a Mega + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the Mega + [mnaylor/mega-base-wikitext](https://huggingface.co/mnaylor/mega-base-wikitext) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 30522): + Vocabulary size of the Mega model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`MegaModel`]. + hidden_size (`int`, *optional*, defaults to 128): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 4): + Number of hidden layers in the Mega encoder. + intermediate_size (`int`, *optional*, defaults to 256): + Dimensionality of the hidden size (self-attention value projection) within the Mega encoder + ema_projection_size (`int`, *optional*, defaults to 16): + Dimensionality of the MegaMultiDimensionDampedEma + bidirectional (`bool`, *optional*, defaults to `True`): + Whether the MegaMultiDimensionDampedEma used in Mega's self-attention should work bidirectionally (`True`) + or unidirectionally (`False`). Bidirectional EMA is incompatible with causal decoding, so this should be + False if you intend to use the model as a decoder. + shared_representation_size (`int`, *optional*, defaults to 64): + Dimensionality of the linear projection for shared representation of self-attention queries and keys + use_chunking (`bool`, *optional*, defaults to `False`): + Whether to chunk inputs for linear self-attention complexity (described as Mega-chunk in the paper) + chunk_size (`int`, *optional*, defaults to -1): + If `use_chunking` is set to `True`, determines the size of the chunks to apply to the input sequence. If + chunking is used, input sequences must be padded to a multiple of `chunk_size` + truncation (`int`, *optional*): + If specified, the sequence length for which to truncate MegaMultiDimensionDampedEma + normalize_before_mega (`bool`, *optional*, defaults to `True`): + Whether to normalize before (`True`) or after (`False`) passing through Mega encoder blocks + normalization_type (`str`, *optional*, defaults to `"scalenorm"`): + Type of normalization to use in Mega encoder blocks. Choose one of `"scalenorm"`, `"layernorm"`, + `"rmsnorm"`, `"batchnorm"`, or `"syncbatchnorm"` (GPU required for syncbatchnorm) + norm_affine (`bool`, *optional*, defaults to `True`): + If `True`, applies a parameterized affine transformation to inputs during normalization + activation (`str`, *optional*, defaults to `"silu"`): + Activation function to apply within Mega encoder blocks. Choose one of `"silu"`, `"relu"`, `"linear"`, + `"gelu"`, or `"gelu_accurate"` + attention_activation (`str`, *optional*, defaults to `"softmax"`): + Activation function to apply for single-headed self-attention (a la Transformer). Choose one of + `"softmax"`, `"laplace"`, or `"relu2"` + dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability for EMA self-attention + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + use_feature_dropout (`bool`, *optional*, defaults to `False`): + Whether to use feature-based (`True`) or standard dropout (`False`) + use_normalized_ffn (`bool`, *optional*, defaults to `True`): + Whether to use the normalized feed-forward sub-layer in Mega blocks (`True`) or pass Mega encoder output + as-is (`False`) + nffn_hidden_size (`int`, *optional*, defaults to 256): + If using the normalized feed-forward network (NFFN) layer within Mega (`use_normalized_ffn = True`), this + is the hidden size of the NFFN + normalize_before_ffn (`bool`, *optional*, defaults to `True`): + Whether to normalize before (`True`) or after (`False`) the feed-forward portion of NFFN + nffn_activation_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the NFFN component. + max_positions (`int`, *optional*, defaults to 2048): + The maximum sequence length to use for positional representations. For `"simple"` relative positional bias, + this is a hard limit on input length; `"rotary"` relative positional bias will extrapolate to longer + sequences + add_token_type_embeddings (`bool`, *optional*, defaults to `True`): + Whether to account for token types in embeddings. Left as optional to maintain compatibility with original + implementation while adding support for token types. + type_vocab_size (`int`, *optional*, defaults to 2): + The vocabulary size of the `token_type_ids` passed when calling [`MegaModel`]. Only used if + `add_token_type_embeddings = True` + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + ema_delta_alpha_range (`float`, *optional*, defaults to 0.2): + The standard deviation for initializing the delta (damping factor) and alpha (decay factor) parameters in + MegaMultiDimensionDampedEma. + ema_beta_range (`float`, *optional*, defaults to 0.02): + The standard deviation for initializing the beta parameter (expansion matrix) in + MegaMultiDimensionDampedEma. + ema_gamma_omega_range (`float`, *optional*, defaults to 1.0): + The standard deviation for initializing the gamma (projection matrix) and omega (residual weight) + parameters in MultiDimensionEMA. + relative_positional_bias (`str`, *optional*, defaults to `"rotary"`): + Type of relative positional encoding. Choose one of `"rotary"` or `"simple"`. If `"simple"` is selected, + `max_positions` is used as a limit on input size, while `"rotary"` extrapolates beyond `max_positions`. + is_decoder (`bool`, *optional*, defaults to `False`): + Whether the model is used as a decoder or not. If `False`, the model is used as an encoder. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + classifier_dropout (`float`, *optional*): + The dropout ratio for the classification head. + add_lm_hidden_dense_layer (`bool`, *optional*, defaults to `True`): + Whether to include a hidden layer for projection between encoder outputs and LM heads (`True`) or pass + hidden states directly to LM head (`False`). Remains optional for compatibility with original + implementation + + Examples: + + ```python + >>> from transformers import MegaConfig, MegaModel + + >>> # Initializing a Mega configuration + >>> configuration = MegaConfig() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = MegaModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "mega" + + def __init__( + self, + vocab_size=30522, + hidden_size=128, + num_hidden_layers=4, + intermediate_size=256, + ema_projection_size=16, + bidirectional=True, + shared_representation_size=64, + use_chunking=False, + chunk_size=-1, + truncation=None, + normalize_before_mega=True, + normalization_type="scalenorm", + norm_affine=True, + activation="silu", + attention_activation="softmax", + dropout_prob=0.1, + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + use_feature_dropout=False, + use_normalized_ffn=True, + nffn_hidden_size=256, + normalize_before_ffn=True, + nffn_activation_dropout_prob=0.1, + max_positions=2048, + add_token_type_embeddings=False, + type_vocab_size=2, + initializer_range=0.02, + ema_delta_alpha_range=0.2, + ema_beta_range=0.02, + ema_gamma_omega_range=1.0, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + relative_positional_bias="rotary", + classifier_dropout=None, + use_cache=True, + add_lm_hidden_dense_layer=True, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.activation = activation + self.attention_activation = attention_activation + self.intermediate_size = intermediate_size + self.ema_projection_size = ema_projection_size + self.bidirectional = bidirectional + self.shared_representation_size = shared_representation_size + self.use_chunking = use_chunking + self.chunk_size = chunk_size + self.truncation = truncation + self.normalize_before_mega = normalize_before_mega + self.normalization_type = normalization_type + self.norm_affine = norm_affine + self.dropout_prob = dropout_prob + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.use_feature_dropout = use_feature_dropout + self.use_normalized_ffn = use_normalized_ffn + self.nffn_hidden_size = nffn_hidden_size + self.normalize_before_ffn = normalize_before_ffn + self.nffn_activation_dropout_prob = nffn_activation_dropout_prob + self.max_positions = max_positions + self.add_token_type_embeddings = add_token_type_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.ema_delta_alpha_range = ema_delta_alpha_range + self.ema_beta_range = ema_beta_range + self.ema_gamma_omega_range = ema_gamma_omega_range + self.relative_positional_bias = relative_positional_bias + self.use_cache = use_cache + self.classifier_dropout = classifier_dropout + self.add_lm_hidden_dense_layer = add_lm_hidden_dense_layer + self.num_attention_heads = 1 # not used but required by Hugging Face + + +class MegaOnnxConfig(OnnxConfig): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + if self.task == "multiple-choice": + dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"} + else: + dynamic_axis = {0: "batch", 1: "sequence"} + return OrderedDict( + [ + ("input_ids", dynamic_axis), + ("attention_mask", dynamic_axis), + ] + ) + + +__all__ = ["MegaConfig", "MegaOnnxConfig"] diff --git a/docs/transformers/build/lib/transformers/models/deprecated/mega/convert_mega_original_pytorch_checkpoint_to_pytorch.py b/docs/transformers/build/lib/transformers/models/deprecated/mega/convert_mega_original_pytorch_checkpoint_to_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..c6dbb12890e55fd26c39794aa28018265e98e164 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/mega/convert_mega_original_pytorch_checkpoint_to_pytorch.py @@ -0,0 +1,298 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Convert Mega pretrained checkpoint. Built to convert the Masked LM checkpoint located at +https://huggingface.co/mnaylor/mega-wikitext-103 + +Requirements: + - clone the Mega repo and install fairseq from there + 1. git clone https://github.com/facebookresearch/mega.git + 2. cd mega && pip install -e + - clone the pretrained weights for the original implementation from the hugging face repo + * use this location as the path for pretrained weights +""" + +import argparse + +# utilities to import the model weights and config file +import os +import pickle as pkl + +# PyTorch + new model classes +import torch +from torch import nn + +from transformers import AutoTokenizer, MegaConfig, MegaForMaskedLM + + +# import the EncoderLayer class used to pretrain +# !! NOTE !! this requires the version of fairseq that is built when you install the Mega source +try: + from fairseq.modules.mega_layer import MegaEncoderLayer +except ImportError: + raise ImportError("You need to install the version of fairseq from the Mega repo!") + + +# define the wrapper classes used to train the MLM (see colab notebook below) +# https://colab.research.google.com/drive/1qfUO6o5HRdxBblWlw058HVyvaEPhPpH8?usp=sharing +# MegaLM outputs hidden states +class MegaLM(nn.Module): + "The base class for our Mega encoder - given input IDs, embed text and return encoder output" + + def __init__(self, mega_args, depth, vocab_size): + super().__init__() + self.mega_args = mega_args + self.embedding_layer = nn.Embedding(vocab_size, self.mega_args.encoder_embed_dim) + self.encoders = nn.ModuleList([MegaEncoderLayer(self.mega_args) for _ in range(depth)]) + self.depth = depth + + def forward(self, input_ids, attention_mask, batch_first=True, ignore_mask_value=0): + """ + Code for a forward pass - expects input_ids and attention_mask to come from a Hugging Face tokenizer as PyTorch + tensors, and returns a tensor of size (batch, n_classes) containing classification logits + + Other options: + - batch_first: boolean indicating whether the batch dimension is first in input_ids (default: True, which + aligns with the HF tokenizer behavior) + - ignore_mask_value: the value in attention_mask that identifies tokens that should be ignored (default: 0, + which aligns with HF tokenizer) + """ + + # Mega expects embeddings to be (time, batch, embedding size), but + # Hugging Face returns tokens as (batch, time) + if batch_first: + input_ids = input_ids.T + + # to make things more confusing, Mega expects the attention mask to + # be (batch, time), but with values of 0 (normal token) and 1 (ignore token) + # which is the opposite of what HF returns + if ignore_mask_value == 0: + attention_mask = 1 - attention_mask + + # get token embeddings from IDs + embeds = self.embedding_layer(input_ids) + + # pass through the Mega layers + # input is (time, batch, encoder dim) and output is the same + for encoder in self.encoders: + embeds = encoder(embeds, attention_mask) + + # return according to the shape specified + if batch_first: + # (T, B, H) --> (B, T, H) + return torch.transpose(embeds, 0, 1) + else: + return embeds + + +# renamed from MegaForMaskedLM to avoid confusion with new module +class OriginalMegaForMaskedLM(nn.Module): + "A wrapper class for doing masked language modeling with Mega" + + def __init__(self, mega_args, depth, vocab_size): + super().__init__() + self.mega = MegaLM(mega_args, depth, vocab_size) + self.mlm_head = nn.Linear(mega_args.encoder_embed_dim, vocab_size) + self.dropout = nn.Dropout(p=0.1) + + def forward(self, input_ids, attention_mask, batch_first=True, ignore_mask_value=0): + """ + Perform a forward pass through the Mega encoder and the masked LM head. Returns logits for each vocabulary + entry. + + If `batch_first` (default to align with Hugging Face tokenizer behavior), output will have the shape (Batch + size, Sequence length, Vocab size); otherwise (S, B, V) + """ + encoder_output = self.mega(input_ids, attention_mask, batch_first, ignore_mask_value) + return self.mlm_head(self.dropout(encoder_output)) + + +# code to convert the checkpoint located in the user-specified location +def convert_checkpoint_to_huggingface(pretrained_checkpoint_path, output_path, includes_tokenizer): + with open(os.path.join(pretrained_checkpoint_path, "model_args.pkl"), "rb") as f: + mega_original_args = pkl.load(f) + + # load the original encoder + original_mlm = OriginalMegaForMaskedLM(**mega_original_args).eval() + + # load its weights + print( + "Original Mega encoder:", + original_mlm.mega.load_state_dict( + torch.load( + os.path.join(pretrained_checkpoint_path, "encoder_weights.pt"), map_location="cpu", weights_only=True + ) + ), + ) + print( + "Original Mega MLM layer:", + original_mlm.mlm_head.load_state_dict( + torch.load( + os.path.join(pretrained_checkpoint_path, "mlm_head_weights.pt"), map_location="cpu", weights_only=True + ) + ), + ) + + # create a new config from the old one + hf_config = MegaConfig( + num_hidden_layers=mega_original_args["depth"], + vocab_size=mega_original_args["vocab_size"], + hidden_size=mega_original_args["mega_args"].encoder_embed_dim, + shared_representation_size=mega_original_args["mega_args"].encoder_z_dim, + intermediate_size=mega_original_args["mega_args"].encoder_hidden_dim, + ema_projection_size=mega_original_args["mega_args"].encoder_n_dim, + dropout_prob=mega_original_args["mega_args"].dropout, + attention_probs_dropout_prob=mega_original_args["mega_args"].attention_dropout, + hidden_dropout_prob=mega_original_args["mega_args"].hidden_dropout, + activation=mega_original_args["mega_args"].activation_fn, + attention_activation=mega_original_args["mega_args"].attention_activation_fn, + bidirectional=mega_original_args["mega_args"].bidirectional, + use_chunking=mega_original_args["mega_args"].encoder_chunk_size > 0, + chunk_size=mega_original_args["mega_args"].encoder_chunk_size, + truncation=mega_original_args["mega_args"].truncation_length, + normalization_type=mega_original_args["mega_args"].normalization_type, + normalize_before_mega=True, + norm_affine=True, + use_feature_dropout=mega_original_args["mega_args"].feature_dropout, + relative_positional_bias=mega_original_args["mega_args"].rel_pos_bias, + max_positions=mega_original_args["mega_args"].max_source_positions, + nffn_hidden_size=mega_original_args["mega_args"].encoder_ffn_embed_dim, + normalize_before_ffn=mega_original_args["mega_args"].normalize_before, + # new arguments added for HF implementation + nffn_activation_dropout_prob=0.0, + add_token_type_embeddings=False, + add_lm_hidden_dense_layer=False, + ) + + hf_mlm = MegaForMaskedLM(hf_config).eval() + + # the originl checkpoint just uses nn.Embedding for the word embeddings + # we use a wrapper module for embeddings to add support for positional embeddings + hf_mlm.mega.embedding_layer.word_embeddings.weight = original_mlm.mega.embedding_layer.weight + + # modify the state dictionary of the original checkpoint to account for naming issues in the Hugging Face + # ecosystem -- any names containing "beta" or "gamma" aren't safe to use and are renamed upon _load_pretrained, + # also renaming previously confusing parameter names + original_state_dict = original_mlm.mega.encoders.state_dict() + updated_keys = {} + for module_name in original_state_dict.keys(): + new_module_name = None + # have to handle gamma, beta, and alpha differently due to their use + # in multiple modules within the original repository; + # beta is used in EMA, MovingAverageGatedAttention, and RotaryRelativePositionalBias, and must be renamed due to flax/tf weights + # the EMA sublayer was renamed from "move" to "ema_gate" for readability, so that is also done here + if "beta" in module_name: + # EMA sub-layers were always called "move" in the original repo + if "move.beta" in module_name: + new_module_name = module_name.replace("move.beta", "ema_gate.ema_expansion_matrix") + elif "mega_layer.beta" in module_name: + new_module_name = module_name.replace("beta", "qk_bias") + else: + new_module_name = module_name.replace("beta", "b_param") + # beta is used in EMA and MovingAverageGatedAttention, and must be renamed due to flax/tf weights + elif "gamma" in module_name: + if "move.gamma" in module_name: + new_module_name = module_name.replace("move.gamma", "ema_gate.kernel_projection_matrix") + elif "mega_layer.gamma" in module_name: + new_module_name = module_name.replace("gamma", "qk_weight") + else: + new_module_name = module_name.replace("gamma", "g_param") + # alpha is used in EMA and positional bias; renaming to improve readability + elif "move.alpha" in module_name: + new_module_name = module_name.replace("move.alpha", "ema_gate.decay_factor") + # delta is only used in EMA; renaming to improve readability + elif "move.delta" in module_name: + new_module_name = module_name.replace("move.delta", "ema_gate.damping_factor") + # omega is only used in EMA; renaming to improve readability + elif "omega" in module_name: + new_module_name = module_name.replace("move.omega", "ema_gate.residual_weight") + + if new_module_name: + updated_keys[module_name] = new_module_name + + if len(updated_keys) != 0: + print(f"Renaming these keys: {updated_keys.keys()}") + else: + print("No need to rename state dict entries") + for old, new in updated_keys.items(): + original_state_dict[new] = original_state_dict.pop(old) + + # now attempt to load the state dictionary with updated names + # note that we now call it `mega.layers` instead of `mega.encoders` due to hugging face style + print("HF Mega encoder:", hf_mlm.mega.layers.load_state_dict(original_state_dict)) + + # load the MLM head weights directly + print( + "HF Mega MLM layer:", + hf_mlm.mlm_head.load_state_dict( + torch.load( + os.path.join(pretrained_checkpoint_path, "mlm_head_weights.pt"), map_location="cpu", weights_only=True + ) + ), + ) + + # test on a randomly generated input sequence + input_ids = torch.randint(0, hf_config.vocab_size, size=(4, 256)) + input_mask = torch.ones_like(input_ids) + # mask a few tokens to make sure masking is applied appropriately :) + input_mask[:, -10:] = 0 + + # run forward passes + original_output = original_mlm(input_ids, input_mask, batch_first=True, ignore_mask_value=0) + hf_output = hf_mlm(input_ids, input_mask)[0] + + # print shapes and diff + print(f"original output {original_output.shape}") + print(f"hf output {hf_output.shape}") + print(f"max diff: {(original_output - hf_output).max()}") # 0.0 + success = torch.allclose(original_output, hf_output, atol=1e-3) + + if success: + print("Yay!") + hf_mlm.save_pretrained(output_path) + else: + raise RuntimeError(f"Something's broken :(\nOriginal:\n{original_output}\n\nHF\n{hf_output}\n{hf_mlm}") + + if includes_tokenizer: + print("Transferring tokenizer") + tokenizer = AutoTokenizer.from_pretrained(pretrained_checkpoint_path) + tokenizer.save_pretrained(output_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--pretrained_checkpoint_path", + default=None, + type=str, + required=True, + help="Point to the directory containing your model weights using the official Mega repo", + ) + + parser.add_argument( + "--output_path", default=None, type=str, required=True, help="Location to save the Hugging Face version" + ) + + parser.add_argument( + "--includes_tokenizer", + action="store_true", + help="Use this flag if there is a Hugging Face tokenizer in the original checkpoint repo", + ) + + args = parser.parse_args() + + convert_checkpoint_to_huggingface(args.pretrained_checkpoint_path, args.output_path, args.includes_tokenizer) diff --git a/docs/transformers/build/lib/transformers/models/deprecated/mega/modeling_mega.py b/docs/transformers/build/lib/transformers/models/deprecated/mega/modeling_mega.py new file mode 100644 index 0000000000000000000000000000000000000000..d5a490b01d3cc893758490c669f3928b93edab65 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/mega/modeling_mega.py @@ -0,0 +1,2285 @@ +# coding=utf-8 +# Copyright 2023 The Mega Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch MEGA model.""" + +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ....activations import ACT2FN +from ....modeling_outputs import ( + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + MaskedLMOutput, + MultipleChoiceModelOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from ....modeling_utils import PreTrainedModel +from ....pytorch_utils import ALL_LAYERNORM_LAYERS +from ....utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from .configuration_mega import MegaConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "mnaylor/mega-base-wikitext" +_CONFIG_FOR_DOC = "MegaConfig" + + +class MegaEmbeddings(nn.Module): + """ + Mega's basic implementation does not incorporate token type embeddings, so this is a stripped-down version of + RoBERTa's embeddings which optionally includes token types + """ + + def __init__(self, config: MegaConfig): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.use_token_types = config.add_token_type_embeddings + if self.use_token_types: + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + # registering a buffer here allows model tracing when not passing optional token type IDs + # more info at transformers issue #5664 + self.register_buffer( + "token_type_ids", torch.zeros(config.max_positions, dtype=torch.long).expand((1, -1)), persistent=False + ) + + self.padding_idx = config.pad_token_id + + def forward(self, input_ids=None, token_type_ids=None, inputs_embeds=None): + if (input_ids is None) and (inputs_embeds is None): + raise ValueError("Must provide one of input_ids or inputs_embeds") + elif input_ids is not None: + input_shape = input_ids.size() + device = input_ids.device + + # get the word embeddings if only IDs are provided + inputs_embeds = self.word_embeddings(input_ids) + else: + input_shape = inputs_embeds.size()[:-1] + device = inputs_embeds.device + + # the original Mega implementation did not include token type embeddings, so we add + # an option to use them if desired; if embeddings are present and token type IDs are + # not provided, we will use a registered buffer (which helps with tracing) + if self.use_token_types: + if token_type_ids is None: + if hasattr(self, "token_type_ids"): + buffered_token_type_ids = self.token_type_ids[:, : input_shape[1]] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], input_shape[1]) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + # access token type embeddings + token_type_embeddings = self.token_type_embeddings(token_type_ids) + # add the token type embeddings to the word embeddings + embeddings = inputs_embeds + token_type_embeddings + else: + embeddings = inputs_embeds + return embeddings + + +class MegaSimpleRelativePositionalBias(nn.Module): + """ + Simple relative positional embeddings copied from the Mega repo; renamed variables for better readability + """ + + def __init__(self, config: MegaConfig): + super().__init__() + self.config = config + self.max_positions = self.config.max_positions if self.config.chunk_size < 0 else self.config.chunk_size + self.rel_pos_bias = nn.Parameter(torch.Tensor(2 * config.max_positions - 1)) + + def forward(self, seq_len): + if seq_len > self.max_positions: + raise ValueError("Sequence length {} going beyond max length {}".format(seq_len, self.max_positions)) + + # seq_len * 2 - 1 + bias = self.rel_pos_bias[(self.max_positions - seq_len) : (self.max_positions + seq_len - 1)] + # seq_len * 3 - 1 + tile = F.pad(bias, (0, seq_len)) + # (seq_len * 3 - 1) * seq_len + tile = torch.tile(tile, (seq_len,)) + tile = tile[:-seq_len] + # seq_len x (3 * seq_len - 2) + tile = tile.view(seq_len, 3 * seq_len - 2) + start = (2 * seq_len - 1) // 2 + end = tile.size(1) - start + tile = tile[:, start:end] + return tile + + +class MegaRotaryRelativePositionalBias(nn.Module): + """ + Rotary relative bias for positional information; similar in concept to RoPE (i.e. RoFormer) but taken from the Mega + repo due to differences in implementation. + + When initialized, produces a positional bias which ranges from position 0 to config.max_positions, but can + extrapolate to longer sequences. Can be indexed according to input position IDs + """ + + def __init__(self, config: MegaConfig): + super().__init__() + if config.hidden_size % 2 != 0: + raise RuntimeError("Rotary positional bias requires `hidden_size` to be a multiple of 2") + self.config = config + self.embed_dim = config.shared_representation_size + self.max_positions = self.config.max_positions if self.config.chunk_size < 0 else self.config.chunk_size + self.sine, self.cosine = MegaRotaryRelativePositionalBias.get_sinusoid_embeddings( + config.max_positions, self.embed_dim + ) + # alpha and beta parameters for the rotary bias; beta renamed to b_param to avoid clashes with tf/flax weight handling + # in loading pretrained weights + self.alpha = nn.Parameter(torch.Tensor(1, self.embed_dim)) + self.b_param = nn.Parameter(torch.Tensor(1, self.embed_dim)) + self.register_buffer("_float_tensor", torch.FloatTensor([0.0])) + + @staticmethod + def get_sinusoid_embeddings(max_positions: int, embedding_dim: int): + half_dim = embedding_dim // 2 + emb = math.log(10000) / half_dim + emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb) + emb = torch.arange(max_positions, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0) + return torch.sin(emb), torch.cos(emb) + + def rotary(self, input): + seq_len, embed_dim = input.size() + chunk_1, chunk_2 = torch.chunk(input, 2, dim=-1) + if self.sine is None or seq_len > self.sine.size(0): + self.sine, self.cosine = MegaRotaryRelativePositionalBias.get_sinusoid_embeddings(seq_len, embed_dim) + self.max_positions = seq_len + self.sine = self.sine.to(self._float_tensor) + self.cosine = self.cosine.to(self._float_tensor) + + sin = self.sine[:seq_len] + cos = self.cosine[:seq_len] + return torch.cat([chunk_1 * cos - chunk_2 * sin, chunk_2 * cos + chunk_1 * sin], dim=1) + + def forward(self, seq_len): + rotary_alpha = self.rotary(self.alpha.expand(seq_len, self.embed_dim)) + rotary_beta = self.rotary(self.b_param.expand(seq_len, self.embed_dim)) + bias = torch.einsum("mk,nk->mn", rotary_alpha, rotary_beta) + return bias + + +class MegaDropout(nn.Module): + """ + A unified class for standard dropout functionality and featurewise dropout. + + The original fairseq Mega repo used 2 classes for these, which included some unnecessary handling of training logic + and an unused `inplace` option. The original implementation used torch.nn.functional instead of submodules, which + is retained here as well. + """ + + def __init__(self, dropout_probability, is_featurewise=False): + super().__init__() + self.dropout_probability = dropout_probability + self.is_featurewise = is_featurewise + + def forward(self, input, batch_first: bool = False): + if self.is_featurewise: + if batch_first: + # (batch_size X sequence_length X feature_dimension) + # -> (batch_size X feature_dimension X sequence_length) + # -> (batch_size X sequence_length X feature_dimension) + return F.dropout2d( + input.transpose(-1, -2), p=self.dropout_probability, training=self.training + ).transpose(-1, -2) + else: + if input.dim() != 3: + raise ValueError( + "Feature dropout inputs must be exactly 3-dimensional if inputs are ordered [sequence length, batch size, hidden dimension]" + ) + # (sequence_length X batch_size X feature_dimension) + # -> (batch_size X feature_dimension X sequence_length) + # -> (sequence_length X batch_size X feature_dimension) + return F.dropout2d(input.permute(1, 2, 0), p=self.dropout_probability, training=self.training).permute( + 2, 0, 1 + ) + else: + return F.dropout(input, p=self.dropout_probability, training=self.training) + + +class MegaRMSNorm(nn.Module): + """ + RMSNorm used in Mega implementation. Differs from T5's RMSNorm by applying the weight prior to taking the square + root (as opposed to after in T5) + """ + + def __init__(self, number_features, eps=1e-6, affine=True): + super().__init__() + self.num_features = number_features + self.eps = eps + self.affine = affine + if affine: + self.weight = nn.Parameter(torch.Tensor(self.num_features)) + else: + self.register_parameter("weight", None) + + def forward(self, input): + mean_square = torch.mean(torch.square(input), dim=-1, keepdim=True) + if self.weight is not None: + input = input * self.weight + + input * torch.rsqrt(mean_square + self.eps) + return input + + def extra_repr(self): + return f"{self.num_features}, eps={self.eps}, affine={self.affine}" + + +class MegaScaleNorm(nn.Module): + """ + Scale normalization introduced in MEGA which is similar to RMSNorm, but uses a single parameter for scalar + multiplication instead of a vector, and applies over a specified dimension + """ + + def __init__(self, dim, eps=1e-6, affine=True): + super().__init__() + self.dim = dim + self.eps = eps + self.affine = affine + if affine: + self.scalar = nn.Parameter(torch.Tensor(1)) + else: + self.register_parameter("scalar", None) + + def forward(self, input): + mean_square = torch.mean(torch.square(input), dim=self.dim, keepdim=True) + if self.scalar is not None: + input = self.scalar * input + + output = input * torch.rsqrt(mean_square + self.eps) + return output + + +class MegaSequenceNorm(nn.Module): + """ + A wrapper class for various layer normalization options used in Mega. Used to handle differences in expectations on + input axis locations for different normalization methods. + """ + + def __init__(self, norm_type, embedding_dim, eps=1e-5, affine=True, export=False): + super().__init__() + if norm_type == "layernorm": + self.norm = nn.LayerNorm(embedding_dim, eps, elementwise_affine=affine) + elif norm_type == "scalenorm": + self.norm = MegaScaleNorm(dim=-1, eps=eps, affine=affine) + elif norm_type == "rmsnorm": + self.norm = MegaRMSNorm(embedding_dim, eps=eps, affine=affine) + elif norm_type == "batchnorm": + self.norm = nn.BatchNorm1d(embedding_dim, eps=eps, affine=affine) + elif norm_type == "syncbatchnorm": + self.norm = nn.SyncBatchNorm(embedding_dim, eps=eps, affine=affine) + else: + raise ValueError("Unknown norm type: {}".format(norm_type)) + + def forward(self, input): + if isinstance(self.norm, nn.modules.batchnorm._BatchNorm): + if input.dim() != 3: + raise ValueError("BatchNorm inputs must be exactly 3-dimensional") + input = input.permute(1, 2, 0) + input = self.norm(input) + return input.permute(2, 0, 1) + else: + return self.norm(input) + + +# add this layernorm class to ALL_LAYERNORM_LAYERS +ALL_LAYERNORM_LAYERS.append(MegaSequenceNorm) + + +class MegaMultiDimensionDampedEma(nn.Module): + """ + Mega's Exponential Moving Average layer, largely left unmodified from the original repo with the exception of + variable names and moving away from the stateful representation of incremental decoding state. See + "https://arxiv.org/abs/2209.10655" for more details. + """ + + def __init__(self, config: MegaConfig): + super().__init__() + + self.config = config + + self.embed_dim = config.hidden_size + self.ndim = config.ema_projection_size + self.bidirectional = config.bidirectional + self.truncation = config.truncation + self.scale = math.sqrt(1.0 / self.ndim) + + kernel_dim = 2 * config.hidden_size if self.bidirectional else config.hidden_size + # renamed delta (damping_factor) and alpha (decay_factor) to be more descriptive of what the parameters are doing + self.damping_factor = nn.Parameter(torch.Tensor(kernel_dim, self.ndim, 1)) + self.decay_factor = nn.Parameter(torch.Tensor(kernel_dim, self.ndim, 1)) + # renamed gamma (kernel_projection_matrix) and beta (ema_expansion_matrix) respectively to avoid HF renaming + # things and align with the paper's description of these params' behavior + self.ema_expansion_matrix = nn.Parameter(torch.Tensor(kernel_dim, self.ndim, 1)) + self.kernel_projection_matrix = nn.Parameter(torch.Tensor(kernel_dim, self.ndim)) + # renamed omega to residual_weight to describe what it's doing + self.residual_weight = nn.Parameter(torch.Tensor(config.hidden_size)) + self._kernel = None + self._coeffs = None + + def _compute_ema_coefficients(self): + self._coeffs = None + # convert the alpha and delta parameters (kernel_dim x EMA projection size x 1) to [0, 1] with sigmoid + damping_factor = torch.sigmoid(self.damping_factor) + decay_factor = torch.sigmoid(self.decay_factor) + previous_timestep_weight = 1.0 - damping_factor * decay_factor + return damping_factor, previous_timestep_weight + + def _compute_efficient_ema_kernel(self, length: int): + # computes the kernel used for efficient damped EMA applied via FFT convolution + self._kernel = None + # p and q have shape (kernel_dim x ema_projection_size x 1) + damping_factor, previous_timestep_weight = self._compute_ema_coefficients() + # extend the kernel to (kernel_dim X ema_projection_size X sequence_length) and + # multiply q by sequential ints up to the sequence length + vander = torch.arange(length).to(damping_factor).view(1, 1, length) * torch.log(previous_timestep_weight) + kernel = (damping_factor * self.ema_expansion_matrix) * torch.exp(vander) + # (kernel_dim X ema_projection_size X sequence_length) -> (kernel_dim, sequence_length) + return torch.einsum("dnl,dn->dl", kernel, self.kernel_projection_matrix * self.scale) + + def get_ema_coefficients(self): + if self.training: + return self._compute_ema_coefficients() + else: + if self._coeffs is None: + self._coeffs = self._compute_ema_coefficients() + return self._coeffs + + def get_ema_kernel(self, length: int): + kernel_size = length if self.truncation is None else min(self.truncation, length) + if self.training: + return self._compute_efficient_ema_kernel(kernel_size) + else: + if self._kernel is None or self._kernel.size(-1) < kernel_size: + self._kernel = self._compute_efficient_ema_kernel(kernel_size) + return self._kernel[..., :kernel_size] + + def fft_convolution(self, inputs, kernel, length): + # this is a wrapper for repeated use of EMA calculation via FFT (fast Fourier transform) convolution + inputs_fft = torch.fft.rfft(inputs.float(), n=2 * length) + kernel_fft = torch.fft.rfft(kernel.float(), n=2 * length) + convolved_sequence = torch.fft.irfft(inputs_fft * kernel_fft, n=2 * length) + return convolved_sequence + + def ema_step(self, inputs, length, past_state=None): + if length == 1: + return self.one_ema_step(inputs, past_state=past_state) + + # (kernel_dim X ema_projection_size X 1) + damping_factor, previous_timestep_weight = self.get_ema_coefficients() + # (kernel_dim X ema_projection_size X 1+sequence_length) + vander = torch.arange(length + 1).to(damping_factor).view(1, 1, length + 1) * torch.log( + previous_timestep_weight + ) + vander = torch.exp(vander) + if past_state is not None: + # (kernel_dim X ema_projection_size X sequence_length) * (kernel_dim X ema_projection_size X 1) + # -> (kernel_dim X ema_projection_size X sequence_length) + past_ema_proj = vander[:, :, 1:] * (self.kernel_projection_matrix * self.scale).unsqueeze(-1) + # past_state will be (batch_size, kernel_dim, ema_projection_size) + past_ema_state = torch.einsum("bdn,dnl->bdl", past_state, past_ema_proj) + # (kernel_dim X ema_projection_size) * (batch_size X kernel_dim X ema_projection_size) + # -> (batch_size X kernel_dim X ema_projection_size) + past_vandermonde = vander[:, :, -1] * past_state + else: + past_ema_state = None + past_vandermonde = None + + # (kernel_dim X ema_projection_size X sequence_length) + vander = vander[:, :, :-1] + kernel = (damping_factor * self.ema_expansion_matrix) * vander + kernel_proj = torch.einsum("dnl,dn->dl", kernel, self.kernel_projection_matrix * self.scale) + + ema_output = self.fft_convolution(inputs, kernel_proj, length=length)[..., 0:length] + ema_output = ema_output.type_as(inputs) + if past_ema_state is not None: + ema_output = ema_output + past_ema_state + + updated_hidden_state = torch.einsum("bdl,dnl->bdn", inputs, torch.flip(kernel, dims=[2])) + if past_vandermonde is not None: + updated_hidden_state = updated_hidden_state + past_vandermonde + # return a tuple: + # (sequence_length, batch_size, kernel_dim) + # (batch_size, kernel_dim, ema_projection_size) + return ema_output.permute(2, 0, 1), updated_hidden_state + + def one_ema_step(self, inputs, past_state=None): + damping_factor, previous_timestep_weight = self.get_ema_coefficients() + # (kernel_dim X ema_projection_size) x (batch_size X kernel_dim X 1) + # -> (batch_size X kernel_dim X ema_projection_size) + updated_state = (damping_factor * self.ema_expansion_matrix).squeeze(-1) * inputs + if past_state is not None: + updated_state = updated_state + previous_timestep_weight.squeeze(-1) * past_state + # (batch_size X kernel_dim) + out = torch.einsum("bdn,dn->bd", updated_state, self.kernel_projection_matrix * self.scale) + # (1 X batch_size X kernel_dim), (batch_size X kernel_dim X ema_projection_size) + return out.unsqueeze(0), updated_state + + def forward( + self, + inputs, + attention_mask: Optional[torch.Tensor] = None, + prev_state: Optional[torch.Tensor] = None, + use_cache: bool = False, + ) -> torch.Tensor: + """ + Mega's exponential moving average (EMA) sub-layer applied prior to single-headed (traditional) self-attention + + Args: + inputs (`torch.Tensor` of shape `(sequence_length, batch_size, hidden_size)`): + Hidden state / embedding input to update via EMA based on FFT convolution + attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indicates which inputs are to be ignored (mostly due to padding), where elements are either 1 for *not + masked* or 0 for *masked* + prev_state (`torch.Tensor` of shape `(batch_size, config.ndim)`, *optional*): + The hidden state returned from the previous timestep during incremental decoding. + use_cache (`bool`, default `False`): + Whether to perfom incremental decoding; uses `prev_state` as the prior timestep, and returns the + updated EMA hidden state for use in the next step + + Returns: + `tuple(torch.FloatTensor)` containing various elements depending on configuration ([`MegaConfig`]) and + inputs: + - **hidden_states** (`torch.FloatTensor` of shape `(sequence_length, batch_size, hidden_size)`) -- Hidden + states updated by EMA, with same shapes as inputs + - **updated_state** (*optional*, returned when `use_cache=True`) `torch.FloatTensor of shape `(batch_size, + config.ndim)` -- The incremental EMA state for use in the next step of incremental decoding + """ + + seq_len, bsz, embed_dim = inputs.size() + if embed_dim != self.embed_dim: + raise ValueError( + f"Unexpected embedding dimension received: input is {embed_dim}, model expects {self.embed_dim}" + ) + + # sequence_length X batch_size X hidden_size + residual = inputs * self.residual_weight + + # (sequence_length x batch_size x hidden_size) -> (batch_size x hidden_size x sequence_length) + inputs = inputs.permute(1, 2, 0) + # mask the input: output is a tensor with 0 in the masked positions + if attention_mask is not None: + inputs = inputs * (attention_mask.unsqueeze(1).type_as(inputs)) + + if self.bidirectional and use_cache: + raise RuntimeError("Bidirectional EMA does not support incremental state") + + if use_cache: + out, updated_state = self.ema_step(inputs, seq_len, past_state=prev_state) + + # (batch_size X hidden_size) -> (1 x batch_size x hidden_size) + out = F.silu(out + residual) + + # if incremental decoding, return the new state along with the output + return out, updated_state + else: + # (hidden_size x sequence_length) + kernel = self.get_ema_kernel(seq_len) + fft_len = seq_len + s_index = 0 + kernel_size = kernel.size(1) + if self.bidirectional: + # split the kernel for each direction of EMA + k1, k2 = torch.split(kernel, [self.embed_dim, self.embed_dim], dim=0) + # (hidden_size X 2*sequence_length - 1) + kernel = F.pad(k1, (kernel_size - 1, 0)) + F.pad(k2.flip(-1), (0, kernel_size - 1)) + inputs = F.pad(inputs, (kernel_size - 1, 0)) + fft_len = fft_len + kernel_size - 1 + s_index = 2 * kernel_size - 2 + + ema_output = self.fft_convolution(inputs, kernel, length=fft_len)[..., s_index : s_index + seq_len] + ema_output = ema_output.type_as(inputs) + # (batch_size X hidden_size X sequence_length) -> (sequence_length X batch_size X hidden_size) + gated_ema_output = F.silu(ema_output.permute(2, 0, 1) + residual) + + return gated_ema_output, None + + +class MegaGatedCrossAttention(nn.Module): + """ + Gated Structured State Attention for use in encoder-decoder model. See Mega paper for more details. Only + modifications from original implementation are variable names, removing the unnecessary `before_attn_fn` and + `static_kv` arguments, and the stateful representation of incremental decoder state. + """ + + def __init__(self, config: MegaConfig): + super().__init__() + + self.config = config + self.activation = ACT2FN[self.config.activation] + self.attention_activation = self.config.attention_activation + self.scaling = self.config.shared_representation_size**-0.5 if self.attention_activation == "softmax" else None + + self.dropout = MegaDropout(self.config.dropout_prob, is_featurewise=self.config.use_feature_dropout) + self.hidden_dropout = MegaDropout( + self.config.hidden_dropout_prob, is_featurewise=self.config.use_feature_dropout + ) + # Attention dropout is standard dropout + self.attention_dropout = MegaDropout(self.config.attention_probs_dropout_prob, is_featurewise=False) + + self.prenorm = self.config.normalize_before_mega + self.norm = MegaSequenceNorm( + self.config.normalization_type, self.config.hidden_size, affine=self.config.norm_affine + ) + + self.k_proj = nn.Linear(self.config.hidden_size, self.config.shared_representation_size) + self.v_proj = nn.Linear(self.config.hidden_size, self.config.hidden_size) + self.q_proj = nn.Linear( + self.config.hidden_size, 2 * self.config.hidden_size + self.config.shared_representation_size + ) + self.h_proj = nn.Linear(self.config.hidden_size, self.config.hidden_size) + + if self.config.relative_positional_bias == "simple": + self.rel_pos_bias = MegaSimpleRelativePositionalBias(config) + elif self.config.relative_positional_bias == "rotary": + self.rel_pos_bias = MegaRotaryRelativePositionalBias(config) + else: + raise ValueError("unknown relative position bias: {}".format(self.config.relative_positional_bias)) + + self.softmax = nn.Softmax(dim=-1) + + def element_attention(self, query, key, key_padding_mask, pidx): + bsz, src_len, _ = key.size() + tgt_len = query.size(1) if pidx is None else pidx + 1 + if key_padding_mask is not None: + # (batch_size X source_sequence_length) --> (batch_size X 1 X 1) + lengths = key_padding_mask.sum(dim=-1).view(bsz, 1, 1) + else: + lengths = src_len + + # (target_sequence_length X source_sequence_length) + bias = self.rel_pos_bias(max(tgt_len, src_len))[:, :src_len] + if pidx is not None: + if query.size(1) != 1: + raise ValueError("Position offset provided with queries longer than 1 token") + # source_sequence_length + bias = bias[pidx] + else: + # (target_sequence_length X source_sequence_length) + bias = bias[:tgt_len] + + # (batch_size X target_sequence_length X source_sequence_length) + qk = torch.bmm(query, key.transpose(1, 2)) / lengths + bias + + attn_weights = ACT2FN[self.attention_activation](qk).type_as(qk) + + if key_padding_mask is not None: + attn_weights = attn_weights * key_padding_mask.unsqueeze(1) + + return attn_weights + + def softmax_attention(self, query, key, key_padding_mask, pidx): + bsz, src_len, _ = key.size() + tgt_len = query.size(1) if pidx is None else pidx + 1 + + # (target_sequence_length X source_sequence_length) + bias = self.rel_pos_bias(max(tgt_len, src_len))[:, :src_len] + if pidx is not None: + if query.size(1) != 1: + raise ValueError("Position offset provided with queries longer than 1 token") + # source_sequence_length + bias = bias[pidx] + else: + # (target_sequence_length X source_sequence_length) + bias = bias[:tgt_len] + + # scaled attention + query = query * self.scaling + # (batch_size X target_sequence_length X source_sequence_length) + qk = torch.bmm(query, key.transpose(1, 2)) + bias + + if key_padding_mask is not None: + qk = qk.masked_fill((1 - key_padding_mask).unsqueeze(1).to(torch.bool), float("-inf")) + + attn_weights = self.softmax(qk).type_as(qk) + return attn_weights + + def forward( + self, + query, + key: Optional[torch.Tensor], + value: Optional[torch.Tensor], + key_padding_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Gated cross-attention used in Mega + + Args: + query (`torch.Tensor` of shape `(target_sequence_length, batch_size, hidden_size)`): + The self (or target) sequence input used as query inputs for cross-attention + key (`torch.Tensor` of shape `(source_sequence_length, batch_size, hidden_size)`): + The cross (or source) sequence input with shape used as keys in cross-attention + value (`torch.Tensor` of shape `(source_sequence_length, batch_size, hidden_size)`): + The cross (or source) sequence input with shape used as values in cross-attention + key_padding_mask (`torch.LongTensor` of shape `(batch_size, source_sequence_length)`, *optional*): + Padding mask corresponding to the source sequence, where entries are 1 for *not masked* and 0 for + *masked* tokens + past_key_values (`tuple(torch.FloatTensor)`, *optional*): + If provided, the hidden state returned from the previous timestep during incremental decoding; expects + that prior cross-attention keys and values will be the last two items in the tuple + output_attentions (`bool`, defaults to `False`): + Whether or not to return the cross-attention weights. + use_cache (`bool`, defaults to `False`): + Whether to perfom incremental decoding; uses `prev_state` as the prior timestep, and returns the + updated EMA hidden state for use in the next step + + Returns: + `tuple(torch.FloatTensor)` containing various elements depending on configuration ([`MegaConfig`]) and + inputs: + - **hidden_states** (`torch.FloatTensor` of shape `(target_sequence_length, batch_size, hidden_size)`) -- + Hidden states from target sequence updated by gated cross-attention + - **attn_weights** (*optional*, returned when `output_attentions=True`) `torch.FloatTensor` of shape + `(batch_size, source_sequence_length, target_sequence_length)` -- The pairwise cross-attention weights + corresponding to each token in the source and target sequences + - **cross_key** (*optional*, returned when `use_cache=True`) `torch.FloatTensor` of shape `(batch_size, + source_sequence_length, config.shared_representation_size)` -- The cross-attention key state for use in + the next step of incremental decoding + - **cross_value** (*optional*, returned when `use_cache=True`) `torch.FloatTensor` of shape `(batch_size, + source_sequence_length, config.hidden_size)` -- The cross-attention value state for use in the next step + of incremental decoding + """ + + seq_len, bsz, embed_dim = query.size() + if embed_dim != self.config.hidden_size: + raise ValueError( + f"Unexpected embedding dimension received: input is {embed_dim} but expected {self.config.hidden_size}" + ) + + if past_key_values is not None: + # make sure the inputs only have a sequence length of 1 if we're doing incremental decoding + if seq_len != 1: + raise ValueError(f"Incremental decoding requested with self-sequence length > 1: {seq_len}") + # expect past_key_values to have (self_key, self_value, self_ema, cross_key, cross_value) + prev_cross_key, prev_cross_value = past_key_values[-2:] + key = value = None + + # use the self-attention cache to get the position id of the current step + prev_self_key = past_key_values[0] + num_incremental_steps = prev_self_key.size(1) + 1 + else: + prev_cross_key = prev_cross_value = None + # we still need the position id if we're doing incremental decoding (past_key_values will be None for the first step) + num_incremental_steps = 0 if use_cache and (seq_len == 1) else None + + full_query = query + if self.prenorm: + full_query = self.norm(full_query) + + # (target_sequence_length X batch_size X 2*hidden_size + shared_representation_size) + query_projected = self.q_proj(full_query) + # split the query projections into separate components + # - residual_weight is passed through sigmoid and sent through elementwise multiplication to the gated/weighted targets prior to being added to the query directly + # - target_gate is a silu-gated tensor that is multiplied by the attention-weighted target below prior to residual connection + # - attention_query is the part that is passed to the attention function + residual_weight, target_gate, attention_query = torch.split( + query_projected, + [self.config.hidden_size, self.config.hidden_size, self.config.shared_representation_size], + dim=-1, + ) + + # (target_sequence_length X batch_size X hidden_size) + residual_weight = torch.sigmoid(residual_weight) + target_gate = F.silu(target_gate) + + if key is None: + if value is not None: + raise ValueError("Key and value must be `None` simultaneously") + projected_key = projected_value = None + else: + # (source_sequence_length X batch_size X shared_representation_size) + projected_key = self.k_proj(key) + # (source_sequence_length X batch_size X hidden_size) + projected_value = self.activation(self.v_proj(key)) + + # (target_sequence_length X batch_size X shared_representation_size) + # -> (batch_size X target_sequence_length X shared_representation_size) + attention_query = attention_query.transpose(0, 1) + if projected_key is not None: + projected_key = projected_key.transpose(0, 1) + if projected_value is not None: + projected_value = projected_value.transpose(0, 1) + + # if we're doing incremental decoding, k and v are None and need to be overwritten with past values + if past_key_values is not None: + projected_key = prev_cross_key + projected_value = prev_cross_value + + # if we're returning the cache for later use, store these now for later return (can be done without having past_key_values provided) + if use_cache: + updated_cross_key = projected_key + updated_cross_value = projected_value + + ctx_len = projected_key.size(1) + # This is part of a workaround to get around fork/join parallelism + # not supporting Optional types. + if key_padding_mask is not None and key_padding_mask.dim() == 0: + key_padding_mask = None + + if key_padding_mask is not None: + if key_padding_mask.size(0) != bsz: + raise ValueError("Key padding mask does not align on the batch dimension") + if key_padding_mask.size(1) != ctx_len: + raise ValueError("Key padding mask does not align on the sequence length dimension") + + if self.attention_activation == "softmax": + attn_weights = self.softmax_attention( + attention_query, projected_key, key_padding_mask, num_incremental_steps + ) + else: + attn_weights = self.element_attention( + attention_query, projected_key, key_padding_mask, num_incremental_steps + ) + + projected_value = self.hidden_dropout(projected_value, batch_first=True) + kernel = self.attention_dropout(attn_weights) + # (batch_size X target_sequence_length X hidden_size) + # -> (target_sequence_length X batch_size X hidden_size) + weighted_targets = torch.bmm(kernel, projected_value).transpose(0, 1) + # (target_sequence_length X batch_size X hidden_size) + weighted_targets = self.activation(self.h_proj(weighted_targets * target_gate)) + weighted_targets = self.dropout(weighted_targets) + out = torch.addcmul(query, residual_weight, weighted_targets - query) + + if not self.prenorm: + out = self.norm(out) + + outputs = (out, attn_weights) if output_attentions else (out,) + if use_cache: + outputs = outputs + (updated_cross_key, updated_cross_value) + + return outputs + + +class MegaMovingAverageGatedAttention(nn.Module): + """ + Pure PyTorch implementation of Mega block; see https://arxiv.org/abs/2209.10655 and original fairseq implementation + at https://github.com/facebookresearch/mega (copyright Meta Research, licensed under MIT License) + + Differences from original implementation include hidden state refactor and fixed inconsistency with additive / + multiplicative attention masks + """ + + def __init__(self, config: MegaConfig): + super().__init__() + self.config = config + self.activation = ACT2FN[self.config.activation] + self.scaling = ( + self.config.shared_representation_size**-0.5 if self.config.attention_activation == "softmax" else None + ) + self.dropout = MegaDropout(self.config.dropout_prob, is_featurewise=self.config.use_feature_dropout) + self.hidden_dropout = MegaDropout( + self.config.hidden_dropout_prob, is_featurewise=self.config.use_feature_dropout + ) + # attention dropout is standard dropout + self.attention_dropout = MegaDropout(self.config.attention_probs_dropout_prob, is_featurewise=False) + + self.norm = MegaSequenceNorm( + self.config.normalization_type, self.config.hidden_size, affine=self.config.norm_affine + ) + self.ema_gate = MegaMultiDimensionDampedEma(config) + + self.v_proj = nn.Linear(self.config.hidden_size, self.config.intermediate_size) + self.mx_proj = nn.Linear( + self.config.hidden_size, + self.config.shared_representation_size + self.config.intermediate_size + 2 * self.config.hidden_size, + ) + self.h_proj = nn.Linear(self.config.intermediate_size, self.config.hidden_size) + + self.qk_weight = nn.Parameter(torch.Tensor(2, self.config.shared_representation_size)) + self.qk_bias = nn.Parameter(torch.Tensor(2, self.config.shared_representation_size)) + + if self.config.relative_positional_bias == "simple": + self.rel_pos_bias = MegaSimpleRelativePositionalBias(config) + elif self.config.relative_positional_bias == "rotary": + self.rel_pos_bias = MegaRotaryRelativePositionalBias(config) + else: + raise ValueError(f"Unknown relative positional bias: {self.config.relative_positional_bias}") + + self.softmax = nn.Softmax(dim=-1) + self.attention_function = ( + self.softmax_attention if self.config.attention_activation == "softmax" else self.element_attention + ) + + def element_attention(self, query, key, padding_mask, causal_mask): + """ + Apply element-wise attention via relu^2 or laplace. Same as original implementation but with standardized + causal attention mask. Expects the Hugging Face standard attention mask paradigm: 1 for not masked, and 0 for + masked. + """ + seq_len = key.size(2) + if padding_mask is not None: + # (batch_size X number of chunks X 1) + lengths = padding_mask.sum(-1, keepdim=True) + # (batch_size X number of chunks X 1 X 1) + lengths = lengths.clamp(min=1.0).unsqueeze(-1) + else: + lengths = seq_len + + if causal_mask is not None: + lengths = causal_mask.sum(dim=-1, keepdim=True) + + # (sequence_length X sequence_length) + bias = self.rel_pos_bias(seq_len) + if seq_len != query.size(2): + if query.size(2) != 1: + raise ValueError("Size mismatch between Q and K in element attention") + # (1 X sequence_length) + bias = bias[-1:] + + # (batch_size X number of chunks X sequence_length X sequence_length) + qk = torch.matmul(query, key.transpose(2, 3)) / lengths + bias + + attn_weights = ACT2FN[self.config.attention_activation](qk).type_as(qk) + + if padding_mask is not None: + attn_weights = attn_weights * padding_mask.unsqueeze(2) + + if causal_mask is not None: + attn_weights = attn_weights * causal_mask + + return attn_weights + + def softmax_attention(self, query, key, padding_mask, causal_mask): + "Standard softmax self-attention, as in the original Transformer paper" + seq_len = key.size(2) + # (sequence_length X sequence_length) + bias = self.rel_pos_bias(seq_len) + if seq_len != query.size(2): + if query.size(2) != 1: + raise ValueError("Size mismatch between Q and K in softmax attention") + # (1 X sequence_length) + bias = bias[-1:] + + # scaled attention + query = query * self.scaling + + # (batch_size x number of chunks x chunk_size x chunk_size) if chunking + # (batch_size x 1 x sequence_length x sequence_length) otherwise + qk = torch.matmul(query, key.transpose(2, 3)) + bias + + # apply causal mask (presumed to be 1/0 for not masked / masked) + # additive, but convert to 0/-inf (which is not explicitly in the Mega source code) + if causal_mask is not None: + additive_causal_mask = torch.zeros_like(causal_mask, dtype=qk.dtype) + additive_causal_mask = additive_causal_mask.masked_fill((1 - causal_mask).bool(), float("-inf")) + qk = qk + additive_causal_mask + + if padding_mask is not None: + # 1 for tokens which are *not masked* + # 0 for tokens which are *masked* + # replace masked tokens with -inf to make softmax ignore them + # need to invert the padding mask to match what mega original did + padding_mask = 1 - padding_mask + padding_mask_all = padding_mask.all(dim=-1, keepdim=True) + padding_mask = torch.logical_and(padding_mask, ~padding_mask_all) + qk = qk.masked_fill(padding_mask.unsqueeze(2).to(torch.bool), float("-inf")) + + attn_weights = self.softmax(qk).type_as(qk) + return attn_weights + + def forward( + self, + input, + padding_mask: Optional[torch.Tensor] = None, + causal_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[torch.Tensor]] = None, + output_attentions=False, + use_cache=False, + ): + """ + Mega's self-attention block, which combines multi-headed EMA with traditional self-attention + + Args: + input (`torch.Tensor` of shape `(sequence_length, batch_size, hidden_size)`): + Hidden states to be updated by Mega's self-attention + padding_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indicates which inputs are to be ignored due to padding, where elements are either 1 for *not masked* + or 0 for *masked* + causal_mask (`torch.LongTensor` of shape `(sequence_length, sequence_length)`, *optional*): + Indicates which inputs are to be ignored due to causal attention, where elements are either 1 for *not + masked* or 0 for *masked* + past_key_values (`tuple(torch.Tensor)`, *optional*): + The hidden states returned from the previous timestep during incremental decoding; expects that + self-attention key, value, and EMA states are the first 3 entries in the tuple + output_attentions (`bool`, default `False`): + Whether to return self-attention weights + use_cache (`bool`, default `False`): + Whether to perfom incremental decoding; uses `past_key_values` as prior state, and returns the updated + states for use in the next step + + Returns: + `tuple(torch.FloatTensor)` containing various elements depending on configuration ([`MegaConfig`]) and + inputs: + - **hidden_states** (`torch.FloatTensor` of shape `(sequence_length, batch_size, hidden_size)`) -- Hidden + states from target sequence updated by Mega's self-attention + - **attn_weights** (*optional*, returned when `output_attentions=True`) `torch.FloatTensor` of shape + `(batch_size, 1, sequence_length, sequence_length)` -- The self-attention weights corresponding to how + each token in the input sequence attends to every other token + - **self_key** (*optional*, returned when `use_cache=True`) `torch.FloatTensor` of shape `(batch_size, + sequence_length, config.shared_representation_size)` -- The self-attention key state for use in the next + step of incremental decoding + - **self_value** (*optional*, returned when `use_cache=True`) `torch.FloatTensor` of shape `(batch_size, + sequence_length, config.hidden_size)` -- The self-attention value state for use in the next step of + incremental decoding + - **self_ema_state** (*optional*, returned when `use_cache=True`) `torch.FloatTensor` of shape + `(batch_size, config.ndim)` The incremental EMA state for use in the next step of incremental decoding. + """ + + seq_len, bsz, embed_dim = input.size() + if embed_dim != self.config.hidden_size: + raise ValueError(f"Input embedding dimension should be {self.config.hidden_size}; received {embed_dim}") + + # store inputs for residual connection and handle pre-norm if requested + residual = input + if self.config.normalize_before_mega: + input = self.norm(input) + + # (sequence_length X batch_size X hidden_size) -> (sequence_length X batch_size X intermediate_size) + value = self.activation(self.v_proj(input)) + + # unpack the incremental state if provided + # assumed to be (self K, self V, self EMA state, cross K, cross V) + # also assumes that incremental decoding is working one token at a time, so input sequence length must be 1 + if self.config.is_decoder and (past_key_values is not None): + if seq_len > 1: + raise ValueError(f"Incremental decoding only supports self sequence length of 1; received {seq_len}") + # the first 3 items in the saved states will be these regardless of whether cross-attention is present + prev_self_key, prev_self_value, prev_ema_state = past_key_values[0:3] + else: + prev_self_key = prev_self_value = prev_ema_state = None + + # ema output is (sequence_length x batch_size x hidden_size) + # updated_ema_state will be None if use_cache=False; otherwise (batch_size, config.ndim) + ema_out, updated_ema_state = self.ema_gate( + input, attention_mask=padding_mask, prev_state=prev_ema_state, use_cache=use_cache + ) + ema_out = self.dropout(ema_out) + + # (sequence_length X batch_size X hidden_size) + # -> (sequence_length X batch_size X 2*hidden_size + config.shared_representation_size + config.intermediate_size) + # - residual_weight -> sigmoid -> applied to residual connection in torch.addcmul + # - query_key_gates -> split into two components: query_key becomes query and key for attention input, gates becomes gating for self-attention output + # - intermediate_state -> added to weighted attention output, sent through activation, and has inputs subtracted during + # torch.addcmul to create the final layer output + base = self.mx_proj(ema_out) + residual_weight, query_key_gates, intermediate_state = torch.split( + base, + [ + self.config.hidden_size, + self.config.shared_representation_size + self.config.intermediate_size, + self.config.hidden_size, + ], + dim=-1, + ) + + # (sequence_length X batch_size X hidden_size) + residual_weight = torch.sigmoid(residual_weight) + + # (sequence_length X batch_size X shared_representation_size + intermediate_size) + query_key_gates = F.silu(query_key_gates) + + # split into two different tensors: one for Q/K usage and the other for gating self-attention + query_key, attention_gate = torch.split( + query_key_gates, [self.config.shared_representation_size, self.config.intermediate_size], dim=-1 + ) + + # (sequence_length X batch_size X shared_representation_size) + # -> (sequence_length X batch_size X 1 X shared_representation_size) + # -> (sequence_length X batch_size X 2 X shared_representation_size) + query_key = query_key.unsqueeze(2) * self.qk_weight + self.qk_bias + + # (sequence_length X batch_size X 2 X shared_representation_size) + # -> 2 tensors of (sequence_length X batch_size X shared_representation_size) + query, key = torch.unbind(query_key, dim=2) + + # (sequence_length X batch_size X dimension) + # -> (batch_size X sequence_length X dimension) + # where `dimension` is either shared_representation_size (queries and keys) or intermediate_size (values) + query = query.transpose(0, 1) + key = key.transpose(0, 1) + value = value.transpose(0, 1) + + if self.config.is_decoder: + # combine history and current to save updated state (if history is provided) + # when chunking is applied, the past states will be None at the end of the chunk, in + # which case, proceed as if no K/V history had been provided + # saved states are stored with shape (batch_size X sequence_length X dimension) + if prev_self_key is not None: + key = torch.cat([prev_self_key, key], dim=1) + if prev_self_value is not None: + value = torch.cat([prev_self_value, value], dim=1) + + # if not chunking, store as-is + if not self.config.use_chunking: + updated_self_key = key + updated_self_value = value + else: + curr_len = key.size(1) % self.config.chunk_size + if curr_len == 0: + # if we're chunking and have reached the end of a chunk, wipe out the saved state + updated_self_key = None + updated_self_value = None + else: + updated_self_key = key + updated_self_value = value + + ctx_len = key.size(1) # potentially differs from seq_len because of incremental decoding + if not self.config.use_chunking: + # if we're not chunking, treat the entire sequence as one long chunk + # (batch_size X sequence_length X dimension) -> (batch_size X 1 X sequence_length X dimension) + query = query.unsqueeze(1) + key = key.unsqueeze(1) + value = value.unsqueeze(1) + if padding_mask is not None: + # (batch_size X sequence_length) -> (batch_size X 1 X sequence_length) + padding_mask = padding_mask.unsqueeze(1) + else: + # otherwise, split the sequences in the batch into `n_chunks` chunks of size `chunk_size` + if seq_len < self.config.chunk_size: + query = query.unsqueeze(1) + else: + # (batch_size X sequence_length X dimension) -> (batch_size X n_chunks X chunk_size X dimension) + n_chunks = seq_len // self.config.chunk_size + query = query.reshape(bsz, n_chunks, self.config.chunk_size, self.config.shared_representation_size) + + if ctx_len < self.config.chunk_size: + key = key.unsqueeze(1) + value = value.unsqueeze(1) + if padding_mask is not None: + padding_mask = padding_mask.unsqueeze(1) + else: + # (batch_size X sequence_length X dimension) -> (batch_size X n_chunks X chunk_size X dimension) + n_chunks = ctx_len // self.config.chunk_size + key = key.reshape(bsz, n_chunks, self.config.chunk_size, self.config.shared_representation_size) + value = value.reshape(bsz, n_chunks, self.config.chunk_size, self.config.intermediate_size) + if padding_mask is not None: + padding_mask = padding_mask.view(bsz, n_chunks, self.config.chunk_size) + + # this is in the original Mega implementation to work around fork/join parallelism not supporting optional types + if padding_mask is not None and padding_mask.dim() == 0: + padding_mask = None + + attn_weights = self.attention_function(query, key, padding_mask=padding_mask, causal_mask=causal_mask) + + value = self.hidden_dropout(value, batch_first=True) + kernel = self.attention_dropout(attn_weights) + + # (batch_size x n_chunks x chunk_size x intermediate_size) -> (sequence_length X batch_size X intermediate_size) + weighted_self_output = ( + torch.matmul(kernel, value).view(bsz, seq_len, self.config.intermediate_size).transpose(0, 1) + ) + + # (sequence_length X batch_size X intermediate_size) -> (sequence_length X batch_size X hidden_size) + weighted_self_output = self.activation(intermediate_state + self.h_proj(weighted_self_output * attention_gate)) + weighted_self_output = self.dropout(weighted_self_output) + # (sequence_length X batch_size X hidden_size) + out = torch.addcmul(residual, residual_weight, weighted_self_output - residual) + + if not self.config.normalize_before_mega: + out = self.norm(out) + + return_values = (out, attn_weights) if output_attentions else (out,) + + if self.config.is_decoder: + return_values = return_values + (updated_self_key, updated_self_value, updated_ema_state) + + return return_values + + +class MegaNormalizedFeedForwardNetwork(nn.Module): + """ + Normalized feed-forward network used in Mega blocks. Left as-is from original Mega repo aside from retrieving args + from Hugging Face config + """ + + def __init__(self, config: MegaConfig): + super().__init__() + + self.config = config + self.hidden_dim = config.nffn_hidden_size + self.act_fn = config.activation + self.activation = ACT2FN[config.activation] + + self.dropout = MegaDropout(self.config.dropout_prob, is_featurewise=self.config.use_feature_dropout) + self.hidden_dropout = MegaDropout( + self.config.nffn_activation_dropout_prob, is_featurewise=self.config.use_feature_dropout + ) + + self.prenorm = self.config.normalize_before_ffn + self.norm = MegaSequenceNorm( + self.config.normalization_type, self.config.hidden_size, affine=self.config.norm_affine + ) + + self.fc1 = nn.Linear(self.config.hidden_size, self.config.nffn_hidden_size) + self.fc2 = nn.Linear(self.config.nffn_hidden_size, self.config.hidden_size) + + def forward(self, inputs): + residual = inputs + + if self.prenorm: + inputs = self.norm(inputs) + + hidden = self.activation(self.fc1(inputs)) + hidden = self.hidden_dropout(hidden) + output = self.fc2(hidden) + output = self.dropout(output) + output = output + residual + + if not self.prenorm: + output = self.norm(output) + + return output + + +class MegaBlock(nn.Module): + def __init__(self, config: MegaConfig): + super().__init__() + self.seq_len_dim = 1 + self.mega_layer = MegaMovingAverageGatedAttention(config) + self.nffn = MegaNormalizedFeedForwardNetwork(config) if config.use_normalized_ffn else None + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + if not self.is_decoder: + raise ValueError(f"{self} should be used as a decoder model if cross attention is added") + self.cross_attn = MegaGatedCrossAttention(config) + else: + self.cross_attn = None + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + causal_mask: Optional[torch.LongTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[torch.FloatTensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor]: + """ + A single Mega layer: either encoder or decoder, with optional cross-attention and optional normalized + feed-forward layer + + Args: + hidden_states (`torch.Tensor` of shape `(target_sequence_length, batch_size, hidden_size)`): + Hidden states to be updated by the Mega block + attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indicates which entries in the self/target sequence are to be ignored (mostly due to padding), where + elements are either 1 for *not masked* or 0 for *masked*. Causal attention is enforced internally. + causal_mask (`torch.LongTensor` of shape `(sequence_length, sequence_length)`, *optional*): + Indicates which inputs are to be ignored due to causal attention, where elements are either 1 for *not + masked* or 0 for *masked* + encoder_hidden_states (`torch.Tensor`, of shape `(source_sequence_length, batch_size, hidden_size)`, *optional*): + Encoder hidden states to be used for cross-attention (and required for encoder-decoder model setup) + encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, source_sequence_length)`, *optional*): + Indicates which entries in the cross/source sequence are to be ignored (mostly due to padding), where + elements are either 1 for *not masked* or 0 for *masked*. + past_key_value (`tuple(torch.Tensor)`, *optional*): + The hidden states returned from the previous timestep during incremental decoding; expects that + self-attention key, value, and EMA states are the first 3 entries in the tuple, and (if doing + cross-attention) cross-attention key and value are the last 2 entries in the tuple + output_attentions (`bool`, default `False`): + Whether to return self-attention weights + use_cache (`bool`, default `False`): + Whether to perfom incremental decoding; uses `past_key_value` as prior state, and returns the updated + states for use in the next step + + Returns: + `tuple(torch.FloatTensor)` containing various elements depending on configuration ([`MegaConfig`]) and + inputs: + - **hidden_states** (`torch.FloatTensor` of shape `(target_sequence_length, batch_size, hidden_size)`) -- + Hidden states from target sequence updated by Mega + - **self_attn_weights** (*optional*, returned when `output_attentions=True`) `torch.FloatTensor` of shape + `(batch_size, 1, target_sequence_length, target_sequence_length)` -- The self-attention weights + corresponding to how each token in the input sequence attends to every other token + - **cross_attn_weights** (*optional*, returned when `output_attentions=True` and + `config.add_cross_attention=True`) `torch.FloatTensor` of shape `(batch_size, source_sequence_length, + target_sequence_length)` -- Pairwise cross-attention weights between every entry in the source sequence + and target sequence + - **self_key** (*optional*, returned when `use_cache=True`) `torch.FloatTensor` of shape `(batch_size, + sequence_length, config.shared_representation_size)` -- The self-attention key state for use in the next + step of incremental decoding + - **self_value** (*optional*, returned when `use_cache=True`) `torch.FloatTensor` of shape `(batch_size, + sequence_length, config.hidden_size)` -- The self-attention value state for use in the next step of + incremental decoding + - **self_ema_state** (*optional*, returned when `use_cache=True`) `torch.FloatTensor` of shape + `(batch_size, config.ndim)` The incremental EMA state for use in the next step of incremental decoding. + - **cross_key** (*optional*, returned when `use_cache=True` and `config.is_decoder=True`) + `torch.FloatTensor` of shape `(batch_size, source_sequence_length, config.shared_representation_size)` -- + The cross-attention key state for use in the next step of incremental decoding + - **cross_value** (*optional*, returned when `use_cache=True` and `config.is_decoder=True`) + `torch.FloatTensor` of shape `(batch_size, source_sequence_length, config.hidden_size)` -- The + cross-attention value state for use in the next step of incremental decoding + """ + + # incremental decoding in the MegaMultiDimensionDampedEma module requires that the attention mask has the same + # sequence length as the input tensor; if we're caching incremental states, we assume the input + # sequence length is 1 (Mega will break otherwise), so we take the padding mask for the final + # token in the input (mask is received as [batch X sequence length]) + if use_cache and (past_key_value is not None) and (attention_mask is not None): + mega_padding_mask = attention_mask[:, -1].unsqueeze(-1) + else: + mega_padding_mask = attention_mask + + mega_outputs = self.mega_layer( + input=hidden_states, + padding_mask=mega_padding_mask, + causal_mask=causal_mask, + past_key_values=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + new_hidden_states = mega_outputs[0] + self_key, self_value, self_ema_state = mega_outputs[-3:] if use_cache else (None, None, None) + self_attention_weights = mega_outputs[1] if output_attentions else None + + # optional cross attention + if self.cross_attn is not None: + if encoder_hidden_states is None: + raise ValueError("Requested cross-attention without providing encoder hidden states") + + cross_attn_outputs = self.cross_attn( + query=new_hidden_states, + key=encoder_hidden_states, + value=encoder_hidden_states, + key_padding_mask=encoder_attention_mask, + past_key_values=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + # update the hidden state from cross attention + new_hidden_states = cross_attn_outputs[0] + # store cross-attention k/v if caching + cross_key, cross_value = cross_attn_outputs[-2:] if use_cache else (None, None) + cross_attention_weights = cross_attn_outputs[1] if output_attentions else None + + # optional NFFN follows cross attention + if self.nffn is not None: + new_hidden_states = self.nffn(new_hidden_states) + + outs = (new_hidden_states,) + if output_attentions: + outs = outs + (self_attention_weights,) + if self.cross_attn is not None: + outs = outs + (cross_attention_weights,) + + if use_cache: + new_key_values = ( + self_key, + self_value, + self_ema_state, + ) + if self.cross_attn is not None: + new_key_values = new_key_values + (cross_key, cross_value) + + outs = outs + (new_key_values,) + + return outs + + +# copied from transformers.models.roberta.modeling_roberta.RobertaPooler with Roberta->Mega +class MegaPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class MegaPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = MegaConfig + base_model_prefix = "mega" + supports_gradient_checkpointing = False + _no_split_modules = ["MegaMovingAverageGatedAttention"] + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, MegaMultiDimensionDampedEma): + with torch.no_grad(): + # delta & alpha + nn.init.normal_(module.damping_factor, mean=0.0, std=self.config.ema_delta_alpha_range) + nn.init.normal_(module.decay_factor, mean=0.0, std=self.config.ema_delta_alpha_range) + # beta [1, -1, 1, -1, ...] seems more stable. + val = torch.ones(self.config.ema_projection_size, 1) + if self.config.ema_projection_size > 1: + idx = torch.tensor(list(range(1, self.config.ema_projection_size, 2))) + val.index_fill_(0, idx, -1.0) + module.ema_expansion_matrix.normal_(mean=0.0, std=self.config.ema_beta_range).add_(val) + # gamma & omega + nn.init.normal_(module.kernel_projection_matrix, mean=0.0, std=self.config.ema_gamma_omega_range) + nn.init.normal_(module.residual_weight, mean=0.0, std=self.config.ema_gamma_omega_range) + elif isinstance(module, MegaSimpleRelativePositionalBias): + nn.init.normal_(module.rel_pos_bias, mean=0.0, std=self.config.initializer_range) + elif isinstance(module, MegaRotaryRelativePositionalBias): + nn.init.normal_(module.alpha, mean=0.0, std=self.config.initializer_range) + nn.init.normal_(module.b_param, mean=0.0, std=self.config.initializer_range) + elif isinstance(module, MegaScaleNorm): + if self.config.norm_affine: + nn.init.constant_(module.scalar, 1.0) + elif isinstance(module, MegaRMSNorm): + if self.config.norm_affine: + nn.init.constant_(module.weight, 1.0) + elif isinstance(module, MegaMovingAverageGatedAttention): + # linear layers covered separately by the generic nn.Linear init below + nn.init.normal_(module.qk_weight, mean=0.0, std=self.config.initializer_range) + nn.init.constant_(module.qk_bias, 0.0) + elif isinstance(module, nn.Linear): + # initializes all linear layers in the entire network + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +MEGA_START_DOCSTRING = r""" + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`MegaConfig`]): Model configuration class with all the parameters of the + model. Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +MEGA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + This parameter can only be used when the model is initialized with `add_token_type_embeddings` parameter + set to `True`. All the value in this tensor should be always < config.type_vocab_size. + + [What are token type IDs?](../glossary#token-type-ids) + inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare MEGA Model transformer outputting raw hidden-states without any specific head on top.", + MEGA_START_DOCSTRING, +) +class MegaModel(MegaPreTrainedModel): + """ + + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added after self-attention, following the architecture described in *Mega: Moving Average + Equipped Gated Attention*_ by Xuezhe Ma, Chunting Zhou, Xiang Kong, Junxian He, Liangke Gui, Graham Neubig, + Jonathan May, and Luke Zettlemoyer + + To behave as a decoder the model needs to be initialized with the `is_decoder` argument of the configuration set to + `True` and `bidirectional` set to `False`. To be used in a Seq2Seq model, the model needs to initialized with both + `is_decoder=True` and `bidirectional=False` argument as well as `add_cross_attention` set to `True`; an + `encoder_hidden_states` is then expected as an input to the forward pass. + + .. _*Mega: Moving Average Equipped Gated Attention*: https://arxiv.org/abs/2209.10655 + + """ + + def __init__(self, config: MegaConfig, add_pooling_layer=True): + super().__init__(config) + self.config = config + + self.embedding_layer = MegaEmbeddings(config) + self.layers = nn.ModuleList([MegaBlock(config) for _ in range(config.num_hidden_layers)]) + + self.pooler = MegaPooler(config) if add_pooling_layer else None + + # Initialize weights and apply final processing (retained from RoBERTa code) + self.post_init() + + def get_input_embeddings(self): + return self.embedding_layer.word_embeddings + + def set_input_embeddings(self, value): + self.embedding_layer.word_embeddings = value + + @add_start_docstrings_to_model_forward(MEGA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPoolingAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + device = input_ids.device + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + device = inputs_embeds.device + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if self.config.use_chunking: + input_shape = torch.tensor([input_shape[0], self.config.chunk_size]) + + batch_size, sequence_length = input_shape + + if self.config.use_chunking and (sequence_length > self.config.chunk_size): + if sequence_length % self.config.chunk_size != 0: + raise ValueError( + f"config.use_chunking is activated; input sequence length must be shorter than or a multiple of config.chunk_size\nreceived sequence length of {sequence_length} with chunk size {self.config.chunk_size}" + ) + + if self.config.is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + + # Mega expects the causal mask to be a 2D square matrix of (from) x (to) over the input sequence length + # the HF utility function generates a 3D causal mask which includes batch size, so we'll create a dummy + # mask with the correct device and all ones + temp_mask_for_extension = torch.ones((1, sequence_length), dtype=torch.long, device=device) + causal_mask = self.create_extended_attention_mask_for_decoder(input_shape, temp_mask_for_extension) + + # get rid of batch dimension in the generated mask; result is (sequence_length X sequence_length) + causal_mask = causal_mask.squeeze(0) + else: + use_cache = False + causal_mask = None + + # if using cache, make sure we have a tuple of tuples which matches the length of our hidden layers + if (past_key_values is not None) and (len(past_key_values) != self.config.num_hidden_layers): + raise ValueError( + f"Received past key/value cache with size mismatch; expected {self.config.num_hidden_layers}, received {len(past_key_values)}" + ) + + # get embeddings (batch X sequence length X embed dim) + embedding_output = self.embedding_layer( + input_ids=input_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds + ) + + # transpose for Mega --> (seq len X batch X embed dim) + hidden_states = embedding_output.transpose(0, 1) + + # we expect encoder hidden states to also have batch first in line + # with typical Hugging Face behavior (which is also how we return them) + # Mega expects sequence length first, so do the same transpose here + if encoder_hidden_states is not None: + encoder_hidden_states = encoder_hidden_states.transpose(0, 1) + + # pass through mega layers + all_hidden_states = (embedding_output,) if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + next_decoder_cache = () if use_cache else None + for i, mega_layer in enumerate(self.layers): + current_decoder_cache = past_key_values[i] if past_key_values is not None else None + mega_outputs = mega_layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + causal_mask=causal_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=current_decoder_cache, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = mega_outputs[0] + if output_hidden_states: + # store layer-wise hidden states in the way that the user expects + # (seq len X batch X embed dim) --> (batch X seq len X embed dim) + all_hidden_states += (hidden_states.transpose(0, 1),) + if output_attentions: + self_attn_weights = mega_outputs[1] + all_self_attentions += (self_attn_weights,) + if self.config.add_cross_attention: + cross_attn_weights = mega_outputs[2] + all_cross_attentions += (cross_attn_weights,) + if use_cache: + updated_cache = mega_outputs[-1] + next_decoder_cache += (updated_cache,) + + # transpose final hidden states + hidden_states = hidden_states.transpose(0, 1) + + # optional pooling layer + pooled_output = self.pooler(hidden_states) if self.pooler is not None else None + + if not return_dict: + return (hidden_states, pooled_output) + ( + all_hidden_states, + next_decoder_cache, + all_self_attentions, + all_cross_attentions, + ) + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=hidden_states, + pooler_output=pooled_output, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +@add_start_docstrings( + """MEGA Model with a `language modeling` head on top for CLM fine-tuning.""", MEGA_START_DOCSTRING +) +class MegaForCausalLM(MegaPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config: MegaConfig): + super().__init__(config) + + if not config.is_decoder: + logger.warning("If you want to use `MegaForCausalLM` as a standalone, add `is_decoder=True.`") + + self.mega = MegaModel(config, add_pooling_layer=False) + + if config.add_lm_hidden_dense_layer: + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.hidden_activation = nn.Tanh() + else: + self.dense = None + self.hidden_activation = None + + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + @add_start_docstrings_to_model_forward(MEGA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + past_key_values: Tuple[Tuple[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are + ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, MegaForCausalLM, AutoConfig + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("mnaylor/mega-base-wikitext") + >>> config = AutoConfig.from_pretrained("mnaylor/mega-base-wikitext") + >>> config.is_decoder = True + >>> config.bidirectional = False + >>> model = MegaForCausalLM.from_pretrained( + ... "mnaylor/mega-base-wikitext", config=config, ignore_mismatched_sizes=True + ... ) + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> prediction_logits = outputs.logits + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if labels is not None: + use_cache = False + + outputs = self.mega( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + if self.dense is not None: + sequence_output = self.dense(sequence_output) + sequence_output = self.hidden_activation(sequence_output) + + prediction_scores = self.lm_head(sequence_output) + + lm_loss = None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss() + lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs): + input_shape = input_ids.shape + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_shape) + + # cut decoder_input_ids if past is used + if past_key_values is not None: + input_ids = input_ids[:, -1:] + + return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} + + def _reorder_cache(self, past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + + +@add_start_docstrings("""MEGA Model with a `language modeling` head on top.""", MEGA_START_DOCSTRING) +class MegaForMaskedLM(MegaPreTrainedModel): + _tied_weights_keys = ["mlm_head.weight"] + + def __init__(self, config: MegaConfig): + super().__init__(config) + + if config.is_decoder: + logger.warning( + "If you want to use `MegaForMaskedLM`, set `config.is_decoder=False` for " + "bi-directional self-attention." + ) + + self.mega = MegaModel(config, add_pooling_layer=False) + if config.add_lm_hidden_dense_layer: + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.hidden_activation = nn.Tanh() + else: + self.dense = None + self.hidden_activation = None + self.mlm_head = nn.Linear(config.hidden_size, config.vocab_size) + self.dropout = nn.Dropout(config.dropout_prob) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.mlm_head + + def set_output_embeddings(self, new_embeddings): + self.mlm_head = new_embeddings + + @add_start_docstrings_to_model_forward(MEGA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MaskedLMOutput, + config_class=_CONFIG_FOR_DOC, + mask="", + expected_output="' Paris'", + expected_loss=0.1, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + kwargs (`Dict[str, any]`, optional, defaults to *{}*): + Used to hide legacy arguments that have been deprecated. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.mega( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + if self.dense is not None: + sequence_output = self.dense(sequence_output) + sequence_output = self.hidden_activation(sequence_output) + prediction_scores = self.mlm_head(sequence_output) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + MEGA Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled + output) e.g. for GLUE tasks. + """, + MEGA_START_DOCSTRING, +) +class MegaForSequenceClassification(MegaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + + self.mega = MegaModel(config, add_pooling_layer=False) + self.classifier = MegaClassificationHead(config) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(MEGA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=SequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.mega( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + MEGA Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a + softmax) e.g. for RocStories/SWAG tasks. + """, + MEGA_START_DOCSTRING, +) +class MegaForMultipleChoice(MegaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.mega = MegaModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, 1) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(MEGA_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MultipleChoiceModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., + num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See + `input_ids` above) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] + + flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None + flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + flat_inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) + + outputs = self.mega( + flat_input_ids, + token_type_ids=flat_token_type_ids, + attention_mask=flat_attention_mask, + inputs_embeds=flat_inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + reshaped_logits = logits.view(-1, num_choices) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + + if not return_dict: + output = (reshaped_logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return MultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + MEGA Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + MEGA_START_DOCSTRING, +) +class MegaForTokenClassification(MegaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.mega = MegaModel(config, add_pooling_layer=False) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(MEGA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.mega( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +# copied from transformers.models.roberta.modeling_roberta.RobertaClassificationHead with Roberta->Mega +class MegaClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) + self.out_proj = nn.Linear(config.hidden_size, config.num_labels) + + def forward(self, features, **kwargs): + x = features[:, 0, :] # take token (equiv. to [CLS]) + x = self.dropout(x) + x = self.dense(x) + x = torch.tanh(x) + x = self.dropout(x) + x = self.out_proj(x) + return x + + +@add_start_docstrings( + """ + MEGA Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear + layers on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + MEGA_START_DOCSTRING, +) +class MegaForQuestionAnswering(MegaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.mega = MegaModel(config, add_pooling_layer=False) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(MEGA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=QuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.mega( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "MegaForCausalLM", + "MegaForMaskedLM", + "MegaForMultipleChoice", + "MegaForQuestionAnswering", + "MegaForSequenceClassification", + "MegaForTokenClassification", + "MegaModel", + "MegaPreTrainedModel", +] diff --git a/docs/transformers/build/lib/transformers/models/deprecated/mmbt/__init__.py b/docs/transformers/build/lib/transformers/models/deprecated/mmbt/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03b556e2eddf6d5f81f6f4a15346596dd5a32f85 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/mmbt/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ....utils import _LazyModule +from ....utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_mmbt import * + from .modeling_mmbt import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/build/lib/transformers/models/deprecated/mmbt/configuration_mmbt.py b/docs/transformers/build/lib/transformers/models/deprecated/mmbt/configuration_mmbt.py new file mode 100644 index 0000000000000000000000000000000000000000..1c58e4e6cdd0d0a8c87b9b94ca896c87970b5654 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/mmbt/configuration_mmbt.py @@ -0,0 +1,45 @@ +# coding=utf-8 +# Copyright (c) Facebook, Inc. and its affiliates. +# Copyright (c) HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MMBT configuration""" + +from ....utils import logging + + +logger = logging.get_logger(__name__) + + +class MMBTConfig: + """ + This is the configuration class to store the configuration of a [`MMBTModel`]. It is used to instantiate a MMBT + model according to the specified arguments, defining the model architecture. + + Args: + config ([`PreTrainedConfig`]): + Config of the underlying Transformer models. Its values are copied over to use a single config. + num_labels (`int`, *optional*): + Size of final Linear layer for classification. + modal_hidden_size (`int`, *optional*, defaults to 2048): + Embedding dimension of the non-text modality encoder. + """ + + def __init__(self, config, num_labels=None, modal_hidden_size=2048): + self.__dict__ = config.__dict__ + self.modal_hidden_size = modal_hidden_size + if num_labels: + self.num_labels = num_labels + + +__all__ = ["MMBTConfig"] diff --git a/docs/transformers/build/lib/transformers/models/deprecated/mmbt/modeling_mmbt.py b/docs/transformers/build/lib/transformers/models/deprecated/mmbt/modeling_mmbt.py new file mode 100644 index 0000000000000000000000000000000000000000..45ae577f7fced2d276f4f54c5bf859e27e08ebae --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/mmbt/modeling_mmbt.py @@ -0,0 +1,410 @@ +# coding=utf-8 +# Copyright (c) Facebook, Inc. and its affiliates. +# Copyright (c) HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch MMBT model.""" + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss, MSELoss + +from ....modeling_outputs import BaseModelOutputWithPooling, SequenceClassifierOutput +from ....modeling_utils import ModuleUtilsMixin +from ....utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "MMBTConfig" + + +class ModalEmbeddings(nn.Module): + """Generic Modal Embeddings which takes in an encoder, and a transformer embedding.""" + + def __init__(self, config, encoder, embeddings): + super().__init__() + self.config = config + self.encoder = encoder + self.proj_embeddings = nn.Linear(config.modal_hidden_size, config.hidden_size) + self.position_embeddings = embeddings.position_embeddings + self.token_type_embeddings = embeddings.token_type_embeddings + self.word_embeddings = embeddings.word_embeddings + self.LayerNorm = embeddings.LayerNorm + self.dropout = nn.Dropout(p=config.hidden_dropout_prob) + + def forward(self, input_modal, start_token=None, end_token=None, position_ids=None, token_type_ids=None): + token_embeddings = self.proj_embeddings(self.encoder(input_modal)) + seq_length = token_embeddings.size(1) + + if start_token is not None: + start_token_embeds = self.word_embeddings(start_token) + seq_length += 1 + token_embeddings = torch.cat([start_token_embeds.unsqueeze(1), token_embeddings], dim=1) + + if end_token is not None: + end_token_embeds = self.word_embeddings(end_token) + seq_length += 1 + token_embeddings = torch.cat([token_embeddings, end_token_embeds.unsqueeze(1)], dim=1) + + if position_ids is None: + position_ids = torch.arange(seq_length, dtype=torch.long, device=input_modal.device) + position_ids = position_ids.unsqueeze(0).expand(input_modal.size(0), seq_length) + + if token_type_ids is None: + token_type_ids = torch.zeros( + (input_modal.size(0), seq_length), dtype=torch.long, device=input_modal.device + ) + + position_embeddings = self.position_embeddings(position_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + embeddings = token_embeddings + position_embeddings + token_type_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +MMBT_START_DOCSTRING = r""" + MMBT model was proposed in [Supervised Multimodal Bitransformers for Classifying Images and + Text](https://github.com/facebookresearch/mmbt) by Douwe Kiela, Suvrat Bhooshan, Hamed Firooz, Davide Testuggine. + It's a supervised multimodal bitransformer model that fuses information from text and other image encoders, and + obtain state-of-the-art performance on various multimodal classification benchmark tasks. + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`MMBTConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. + transformer (`nn.Module`): A text transformer that is used by MMBT. + It should have embeddings, encoder, and pooler attributes. + encoder (`nn.Module`): Encoder for the second modality. + It should take in a batch of modal inputs and return k, n dimension embeddings. +""" + +MMBT_INPUTS_DOCSTRING = r""" + Args: + input_modal (`torch.FloatTensor` of shape `(batch_size, ***)`): + The other modality data. It will be the shape that the encoder for that type expects. e.g. With an Image + Encoder, the shape would be (batch_size, channels, height, width) + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. It does not expect [CLS] token to be added as it's + appended to the end of other modality embeddings. Indices can be obtained using [`AutoTokenizer`]. See + [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + modal_start_tokens (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Optional start token to be added to Other Modality Embedding. [CLS] Most commonly used for classification + tasks. + modal_end_tokens (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Optional end token to be added to Other Modality Embedding. [SEP] Most commonly used. + attention_mask (*optional*) `torch.FloatTensor` of shape `(batch_size, sequence_length)`: + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (*optional*) `torch.LongTensor` of shape `(batch_size, sequence_length)`: + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + modal_token_type_ids (*optional*) `torch.LongTensor` of shape `(batch_size, modal_sequence_length)`: + Segment token indices to indicate different portions of the non-text modality. The embeddings from these + tokens will be summed with the respective token embeddings for the non-text modality. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + modal_position_ids (`torch.LongTensor` of shape `(batch_size, modal_sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings for the non-text modality. + Selected in the range `[0, config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, embedding_dim)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare MMBT Model outputting raw hidden-states without any specific head on top.", + MMBT_START_DOCSTRING, +) +class MMBTModel(nn.Module, ModuleUtilsMixin): + def __init__(self, config, transformer, encoder): + super().__init__() + self.config = config + self.transformer = transformer + self.modal_encoder = ModalEmbeddings(config, encoder, transformer.embeddings) + + @add_start_docstrings_to_model_forward(MMBT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_modal, + input_ids=None, + modal_start_tokens=None, + modal_end_tokens=None, + attention_mask=None, + token_type_ids=None, + modal_token_type_ids=None, + position_ids=None, + modal_position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Returns: + + Examples: + + ```python + # For example purposes. Not runnable. + transformer = BertModel.from_pretrained("google-bert/bert-base-uncased") + encoder = ImageEncoder(args) + mmbt = MMBTModel(config, transformer, encoder) + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_txt_shape = input_ids.size() + elif inputs_embeds is not None: + input_txt_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + modal_embeddings = self.modal_encoder( + input_modal, + start_token=modal_start_tokens, + end_token=modal_end_tokens, + position_ids=modal_position_ids, + token_type_ids=modal_token_type_ids, + ) + + input_modal_shape = modal_embeddings.size()[:-1] + + if token_type_ids is None: + token_type_ids = torch.ones(input_txt_shape, dtype=torch.long, device=device) + + txt_embeddings = self.transformer.embeddings( + input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds + ) + + embedding_output = torch.cat([modal_embeddings, txt_embeddings], 1) + + input_shape = embedding_output.size()[:-1] + + if attention_mask is None: + attention_mask = torch.ones(input_shape, device=device) + else: + attention_mask = torch.cat( + [torch.ones(input_modal_shape, device=device, dtype=torch.long), attention_mask], dim=1 + ) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(input_shape, device=device) + else: + encoder_attention_mask = torch.cat( + [torch.ones(input_modal_shape, device=device), encoder_attention_mask], dim=1 + ) + + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + encoder_outputs = self.transformer.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = encoder_outputs[0] + pooled_output = self.transformer.pooler(sequence_output) + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + +@add_start_docstrings( + """ + MMBT Model with a sequence classification/regression head on top (a linear layer on top of the pooled output) + """, + MMBT_START_DOCSTRING, + MMBT_INPUTS_DOCSTRING, +) +class MMBTForClassification(nn.Module): + r""" + **labels**: (*optional*) `torch.LongTensor` of shape `(batch_size,)`: + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + Returns: *Tuple* comprising various elements depending on the configuration (config) and inputs: **loss**: + (*optional*, returned when `labels` is provided) `torch.FloatTensor` of shape `(1,)`: Classification (or + regression if config.num_labels==1) loss. **logits**: + `torch.FloatTensor` of shape `(batch_size, config.num_labels)` Classification (or regression if + config.num_labels==1) scores (before SoftMax). + **hidden_states**: (*optional*, returned when `output_hidden_states=True`) list of `torch.FloatTensor` (one for + the output of each layer + the output of the embeddings) of shape `(batch_size, sequence_length, hidden_size)`: + Hidden-states of the model at the output of each layer plus the initial embedding outputs. **attentions**: + (*optional*, returned when `output_attentions=True`) list of `torch.FloatTensor` (one for each layer) of shape + `(batch_size, num_heads, sequence_length, sequence_length)`: Attentions weights after the attention softmax, used + to compute the weighted average in the self-attention heads. + + Examples: + + ```python + # For example purposes. Not runnable. + transformer = BertModel.from_pretrained("google-bert/bert-base-uncased") + encoder = ImageEncoder(args) + model = MMBTForClassification(config, transformer, encoder) + outputs = model(input_modal, input_ids, labels=labels) + loss, logits = outputs[:2] + ```""" + + def __init__(self, config, transformer, encoder): + super().__init__() + self.num_labels = config.num_labels + + self.mmbt = MMBTModel(config, transformer, encoder) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + def forward( + self, + input_modal, + input_ids=None, + modal_start_tokens=None, + modal_end_tokens=None, + attention_mask=None, + token_type_ids=None, + modal_token_type_ids=None, + position_ids=None, + modal_position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + return_dict=None, + ): + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.mmbt( + input_modal=input_modal, + input_ids=input_ids, + modal_start_tokens=modal_start_tokens, + modal_end_tokens=modal_end_tokens, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + modal_token_type_ids=modal_token_type_ids, + position_ids=position_ids, + modal_position_ids=modal_position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.num_labels == 1: + # We are doing regression + loss_fct = MSELoss() + loss = loss_fct(logits.view(-1), labels.view(-1)) + else: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = ["MMBTForClassification", "MMBTModel", "ModalEmbeddings"] diff --git a/docs/transformers/build/lib/transformers/models/deprecated/nat/__init__.py b/docs/transformers/build/lib/transformers/models/deprecated/nat/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c5373969ce7831491b0d5fa5495078fb1d3f6e4e --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/nat/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ....utils import _LazyModule +from ....utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_nat import * + from .modeling_nat import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/build/lib/transformers/models/deprecated/nat/configuration_nat.py b/docs/transformers/build/lib/transformers/models/deprecated/nat/configuration_nat.py new file mode 100644 index 0000000000000000000000000000000000000000..85961aa2fe8d0bc547b4919855ff8a28815a0fcd --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/nat/configuration_nat.py @@ -0,0 +1,148 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Neighborhood Attention Transformer model configuration""" + +from ....configuration_utils import PretrainedConfig +from ....utils import logging +from ....utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices + + +logger = logging.get_logger(__name__) + + +class NatConfig(BackboneConfigMixin, PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`NatModel`]. It is used to instantiate a Nat model + according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the Nat + [shi-labs/nat-mini-in1k-224](https://huggingface.co/shi-labs/nat-mini-in1k-224) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + patch_size (`int`, *optional*, defaults to 4): + The size (resolution) of each patch. NOTE: Only patch size of 4 is supported at the moment. + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + embed_dim (`int`, *optional*, defaults to 64): + Dimensionality of patch embedding. + depths (`List[int]`, *optional*, defaults to `[3, 4, 6, 5]`): + Number of layers in each level of the encoder. + num_heads (`List[int]`, *optional*, defaults to `[2, 4, 8, 16]`): + Number of attention heads in each layer of the Transformer encoder. + kernel_size (`int`, *optional*, defaults to 7): + Neighborhood Attention kernel size. + mlp_ratio (`float`, *optional*, defaults to 3.0): + Ratio of MLP hidden dimensionality to embedding dimensionality. + qkv_bias (`bool`, *optional*, defaults to `True`): + Whether or not a learnable bias should be added to the queries, keys and values. + hidden_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings and encoder. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + drop_path_rate (`float`, *optional*, defaults to 0.1): + Stochastic depth rate. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder. If string, `"gelu"`, `"relu"`, + `"selu"` and `"gelu_new"` are supported. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the layer normalization layers. + layer_scale_init_value (`float`, *optional*, defaults to 0.0): + The initial value for the layer scale. Disabled if <=0. + out_features (`List[str]`, *optional*): + If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc. + (depending on how many stages the model has). If unset and `out_indices` is set, will default to the + corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the + same order as defined in the `stage_names` attribute. + out_indices (`List[int]`, *optional*): + If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how + many stages the model has). If unset and `out_features` is set, will default to the corresponding stages. + If unset and `out_features` is unset, will default to the last stage. Must be in the + same order as defined in the `stage_names` attribute. + + Example: + + ```python + >>> from transformers import NatConfig, NatModel + + >>> # Initializing a Nat shi-labs/nat-mini-in1k-224 style configuration + >>> configuration = NatConfig() + + >>> # Initializing a model (with random weights) from the shi-labs/nat-mini-in1k-224 style configuration + >>> model = NatModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "nat" + + attribute_map = { + "num_attention_heads": "num_heads", + "num_hidden_layers": "num_layers", + } + + def __init__( + self, + patch_size=4, + num_channels=3, + embed_dim=64, + depths=[3, 4, 6, 5], + num_heads=[2, 4, 8, 16], + kernel_size=7, + mlp_ratio=3.0, + qkv_bias=True, + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + drop_path_rate=0.1, + hidden_act="gelu", + initializer_range=0.02, + layer_norm_eps=1e-5, + layer_scale_init_value=0.0, + out_features=None, + out_indices=None, + **kwargs, + ): + super().__init__(**kwargs) + + self.patch_size = patch_size + self.num_channels = num_channels + self.embed_dim = embed_dim + self.depths = depths + self.num_layers = len(depths) + self.num_heads = num_heads + self.kernel_size = kernel_size + self.mlp_ratio = mlp_ratio + self.qkv_bias = qkv_bias + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.drop_path_rate = drop_path_rate + self.hidden_act = hidden_act + self.layer_norm_eps = layer_norm_eps + self.initializer_range = initializer_range + # we set the hidden_size attribute in order to make Nat work with VisionEncoderDecoderModel + # this indicates the channel dimension after the last stage of the model + self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1)) + self.layer_scale_init_value = layer_scale_init_value + self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)] + self._out_features, self._out_indices = get_aligned_output_features_output_indices( + out_features=out_features, out_indices=out_indices, stage_names=self.stage_names + ) + + +__all__ = ["NatConfig"] diff --git a/docs/transformers/build/lib/transformers/models/deprecated/nat/modeling_nat.py b/docs/transformers/build/lib/transformers/models/deprecated/nat/modeling_nat.py new file mode 100644 index 0000000000000000000000000000000000000000..70ecffcf51ea9afa07baf52b32181d9a803b9871 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/nat/modeling_nat.py @@ -0,0 +1,953 @@ +# coding=utf-8 +# Copyright 2022 SHI Labs and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Neighborhood Attention Transformer model.""" + +import math +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ....activations import ACT2FN +from ....modeling_outputs import BackboneOutput +from ....modeling_utils import PreTrainedModel +from ....pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ....utils import ( + ModelOutput, + OptionalDependencyNotAvailable, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_natten_available, + logging, + replace_return_docstrings, + requires_backends, +) +from ....utils.backbone_utils import BackboneMixin +from .configuration_nat import NatConfig + + +if is_natten_available(): + from natten.functional import natten2dav, natten2dqkrpb +else: + + def natten2dqkrpb(*args, **kwargs): + raise OptionalDependencyNotAvailable() + + def natten2dav(*args, **kwargs): + raise OptionalDependencyNotAvailable() + + +logger = logging.get_logger(__name__) + +# General docstring +_CONFIG_FOR_DOC = "NatConfig" + +# Base docstring +_CHECKPOINT_FOR_DOC = "shi-labs/nat-mini-in1k-224" +_EXPECTED_OUTPUT_SHAPE = [1, 7, 7, 512] + +# Image classification docstring +_IMAGE_CLASS_CHECKPOINT = "shi-labs/nat-mini-in1k-224" +_IMAGE_CLASS_EXPECTED_OUTPUT = "tiger cat" + + +# drop_path and NatDropPath are from the timm library. + + +@dataclass +class NatEncoderOutput(ModelOutput): + """ + Nat encoder's outputs, with potential hidden states and attentions. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, hidden_size, height, width)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to + include the spatial dimensions. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class NatModelOutput(ModelOutput): + """ + Nat model's outputs that also contains a pooling of the last hidden states. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*, returned when `add_pooling_layer=True` is passed): + Average pooling of the last layer hidden-state. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, hidden_size, height, width)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to + include the spatial dimensions. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + pooler_output: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class NatImageClassifierOutput(ModelOutput): + """ + Nat outputs for image classification. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, hidden_size, height, width)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to + include the spatial dimensions. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + + +class NatEmbeddings(nn.Module): + """ + Construct the patch and position embeddings. + """ + + def __init__(self, config): + super().__init__() + + self.patch_embeddings = NatPatchEmbeddings(config) + + self.norm = nn.LayerNorm(config.embed_dim) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor]: + embeddings = self.patch_embeddings(pixel_values) + embeddings = self.norm(embeddings) + + embeddings = self.dropout(embeddings) + + return embeddings + + +class NatPatchEmbeddings(nn.Module): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, height, width, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__(self, config): + super().__init__() + patch_size = config.patch_size + num_channels, hidden_size = config.num_channels, config.embed_dim + self.num_channels = num_channels + + if patch_size == 4: + pass + else: + # TODO: Support arbitrary patch sizes. + raise ValueError("Dinat only supports patch size of 4 at the moment.") + + self.projection = nn.Sequential( + nn.Conv2d(self.num_channels, hidden_size // 2, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)), + nn.Conv2d(hidden_size // 2, hidden_size, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)), + ) + + def forward(self, pixel_values: Optional[torch.FloatTensor]) -> torch.Tensor: + _, num_channels, height, width = pixel_values.shape + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + embeddings = self.projection(pixel_values) + embeddings = embeddings.permute(0, 2, 3, 1) + + return embeddings + + +class NatDownsampler(nn.Module): + """ + Convolutional Downsampling Layer. + + Args: + dim (`int`): + Number of input channels. + norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`): + Normalization layer class. + """ + + def __init__(self, dim: int, norm_layer: nn.Module = nn.LayerNorm) -> None: + super().__init__() + self.dim = dim + self.reduction = nn.Conv2d(dim, 2 * dim, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) + self.norm = norm_layer(2 * dim) + + def forward(self, input_feature: torch.Tensor) -> torch.Tensor: + input_feature = self.reduction(input_feature.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) + input_feature = self.norm(input_feature) + return input_feature + + +def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, + however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the + layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the + argument. + """ + if drop_prob == 0.0 or not training: + return input + keep_prob = 1 - drop_prob + shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) + random_tensor.floor_() # binarize + output = input.div(keep_prob) * random_tensor + return output + + +class NatDropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: Optional[float] = None) -> None: + super().__init__() + self.drop_prob = drop_prob + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return drop_path(hidden_states, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return "p={}".format(self.drop_prob) + + +class NeighborhoodAttention(nn.Module): + def __init__(self, config, dim, num_heads, kernel_size): + super().__init__() + if dim % num_heads != 0: + raise ValueError( + f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})" + ) + + self.num_attention_heads = num_heads + self.attention_head_size = int(dim / num_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.kernel_size = kernel_size + + # rpb is learnable relative positional biases; same concept is used Swin. + self.rpb = nn.Parameter(torch.zeros(num_heads, (2 * self.kernel_size - 1), (2 * self.kernel_size - 1))) + + self.query = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias) + self.key = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias) + self.value = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 3, 1, 2, 4) + + def forward( + self, + hidden_states: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + query_layer = self.transpose_for_scores(self.query(hidden_states)) + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + # Apply the scale factor before computing attention weights. It's usually more efficient because + # attention weights are typically a bigger tensor compared to query. + # It gives identical results because scalars are commutable in matrix multiplication. + query_layer = query_layer / math.sqrt(self.attention_head_size) + + # Compute NA between "query" and "key" to get the raw attention scores, and add relative positional biases. + attention_scores = natten2dqkrpb(query_layer, key_layer, self.rpb, self.kernel_size, 1) + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + context_layer = natten2dav(attention_probs, value_layer, self.kernel_size, 1) + context_layer = context_layer.permute(0, 2, 3, 1, 4).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +class NeighborhoodAttentionOutput(nn.Module): + def __init__(self, config, dim): + super().__init__() + self.dense = nn.Linear(dim, dim) + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + return hidden_states + + +class NeighborhoodAttentionModule(nn.Module): + def __init__(self, config, dim, num_heads, kernel_size): + super().__init__() + self.self = NeighborhoodAttention(config, dim, num_heads, kernel_size) + self.output = NeighborhoodAttentionOutput(config, dim) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + self_outputs = self.self(hidden_states, output_attentions) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class NatIntermediate(nn.Module): + def __init__(self, config, dim): + super().__init__() + self.dense = nn.Linear(dim, int(config.mlp_ratio * dim)) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class NatOutput(nn.Module): + def __init__(self, config, dim): + super().__init__() + self.dense = nn.Linear(int(config.mlp_ratio * dim), dim) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +class NatLayer(nn.Module): + def __init__(self, config, dim, num_heads, drop_path_rate=0.0): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.kernel_size = config.kernel_size + self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps) + self.attention = NeighborhoodAttentionModule(config, dim, num_heads, kernel_size=self.kernel_size) + self.drop_path = NatDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() + self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps) + self.intermediate = NatIntermediate(config, dim) + self.output = NatOutput(config, dim) + self.layer_scale_parameters = ( + nn.Parameter(config.layer_scale_init_value * torch.ones((2, dim)), requires_grad=True) + if config.layer_scale_init_value > 0 + else None + ) + + def maybe_pad(self, hidden_states, height, width): + window_size = self.kernel_size + pad_values = (0, 0, 0, 0, 0, 0) + if height < window_size or width < window_size: + pad_l = pad_t = 0 + pad_r = max(0, window_size - width) + pad_b = max(0, window_size - height) + pad_values = (0, 0, pad_l, pad_r, pad_t, pad_b) + hidden_states = nn.functional.pad(hidden_states, pad_values) + return hidden_states, pad_values + + def forward( + self, + hidden_states: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size, height, width, channels = hidden_states.size() + shortcut = hidden_states + + hidden_states = self.layernorm_before(hidden_states) + # pad hidden_states if they are smaller than kernel size + hidden_states, pad_values = self.maybe_pad(hidden_states, height, width) + + _, height_pad, width_pad, _ = hidden_states.shape + + attention_outputs = self.attention(hidden_states, output_attentions=output_attentions) + + attention_output = attention_outputs[0] + + was_padded = pad_values[3] > 0 or pad_values[5] > 0 + if was_padded: + attention_output = attention_output[:, :height, :width, :].contiguous() + + if self.layer_scale_parameters is not None: + attention_output = self.layer_scale_parameters[0] * attention_output + + hidden_states = shortcut + self.drop_path(attention_output) + + layer_output = self.layernorm_after(hidden_states) + layer_output = self.output(self.intermediate(layer_output)) + + if self.layer_scale_parameters is not None: + layer_output = self.layer_scale_parameters[1] * layer_output + + layer_output = hidden_states + self.drop_path(layer_output) + + layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,) + return layer_outputs + + +class NatStage(nn.Module): + def __init__(self, config, dim, depth, num_heads, drop_path_rate, downsample): + super().__init__() + self.config = config + self.dim = dim + self.layers = nn.ModuleList( + [ + NatLayer( + config=config, + dim=dim, + num_heads=num_heads, + drop_path_rate=drop_path_rate[i], + ) + for i in range(depth) + ] + ) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(dim=dim, norm_layer=nn.LayerNorm) + else: + self.downsample = None + + self.pointing = False + + def forward( + self, + hidden_states: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + _, height, width, _ = hidden_states.size() + for i, layer_module in enumerate(self.layers): + layer_outputs = layer_module(hidden_states, output_attentions) + hidden_states = layer_outputs[0] + + hidden_states_before_downsampling = hidden_states + if self.downsample is not None: + hidden_states = self.downsample(hidden_states_before_downsampling) + + stage_outputs = (hidden_states, hidden_states_before_downsampling) + + if output_attentions: + stage_outputs += layer_outputs[1:] + return stage_outputs + + +class NatEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.num_levels = len(config.depths) + self.config = config + dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")] + self.levels = nn.ModuleList( + [ + NatStage( + config=config, + dim=int(config.embed_dim * 2**i_layer), + depth=config.depths[i_layer], + num_heads=config.num_heads[i_layer], + drop_path_rate=dpr[sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])], + downsample=NatDownsampler if (i_layer < self.num_levels - 1) else None, + ) + for i_layer in range(self.num_levels) + ] + ) + + def forward( + self, + hidden_states: torch.Tensor, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + output_hidden_states_before_downsampling: Optional[bool] = False, + return_dict: Optional[bool] = True, + ) -> Union[Tuple, NatEncoderOutput]: + all_hidden_states = () if output_hidden_states else None + all_reshaped_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + if output_hidden_states: + # rearrange b h w c -> b c h w + reshaped_hidden_state = hidden_states.permute(0, 3, 1, 2) + all_hidden_states += (hidden_states,) + all_reshaped_hidden_states += (reshaped_hidden_state,) + + for i, layer_module in enumerate(self.levels): + layer_outputs = layer_module(hidden_states, output_attentions) + + hidden_states = layer_outputs[0] + hidden_states_before_downsampling = layer_outputs[1] + + if output_hidden_states and output_hidden_states_before_downsampling: + # rearrange b h w c -> b c h w + reshaped_hidden_state = hidden_states_before_downsampling.permute(0, 3, 1, 2) + all_hidden_states += (hidden_states_before_downsampling,) + all_reshaped_hidden_states += (reshaped_hidden_state,) + elif output_hidden_states and not output_hidden_states_before_downsampling: + # rearrange b h w c -> b c h w + reshaped_hidden_state = hidden_states.permute(0, 3, 1, 2) + all_hidden_states += (hidden_states,) + all_reshaped_hidden_states += (reshaped_hidden_state,) + + if output_attentions: + all_self_attentions += layer_outputs[2:] + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + + return NatEncoderOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + reshaped_hidden_states=all_reshaped_hidden_states, + ) + + +class NatPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = NatConfig + base_model_prefix = "nat" + main_input_name = "pixel_values" + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +NAT_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use + it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`NatConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +NAT_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`] + for details. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare Nat Model transformer outputting raw hidden-states without any specific head on top.", + NAT_START_DOCSTRING, +) +class NatModel(NatPreTrainedModel): + def __init__(self, config, add_pooling_layer=True): + super().__init__(config) + + requires_backends(self, ["natten"]) + + self.config = config + self.num_levels = len(config.depths) + self.num_features = int(config.embed_dim * 2 ** (self.num_levels - 1)) + + self.embeddings = NatEmbeddings(config) + self.encoder = NatEncoder(config) + + self.layernorm = nn.LayerNorm(self.num_features, eps=config.layer_norm_eps) + self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.patch_embeddings + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(NAT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=NatModelOutput, + config_class=_CONFIG_FOR_DOC, + modality="vision", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, NatModelOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + embedding_output = self.embeddings(pixel_values) + + encoder_outputs = self.encoder( + embedding_output, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + + pooled_output = None + if self.pooler is not None: + pooled_output = self.pooler(sequence_output.flatten(1, 2).transpose(1, 2)) + pooled_output = torch.flatten(pooled_output, 1) + + if not return_dict: + output = (sequence_output, pooled_output) + encoder_outputs[1:] + + return output + + return NatModelOutput( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + reshaped_hidden_states=encoder_outputs.reshaped_hidden_states, + ) + + +@add_start_docstrings( + """ + Nat Model transformer with an image classification head on top (a linear layer on top of the final hidden state of + the [CLS] token) e.g. for ImageNet. + """, + NAT_START_DOCSTRING, +) +class NatForImageClassification(NatPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + requires_backends(self, ["natten"]) + + self.num_labels = config.num_labels + self.nat = NatModel(config) + + # Classifier head + self.classifier = ( + nn.Linear(self.nat.num_features, config.num_labels) if config.num_labels > 0 else nn.Identity() + ) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(NAT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_IMAGE_CLASS_CHECKPOINT, + output_type=NatImageClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, + ) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, NatImageClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.nat( + pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return NatImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + reshaped_hidden_states=outputs.reshaped_hidden_states, + ) + + +@add_start_docstrings( + "NAT backbone, to be used with frameworks like DETR and MaskFormer.", + NAT_START_DOCSTRING, +) +class NatBackbone(NatPreTrainedModel, BackboneMixin): + def __init__(self, config): + super().__init__(config) + super()._init_backbone(config) + + requires_backends(self, ["natten"]) + + self.embeddings = NatEmbeddings(config) + self.encoder = NatEncoder(config) + self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))] + + # Add layer norms to hidden states of out_features + hidden_states_norms = {} + for stage, num_channels in zip(self.out_features, self.channels): + hidden_states_norms[stage] = nn.LayerNorm(num_channels) + self.hidden_states_norms = nn.ModuleDict(hidden_states_norms) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.patch_embeddings + + @add_start_docstrings_to_model_forward(NAT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: torch.Tensor, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> BackboneOutput: + """ + Returns: + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, AutoBackbone + >>> import torch + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> processor = AutoImageProcessor.from_pretrained("shi-labs/nat-mini-in1k-224") + >>> model = AutoBackbone.from_pretrained( + ... "shi-labs/nat-mini-in1k-224", out_features=["stage1", "stage2", "stage3", "stage4"] + ... ) + + >>> inputs = processor(image, return_tensors="pt") + + >>> outputs = model(**inputs) + + >>> feature_maps = outputs.feature_maps + >>> list(feature_maps[-1].shape) + [1, 512, 7, 7] + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + embedding_output = self.embeddings(pixel_values) + + outputs = self.encoder( + embedding_output, + output_attentions=output_attentions, + output_hidden_states=True, + output_hidden_states_before_downsampling=True, + return_dict=True, + ) + + hidden_states = outputs.reshaped_hidden_states + + feature_maps = () + for stage, hidden_state in zip(self.stage_names, hidden_states): + if stage in self.out_features: + # TODO can we simplify this? + batch_size, num_channels, height, width = hidden_state.shape + hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous() + hidden_state = hidden_state.view(batch_size, height * width, num_channels) + hidden_state = self.hidden_states_norms[stage](hidden_state) + hidden_state = hidden_state.view(batch_size, height, width, num_channels) + hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous() + feature_maps += (hidden_state,) + + if not return_dict: + output = (feature_maps,) + if output_hidden_states: + output += (outputs.hidden_states,) + return output + + return BackboneOutput( + feature_maps=feature_maps, + hidden_states=outputs.hidden_states if output_hidden_states else None, + attentions=outputs.attentions, + ) + + +__all__ = ["NatForImageClassification", "NatModel", "NatPreTrainedModel", "NatBackbone"] diff --git a/docs/transformers/build/lib/transformers/models/deprecated/nezha/__init__.py b/docs/transformers/build/lib/transformers/models/deprecated/nezha/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f0690129ae9edf84829278790b5e065bbd6608ee --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/nezha/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ....utils import _LazyModule +from ....utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_nezha import * + from .modeling_nezha import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/build/lib/transformers/models/deprecated/nezha/configuration_nezha.py b/docs/transformers/build/lib/transformers/models/deprecated/nezha/configuration_nezha.py new file mode 100644 index 0000000000000000000000000000000000000000..00d193cd1ae68655e1631b7caea2220987a29f0b --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/nezha/configuration_nezha.py @@ -0,0 +1,105 @@ +from .... import PretrainedConfig + + +class NezhaConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of an [`NezhaModel`]. It is used to instantiate an Nezha + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the Nezha + [sijunhe/nezha-cn-base](https://huggingface.co/sijunhe/nezha-cn-base) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, optional, defaults to 21128): + Vocabulary size of the NEZHA model. Defines the different tokens that can be represented by the + *inputs_ids* passed to the forward method of [`NezhaModel`]. + hidden_size (`int`, optional, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, optional, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, optional, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, optional, defaults to 3072): + The dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `function`, optional, defaults to "gelu"): + The non-linear activation function (function or string) in the encoder and pooler. + hidden_dropout_prob (`float`, optional, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, optional, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (`int`, optional, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + (e.g., 512 or 1024 or 2048). + type_vocab_size (`int`, optional, defaults to 2): + The vocabulary size of the *token_type_ids* passed into [`NezhaModel`]. + initializer_range (`float`, optional, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, optional, defaults to 1e-12): + The epsilon used by the layer normalization layers. + classifier_dropout (`float`, optional, defaults to 0.1): + The dropout ratio for attached classifiers. + is_decoder (`bool`, *optional*, defaults to `False`): + Whether the model is used as a decoder or not. If `False`, the model is used as an encoder. + + Example: + + ```python + >>> from transformers import NezhaConfig, NezhaModel + + >>> # Initializing an Nezha configuration + >>> configuration = NezhaConfig() + + >>> # Initializing a model (with random weights) from the Nezha-base style configuration model + >>> model = NezhaModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "nezha" + + def __init__( + self, + vocab_size=21128, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + max_relative_position=64, + type_vocab_size=2, + initializer_range=0.02, + layer_norm_eps=1e-12, + classifier_dropout=0.1, + pad_token_id=0, + bos_token_id=2, + eos_token_id=3, + use_cache=True, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.max_relative_position = max_relative_position + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.classifier_dropout = classifier_dropout + self.use_cache = use_cache + + +__all__ = ["NezhaConfig"] diff --git a/docs/transformers/build/lib/transformers/models/deprecated/nezha/modeling_nezha.py b/docs/transformers/build/lib/transformers/models/deprecated/nezha/modeling_nezha.py new file mode 100644 index 0000000000000000000000000000000000000000..7be52bee5847cb9c82b2efb32e8729f198f9ae75 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/nezha/modeling_nezha.py @@ -0,0 +1,1697 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Nezha model.""" + +import math +import os +import warnings +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ....activations import ACT2FN +from ....modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + MaskedLMOutput, + MultipleChoiceModelOutput, + NextSentencePredictorOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from ....modeling_utils import PreTrainedModel +from ....pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ....utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from .configuration_nezha import NezhaConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "sijunhe/nezha-cn-base" +_CONFIG_FOR_DOC = "NezhaConfig" + + +def load_tf_weights_in_nezha(model, config, tf_checkpoint_path): + """Load tf checkpoints in a pytorch model.""" + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array) + + for name, array in zip(names, arrays): + name = name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any( + n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] + for n in name + ): + logger.info(f"Skipping {'/'.join(name)}") + continue + pointer = model + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "kernel" or scope_names[0] == "gamma": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "output_bias" or scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "output_weights": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "squad": + pointer = getattr(pointer, "classifier") + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info(f"Skipping {'/'.join(name)}") + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if m_name[-11:] == "_embeddings": + pointer = getattr(pointer, "weight") + elif m_name == "kernel": + array = np.transpose(array) + try: + if pointer.shape != array.shape: + raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array) + return model + + +class NezhaRelativePositionsEncoding(nn.Module): + """Implement the Functional Relative Position Encoding""" + + def __init__(self, length, depth, max_relative_position=127): + super().__init__() + vocab_size = max_relative_position * 2 + 1 + range_vec = torch.arange(length) + range_mat = range_vec.repeat(length).view(length, length) + distance_mat = range_mat - torch.t(range_mat) + distance_mat_clipped = torch.clamp(distance_mat, -max_relative_position, max_relative_position) + final_mat = distance_mat_clipped + max_relative_position + + embeddings_table = torch.zeros(vocab_size, depth) + position = torch.arange(0, vocab_size, dtype=torch.int64).float().unsqueeze(1) + div_term = torch.exp(torch.arange(0, depth, 2).float() * (-math.log(10000.0) / depth)) + embeddings_table[:, 0::2] = torch.sin(position * div_term) + embeddings_table[:, 1::2] = torch.cos(position * div_term) + + flat_relative_positions_matrix = final_mat.view(-1) + one_hot_relative_positions_matrix = torch.nn.functional.one_hot( + flat_relative_positions_matrix, num_classes=vocab_size + ).float() + positions_encoding = torch.matmul(one_hot_relative_positions_matrix, embeddings_table) + my_shape = list(final_mat.size()) + my_shape.append(depth) + positions_encoding = positions_encoding.view(my_shape) + self.register_buffer("positions_encoding", positions_encoding, persistent=False) + + def forward(self, length): + return self.positions_encoding[:length, :length, :] + + +class NezhaEmbeddings(nn.Module): + """Construct the embeddings from word and token_type embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.register_buffer( + "token_type_ids", torch.zeros((1, config.max_position_embeddings), dtype=torch.long), persistent=False + ) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ) -> torch.Tensor: + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs + # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves + # issue #5664 + if token_type_ids is None: + if hasattr(self, "token_type_ids"): + buffered_token_type_ids = self.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=inputs_embeds.device) + + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class NezhaSelfAttention(nn.Module): + def __init__(self, config): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.relative_positions_encoding = NezhaRelativePositionsEncoding( + length=config.max_position_embeddings, + depth=self.attention_head_size, + max_relative_position=config.max_relative_position, + ) + self.is_decoder = config.is_decoder + + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_layer = past_key_value[0] + value_layer = past_key_value[1] + attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + batch_size, num_attention_heads, from_seq_length, to_seq_length = attention_scores.size() + relations_keys = self.relative_positions_encoding(to_seq_length) + query_layer_t = query_layer.permute(2, 0, 1, 3) + + query_layer_r = query_layer_t.contiguous().view( + from_seq_length, batch_size * num_attention_heads, self.attention_head_size + ) + key_position_scores = torch.matmul(query_layer_r, relations_keys.permute(0, 2, 1)) + key_position_scores_r = key_position_scores.view( + from_seq_length, batch_size, num_attention_heads, from_seq_length + ) + key_position_scores_r_t = key_position_scores_r.permute(1, 2, 0, 3) + attention_scores = attention_scores + key_position_scores_r_t + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in NezhaModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + relations_values = self.relative_positions_encoding(to_seq_length) + attention_probs_t = attention_probs.permute(2, 0, 1, 3) + attentions_probs_r = attention_probs_t.contiguous().view( + from_seq_length, batch_size * num_attention_heads, to_seq_length + ) + value_position_scores = torch.matmul(attentions_probs_r, relations_values) + value_position_scores_r = value_position_scores.view( + from_seq_length, batch_size, num_attention_heads, self.attention_head_size + ) + value_position_scores_r_t = value_position_scores_r.permute(1, 2, 0, 3) + context_layer = context_layer + value_position_scores_r_t + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + +class NezhaSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class NezhaAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.self = NezhaSelfAttention(config) + self.output = NezhaSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class NezhaIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class NezhaOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class NezhaLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = NezhaAttention(config) + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + if not self.is_decoder: + raise ValueError(f"{self} should be used as a decoder model if cross attention is added") + self.crossattention = NezhaAttention(config) + self.intermediate = NezhaIntermediate(config) + self.output = NezhaOutput(config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + + # if decoder, the last output is tuple of self-attn cache + if self.is_decoder: + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + else: + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + cross_attn_present_key_value = None + if self.is_decoder and encoder_hidden_states is not None: + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" + " by setting `config.add_cross_attention=True`" + ) + + # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + cross_attn_past_key_value, + output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + + # add cross-attn cache to positions 3,4 of present_key_value tuple + cross_attn_present_key_value = cross_attention_outputs[-1] + present_key_value = present_key_value + cross_attn_present_key_value + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + ) + outputs = (layer_output,) + outputs + + # if decoder, return the attn key/values as the last output + if self.is_decoder: + outputs = outputs + (present_key_value,) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class NezhaEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([NezhaLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + next_decoder_cache = () if use_cache else None + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class NezhaPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class NezhaPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class NezhaLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = NezhaPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def _tie_weights(self): + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class NezhaOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = NezhaLMPredictionHead(config) + + def forward(self, sequence_output: torch.Tensor) -> torch.Tensor: + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class NezhaOnlyNSPHead(nn.Module): + def __init__(self, config): + super().__init__() + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, pooled_output): + seq_relationship_score = self.seq_relationship(pooled_output) + return seq_relationship_score + + +class NezhaPreTrainingHeads(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = NezhaLMPredictionHead(config) + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, sequence_output, pooled_output): + prediction_scores = self.predictions(sequence_output) + seq_relationship_score = self.seq_relationship(pooled_output) + return prediction_scores, seq_relationship_score + + +class NezhaPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = NezhaConfig + load_tf_weights = load_tf_weights_in_nezha + base_model_prefix = "nezha" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +@dataclass +class NezhaForPreTrainingOutput(ModelOutput): + """ + Output type of [`NezhaForPreTraining`]. + + Args: + loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`): + Total loss as the sum of the masked language modeling loss and the next sequence prediction + (classification) loss. + prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + seq_relationship_logits (`torch.FloatTensor` of shape `(batch_size, 2)`): + Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation + before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + prediction_logits: Optional[torch.FloatTensor] = None + seq_relationship_logits: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +NEZHA_START_DOCSTRING = r""" + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`NezhaConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +NEZHA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare Nezha Model transformer outputting raw hidden-states without any specific head on top.", + NEZHA_START_DOCSTRING, +) +class NezhaModel(NezhaPreTrainedModel): + """ + + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in [Attention is + all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + + To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set + to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and + `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass. + """ + + def __init__(self, config, add_pooling_layer=True): + super().__init__(config) + self.config = config + + self.embeddings = NezhaEmbeddings(config) + self.encoder = NezhaEncoder(config) + + self.pooler = NezhaPooler(config) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(NEZHA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPoolingAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.config.is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + + if token_type_ids is None: + if hasattr(self.embeddings, "token_type_ids"): + buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings( + input_ids=input_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + ) + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +@add_start_docstrings( + """ + Nezha Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next + sentence prediction (classification)` head. + """, + NEZHA_START_DOCSTRING, +) +class NezhaForPreTraining(NezhaPreTrainedModel): + _tied_weights_keys = ["cls.predictions.decoder"] + + def __init__(self, config): + super().__init__(config) + + self.nezha = NezhaModel(config) + self.cls = NezhaPreTrainingHeads(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias + + @add_start_docstrings_to_model_forward(NEZHA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=NezhaForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + next_sentence_label: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], NezhaForPreTrainingOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), + the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + next_sentence_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the next sequence prediction (classification) loss. Input should be a sequence + pair (see `input_ids` docstring) Indices should be in `[0, 1]`: + + - 0 indicates sequence B is a continuation of sequence A, + - 1 indicates sequence B is a random sequence. + kwargs (`Dict[str, any]`, optional, defaults to *{}*): + Used to hide legacy arguments that have been deprecated. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, NezhaForPreTraining + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("sijunhe/nezha-cn-base") + >>> model = NezhaForPreTraining.from_pretrained("sijunhe/nezha-cn-base") + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> prediction_logits = outputs.prediction_logits + >>> seq_relationship_logits = outputs.seq_relationship_logits + ``` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.nezha( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output, pooled_output = outputs[:2] + prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) + + total_loss = None + if labels is not None and next_sentence_label is not None: + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) + total_loss = masked_lm_loss + next_sentence_loss + + if not return_dict: + output = (prediction_scores, seq_relationship_score) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return NezhaForPreTrainingOutput( + loss=total_loss, + prediction_logits=prediction_scores, + seq_relationship_logits=seq_relationship_score, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings("""Nezha Model with a `language modeling` head on top.""", NEZHA_START_DOCSTRING) +class NezhaForMaskedLM(NezhaPreTrainedModel): + _tied_weights_keys = ["cls.predictions.decoder"] + + def __init__(self, config): + super().__init__(config) + + if config.is_decoder: + logger.warning( + "If you want to use `NezhaForMaskedLM` make sure `config.is_decoder=False` for " + "bi-directional self-attention." + ) + + self.nezha = NezhaModel(config, add_pooling_layer=False) + self.cls = NezhaOnlyMLMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias + + @add_start_docstrings_to_model_forward(NEZHA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MaskedLMOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.nezha( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs): + input_shape = input_ids.shape + effective_batch_size = input_shape[0] + + # add a dummy token + if self.config.pad_token_id is None: + raise ValueError("The PAD token should be defined for generation") + + attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1) + dummy_token = torch.full( + (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device + ) + input_ids = torch.cat([input_ids, dummy_token], dim=1) + + return {"input_ids": input_ids, "attention_mask": attention_mask} + + +@add_start_docstrings( + """Nezha Model with a `next sentence prediction (classification)` head on top.""", + NEZHA_START_DOCSTRING, +) +class NezhaForNextSentencePrediction(NezhaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.nezha = NezhaModel(config) + self.cls = NezhaOnlyNSPHead(config) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(NEZHA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair + (see `input_ids` docstring). Indices should be in `[0, 1]`: + + - 0 indicates sequence B is a continuation of sequence A, + - 1 indicates sequence B is a random sequence. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, NezhaForNextSentencePrediction + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("sijunhe/nezha-cn-base") + >>> model = NezhaForNextSentencePrediction.from_pretrained("sijunhe/nezha-cn-base") + + >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." + >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light." + >>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt") + + >>> outputs = model(**encoding, labels=torch.LongTensor([1])) + >>> logits = outputs.logits + >>> assert logits[0, 0] < logits[0, 1] # next sentence was random + ``` + """ + + if "next_sentence_label" in kwargs: + warnings.warn( + "The `next_sentence_label` argument is deprecated and will be removed in a future version, use" + " `labels` instead.", + FutureWarning, + ) + labels = kwargs.pop("next_sentence_label") + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.nezha( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + seq_relationship_scores = self.cls(pooled_output) + + next_sentence_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1)) + + if not return_dict: + output = (seq_relationship_scores,) + outputs[2:] + return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output + + return NextSentencePredictorOutput( + loss=next_sentence_loss, + logits=seq_relationship_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Nezha Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled + output) e.g. for GLUE tasks. + """, + NEZHA_START_DOCSTRING, +) +class NezhaForSequenceClassification(NezhaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + + self.nezha = NezhaModel(config) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(NEZHA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=SequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.nezha( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Nezha Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a + softmax) e.g. for RocStories/SWAG tasks. + """, + NEZHA_START_DOCSTRING, +) +class NezhaForMultipleChoice(NezhaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.nezha = NezhaModel(config) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, 1) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(NEZHA_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MultipleChoiceModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., + num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See + `input_ids` above) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] + input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None + inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) + + outputs = self.nezha( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + print(pooled_output.shape) + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + print(logits.shape) + print(num_choices) + reshaped_logits = logits.view(-1, num_choices) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + + if not return_dict: + output = (reshaped_logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return MultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Nezha Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + NEZHA_START_DOCSTRING, +) +class NezhaForTokenClassification(NezhaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.nezha = NezhaModel(config, add_pooling_layer=False) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(NEZHA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.nezha( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Nezha Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear + layers on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + NEZHA_START_DOCSTRING, +) +class NezhaForQuestionAnswering(NezhaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.nezha = NezhaModel(config, add_pooling_layer=False) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(NEZHA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=QuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + start_positions: Optional[torch.Tensor] = None, + end_positions: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.nezha( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "NezhaForNextSentencePrediction", + "NezhaForMaskedLM", + "NezhaForPreTraining", + "NezhaForMultipleChoice", + "NezhaForQuestionAnswering", + "NezhaForSequenceClassification", + "NezhaForTokenClassification", + "NezhaModel", + "NezhaPreTrainedModel", +] diff --git a/docs/transformers/build/lib/transformers/models/deprecated/open_llama/__init__.py b/docs/transformers/build/lib/transformers/models/deprecated/open_llama/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2b3964d194bed041987c6236b5c60bfcf3b7caf4 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/open_llama/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2023 EleutherAI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ....utils import _LazyModule +from ....utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_open_llama import * + from .modeling_open_llama import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/build/lib/transformers/models/deprecated/open_llama/configuration_open_llama.py b/docs/transformers/build/lib/transformers/models/deprecated/open_llama/configuration_open_llama.py new file mode 100644 index 0000000000000000000000000000000000000000..b4bc9cc72a7661fe571cce6649b40f2307d9b3ce --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/open_llama/configuration_open_llama.py @@ -0,0 +1,169 @@ +# coding=utf-8 +# Copyright 2023 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Open-Llama model configuration""" + +from ....configuration_utils import PretrainedConfig +from ....utils import logging + + +logger = logging.get_logger(__name__) + + +class OpenLlamaConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`OpenLlamaModel`]. It is used to instantiate an + Open-Llama model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the + [s-JoL/Open-Llama-V1](https://huggingface.co/s-JoL/Open-Llama-V1). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the Open-Llama model. Defines the number of different tokens that can be represented by + the `inputs_ids` passed when calling [`OpenLlamaModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings(`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling + strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is + `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. See the following thread for more information on how + these scaling strategies behave: + https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an + experimental feature, subject to breaking API changes in future versions. + + Example: + + ```python + >>> from transformers import OpenLlamaModel, OpenLlamaConfig + + >>> # Initializing a Open-Llama open_llama-7b style configuration + >>> configuration = OpenLlamaConfig() + + >>> # Initializing a model from the open_llama-7b style configuration + >>> model = OpenLlamaModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "open-llama" + + def __init__( + self, + vocab_size=100000, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + use_memory_efficient_attention=True, + hidden_dropout_prob=0.1, + attention_dropout_prob=0.1, + use_stable_embedding=True, + shared_input_output_embedding=True, + rope_theta=10000.0, + rope_scaling=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.use_memory_efficient_attention = kwargs.pop( + "use_memorry_efficient_attention", use_memory_efficient_attention + ) + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_dropout_prob = attention_dropout_prob + self.use_stable_embedding = use_stable_embedding + self.shared_input_output_embedding = shared_input_output_embedding + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self._rope_scaling_validation() + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + def _rope_scaling_validation(self): + """ + Validate the `rope_scaling` configuration. + """ + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: + raise ValueError( + f"`rope_scaling` must be a dictionary with two fields, `type` and `factor`, got {self.rope_scaling}" + ) + rope_scaling_type = self.rope_scaling.get("type", None) + rope_scaling_factor = self.rope_scaling.get("factor", None) + if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: + raise ValueError( + f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" + ) + if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: + raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") + + +__all__ = ["OpenLlamaConfig"] diff --git a/docs/transformers/build/lib/transformers/models/deprecated/open_llama/modeling_open_llama.py b/docs/transformers/build/lib/transformers/models/deprecated/open_llama/modeling_open_llama.py new file mode 100644 index 0000000000000000000000000000000000000000..79d79ea546a950bd3e6deb81b93ff6ee7b5d60a4 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/open_llama/modeling_open_llama.py @@ -0,0 +1,975 @@ +# coding=utf-8 +# Copyright 2023 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Open-Llama model.""" + +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ....activations import ACT2FN +from ....modeling_attn_mask_utils import _prepare_4d_causal_attention_mask +from ....modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast +from ....modeling_utils import PreTrainedModel +from ....utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from .configuration_open_llama import OpenLlamaConfig + + +logger = logging.get_logger(__name__) + +try: + from xformers import ops as xops +except ImportError: + xops = None + + +_CONFIG_FOR_DOC = "OpenLlamaConfig" + + +class OpenLlamaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + OpenLlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class OpenLlamaRotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / ( + self.base + ** (torch.arange(0, self.dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / self.dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), + ) + + +class OpenLlamaLinearScalingRotaryEmbedding(OpenLlamaRotaryEmbedding): + """OpenLlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + t = t / self.scaling_factor + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + +class OpenLlamaDynamicNTKScalingRotaryEmbedding(OpenLlamaRotaryEmbedding): + """OpenLlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / ( + base + ** (torch.arange(0, self.dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / self.dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class OpenLlamaMLP(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + dropout_prob: float, + ): + super().__init__() + self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) + self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) + self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) + self.act_fn = ACT2FN[hidden_act] + self.dropout = nn.Dropout(dropout_prob) + + def forward(self, x): + out = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return self.dropout(out) + + +class OpenLlamaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: OpenLlamaConfig): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.max_position_embeddings = config.max_position_embeddings + self.dropout_prob = config.attention_dropout_prob + self.rope_theta = config.rope_theta + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + self._init_rope() + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = OpenLlamaRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = OpenLlamaLinearScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "dynamic": + self.rotary_emb = OpenLlamaDynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + # [bsz, nh, t, hd] + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + if self.config.use_memory_efficient_attention and xops is not None and self.training: + attn_weights = None + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + attn_output = xops.memory_efficient_attention( + query_states, key_states, value_states, attn_bias=xops.LowerTriangularMask(), p=self.dropout_prob + ) + else: + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + attn_weights = torch.max( + attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device) + ) + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class OpenLlamaDecoderLayer(nn.Module): + def __init__(self, config: OpenLlamaConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = OpenLlamaAttention(config=config) + self.mlp = OpenLlamaMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + dropout_prob=config.hidden_dropout_prob, + ) + self.input_layernorm = OpenLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = OpenLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +OPEN_LLAMA_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`OpenLlamaConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Open-Llama Model outputting raw hidden-states without any specific head on top.", + OPEN_LLAMA_START_DOCSTRING, +) +class OpenLlamaPreTrainedModel(PreTrainedModel): + config_class = OpenLlamaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["OpenLlamaDecoderLayer"] + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + if self.config.use_stable_embedding: + torch.nn.init.xavier_normal_(module.weight.data) + else: + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +OPEN_LLAMA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare Open-Llama Model outputting raw hidden-states without any specific head on top.", + OPEN_LLAMA_START_DOCSTRING, +) +class OpenLlamaModel(OpenLlamaPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`OpenLlamaDecoderLayer`] + + Args: + config: OpenLlamaConfig + """ + + def __init__(self, config: OpenLlamaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + if config.use_stable_embedding: + self.embed_layer_norm = nn.LayerNorm(config.hidden_size) + else: + self.embed_layer_norm = None + self.layers = nn.ModuleList([OpenLlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.norm = OpenLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(OPEN_LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + seq_length_with_past = seq_length + past_key_values_length = 0 + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + if self.embed_layer_norm: + inputs_embeds = self.embed_layer_norm(inputs_embeds) + # embed positions + if self.config.use_memory_efficient_attention and self.training: + attention_mask = None + elif attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device + ) + + input_shape = (batch_size, seq_length) + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + None, + output_attentions, + None, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class OpenLlamaForCausalLM(OpenLlamaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.model = OpenLlamaModel(config) + if config.shared_input_output_embedding: + self.lm_head = None + else: + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(OPEN_LLAMA_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, OpenLlamaForCausalLM + + >>> model = OpenLlamaForCausalLM.from_pretrained("openlm-research/open_llama_7b") + >>> tokenizer = AutoTokenizer.from_pretrained("openlm-research/open_llama_7b") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + if self.config.shared_input_output_embedding: + logits = torch.einsum( + "blh,vh->blv", hidden_states.to(self.model.embed_tokens.weight.device), self.model.embed_tokens.weight + ) + else: + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + ): + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + + +@add_start_docstrings( + """ + The LLaMa Model transformer with a sequence classification head on top (linear layer). + + [`OpenLlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal + models (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + OPEN_LLAMA_START_DOCSTRING, +) +class OpenLlamaForSequenceClassification(OpenLlamaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = OpenLlamaModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(OPEN_LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +__all__ = ["OpenLlamaPreTrainedModel", "OpenLlamaModel", "OpenLlamaForCausalLM", "OpenLlamaForSequenceClassification"] diff --git a/docs/transformers/build/lib/transformers/models/deprecated/qdqbert/__init__.py b/docs/transformers/build/lib/transformers/models/deprecated/qdqbert/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..864b321bc2ee3a521e3d6da5403cab7363dea56b --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/qdqbert/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2021 NVIDIA Corporation and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ....utils import _LazyModule +from ....utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_qdqbert import * + from .modeling_qdqbert import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/build/lib/transformers/models/deprecated/qdqbert/configuration_qdqbert.py b/docs/transformers/build/lib/transformers/models/deprecated/qdqbert/configuration_qdqbert.py new file mode 100644 index 0000000000000000000000000000000000000000..91ac82bc5a0292355ac659548c7c3d4336f2536c --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/qdqbert/configuration_qdqbert.py @@ -0,0 +1,123 @@ +# coding=utf-8 +# Copyright 2021 NVIDIA Corporation and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""QDQBERT model configuration""" + +from ....configuration_utils import PretrainedConfig +from ....utils import logging + + +logger = logging.get_logger(__name__) + + +class QDQBertConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`QDQBertModel`]. It is used to instantiate an + QDQBERT model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the BERT + [google-bert/bert-base-uncased](https://huggingface.co/google-bert/bert-base-uncased) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 30522): + Vocabulary size of the QDQBERT model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`QDQBertModel`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimension of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + type_vocab_size (`int`, *optional*, defaults to 2): + The vocabulary size of the `token_type_ids` passed when calling [`QDQBertModel`]. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + is_decoder (`bool`, *optional*, defaults to `False`): + Whether the model is used as a decoder or not. If `False`, the model is used as an encoder. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + + Examples: + + ```python + >>> from transformers import QDQBertModel, QDQBertConfig + + >>> # Initializing a QDQBERT google-bert/bert-base-uncased style configuration + >>> configuration = QDQBertConfig() + + >>> # Initializing a model from the google-bert/bert-base-uncased style configuration + >>> model = QDQBertModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "qdqbert" + + def __init__( + self, + vocab_size=30522, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + layer_norm_eps=1e-12, + use_cache=True, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.initializer_range = initializer_range + self.type_vocab_size = type_vocab_size + self.layer_norm_eps = layer_norm_eps + self.use_cache = use_cache + + +__all__ = ["QDQBertConfig"] diff --git a/docs/transformers/build/lib/transformers/models/deprecated/qdqbert/modeling_qdqbert.py b/docs/transformers/build/lib/transformers/models/deprecated/qdqbert/modeling_qdqbert.py new file mode 100644 index 0000000000000000000000000000000000000000..8b68a4e426d445f5c6e3c472e3dc5ea54e661d57 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/qdqbert/modeling_qdqbert.py @@ -0,0 +1,1749 @@ +# coding=utf-8 +# Copyright 2021 NVIDIA Corporation and The HuggingFace Team. +# Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch QDQBERT model.""" + +import math +import os +import warnings +from typing import Dict, List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ....activations import ACT2FN +from ....modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + MaskedLMOutput, + MultipleChoiceModelOutput, + NextSentencePredictorOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from ....modeling_utils import PreTrainedModel +from ....pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ....utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_pytorch_quantization_available, + logging, + replace_return_docstrings, + requires_backends, +) +from .configuration_qdqbert import QDQBertConfig + + +logger = logging.get_logger(__name__) + +# soft dependency +if is_pytorch_quantization_available(): + try: + from pytorch_quantization import nn as quant_nn + from pytorch_quantization.nn.modules.tensor_quantizer import TensorQuantizer + except OSError: + logger.error( + "QDQBERT model are not usable since `pytorch_quantization` can't be loaded. Please try to reinstall it" + " following the instructions here:" + " https://github.com/NVIDIA/TensorRT/tree/master/tools/pytorch-quantization." + ) + +_CHECKPOINT_FOR_DOC = "google-bert/bert-base-uncased" +_CONFIG_FOR_DOC = "QDQBertConfig" + + +def load_tf_weights_in_qdqbert(model, tf_checkpoint_path): + """Load tf checkpoints in a pytorch model.""" + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array) + + for name, array in zip(names, arrays): + name = name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any( + n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] + for n in name + ): + logger.info(f"Skipping {'/'.join(name)}") + continue + pointer = model + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "kernel" or scope_names[0] == "gamma": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "output_bias" or scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "output_weights": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "squad": + pointer = getattr(pointer, "classifier") + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info(f"Skipping {'/'.join(name)}") + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if m_name[-11:] == "_embeddings": + pointer = getattr(pointer, "weight") + elif m_name == "kernel": + array = np.transpose(array) + try: + if pointer.shape != array.shape: + raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array) + return model + + +class QDQBertEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + self.register_buffer( + "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False + ) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values_length: int = 0, + ) -> torch.Tensor: + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] + + # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs + # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves + # issue #5664 + if token_type_ids is None: + if hasattr(self, "token_type_ids"): + buffered_token_type_ids = self.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class QDQBertSelfAttention(nn.Module): + def __init__(self, config): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = quant_nn.QuantLinear(config.hidden_size, self.all_head_size) + self.key = quant_nn.QuantLinear(config.hidden_size, self.all_head_size) + self.value = quant_nn.QuantLinear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) + + self.is_decoder = config.is_decoder + + self.matmul_q_input_quantizer = TensorQuantizer(quant_nn.QuantLinear.default_quant_desc_input) + self.matmul_k_input_quantizer = TensorQuantizer(quant_nn.QuantLinear.default_quant_desc_input) + self.matmul_v_input_quantizer = TensorQuantizer(quant_nn.QuantLinear.default_quant_desc_input) + self.matmul_a_input_quantizer = TensorQuantizer(quant_nn.QuantLinear.default_quant_desc_input) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_layer = past_key_value[0] + value_layer = past_key_value[1] + attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul( + self.matmul_q_input_quantizer(query_layer), self.matmul_k_input_quantizer(key_layer.transpose(-1, -2)) + ) + + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + seq_length = hidden_states.size()[1] + position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in QDQBertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul( + self.matmul_a_input_quantizer(attention_probs), self.matmul_v_input_quantizer(value_layer) + ) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + +class QDQBertSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + # Quantize Linear layer + self.dense = quant_nn.QuantLinear(config.hidden_size, config.hidden_size) + + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # Quantize the inputs to the residual add + self.add_local_input_quantizer = TensorQuantizer(quant_nn.QuantLinear.default_quant_desc_input) + self.add_residual_input_quantizer = TensorQuantizer(quant_nn.QuantLinear.default_quant_desc_input) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + # Quantize the inputs to the residual add + add_local = self.add_local_input_quantizer(hidden_states) + add_residual = self.add_residual_input_quantizer(input_tensor) + hidden_states = self.LayerNorm(add_local + add_residual) + return hidden_states + + +# Based on transformers.models.bert.modeling_bert.BertAttention with Bert -> QDQBert +class QDQBertAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.self = QDQBertSelfAttention(config) + self.output = QDQBertSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class QDQBertIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + # Quantize Linear layer + self.dense = quant_nn.QuantLinear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class QDQBertOutput(nn.Module): + def __init__(self, config): + super().__init__() + # Quantize Linear layer + self.dense = quant_nn.QuantLinear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # Quantize the inputs to the residual add + self.add_local_input_quantizer = TensorQuantizer(quant_nn.QuantLinear.default_quant_desc_input) + self.add_residual_input_quantizer = TensorQuantizer(quant_nn.QuantLinear.default_quant_desc_input) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + # Quantize the inputs to the residual add + add_local = self.add_local_input_quantizer(hidden_states) + add_residual = self.add_residual_input_quantizer(input_tensor) + hidden_states = self.LayerNorm(add_local + add_residual) + return hidden_states + + +# Based on transformers.models.bert.modeling_bert.BertLayer with Bert -> QDQBert +class QDQBertLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.seq_len_dim = 1 + self.attention = QDQBertAttention(config) + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + if not self.is_decoder: + raise ValueError(f"{self} should be used as a decoder model if cross attention is added") + self.crossattention = QDQBertAttention(config) + self.intermediate = QDQBertIntermediate(config) + self.output = QDQBertOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + + # if decoder, the last output is tuple of self-attn cache + if self.is_decoder: + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + else: + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + cross_attn_present_key_value = None + if self.is_decoder and encoder_hidden_states is not None: + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" + " by setting `config.add_cross_attention=True`" + ) + + # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + cross_attn_past_key_value, + output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + + # add cross-attn cache to positions 3,4 of present_key_value tuple + cross_attn_present_key_value = cross_attention_outputs[-1] + present_key_value = present_key_value + cross_attn_present_key_value + + layer_output = self.feed_forward_chunk(attention_output) + outputs = (layer_output,) + outputs + + # if decoder, return the attn key/values as the last output + if self.is_decoder: + outputs = outputs + (present_key_value,) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +# Based on transformers.models.bert.modeling_bert.BertEncoder with Bert -> QDQBert +class QDQBertEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([QDQBertLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + next_decoder_cache = () if use_cache else None + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + layer_outputs = self._gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class QDQBertPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class QDQBertPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +# Based on transformers.models.bert.modeling_bert.BertLMPredictionHead with Bert -> QDQBert +class QDQBertLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = QDQBertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def _tie_weights(self): + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +# Based on transformers.models.bert.modeling_bert.BertOnlyMLMHead with Bert -> QDQBert +class QDQBertOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = QDQBertLMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class QDQBertOnlyNSPHead(nn.Module): + def __init__(self, config): + super().__init__() + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, pooled_output): + seq_relationship_score = self.seq_relationship(pooled_output) + return seq_relationship_score + + +# Based on transformers.models.bert.modeling_bert.BertPreTrainingHeads with Bert -> QDQBert +class QDQBertPreTrainingHeads(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = QDQBertLMPredictionHead(config) + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, sequence_output, pooled_output): + prediction_scores = self.predictions(sequence_output) + seq_relationship_score = self.seq_relationship(pooled_output) + return prediction_scores, seq_relationship_score + + +# Based on transformers.models.bert.modeling_bert.BertPreTrainedModel with Bert -> QDQBert +class QDQBertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = QDQBertConfig + load_tf_weights = load_tf_weights_in_qdqbert + base_model_prefix = "bert" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +QDQBERT_START_DOCSTRING = r""" + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`QDQBertConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +QDQBERT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare QDQBERT Model transformer outputting raw hidden-states without any specific head on top.", + QDQBERT_START_DOCSTRING, +) +class QDQBertModel(QDQBertPreTrainedModel): + """ + + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in [Attention is + all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + + To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set + to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and + `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass. + """ + + def __init__(self, config, add_pooling_layer: bool = True): + requires_backends(self, "pytorch_quantization") + super().__init__(config) + self.config = config + + self.embeddings = QDQBertEmbeddings(config) + self.encoder = QDQBertEncoder(config) + + self.pooler = QDQBertPooler(config) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune: Dict[int, List[int]]): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(QDQBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPoolingAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPoolingAndCrossAttentions]: + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.config.is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + batch_size, seq_length = input_shape + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size, seq_length = input_shape + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + + if token_type_ids is None: + if hasattr(self.embeddings, "token_type_ids"): + buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +@add_start_docstrings( + """QDQBERT Model with a `language modeling` head on top for CLM fine-tuning.""", QDQBERT_START_DOCSTRING +) +class QDQBertLMHeadModel(QDQBertPreTrainedModel): + _tied_weights_keys = ["predictions.decoder.weight", "predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + if not config.is_decoder: + logger.warning("If you want to use `QDQBertLMHeadModel` as a standalone, add `is_decoder=True.`") + + self.bert = QDQBertModel(config, add_pooling_layer=False) + self.cls = QDQBertOnlyMLMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias + + @add_start_docstrings_to_model_forward(QDQBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.LongTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are + ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]` + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, QDQBertLMHeadModel, QDQBertConfig + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-cased") + >>> config = QDQBertConfig.from_pretrained("google-bert/bert-base-cased") + >>> config.is_decoder = True + >>> model = QDQBertLMHeadModel.from_pretrained("google-bert/bert-base-cased", config=config) + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> prediction_logits = outputs.logits + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if labels is not None: + use_cache = False + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + lm_loss = None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss() + lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids: Optional[torch.LongTensor], + past_key_values=None, + attention_mask: Optional[torch.Tensor] = None, + **model_kwargs, + ): + input_shape = input_ids.shape + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_shape) + + # cut decoder_input_ids if past_key_values is used + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] + + return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} + + def _reorder_cache(self, past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + + +@add_start_docstrings("""QDQBERT Model with a `language modeling` head on top.""", QDQBERT_START_DOCSTRING) +class QDQBertForMaskedLM(QDQBertPreTrainedModel): + _tied_weights_keys = ["predictions.decoder.weight", "predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + if config.is_decoder: + logger.warning( + "If you want to use `QDQBertForMaskedLM` make sure `config.is_decoder=False` for " + "bi-directional self-attention." + ) + + self.bert = QDQBertModel(config, add_pooling_layer=False) + self.cls = QDQBertOnlyMLMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias + + @add_start_docstrings_to_model_forward(QDQBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MaskedLMOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, MaskedLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids: torch.LongTensor, attention_mask: Optional[torch.FloatTensor] = None, **model_kwargs + ): + input_shape = input_ids.shape + effective_batch_size = input_shape[0] + + # add a dummy token + if self.config.pad_token_id is None: + raise ValueError("The PAD token should be defined for generation") + + attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1) + dummy_token = torch.full( + (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device + ) + input_ids = torch.cat([input_ids, dummy_token], dim=1) + + return {"input_ids": input_ids, "attention_mask": attention_mask} + + +@add_start_docstrings( + """Bert Model with a `next sentence prediction (classification)` head on top.""", + QDQBERT_START_DOCSTRING, +) +class QDQBertForNextSentencePrediction(QDQBertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.bert = QDQBertModel(config) + self.cls = QDQBertOnlyNSPHead(config) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(QDQBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple, NextSentencePredictorOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair + (see `input_ids` docstring). Indices should be in `[0, 1]`: + + - 0 indicates sequence B is a continuation of sequence A, + - 1 indicates sequence B is a random sequence. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, QDQBertForNextSentencePrediction + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased") + >>> model = QDQBertForNextSentencePrediction.from_pretrained("google-bert/bert-base-uncased") + + >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." + >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light." + >>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt") + + >>> outputs = model(**encoding, labels=torch.LongTensor([1])) + >>> logits = outputs.logits + >>> assert logits[0, 0] < logits[0, 1] # next sentence was random + ```""" + + if "next_sentence_label" in kwargs: + warnings.warn( + "The `next_sentence_label` argument is deprecated and will be removed in a future version, use" + " `labels` instead.", + FutureWarning, + ) + labels = kwargs.pop("next_sentence_label") + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + seq_relationship_scores = self.cls(pooled_output) + + next_sentence_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1)) + + if not return_dict: + output = (seq_relationship_scores,) + outputs[2:] + return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output + + return NextSentencePredictorOutput( + loss=next_sentence_loss, + logits=seq_relationship_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled + output) e.g. for GLUE tasks. + """, + QDQBERT_START_DOCSTRING, +) +class QDQBertForSequenceClassification(QDQBertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + + self.bert = QDQBertModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(QDQBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=SequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a + softmax) e.g. for RocStories/SWAG tasks. + """, + QDQBERT_START_DOCSTRING, +) +class QDQBertForMultipleChoice(QDQBertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.bert = QDQBertModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, 1) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(QDQBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MultipleChoiceModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, MultipleChoiceModelOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., + num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See + `input_ids` above) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] + + input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None + position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + reshaped_logits = logits.view(-1, num_choices) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + + if not return_dict: + output = (reshaped_logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return MultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + QDQBERT Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + QDQBERT_START_DOCSTRING, +) +class QDQBertForTokenClassification(QDQBertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.bert = QDQBertModel(config, add_pooling_layer=False) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(QDQBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + QDQBERT Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear + layers on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + QDQBERT_START_DOCSTRING, +) +class QDQBertForQuestionAnswering(QDQBertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.bert = QDQBertModel(config, add_pooling_layer=False) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(QDQBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=QuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "QDQBertForMaskedLM", + "QDQBertForMultipleChoice", + "QDQBertForNextSentencePrediction", + "QDQBertForQuestionAnswering", + "QDQBertForSequenceClassification", + "QDQBertForTokenClassification", + "QDQBertLayer", + "QDQBertLMHeadModel", + "QDQBertModel", + "QDQBertPreTrainedModel", + "load_tf_weights_in_qdqbert", +] diff --git a/docs/transformers/build/lib/transformers/models/deprecated/realm/__init__.py b/docs/transformers/build/lib/transformers/models/deprecated/realm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cdfdeb5d179c8e954c9c6822d3f154d632394964 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/realm/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ....utils import _LazyModule +from ....utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_realm import * + from .modeling_realm import * + from .retrieval_realm import * + from .tokenization_realm import * + from .tokenization_realm_fast import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/build/lib/transformers/models/deprecated/realm/configuration_realm.py b/docs/transformers/build/lib/transformers/models/deprecated/realm/configuration_realm.py new file mode 100644 index 0000000000000000000000000000000000000000..fbf32378a604beca939ff6e0241f69f4c4382120 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/realm/configuration_realm.py @@ -0,0 +1,169 @@ +# coding=utf-8 +# Copyright 2022 The REALM authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""REALM model configuration.""" + +from ....configuration_utils import PretrainedConfig +from ....utils import logging + + +logger = logging.get_logger(__name__) + + +class RealmConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of + + 1. [`RealmEmbedder`] + 2. [`RealmScorer`] + 3. [`RealmKnowledgeAugEncoder`] + 4. [`RealmRetriever`] + 5. [`RealmReader`] + 6. [`RealmForOpenQA`] + + It is used to instantiate an REALM model according to the specified arguments, defining the model architecture. + Instantiating a configuration with the defaults will yield a similar configuration to that of the REALM + [google/realm-cc-news-pretrained-embedder](https://huggingface.co/google/realm-cc-news-pretrained-embedder) + architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 30522): + Vocabulary size of the REALM model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`RealmEmbedder`], [`RealmScorer`], [`RealmKnowledgeAugEncoder`], or + [`RealmReader`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimension of the encoder layers and the pooler layer. + retriever_proj_size (`int`, *optional*, defaults to 128): + Dimension of the retriever(embedder) projection. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + num_candidates (`int`, *optional*, defaults to 8): + Number of candidates inputted to the RealmScorer or RealmKnowledgeAugEncoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu_new"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + type_vocab_size (`int`, *optional*, defaults to 2): + The vocabulary size of the `token_type_ids` passed when calling [`RealmEmbedder`], [`RealmScorer`], + [`RealmKnowledgeAugEncoder`], or [`RealmReader`]. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + span_hidden_size (`int`, *optional*, defaults to 256): + Dimension of the reader's spans. + max_span_width (`int`, *optional*, defaults to 10): + Max span width of the reader. + reader_layer_norm_eps (`float`, *optional*, defaults to 1e-3): + The epsilon used by the reader's layer normalization layers. + reader_beam_size (`int`, *optional*, defaults to 5): + Beam size of the reader. + reader_seq_len (`int`, *optional*, defaults to 288+32): + Maximum sequence length of the reader. + num_block_records (`int`, *optional*, defaults to 13353718): + Number of block records. + searcher_beam_size (`int`, *optional*, defaults to 5000): + Beam size of the searcher. Note that when eval mode is enabled, *searcher_beam_size* will be the same as + *reader_beam_size*. + + Example: + + ```python + >>> from transformers import RealmConfig, RealmEmbedder + + >>> # Initializing a REALM realm-cc-news-pretrained-* style configuration + >>> configuration = RealmConfig() + + >>> # Initializing a model (with random weights) from the google/realm-cc-news-pretrained-embedder style configuration + >>> model = RealmEmbedder(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "realm" + + def __init__( + self, + vocab_size=30522, + hidden_size=768, + retriever_proj_size=128, + num_hidden_layers=12, + num_attention_heads=12, + num_candidates=8, + intermediate_size=3072, + hidden_act="gelu_new", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + layer_norm_eps=1e-12, + span_hidden_size=256, + max_span_width=10, + reader_layer_norm_eps=1e-3, + reader_beam_size=5, + reader_seq_len=320, # 288 + 32 + num_block_records=13353718, + searcher_beam_size=5000, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + # Common config + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.retriever_proj_size = retriever_proj_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_candidates = num_candidates + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.initializer_range = initializer_range + self.type_vocab_size = type_vocab_size + self.layer_norm_eps = layer_norm_eps + + # Reader config + self.span_hidden_size = span_hidden_size + self.max_span_width = max_span_width + self.reader_layer_norm_eps = reader_layer_norm_eps + self.reader_beam_size = reader_beam_size + self.reader_seq_len = reader_seq_len + + # Retrieval config + self.num_block_records = num_block_records + self.searcher_beam_size = searcher_beam_size + + +__all__ = ["RealmConfig"] diff --git a/docs/transformers/build/lib/transformers/models/deprecated/realm/modeling_realm.py b/docs/transformers/build/lib/transformers/models/deprecated/realm/modeling_realm.py new file mode 100644 index 0000000000000000000000000000000000000000..ac25a177333ef4eb5eb18f628adf1235d2084fee --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/realm/modeling_realm.py @@ -0,0 +1,1862 @@ +# coding=utf-8 +# Copyright 2022 The REALM authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch REALM model.""" + +import math +import os +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss + +from ....activations import ACT2FN +from ....modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + MaskedLMOutput, + ModelOutput, +) +from ....modeling_utils import PreTrainedModel +from ....pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ....utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from .configuration_realm import RealmConfig + + +logger = logging.get_logger(__name__) +_EMBEDDER_CHECKPOINT_FOR_DOC = "google/realm-cc-news-pretrained-embedder" +_ENCODER_CHECKPOINT_FOR_DOC = "google/realm-cc-news-pretrained-encoder" +_SCORER_CHECKPOINT_FOR_DOC = "google/realm-cc-news-pretrained-scorer" +_CONFIG_FOR_DOC = "RealmConfig" + + +def load_tf_weights_in_realm(model, config, tf_checkpoint_path): + """Load tf checkpoints in a pytorch model.""" + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array) + + for name, array in zip(names, arrays): + if isinstance(model, RealmReader) and "reader" not in name: + logger.info(f"Skipping {name} as it is not {model.__class__.__name__}'s parameter") + continue + + # For pretrained openqa reader + if (name.startswith("bert") or name.startswith("cls")) and isinstance(model, RealmForOpenQA): + name = name.replace("bert/", "reader/realm/") + name = name.replace("cls/", "reader/cls/") + + # For pretrained encoder + if (name.startswith("bert") or name.startswith("cls")) and isinstance(model, RealmKnowledgeAugEncoder): + name = name.replace("bert/", "realm/") + + # For finetuned reader + if name.startswith("reader"): + reader_prefix = "" if isinstance(model, RealmReader) else "reader/" + name = name.replace("reader/module/bert/", f"{reader_prefix}realm/") + name = name.replace("reader/module/cls/", f"{reader_prefix}cls/") + name = name.replace("reader/dense/", f"{reader_prefix}qa_outputs/dense_intermediate/") + name = name.replace("reader/dense_1/", f"{reader_prefix}qa_outputs/dense_output/") + name = name.replace("reader/layer_normalization", f"{reader_prefix}qa_outputs/layer_normalization") + + # For embedder and scorer + if name.startswith("module/module/module/"): # finetuned + embedder_prefix = "" if isinstance(model, RealmEmbedder) else "embedder/" + name = name.replace("module/module/module/module/bert/", f"{embedder_prefix}realm/") + name = name.replace("module/module/module/LayerNorm/", f"{embedder_prefix}cls/LayerNorm/") + name = name.replace("module/module/module/dense/", f"{embedder_prefix}cls/dense/") + name = name.replace("module/module/module/module/cls/predictions/", f"{embedder_prefix}cls/predictions/") + name = name.replace("module/module/module/bert/", f"{embedder_prefix}realm/") + name = name.replace("module/module/module/cls/predictions/", f"{embedder_prefix}cls/predictions/") + elif name.startswith("module/module/"): # pretrained + embedder_prefix = "" if isinstance(model, RealmEmbedder) else "embedder/" + name = name.replace("module/module/LayerNorm/", f"{embedder_prefix}cls/LayerNorm/") + name = name.replace("module/module/dense/", f"{embedder_prefix}cls/dense/") + + name = name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any( + n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] + for n in name + ): + logger.info(f"Skipping {'/'.join(name)}") + continue + pointer = model + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "kernel" or scope_names[0] == "gamma": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "output_bias" or scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info(f"Skipping {'/'.join(name)}") + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if m_name[-11:] == "_embeddings": + pointer = getattr(pointer, "weight") + elif m_name == "kernel": + array = np.transpose(array) + try: + assert pointer.shape == array.shape, ( + f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" + ) + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array) + return model + + +class RealmEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + self.register_buffer( + "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False + ) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values_length: int = 0, + ) -> torch.Tensor: + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] + + # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs + # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves + # issue #5664 + if token_type_ids is None: + if hasattr(self, "token_type_ids"): + buffered_token_type_ids = self.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class RealmSelfAttention(nn.Module): + def __init__(self, config, position_embedding_type=None): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = position_embedding_type or getattr( + config, "position_embedding_type", "absolute" + ) + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) + + self.is_decoder = config.is_decoder + + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_layer = past_key_value[0] + value_layer = past_key_value[1] + attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + use_cache = past_key_value is not None + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + query_length, key_length = query_layer.shape[2], key_layer.shape[2] + if use_cache: + position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( + -1, 1 + ) + else: + position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in RealmModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + +class RealmSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +REALM_SELF_ATTENTION_CLASSES = { + "eager": RealmSelfAttention, +} + + +class RealmAttention(nn.Module): + def __init__(self, config, position_embedding_type=None): + super().__init__() + self.self = REALM_SELF_ATTENTION_CLASSES[config._attn_implementation]( + config, position_embedding_type=position_embedding_type + ) + self.output = RealmSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class RealmIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class RealmOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class RealmLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = RealmAttention(config) + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + if not self.is_decoder: + raise ValueError(f"{self} should be used as a decoder model if cross attention is added") + self.crossattention = RealmAttention(config, position_embedding_type="absolute") + self.intermediate = RealmIntermediate(config) + self.output = RealmOutput(config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + + # if decoder, the last output is tuple of self-attn cache + if self.is_decoder: + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + else: + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + cross_attn_present_key_value = None + if self.is_decoder and encoder_hidden_states is not None: + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" + " by setting `config.add_cross_attention=True`" + ) + + # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + cross_attn_past_key_value, + output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + + # add cross-attn cache to positions 3,4 of present_key_value tuple + cross_attn_present_key_value = cross_attention_outputs[-1] + present_key_value = present_key_value + cross_attn_present_key_value + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + ) + outputs = (layer_output,) + outputs + + # if decoder, return the attn key/values as the last output + if self.is_decoder: + outputs = outputs + (present_key_value,) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class RealmEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([RealmLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + next_decoder_cache = () if use_cache else None + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class RealmPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +@dataclass +class RealmEmbedderOutput(ModelOutput): + """ + Outputs of [`RealmEmbedder`] models. + + Args: + projected_score (`torch.FloatTensor` of shape `(batch_size, config.retriever_proj_size)`): + + Projected score. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + projected_score: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class RealmScorerOutput(ModelOutput): + """ + Outputs of [`RealmScorer`] models. + + Args: + relevance_score (`torch.FloatTensor` of shape `(batch_size, config.num_candidates)`): + The relevance score of document candidates (before softmax). + query_score (`torch.FloatTensor` of shape `(batch_size, config.retriever_proj_size)`): + Query score derived from the query embedder. + candidate_score (`torch.FloatTensor` of shape `(batch_size, config.num_candidates, config.retriever_proj_size)`): + Candidate score derived from the embedder. + """ + + relevance_score: Optional[torch.FloatTensor] = None + query_score: Optional[torch.FloatTensor] = None + candidate_score: Optional[torch.FloatTensor] = None + + +@dataclass +class RealmReaderOutput(ModelOutput): + """ + Outputs of [`RealmReader`] models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `start_positions`, `end_positions`, `has_answers` are provided): + Total loss. + retriever_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `start_positions`, `end_positions`, `has_answers` are provided): + Retriever loss. + reader_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `start_positions`, `end_positions`, `has_answers` are provided): + Reader loss. + retriever_correct (`torch.BoolTensor` of shape `(config.searcher_beam_size,)`, *optional*): + Whether or not an evidence block contains answer. + reader_correct (`torch.BoolTensor` of shape `(config.reader_beam_size, num_candidates)`, *optional*): + Whether or not a span candidate contains answer. + block_idx (`torch.LongTensor` of shape `()`): + The index of the retrieved evidence block in which the predicted answer is most likely. + candidate (`torch.LongTensor` of shape `()`): + The index of the retrieved span candidates in which the predicted answer is most likely. + start_pos (`torch.IntTensor` of shape `()`): + Predicted answer starting position in *RealmReader*'s inputs. + end_pos (`torch.IntTensor` of shape `()`): + Predicted answer ending position in *RealmReader*'s inputs. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + retriever_loss: Optional[torch.FloatTensor] = None + reader_loss: Optional[torch.FloatTensor] = None + retriever_correct: torch.BoolTensor = None + reader_correct: torch.BoolTensor = None + block_idx: Optional[torch.LongTensor] = None + candidate: Optional[torch.LongTensor] = None + start_pos: torch.int32 = None + end_pos: torch.int32 = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class RealmForOpenQAOutput(ModelOutput): + """ + + Outputs of [`RealmForOpenQA`] models. + + Args: + reader_output (`dict`): + Reader output. + predicted_answer_ids (`torch.LongTensor` of shape `(answer_sequence_length)`): + Predicted answer ids. + """ + + reader_output: dict = None + predicted_answer_ids: Optional[torch.LongTensor] = None + + +class RealmPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class RealmLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = RealmPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def _tie_weights(self): + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class RealmOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = RealmLMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class RealmScorerProjection(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = RealmLMPredictionHead(config) + self.dense = nn.Linear(config.hidden_size, config.retriever_proj_size) + self.LayerNorm = nn.LayerNorm(config.retriever_proj_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class RealmReaderProjection(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.dense_intermediate = nn.Linear(config.hidden_size, config.span_hidden_size * 2) + self.dense_output = nn.Linear(config.span_hidden_size, 1) + self.layer_normalization = nn.LayerNorm(config.span_hidden_size, eps=config.reader_layer_norm_eps) + self.relu = nn.ReLU() + + def forward(self, hidden_states, block_mask): + def span_candidates(masks): + """ + Generate span candidates. + + Args: + masks: [num_retrievals, max_sequence_len] + + Returns: + starts: [num_spans] ends: [num_spans] span_masks: [num_retrievals, num_spans] + whether spans locate in evidence block. + """ + _, max_sequence_len = masks.shape + + def _spans_given_width(width): + current_starts = torch.arange(max_sequence_len - width + 1, device=masks.device) + current_ends = torch.arange(width - 1, max_sequence_len, device=masks.device) + return current_starts, current_ends + + starts, ends = zip(*(_spans_given_width(w + 1) for w in range(self.config.max_span_width))) + + # [num_spans] + starts = torch.cat(starts, 0) + ends = torch.cat(ends, 0) + + # [num_retrievals, num_spans] + start_masks = torch.index_select(masks, dim=-1, index=starts) + end_masks = torch.index_select(masks, dim=-1, index=ends) + span_masks = start_masks * end_masks + + return starts, ends, span_masks + + def mask_to_score(mask, dtype=torch.float32): + return (1.0 - mask.type(dtype)) * torch.finfo(dtype).min + + # [reader_beam_size, max_sequence_len, span_hidden_size * 2] + hidden_states = self.dense_intermediate(hidden_states) + # [reader_beam_size, max_sequence_len, span_hidden_size] + start_projection, end_projection = hidden_states.chunk(2, dim=-1) + + candidate_starts, candidate_ends, candidate_mask = span_candidates(block_mask) + + candidate_start_projections = torch.index_select(start_projection, dim=1, index=candidate_starts) + candidate_end_projections = torch.index_select(end_projection, dim=1, index=candidate_ends) + candidate_hidden = candidate_start_projections + candidate_end_projections + + # [reader_beam_size, num_candidates, span_hidden_size] + candidate_hidden = self.relu(candidate_hidden) + # [reader_beam_size, num_candidates, span_hidden_size] + candidate_hidden = self.layer_normalization(candidate_hidden) + # [reader_beam_size, num_candidates] + reader_logits = self.dense_output(candidate_hidden).squeeze(-1) + # [reader_beam_size, num_candidates] + reader_logits += mask_to_score(candidate_mask, dtype=reader_logits.dtype) + + return reader_logits, candidate_starts, candidate_ends + + +REALM_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use + it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`RealmConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +REALM_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert *input_ids* indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class RealmPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = RealmConfig + load_tf_weights = load_tf_weights_in_realm + base_model_prefix = "realm" + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def _flatten_inputs(self, *inputs): + """Flatten inputs' shape to (-1, input_shape[-1])""" + flattened_inputs = [] + for tensor in inputs: + if tensor is None: + flattened_inputs.append(None) + else: + input_shape = tensor.shape + if len(input_shape) > 2: + tensor = tensor.view((-1, input_shape[-1])) + flattened_inputs.append(tensor) + return flattened_inputs + + +class RealmBertModel(RealmPreTrainedModel): + """ + Same as the original BertModel but remove docstrings. + """ + + def __init__(self, config, add_pooling_layer=True): + super().__init__(config) + self.config = config + + self.embeddings = RealmEmbeddings(config) + self.encoder = RealmEncoder(config) + + self.pooler = RealmPooler(config) if add_pooling_layer else None + + # Weights initialization is mostly managed by other Realm models, + # but we also have them initialized here to keep a consistency. + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.config.is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + + if token_type_ids is None: + if hasattr(self.embeddings, "token_type_ids"): + buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +@add_start_docstrings( + "The embedder of REALM outputting projected score that will be used to calculate relevance score.", + REALM_START_DOCSTRING, +) +class RealmEmbedder(RealmPreTrainedModel): + _tied_weights_keys = ["cls.predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.realm = RealmBertModel(self.config) + self.cls = RealmScorerProjection(self.config) + self.post_init() + + def get_input_embeddings(self): + return self.realm.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.realm.embeddings.word_embeddings = value + + @add_start_docstrings_to_model_forward(REALM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=RealmEmbedderOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, RealmEmbedderOutput]: + r""" + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, RealmEmbedder + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("google/realm-cc-news-pretrained-embedder") + >>> model = RealmEmbedder.from_pretrained("google/realm-cc-news-pretrained-embedder") + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> projected_score = outputs.projected_score + ``` + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + realm_outputs = self.realm( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + # [batch_size, hidden_size] + pooler_output = realm_outputs[1] + # [batch_size, retriever_proj_size] + projected_score = self.cls(pooler_output) + + if not return_dict: + return (projected_score,) + realm_outputs[2:4] + else: + return RealmEmbedderOutput( + projected_score=projected_score, + hidden_states=realm_outputs.hidden_states, + attentions=realm_outputs.attentions, + ) + + +@add_start_docstrings( + "The scorer of REALM outputting relevance scores representing the score of document candidates (before softmax).", + REALM_START_DOCSTRING, +) +class RealmScorer(RealmPreTrainedModel): + r""" + Args: + query_embedder ([`RealmEmbedder`]): + Embedder for input sequences. If not specified, it will use the same embedder as candidate sequences. + """ + + def __init__(self, config, query_embedder=None): + super().__init__(config) + + self.embedder = RealmEmbedder(self.config) + + self.query_embedder = query_embedder if query_embedder is not None else self.embedder + + self.post_init() + + @add_start_docstrings_to_model_forward(REALM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=RealmScorerOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + candidate_input_ids: Optional[torch.LongTensor] = None, + candidate_attention_mask: Optional[torch.FloatTensor] = None, + candidate_token_type_ids: Optional[torch.LongTensor] = None, + candidate_inputs_embeds: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, RealmScorerOutput]: + r""" + candidate_input_ids (`torch.LongTensor` of shape `(batch_size, num_candidates, sequence_length)`): + Indices of candidate input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + candidate_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_candidates, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + candidate_token_type_ids (`torch.LongTensor` of shape `(batch_size, num_candidates, sequence_length)`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + candidate_inputs_embeds (`torch.FloatTensor` of shape `(batch_size * num_candidates, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `candidate_input_ids` you can choose to directly pass an embedded + representation. This is useful if you want more control over how to convert *candidate_input_ids* indices + into associated vectors than the model's internal embedding lookup matrix. + + Returns: + + Example: + + ```python + >>> import torch + >>> from transformers import AutoTokenizer, RealmScorer + + >>> tokenizer = AutoTokenizer.from_pretrained("google/realm-cc-news-pretrained-scorer") + >>> model = RealmScorer.from_pretrained("google/realm-cc-news-pretrained-scorer", num_candidates=2) + + >>> # batch_size = 2, num_candidates = 2 + >>> input_texts = ["How are you?", "What is the item in the picture?"] + >>> candidates_texts = [["Hello world!", "Nice to meet you!"], ["A cute cat.", "An adorable dog."]] + + >>> inputs = tokenizer(input_texts, return_tensors="pt") + >>> candidates_inputs = tokenizer.batch_encode_candidates(candidates_texts, max_length=10, return_tensors="pt") + + >>> outputs = model( + ... **inputs, + ... candidate_input_ids=candidates_inputs.input_ids, + ... candidate_attention_mask=candidates_inputs.attention_mask, + ... candidate_token_type_ids=candidates_inputs.token_type_ids, + ... ) + >>> relevance_score = outputs.relevance_score + ```""" + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is None and inputs_embeds is None: + raise ValueError("You have to specify either input_ids or input_embeds.") + + if candidate_input_ids is None and candidate_inputs_embeds is None: + raise ValueError("You have to specify either candidate_input_ids or candidate_inputs_embeds.") + + query_outputs = self.query_embedder( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + # [batch_size * num_candidates, candidate_seq_len] + (flattened_input_ids, flattened_attention_mask, flattened_token_type_ids) = self._flatten_inputs( + candidate_input_ids, candidate_attention_mask, candidate_token_type_ids + ) + + candidate_outputs = self.embedder( + flattened_input_ids, + attention_mask=flattened_attention_mask, + token_type_ids=flattened_token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=candidate_inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + # [batch_size, retriever_proj_size] + query_score = query_outputs[0] + # [batch_size * num_candidates, retriever_proj_size] + candidate_score = candidate_outputs[0] + # [batch_size, num_candidates, retriever_proj_size] + candidate_score = candidate_score.view(-1, self.config.num_candidates, self.config.retriever_proj_size) + # [batch_size, num_candidates] + relevance_score = torch.einsum("bd,bnd->bn", query_score, candidate_score) + + if not return_dict: + return relevance_score, query_score, candidate_score + + return RealmScorerOutput( + relevance_score=relevance_score, query_score=query_score, candidate_score=candidate_score + ) + + +@add_start_docstrings( + "The knowledge-augmented encoder of REALM outputting masked language model logits and marginal log-likelihood" + " loss.", + REALM_START_DOCSTRING, +) +class RealmKnowledgeAugEncoder(RealmPreTrainedModel): + _tied_weights_keys = ["cls.predictions.decoder"] + + def __init__(self, config): + super().__init__(config) + self.realm = RealmBertModel(self.config) + self.cls = RealmOnlyMLMHead(self.config) + self.post_init() + + def get_input_embeddings(self): + return self.realm.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.realm.embeddings.word_embeddings = value + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias + + @add_start_docstrings_to_model_forward( + REALM_INPUTS_DOCSTRING.format("batch_size, num_candidates, sequence_length") + ) + @replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + relevance_score: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + mlm_mask: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, MaskedLMOutput]: + r""" + relevance_score (`torch.FloatTensor` of shape `(batch_size, num_candidates)`, *optional*): + Relevance score derived from RealmScorer, must be specified if you want to compute the masked language + modeling loss. + + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + + mlm_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid calculating joint loss on certain positions. If not specified, the loss will not be masked. + Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + Returns: + + Example: + + ```python + >>> import torch + >>> from transformers import AutoTokenizer, RealmKnowledgeAugEncoder + + >>> tokenizer = AutoTokenizer.from_pretrained("google/realm-cc-news-pretrained-encoder") + >>> model = RealmKnowledgeAugEncoder.from_pretrained( + ... "google/realm-cc-news-pretrained-encoder", num_candidates=2 + ... ) + + >>> # batch_size = 2, num_candidates = 2 + >>> text = [["Hello world!", "Nice to meet you!"], ["The cute cat.", "The adorable dog."]] + + >>> inputs = tokenizer.batch_encode_candidates(text, max_length=10, return_tensors="pt") + >>> outputs = model(**inputs) + >>> logits = outputs.logits + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None and relevance_score is None: + raise ValueError( + "You have to specify `relevance_score` when `labels` is specified in order to compute loss." + ) + + (flattened_input_ids, flattened_attention_mask, flattened_token_type_ids) = self._flatten_inputs( + input_ids, attention_mask, token_type_ids + ) + + joint_outputs = self.realm( + flattened_input_ids, + attention_mask=flattened_attention_mask, + token_type_ids=flattened_token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + # [batch_size * num_candidates, joint_seq_len, hidden_size] + joint_output = joint_outputs[0] + # [batch_size * num_candidates, joint_seq_len, vocab_size] + prediction_scores = self.cls(joint_output) + # [batch_size, num_candidates] + candidate_score = relevance_score + + masked_lm_loss = None + if labels is not None: + batch_size, seq_length = labels.size() + + if mlm_mask is None: + mlm_mask = torch.ones_like(labels, dtype=torch.float32) + else: + mlm_mask = mlm_mask.type(torch.float32) + + # Compute marginal log-likelihood + loss_fct = CrossEntropyLoss(reduction="none") # -100 index = padding token + + # [batch_size * num_candidates * joint_seq_len, vocab_size] + mlm_logits = prediction_scores.view(-1, self.config.vocab_size) + # [batch_size * num_candidates * joint_seq_len] + mlm_targets = labels.tile(1, self.config.num_candidates).view(-1) + # [batch_size, num_candidates, joint_seq_len] + masked_lm_log_prob = -loss_fct(mlm_logits, mlm_targets).view( + batch_size, self.config.num_candidates, seq_length + ) + # [batch_size, num_candidates, 1] + candidate_log_prob = candidate_score.log_softmax(-1).unsqueeze(-1) + # [batch_size, num_candidates, joint_seq_len] + joint_gold_log_prob = candidate_log_prob + masked_lm_log_prob + # [batch_size, joint_seq_len] + marginal_gold_log_probs = joint_gold_log_prob.logsumexp(1) + # [] + masked_lm_loss = -torch.nansum(torch.sum(marginal_gold_log_probs * mlm_mask) / torch.sum(mlm_mask)) + + if not return_dict: + output = (prediction_scores,) + joint_outputs[2:4] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=joint_outputs.hidden_states, + attentions=joint_outputs.attentions, + ) + + +@add_start_docstrings("The reader of REALM.", REALM_START_DOCSTRING) +class RealmReader(RealmPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.realm = RealmBertModel(config) + self.cls = RealmOnlyMLMHead(config) + self.qa_outputs = RealmReaderProjection(config) + + self.post_init() + + @add_start_docstrings_to_model_forward(REALM_INPUTS_DOCSTRING.format("reader_beam_size, sequence_length")) + @replace_return_docstrings(output_type=RealmReaderOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + relevance_score: Optional[torch.FloatTensor] = None, + block_mask: Optional[torch.BoolTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + has_answers: Optional[torch.BoolTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, RealmReaderOutput]: + r""" + relevance_score (`torch.FloatTensor` of shape `(searcher_beam_size,)`, *optional*): + Relevance score, which must be specified if you want to compute the logits and marginal log loss. + block_mask (`torch.BoolTensor` of shape `(searcher_beam_size, sequence_length)`, *optional*): + The mask of the evidence block, which must be specified if you want to compute the logits and marginal log + loss. + start_positions (`torch.LongTensor` of shape `(searcher_beam_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(searcher_beam_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + has_answers (`torch.BoolTensor` of shape `(searcher_beam_size,)`, *optional*): + Whether or not the evidence block has answer(s). + + Returns: + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if relevance_score is None: + raise ValueError("You have to specify `relevance_score` to calculate logits and loss.") + if block_mask is None: + raise ValueError("You have to specify `block_mask` to separate question block and evidence block.") + if token_type_ids.size(1) < self.config.max_span_width: + raise ValueError("The input sequence length must be greater than or equal to config.max_span_width.") + outputs = self.realm( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + # [reader_beam_size, joint_seq_len, hidden_size] + sequence_output = outputs[0] + + # [reader_beam_size, num_candidates], [num_candidates], [num_candidates] + reader_logits, candidate_starts, candidate_ends = self.qa_outputs( + sequence_output, block_mask[0 : self.config.reader_beam_size] + ) + # [searcher_beam_size, 1] + retriever_logits = torch.unsqueeze(relevance_score[0 : self.config.reader_beam_size], -1) + # [reader_beam_size, num_candidates] + reader_logits += retriever_logits + # [] + predicted_block_index = torch.argmax(torch.max(reader_logits, dim=1).values) + # [] + predicted_candidate = torch.argmax(torch.max(reader_logits, dim=0).values) + # [1] + predicted_start = torch.index_select(candidate_starts, dim=0, index=predicted_candidate) + # [1] + predicted_end = torch.index_select(candidate_ends, dim=0, index=predicted_candidate) + + total_loss = None + retriever_loss = None + reader_loss = None + retriever_correct = None + reader_correct = None + if start_positions is not None and end_positions is not None and has_answers is not None: + + def compute_correct_candidates(candidate_starts, candidate_ends, gold_starts, gold_ends): + """Compute correct span.""" + # [reader_beam_size, num_answers, num_candidates] + is_gold_start = torch.eq( + torch.unsqueeze(torch.unsqueeze(candidate_starts, 0), 0), torch.unsqueeze(gold_starts, -1) + ) + is_gold_end = torch.eq( + torch.unsqueeze(torch.unsqueeze(candidate_ends, 0), 0), torch.unsqueeze(gold_ends, -1) + ) + + # [reader_beam_size, num_candidates] + return torch.any(torch.logical_and(is_gold_start, is_gold_end), 1) + + def marginal_log_loss(logits, is_correct): + """Loss based on the negative marginal log-likelihood.""" + + def mask_to_score(mask, dtype=torch.float32): + return (1.0 - mask.type(dtype)) * torch.finfo(dtype).min + + # [] + log_numerator = torch.logsumexp(logits + mask_to_score(is_correct, dtype=logits.dtype), dim=-1) + log_denominator = torch.logsumexp(logits, dim=-1) + return log_denominator - log_numerator + + # sometimes the start/end positions are outside our model inputs, we ignore these terms + # `-1` is reserved for no answer. + ignored_index = sequence_output.size(1) + start_positions = start_positions.clamp(-1, ignored_index) + end_positions = end_positions.clamp(-1, ignored_index) + + retriever_correct = has_answers + any_retriever_correct = torch.any(retriever_correct) + + reader_correct = compute_correct_candidates( + candidate_starts=candidate_starts, + candidate_ends=candidate_ends, + gold_starts=start_positions[0 : self.config.reader_beam_size], + gold_ends=end_positions[0 : self.config.reader_beam_size], + ) + any_reader_correct = torch.any(reader_correct) + + retriever_loss = marginal_log_loss(relevance_score, retriever_correct) + reader_loss = marginal_log_loss(reader_logits.view(-1), reader_correct.view(-1)) + retriever_loss *= any_retriever_correct.type(torch.float32) + reader_loss *= any_reader_correct.type(torch.float32) + + total_loss = (retriever_loss + reader_loss).mean() + + if not return_dict: + output = (predicted_block_index, predicted_candidate, predicted_start, predicted_end) + outputs[2:] + return ( + ((total_loss, retriever_loss, reader_loss, retriever_correct, reader_correct) + output) + if total_loss is not None + else output + ) + + return RealmReaderOutput( + loss=total_loss, + retriever_loss=retriever_loss, + reader_loss=reader_loss, + retriever_correct=retriever_correct, + reader_correct=reader_correct, + block_idx=predicted_block_index, + candidate=predicted_candidate, + start_pos=predicted_start, + end_pos=predicted_end, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +REALM_FOR_OPEN_QA_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token (should not be used in this model by design). + + [What are token type IDs?](../glossary#token-type-ids) + answer_ids (`list` of shape `(num_answers, answer_length)`, *optional*): + Answer ids for computing the marginal log-likelihood loss. Indices should be in `[-1, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-1` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "`RealmForOpenQA` for end-to-end open domain question answering.", + REALM_START_DOCSTRING, +) +class RealmForOpenQA(RealmPreTrainedModel): + def __init__(self, config, retriever=None): + super().__init__(config) + self.embedder = RealmEmbedder(config) + self.reader = RealmReader(config) + self.register_buffer( + "block_emb", + torch.zeros(()).new_empty( + size=(config.num_block_records, config.retriever_proj_size), + dtype=torch.float32, + device=torch.device("cpu"), + ), + ) + self.retriever = retriever + + self.post_init() + + @property + def searcher_beam_size(self): + if self.training: + return self.config.searcher_beam_size + return self.config.reader_beam_size + + def block_embedding_to(self, device): + """Send `self.block_emb` to a specific device. + + Args: + device (`str` or `torch.device`): + The device to which `self.block_emb` will be sent. + """ + + self.block_emb = self.block_emb.to(device) + + @add_start_docstrings_to_model_forward(REALM_FOR_OPEN_QA_DOCSTRING.format("1, sequence_length")) + @replace_return_docstrings(output_type=RealmForOpenQAOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor], + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + answer_ids: Optional[torch.LongTensor] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, RealmForOpenQAOutput]: + r""" + Returns: + + Example: + + ```python + >>> import torch + >>> from transformers import RealmForOpenQA, RealmRetriever, AutoTokenizer + + >>> retriever = RealmRetriever.from_pretrained("google/realm-orqa-nq-openqa") + >>> tokenizer = AutoTokenizer.from_pretrained("google/realm-orqa-nq-openqa") + >>> model = RealmForOpenQA.from_pretrained("google/realm-orqa-nq-openqa", retriever=retriever) + + >>> question = "Who is the pioneer in modern computer science?" + >>> question_ids = tokenizer([question], return_tensors="pt") + >>> answer_ids = tokenizer( + ... ["alan mathison turing"], + ... add_special_tokens=False, + ... return_token_type_ids=False, + ... return_attention_mask=False, + ... ).input_ids + + >>> reader_output, predicted_answer_ids = model(**question_ids, answer_ids=answer_ids, return_dict=False) + >>> predicted_answer = tokenizer.decode(predicted_answer_ids) + >>> loss = reader_output.loss + ```""" + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and input_ids.shape[0] != 1: + raise ValueError("The batch_size of the inputs must be 1.") + + question_outputs = self.embedder( + input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, return_dict=True + ) + # [1, projection_size] + question_projection = question_outputs[0] + + # CPU computation starts. + # [1, block_emb_size] + batch_scores = torch.einsum("BD,QD->QB", self.block_emb, question_projection.to(self.block_emb.device)) + # [1, searcher_beam_size] + _, retrieved_block_ids = torch.topk(batch_scores, k=self.searcher_beam_size, dim=-1) + # [searcher_beam_size] + retrieved_block_ids = retrieved_block_ids.squeeze() + # [searcher_beam_size, projection_size] + retrieved_block_emb = torch.index_select(self.block_emb, dim=0, index=retrieved_block_ids) + # CPU computation ends. + + # Retrieve possible answers + has_answers, start_pos, end_pos, concat_inputs = self.retriever( + retrieved_block_ids.cpu(), input_ids, answer_ids, max_length=self.config.reader_seq_len + ) + + concat_inputs = concat_inputs.to(self.reader.device) + block_mask = concat_inputs.special_tokens_mask.type(torch.bool).to(device=self.reader.device) + block_mask.logical_not_().logical_and_(concat_inputs.token_type_ids.type(torch.bool)) + + if has_answers is not None: + has_answers = torch.tensor(has_answers, dtype=torch.bool, device=self.reader.device) + start_pos = torch.tensor(start_pos, dtype=torch.long, device=self.reader.device) + end_pos = torch.tensor(end_pos, dtype=torch.long, device=self.reader.device) + + # [searcher_beam_size] + retrieved_logits = torch.einsum( + "D,BD->B", question_projection.squeeze(), retrieved_block_emb.to(self.reader.device) + ) + + reader_output = self.reader( + input_ids=concat_inputs.input_ids[0 : self.config.reader_beam_size], + attention_mask=concat_inputs.attention_mask[0 : self.config.reader_beam_size], + token_type_ids=concat_inputs.token_type_ids[0 : self.config.reader_beam_size], + relevance_score=retrieved_logits, + block_mask=block_mask, + has_answers=has_answers, + start_positions=start_pos, + end_positions=end_pos, + return_dict=True, + ) + + predicted_block = concat_inputs.input_ids[reader_output.block_idx] + predicted_answer_ids = predicted_block[reader_output.start_pos : reader_output.end_pos + 1] + + if not return_dict: + return reader_output, predicted_answer_ids + + return RealmForOpenQAOutput( + reader_output=reader_output, + predicted_answer_ids=predicted_answer_ids, + ) + + +__all__ = [ + "RealmEmbedder", + "RealmForOpenQA", + "RealmKnowledgeAugEncoder", + "RealmPreTrainedModel", + "RealmReader", + "RealmScorer", + "load_tf_weights_in_realm", +] diff --git a/docs/transformers/build/lib/transformers/models/deprecated/realm/retrieval_realm.py b/docs/transformers/build/lib/transformers/models/deprecated/realm/retrieval_realm.py new file mode 100644 index 0000000000000000000000000000000000000000..b3c084f1d2090975f8d6ed91665caf21e3e7829c --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/realm/retrieval_realm.py @@ -0,0 +1,176 @@ +# coding=utf-8 +# Copyright 2022 The REALM authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""REALM Retriever model implementation.""" + +import os +from typing import Optional, Union + +import numpy as np +from huggingface_hub import hf_hub_download + +from transformers import AutoTokenizer + +from ....utils import logging, strtobool + + +_REALM_BLOCK_RECORDS_FILENAME = "block_records.npy" + + +logger = logging.get_logger(__name__) + + +def convert_tfrecord_to_np(block_records_path: str, num_block_records: int) -> np.ndarray: + import tensorflow.compat.v1 as tf + + blocks_dataset = tf.data.TFRecordDataset(block_records_path, buffer_size=512 * 1024 * 1024) + blocks_dataset = blocks_dataset.batch(num_block_records, drop_remainder=True) + np_record = next(blocks_dataset.take(1).as_numpy_iterator()) + + return np_record + + +class ScaNNSearcher: + """Note that ScaNNSearcher cannot currently be used within the model. In future versions, it might however be included.""" + + def __init__( + self, + db, + num_neighbors, + dimensions_per_block=2, + num_leaves=1000, + num_leaves_to_search=100, + training_sample_size=100000, + ): + """Build scann searcher.""" + + from scann.scann_ops.py.scann_ops_pybind import builder as Builder + + builder = Builder(db=db, num_neighbors=num_neighbors, distance_measure="dot_product") + builder = builder.tree( + num_leaves=num_leaves, num_leaves_to_search=num_leaves_to_search, training_sample_size=training_sample_size + ) + builder = builder.score_ah(dimensions_per_block=dimensions_per_block) + + self.searcher = builder.build() + + def search_batched(self, question_projection): + retrieved_block_ids, _ = self.searcher.search_batched(question_projection.detach().cpu()) + return retrieved_block_ids.astype("int64") + + +class RealmRetriever: + """The retriever of REALM outputting the retrieved evidence block and whether the block has answers as well as answer + positions." + + Parameters: + block_records (`np.ndarray`): + A numpy array which cantains evidence texts. + tokenizer ([`RealmTokenizer`]): + The tokenizer to encode retrieved texts. + """ + + def __init__(self, block_records, tokenizer): + super().__init__() + self.block_records = block_records + self.tokenizer = tokenizer + + def __call__(self, retrieved_block_ids, question_input_ids, answer_ids, max_length=None, return_tensors="pt"): + retrieved_blocks = np.take(self.block_records, indices=retrieved_block_ids, axis=0) + + question = self.tokenizer.decode(question_input_ids[0], skip_special_tokens=True) + + text = [] + text_pair = [] + for retrieved_block in retrieved_blocks: + text.append(question) + text_pair.append(retrieved_block.decode()) + + concat_inputs = self.tokenizer( + text, text_pair, padding=True, truncation=True, return_special_tokens_mask=True, max_length=max_length + ) + concat_inputs_tensors = concat_inputs.convert_to_tensors(return_tensors) + + if answer_ids is not None: + return self.block_has_answer(concat_inputs, answer_ids) + (concat_inputs_tensors,) + else: + return (None, None, None, concat_inputs_tensors) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *init_inputs, **kwargs): + if os.path.isdir(pretrained_model_name_or_path): + block_records_path = os.path.join(pretrained_model_name_or_path, _REALM_BLOCK_RECORDS_FILENAME) + else: + block_records_path = hf_hub_download( + repo_id=pretrained_model_name_or_path, filename=_REALM_BLOCK_RECORDS_FILENAME, **kwargs + ) + if not strtobool(os.environ.get("TRUST_REMOTE_CODE", "False")): + raise ValueError( + "This part uses `pickle.load` which is insecure and will execute arbitrary code that is " + "potentially malicious. It's recommended to never unpickle data that could have come from an " + "untrusted source, or that could have been tampered with. If you already verified the pickle " + "data and decided to use it, you can set the environment variable " + "`TRUST_REMOTE_CODE` to `True` to allow it." + ) + block_records = np.load(block_records_path, allow_pickle=True) + + tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, *init_inputs, **kwargs) + + return cls(block_records, tokenizer) + + def save_pretrained(self, save_directory): + # save block records + np.save(os.path.join(save_directory, _REALM_BLOCK_RECORDS_FILENAME), self.block_records) + # save tokenizer + self.tokenizer.save_pretrained(save_directory) + + def block_has_answer(self, concat_inputs, answer_ids): + """check if retrieved_blocks has answers.""" + has_answers = [] + start_pos = [] + end_pos = [] + max_answers = 0 + + for input_id in concat_inputs.input_ids: + input_id_list = input_id.tolist() + # Check answers between two [SEP] tokens + first_sep_idx = input_id_list.index(self.tokenizer.sep_token_id) + second_sep_idx = first_sep_idx + 1 + input_id_list[first_sep_idx + 1 :].index(self.tokenizer.sep_token_id) + + start_pos.append([]) + end_pos.append([]) + for answer in answer_ids: + for idx in range(first_sep_idx + 1, second_sep_idx): + if answer[0] == input_id_list[idx]: + if input_id_list[idx : idx + len(answer)] == answer: + start_pos[-1].append(idx) + end_pos[-1].append(idx + len(answer) - 1) + + if len(start_pos[-1]) == 0: + has_answers.append(False) + else: + has_answers.append(True) + if len(start_pos[-1]) > max_answers: + max_answers = len(start_pos[-1]) + + # Pad -1 to max_answers + for start_pos_, end_pos_ in zip(start_pos, end_pos): + if len(start_pos_) < max_answers: + padded = [-1] * (max_answers - len(start_pos_)) + start_pos_ += padded + end_pos_ += padded + return has_answers, start_pos, end_pos + + +__all__ = ["RealmRetriever"] diff --git a/docs/transformers/build/lib/transformers/models/deprecated/realm/tokenization_realm.py b/docs/transformers/build/lib/transformers/models/deprecated/realm/tokenization_realm.py new file mode 100644 index 0000000000000000000000000000000000000000..70e69bc4bc2bce4024322950add44e5f58b04f1c --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/realm/tokenization_realm.py @@ -0,0 +1,563 @@ +# coding=utf-8 +# Copyright 2022 The REALM authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for REALM.""" + +import collections +import os +import unicodedata +from typing import List, Optional, Tuple + +from ....tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace +from ....tokenization_utils_base import BatchEncoding +from ....utils import PaddingStrategy, logging + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} + + +def load_vocab(vocab_file): + """Loads a vocabulary file into a dictionary.""" + vocab = collections.OrderedDict() + with open(vocab_file, "r", encoding="utf-8") as reader: + tokens = reader.readlines() + for index, token in enumerate(tokens): + token = token.rstrip("\n") + vocab[token] = index + return vocab + + +def whitespace_tokenize(text): + """Runs basic whitespace cleaning and splitting on a piece of text.""" + text = text.strip() + if not text: + return [] + tokens = text.split() + return tokens + + +class RealmTokenizer(PreTrainedTokenizer): + r""" + Construct a REALM tokenizer. + + [`RealmTokenizer`] is identical to [`BertTokenizer`] and runs end-to-end tokenization: punctuation splitting and + wordpiece. + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + File containing the vocabulary. + do_lower_case (`bool`, *optional*, defaults to `True`): + Whether or not to lowercase the input when tokenizing. + do_basic_tokenize (`bool`, *optional*, defaults to `True`): + Whether or not to do basic tokenization before WordPiece. + never_split (`Iterable`, *optional*): + Collection of tokens which will never be split during tokenization. Only has an effect when + `do_basic_tokenize=True` + unk_token (`str`, *optional*, defaults to `"[UNK]"`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + sep_token (`str`, *optional*, defaults to `"[SEP]"`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + pad_token (`str`, *optional*, defaults to `"[PAD]"`): + The token used for padding, for example when batching sequences of different lengths. + cls_token (`str`, *optional*, defaults to `"[CLS]"`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + mask_token (`str`, *optional*, defaults to `"[MASK]"`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + tokenize_chinese_chars (`bool`, *optional*, defaults to `True`): + Whether or not to tokenize Chinese characters. + + This should likely be deactivated for Japanese (see this + [issue](https://github.com/huggingface/transformers/issues/328)). + strip_accents (`bool`, *optional*): + Whether or not to strip all accents. If this option is not specified, then it will be determined by the + value for `lowercase` (as in the original BERT). + """ + + vocab_files_names = VOCAB_FILES_NAMES + + def __init__( + self, + vocab_file, + do_lower_case=True, + do_basic_tokenize=True, + never_split=None, + unk_token="[UNK]", + sep_token="[SEP]", + pad_token="[PAD]", + cls_token="[CLS]", + mask_token="[MASK]", + tokenize_chinese_chars=True, + strip_accents=None, + **kwargs, + ): + if not os.path.isfile(vocab_file): + raise ValueError( + f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained" + " model use `tokenizer = RealmTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`" + ) + self.vocab = load_vocab(vocab_file) + self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()]) + self.do_basic_tokenize = do_basic_tokenize + if do_basic_tokenize: + self.basic_tokenizer = BasicTokenizer( + do_lower_case=do_lower_case, + never_split=never_split, + tokenize_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, + ) + self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=str(unk_token)) + super().__init__( + do_lower_case=do_lower_case, + do_basic_tokenize=do_basic_tokenize, + never_split=never_split, + unk_token=unk_token, + sep_token=sep_token, + pad_token=pad_token, + cls_token=cls_token, + mask_token=mask_token, + tokenize_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, + **kwargs, + ) + + @property + def do_lower_case(self): + return self.basic_tokenizer.do_lower_case + + @property + def vocab_size(self): + return len(self.vocab) + + def get_vocab(self): + return dict(self.vocab, **self.added_tokens_encoder) + + def _tokenize(self, text): + split_tokens = [] + if self.do_basic_tokenize: + for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens): + # If the token is part of the never_split set + if token in self.basic_tokenizer.never_split: + split_tokens.append(token) + else: + split_tokens += self.wordpiece_tokenizer.tokenize(token) + else: + split_tokens = self.wordpiece_tokenizer.tokenize(text) + return split_tokens + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.vocab.get(token, self.vocab.get(self.unk_token)) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.ids_to_tokens.get(index, self.unk_token) + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + out_string = " ".join(tokens).replace(" ##", "").strip() + return out_string + + def batch_encode_candidates(self, text, **kwargs): + r""" + Encode a batch of text or text pair. This method is similar to regular __call__ method but has the following + differences: + + 1. Handle additional num_candidate axis. (batch_size, num_candidates, text) + 2. Always pad the sequences to *max_length*. + 3. Must specify *max_length* in order to stack packs of candidates into a batch. + + - single sequence: `[CLS] X [SEP]` + - pair of sequences: `[CLS] A [SEP] B [SEP]` + + Args: + text (`List[List[str]]`): + The batch of sequences to be encoded. Each sequence must be in this format: (batch_size, + num_candidates, text). + text_pair (`List[List[str]]`, *optional*): + The batch of sequences to be encoded. Each sequence must be in this format: (batch_size, + num_candidates, text). + **kwargs: + Keyword arguments of the __call__ method. + + Returns: + [`BatchEncoding`]: Encoded text or text pair. + + Example: + + ```python + >>> from transformers import RealmTokenizer + + >>> # batch_size = 2, num_candidates = 2 + >>> text = [["Hello world!", "Nice to meet you!"], ["The cute cat.", "The adorable dog."]] + + >>> tokenizer = RealmTokenizer.from_pretrained("google/realm-cc-news-pretrained-encoder") + >>> tokenized_text = tokenizer.batch_encode_candidates(text, max_length=10, return_tensors="pt") + ```""" + + # Always using a fixed sequence length to encode in order to stack candidates into a batch. + kwargs["padding"] = PaddingStrategy.MAX_LENGTH + + batch_text = text + batch_text_pair = kwargs.pop("text_pair", None) + return_tensors = kwargs.pop("return_tensors", None) + + output_data = { + "input_ids": [], + "attention_mask": [], + "token_type_ids": [], + } + + for idx, candidate_text in enumerate(batch_text): + if batch_text_pair is not None: + candidate_text_pair = batch_text_pair[idx] + else: + candidate_text_pair = None + + encoded_candidates = super().__call__(candidate_text, candidate_text_pair, return_tensors=None, **kwargs) + + encoded_input_ids = encoded_candidates.get("input_ids") + encoded_attention_mask = encoded_candidates.get("attention_mask") + encoded_token_type_ids = encoded_candidates.get("token_type_ids") + + if encoded_input_ids is not None: + output_data["input_ids"].append(encoded_input_ids) + if encoded_attention_mask is not None: + output_data["attention_mask"].append(encoded_attention_mask) + if encoded_token_type_ids is not None: + output_data["token_type_ids"].append(encoded_token_type_ids) + + output_data = {key: item for key, item in output_data.items() if len(item) != 0} + + return BatchEncoding(output_data, tensor_type=return_tensors) + + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A REALM sequence has the following format: + + - single sequence: `[CLS] X [SEP]` + - pair of sequences: `[CLS] A [SEP] B [SEP]` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + cls = [self.cls_token_id] + sep = [self.sep_token_id] + return cls + token_ids_0 + sep + token_ids_1 + sep + + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + if token_ids_1 is not None: + return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1] + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. A REALM sequence + pair mask has the following format: + + ``` + 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 + | first sequence | second sequence | + ``` + + If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s). + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s). + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + index = 0 + if os.path.isdir(save_directory): + vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + else: + vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory + with open(vocab_file, "w", encoding="utf-8") as writer: + for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning( + f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive." + " Please check that the vocabulary is not corrupted!" + ) + index = token_index + writer.write(token + "\n") + index += 1 + return (vocab_file,) + + +class BasicTokenizer: + """ + Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.). + + Args: + do_lower_case (`bool`, *optional*, defaults to `True`): + Whether or not to lowercase the input when tokenizing. + never_split (`Iterable`, *optional*): + Collection of tokens which will never be split during tokenization. Only has an effect when + `do_basic_tokenize=True` + tokenize_chinese_chars (`bool`, *optional*, defaults to `True`): + Whether or not to tokenize Chinese characters. + + This should likely be deactivated for Japanese (see this + [issue](https://github.com/huggingface/transformers/issues/328)). + strip_accents (`bool`, *optional*): + Whether or not to strip all accents. If this option is not specified, then it will be determined by the + value for `lowercase` (as in the original BERT). + """ + + def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True, strip_accents=None): + if never_split is None: + never_split = [] + self.do_lower_case = do_lower_case + self.never_split = set(never_split) + self.tokenize_chinese_chars = tokenize_chinese_chars + self.strip_accents = strip_accents + + def tokenize(self, text, never_split=None): + """ + Basic Tokenization of a piece of text. Split on "white spaces" only, for sub-word tokenization, see + WordPieceTokenizer. + + Args: + never_split (`List[str]`, *optional*) + Kept for backward compatibility purposes. Now implemented directly at the base class level (see + [`PreTrainedTokenizer.tokenize`]) List of token not to split. + """ + # union() returns a new set by concatenating the two sets. + never_split = self.never_split.union(set(never_split)) if never_split else self.never_split + text = self._clean_text(text) + + # This was added on November 1st, 2018 for the multilingual and Chinese + # models. This is also applied to the English models now, but it doesn't + # matter since the English models were not trained on any Chinese data + # and generally don't have any Chinese data in them (there are Chinese + # characters in the vocabulary because Wikipedia does have some Chinese + # words in the English Wikipedia.). + if self.tokenize_chinese_chars: + text = self._tokenize_chinese_chars(text) + orig_tokens = whitespace_tokenize(text) + split_tokens = [] + for token in orig_tokens: + if token not in never_split: + if self.do_lower_case: + token = token.lower() + if self.strip_accents is not False: + token = self._run_strip_accents(token) + elif self.strip_accents: + token = self._run_strip_accents(token) + split_tokens.extend(self._run_split_on_punc(token, never_split)) + + output_tokens = whitespace_tokenize(" ".join(split_tokens)) + return output_tokens + + def _run_strip_accents(self, text): + """Strips accents from a piece of text.""" + text = unicodedata.normalize("NFD", text) + output = [] + for char in text: + cat = unicodedata.category(char) + if cat == "Mn": + continue + output.append(char) + return "".join(output) + + def _run_split_on_punc(self, text, never_split=None): + """Splits punctuation on a piece of text.""" + if never_split is not None and text in never_split: + return [text] + chars = list(text) + i = 0 + start_new_word = True + output = [] + while i < len(chars): + char = chars[i] + if _is_punctuation(char): + output.append([char]) + start_new_word = True + else: + if start_new_word: + output.append([]) + start_new_word = False + output[-1].append(char) + i += 1 + + return ["".join(x) for x in output] + + def _tokenize_chinese_chars(self, text): + """Adds whitespace around any CJK character.""" + output = [] + for char in text: + cp = ord(char) + if self._is_chinese_char(cp): + output.append(" ") + output.append(char) + output.append(" ") + else: + output.append(char) + return "".join(output) + + def _is_chinese_char(self, cp): + """Checks whether CP is the codepoint of a CJK character.""" + # This defines a "chinese character" as anything in the CJK Unicode block: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + # + # Note that the CJK Unicode block is NOT all Japanese and Korean characters, + # despite its name. The modern Korean Hangul alphabet is a different block, + # as is Japanese Hiragana and Katakana. Those alphabets are used to write + # space-separated words, so they are not treated specially and handled + # like the all of the other languages. + if ( + (cp >= 0x4E00 and cp <= 0x9FFF) + or (cp >= 0x3400 and cp <= 0x4DBF) # + or (cp >= 0x20000 and cp <= 0x2A6DF) # + or (cp >= 0x2A700 and cp <= 0x2B73F) # + or (cp >= 0x2B740 and cp <= 0x2B81F) # + or (cp >= 0x2B820 and cp <= 0x2CEAF) # + or (cp >= 0xF900 and cp <= 0xFAFF) + or (cp >= 0x2F800 and cp <= 0x2FA1F) # + ): # + return True + + return False + + def _clean_text(self, text): + """Performs invalid character removal and whitespace cleanup on text.""" + output = [] + for char in text: + cp = ord(char) + if cp == 0 or cp == 0xFFFD or _is_control(char): + continue + if _is_whitespace(char): + output.append(" ") + else: + output.append(char) + return "".join(output) + + +class WordpieceTokenizer: + """Runs WordPiece tokenization.""" + + def __init__(self, vocab, unk_token, max_input_chars_per_word=100): + self.vocab = vocab + self.unk_token = unk_token + self.max_input_chars_per_word = max_input_chars_per_word + + def tokenize(self, text): + """ + Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform + tokenization using the given vocabulary. + + For example, `input = "unaffable"` wil return as output `["un", "##aff", "##able"]`. + + Args: + text: A single token or whitespace separated tokens. This should have + already been passed through *BasicTokenizer*. + + Returns: + A list of wordpiece tokens. + """ + + output_tokens = [] + for token in whitespace_tokenize(text): + chars = list(token) + if len(chars) > self.max_input_chars_per_word: + output_tokens.append(self.unk_token) + continue + + is_bad = False + start = 0 + sub_tokens = [] + while start < len(chars): + end = len(chars) + cur_substr = None + while start < end: + substr = "".join(chars[start:end]) + if start > 0: + substr = "##" + substr + if substr in self.vocab: + cur_substr = substr + break + end -= 1 + if cur_substr is None: + is_bad = True + break + sub_tokens.append(cur_substr) + start = end + + if is_bad: + output_tokens.append(self.unk_token) + else: + output_tokens.extend(sub_tokens) + return output_tokens + + +__all__ = ["RealmTokenizer"] diff --git a/docs/transformers/build/lib/transformers/models/deprecated/realm/tokenization_realm_fast.py b/docs/transformers/build/lib/transformers/models/deprecated/realm/tokenization_realm_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..7c173227befd31e7f6e1fed26877dceb13ad4041 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/realm/tokenization_realm_fast.py @@ -0,0 +1,252 @@ +# coding=utf-8 +# Copyright 2022 The REALM authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Tokenization classes for REALM.""" + +import json +from typing import List, Optional, Tuple + +from tokenizers import normalizers + +from ....tokenization_utils_base import BatchEncoding +from ....tokenization_utils_fast import PreTrainedTokenizerFast +from ....utils import PaddingStrategy, logging +from .tokenization_realm import RealmTokenizer + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.json"} + + +class RealmTokenizerFast(PreTrainedTokenizerFast): + r""" + Construct a "fast" REALM tokenizer (backed by HuggingFace's *tokenizers* library). Based on WordPiece. + + [`RealmTokenizerFast`] is identical to [`BertTokenizerFast`] and runs end-to-end tokenization: punctuation + splitting and wordpiece. + + This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should + refer to this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + File containing the vocabulary. + do_lower_case (`bool`, *optional*, defaults to `True`): + Whether or not to lowercase the input when tokenizing. + unk_token (`str`, *optional*, defaults to `"[UNK]"`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + sep_token (`str`, *optional*, defaults to `"[SEP]"`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + pad_token (`str`, *optional*, defaults to `"[PAD]"`): + The token used for padding, for example when batching sequences of different lengths. + cls_token (`str`, *optional*, defaults to `"[CLS]"`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + mask_token (`str`, *optional*, defaults to `"[MASK]"`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + clean_text (`bool`, *optional*, defaults to `True`): + Whether or not to clean the text before tokenization by removing any control characters and replacing all + whitespaces by the classic one. + tokenize_chinese_chars (`bool`, *optional*, defaults to `True`): + Whether or not to tokenize Chinese characters. This should likely be deactivated for Japanese (see [this + issue](https://github.com/huggingface/transformers/issues/328)). + strip_accents (`bool`, *optional*): + Whether or not to strip all accents. If this option is not specified, then it will be determined by the + value for `lowercase` (as in the original BERT). + wordpieces_prefix (`str`, *optional*, defaults to `"##"`): + The prefix for subwords. + """ + + vocab_files_names = VOCAB_FILES_NAMES + slow_tokenizer_class = RealmTokenizer + + def __init__( + self, + vocab_file=None, + tokenizer_file=None, + do_lower_case=True, + unk_token="[UNK]", + sep_token="[SEP]", + pad_token="[PAD]", + cls_token="[CLS]", + mask_token="[MASK]", + tokenize_chinese_chars=True, + strip_accents=None, + **kwargs, + ): + super().__init__( + vocab_file, + tokenizer_file=tokenizer_file, + do_lower_case=do_lower_case, + unk_token=unk_token, + sep_token=sep_token, + pad_token=pad_token, + cls_token=cls_token, + mask_token=mask_token, + tokenize_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, + **kwargs, + ) + + normalizer_state = json.loads(self.backend_tokenizer.normalizer.__getstate__()) + if ( + normalizer_state.get("lowercase", do_lower_case) != do_lower_case + or normalizer_state.get("strip_accents", strip_accents) != strip_accents + or normalizer_state.get("handle_chinese_chars", tokenize_chinese_chars) != tokenize_chinese_chars + ): + normalizer_class = getattr(normalizers, normalizer_state.pop("type")) + normalizer_state["lowercase"] = do_lower_case + normalizer_state["strip_accents"] = strip_accents + normalizer_state["handle_chinese_chars"] = tokenize_chinese_chars + self.backend_tokenizer.normalizer = normalizer_class(**normalizer_state) + + self.do_lower_case = do_lower_case + + def batch_encode_candidates(self, text, **kwargs): + r""" + Encode a batch of text or text pair. This method is similar to regular __call__ method but has the following + differences: + + 1. Handle additional num_candidate axis. (batch_size, num_candidates, text) + 2. Always pad the sequences to *max_length*. + 3. Must specify *max_length* in order to stack packs of candidates into a batch. + + - single sequence: `[CLS] X [SEP]` + - pair of sequences: `[CLS] A [SEP] B [SEP]` + + Args: + text (`List[List[str]]`): + The batch of sequences to be encoded. Each sequence must be in this format: (batch_size, + num_candidates, text). + text_pair (`List[List[str]]`, *optional*): + The batch of sequences to be encoded. Each sequence must be in this format: (batch_size, + num_candidates, text). + **kwargs: + Keyword arguments of the __call__ method. + + Returns: + [`BatchEncoding`]: Encoded text or text pair. + + Example: + + ```python + >>> from transformers import RealmTokenizerFast + + >>> # batch_size = 2, num_candidates = 2 + >>> text = [["Hello world!", "Nice to meet you!"], ["The cute cat.", "The adorable dog."]] + + >>> tokenizer = RealmTokenizerFast.from_pretrained("google/realm-cc-news-pretrained-encoder") + >>> tokenized_text = tokenizer.batch_encode_candidates(text, max_length=10, return_tensors="pt") + ```""" + + # Always using a fixed sequence length to encode in order to stack candidates into a batch. + kwargs["padding"] = PaddingStrategy.MAX_LENGTH + + batch_text = text + batch_text_pair = kwargs.pop("text_pair", None) + return_tensors = kwargs.pop("return_tensors", None) + + output_data = { + "input_ids": [], + "attention_mask": [], + "token_type_ids": [], + } + + for idx, candidate_text in enumerate(batch_text): + if batch_text_pair is not None: + candidate_text_pair = batch_text_pair[idx] + else: + candidate_text_pair = None + + encoded_candidates = super().__call__(candidate_text, candidate_text_pair, return_tensors=None, **kwargs) + + encoded_input_ids = encoded_candidates.get("input_ids") + encoded_attention_mask = encoded_candidates.get("attention_mask") + encoded_token_type_ids = encoded_candidates.get("token_type_ids") + + if encoded_input_ids is not None: + output_data["input_ids"].append(encoded_input_ids) + if encoded_attention_mask is not None: + output_data["attention_mask"].append(encoded_attention_mask) + if encoded_token_type_ids is not None: + output_data["token_type_ids"].append(encoded_token_type_ids) + + output_data = {key: item for key, item in output_data.items() if len(item) != 0} + + return BatchEncoding(output_data, tensor_type=return_tensors) + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A REALM sequence has the following format: + + - single sequence: `[CLS] X [SEP]` + - pair of sequences: `[CLS] A [SEP] B [SEP]` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + + if token_ids_1 is not None: + output += token_ids_1 + [self.sep_token_id] + + return output + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. A REALM sequence + pair mask has the following format: + + ``` + 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 + | first sequence | second sequence | + ``` + + If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s). + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s). + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + files = self._tokenizer.model.save(save_directory, name=filename_prefix) + return tuple(files) + + +__all__ = ["RealmTokenizerFast"] diff --git a/docs/transformers/build/lib/transformers/models/deprecated/retribert/__init__.py b/docs/transformers/build/lib/transformers/models/deprecated/retribert/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a875576607430c041abab01eccaf468a6cc9272e --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/retribert/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ....utils import _LazyModule +from ....utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_retribert import * + from .modeling_retribert import * + from .tokenization_retribert import * + from .tokenization_retribert_fast import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/build/lib/transformers/models/deprecated/retribert/configuration_retribert.py b/docs/transformers/build/lib/transformers/models/deprecated/retribert/configuration_retribert.py new file mode 100644 index 0000000000000000000000000000000000000000..80d755a16961450cae783d834385e5e6873dc24e --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/retribert/configuration_retribert.py @@ -0,0 +1,108 @@ +# coding=utf-8 +# Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""RetriBERT model configuration""" + +from ....configuration_utils import PretrainedConfig +from ....utils import logging + + +logger = logging.get_logger(__name__) + + +class RetriBertConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`RetriBertModel`]. It is used to instantiate a + RetriBertModel model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the RetriBERT + [yjernite/retribert-base-uncased](https://huggingface.co/yjernite/retribert-base-uncased) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 30522): + Vocabulary size of the RetriBERT model. Defines the number of different tokens that can be represented by + the `inputs_ids` passed when calling [`RetriBertModel`] + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + type_vocab_size (`int`, *optional*, defaults to 2): + The vocabulary size of the *token_type_ids* passed into [`BertModel`]. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + share_encoders (`bool`, *optional*, defaults to `True`): + Whether or not to use the same Bert-type encoder for the queries and document + projection_dim (`int`, *optional*, defaults to 128): + Final dimension of the query and document representation after projection + """ + + model_type = "retribert" + + def __init__( + self, + vocab_size=30522, + hidden_size=768, + num_hidden_layers=8, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + layer_norm_eps=1e-12, + share_encoders=True, + projection_dim=128, + pad_token_id=0, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.share_encoders = share_encoders + self.projection_dim = projection_dim + + +__all__ = ["RetriBertConfig"] diff --git a/docs/transformers/build/lib/transformers/models/deprecated/retribert/modeling_retribert.py b/docs/transformers/build/lib/transformers/models/deprecated/retribert/modeling_retribert.py new file mode 100644 index 0000000000000000000000000000000000000000..bcae1c0239e6f9b214e6227bf279ad9fc7b8381a --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/retribert/modeling_retribert.py @@ -0,0 +1,217 @@ +# coding=utf-8 +# Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +RetriBERT model +""" + +import math +from typing import Optional + +import torch +import torch.utils.checkpoint as checkpoint +from torch import nn + +from ....modeling_utils import PreTrainedModel +from ....utils import add_start_docstrings, logging +from ...bert.modeling_bert import BertModel +from .configuration_retribert import RetriBertConfig + + +logger = logging.get_logger(__name__) + + +# INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL # +class RetriBertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = RetriBertConfig + load_tf_weights = None + base_model_prefix = "retribert" + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +RETRIBERT_START_DOCSTRING = r""" + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`RetriBertConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + """Bert Based model to embed queries or document for document retrieval.""", + RETRIBERT_START_DOCSTRING, +) +class RetriBertModel(RetriBertPreTrainedModel): + def __init__(self, config: RetriBertConfig) -> None: + super().__init__(config) + self.projection_dim = config.projection_dim + + self.bert_query = BertModel(config) + self.bert_doc = None if config.share_encoders else BertModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.project_query = nn.Linear(config.hidden_size, config.projection_dim, bias=False) + self.project_doc = nn.Linear(config.hidden_size, config.projection_dim, bias=False) + + self.ce_loss = nn.CrossEntropyLoss(reduction="mean") + + # Initialize weights and apply final processing + self.post_init() + + def embed_sentences_checkpointed( + self, + input_ids, + attention_mask, + sent_encoder, + checkpoint_batch_size=-1, + ): + # reproduces BERT forward pass with checkpointing + if checkpoint_batch_size < 0 or input_ids.shape[0] < checkpoint_batch_size: + return sent_encoder(input_ids, attention_mask=attention_mask)[1] + else: + # prepare implicit variables + device = input_ids.device + input_shape = input_ids.size() + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + head_mask = [None] * sent_encoder.config.num_hidden_layers + extended_attention_mask: torch.Tensor = sent_encoder.get_extended_attention_mask( + attention_mask, input_shape + ) + + # define function for checkpointing + def partial_encode(*inputs): + encoder_outputs = sent_encoder.encoder( + inputs[0], + attention_mask=inputs[1], + head_mask=head_mask, + ) + sequence_output = encoder_outputs[0] + pooled_output = sent_encoder.pooler(sequence_output) + return pooled_output + + # run embedding layer on everything at once + embedding_output = sent_encoder.embeddings( + input_ids=input_ids, position_ids=None, token_type_ids=token_type_ids, inputs_embeds=None + ) + # run encoding and pooling on one mini-batch at a time + pooled_output_list = [] + for b in range(math.ceil(input_ids.shape[0] / checkpoint_batch_size)): + b_embedding_output = embedding_output[b * checkpoint_batch_size : (b + 1) * checkpoint_batch_size] + b_attention_mask = extended_attention_mask[b * checkpoint_batch_size : (b + 1) * checkpoint_batch_size] + pooled_output = checkpoint.checkpoint(partial_encode, b_embedding_output, b_attention_mask) + pooled_output_list.append(pooled_output) + return torch.cat(pooled_output_list, dim=0) + + def embed_questions( + self, + input_ids, + attention_mask=None, + checkpoint_batch_size=-1, + ): + q_reps = self.embed_sentences_checkpointed( + input_ids, + attention_mask, + self.bert_query, + checkpoint_batch_size, + ) + return self.project_query(q_reps) + + def embed_answers( + self, + input_ids, + attention_mask=None, + checkpoint_batch_size=-1, + ): + a_reps = self.embed_sentences_checkpointed( + input_ids, + attention_mask, + self.bert_query if self.bert_doc is None else self.bert_doc, + checkpoint_batch_size, + ) + return self.project_doc(a_reps) + + def forward( + self, + input_ids_query: torch.LongTensor, + attention_mask_query: Optional[torch.FloatTensor], + input_ids_doc: torch.LongTensor, + attention_mask_doc: Optional[torch.FloatTensor], + checkpoint_batch_size: int = -1, + ) -> torch.FloatTensor: + r""" + Args: + input_ids_query (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary for the queries in a batch. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask_query (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + input_ids_doc (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary for the documents in a batch. + attention_mask_doc (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on documents padding token indices. + checkpoint_batch_size (`int`, *optional*, defaults to `-1`): + If greater than 0, uses gradient checkpointing to only compute sequence representation on + `checkpoint_batch_size` examples at a time on the GPU. All query representations are still compared to + all document representations in the batch. + + Return: + `torch.FloatTensor``: The bidirectional cross-entropy loss obtained while trying to match each query to its + corresponding document and each document to its corresponding query in the batch + """ + device = input_ids_query.device + q_reps = self.embed_questions(input_ids_query, attention_mask_query, checkpoint_batch_size) + a_reps = self.embed_answers(input_ids_doc, attention_mask_doc, checkpoint_batch_size) + compare_scores = torch.mm(q_reps, a_reps.t()) + loss_qa = self.ce_loss(compare_scores, torch.arange(compare_scores.shape[1]).to(device)) + loss_aq = self.ce_loss(compare_scores.t(), torch.arange(compare_scores.shape[0]).to(device)) + loss = (loss_qa + loss_aq) / 2 + return loss + + +__all__ = ["RetriBertModel", "RetriBertPreTrainedModel"] diff --git a/docs/transformers/build/lib/transformers/models/deprecated/retribert/tokenization_retribert.py b/docs/transformers/build/lib/transformers/models/deprecated/retribert/tokenization_retribert.py new file mode 100644 index 0000000000000000000000000000000000000000..35a1874aa0e3c8c39e3d96c20a31d3b0595235f1 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/retribert/tokenization_retribert.py @@ -0,0 +1,504 @@ +# coding=utf-8 +# Copyright 2018 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for RetriBERT.""" + +import collections +import os +import unicodedata +from typing import List, Optional, Tuple + +from ....tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace +from ....utils import logging + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} + + +def load_vocab(vocab_file): + """Loads a vocabulary file into a dictionary.""" + vocab = collections.OrderedDict() + with open(vocab_file, "r", encoding="utf-8") as reader: + tokens = reader.readlines() + for index, token in enumerate(tokens): + token = token.rstrip("\n") + vocab[token] = index + return vocab + + +def whitespace_tokenize(text): + """Runs basic whitespace cleaning and splitting on a piece of text.""" + text = text.strip() + if not text: + return [] + tokens = text.split() + return tokens + + +class RetriBertTokenizer(PreTrainedTokenizer): + r""" + Constructs a RetriBERT tokenizer. + + [`RetriBertTokenizer`] is identical to [`BertTokenizer`] and runs end-to-end tokenization: punctuation splitting + and wordpiece. + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer + to: this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + File containing the vocabulary. + do_lower_case (`bool`, *optional*, defaults to `True`): + Whether or not to lowercase the input when tokenizing. + do_basic_tokenize (`bool`, *optional*, defaults to `True`): + Whether or not to do basic tokenization before WordPiece. + never_split (`Iterable`, *optional*): + Collection of tokens which will never be split during tokenization. Only has an effect when + `do_basic_tokenize=True` + unk_token (`str`, *optional*, defaults to `"[UNK]"`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + sep_token (`str`, *optional*, defaults to `"[SEP]"`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + pad_token (`str`, *optional*, defaults to `"[PAD]"`): + The token used for padding, for example when batching sequences of different lengths. + cls_token (`str`, *optional*, defaults to `"[CLS]"`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + mask_token (`str`, *optional*, defaults to `"[MASK]"`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + tokenize_chinese_chars (`bool`, *optional*, defaults to `True`): + Whether or not to tokenize Chinese characters. This should likely be deactivated for Japanese (see this + [issue](https://github.com/huggingface/transformers/issues/328)). + strip_accents (`bool`, *optional*): + Whether or not to strip all accents. If this option is not specified, then it will be determined by the + value for `lowercase` (as in the original BERT). + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + do_lower_case=True, + do_basic_tokenize=True, + never_split=None, + unk_token="[UNK]", + sep_token="[SEP]", + pad_token="[PAD]", + cls_token="[CLS]", + mask_token="[MASK]", + tokenize_chinese_chars=True, + strip_accents=None, + **kwargs, + ): + if not os.path.isfile(vocab_file): + raise ValueError( + f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained" + " model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`" + ) + self.vocab = load_vocab(vocab_file) + self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()]) + self.do_basic_tokenize = do_basic_tokenize + if do_basic_tokenize: + self.basic_tokenizer = BasicTokenizer( + do_lower_case=do_lower_case, + never_split=never_split, + tokenize_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, + ) + + self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=str(unk_token)) + + super().__init__( + do_lower_case=do_lower_case, + do_basic_tokenize=do_basic_tokenize, + never_split=never_split, + unk_token=unk_token, + sep_token=sep_token, + pad_token=pad_token, + cls_token=cls_token, + mask_token=mask_token, + tokenize_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, + **kwargs, + ) + + @property + def do_lower_case(self): + return self.basic_tokenizer.do_lower_case + + @property + def vocab_size(self): + return len(self.vocab) + + def get_vocab(self): + return dict(self.vocab, **self.added_tokens_encoder) + + def _tokenize(self, text, split_special_tokens=False): + split_tokens = [] + if self.do_basic_tokenize: + for token in self.basic_tokenizer.tokenize( + text, never_split=self.all_special_tokens if not split_special_tokens else None + ): + # If the token is part of the never_split set + if token in self.basic_tokenizer.never_split: + split_tokens.append(token) + else: + split_tokens += self.wordpiece_tokenizer.tokenize(token) + else: + split_tokens = self.wordpiece_tokenizer.tokenize(text) + return split_tokens + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.vocab.get(token, self.vocab.get(self.unk_token)) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.ids_to_tokens.get(index, self.unk_token) + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + out_string = " ".join(tokens).replace(" ##", "").strip() + return out_string + + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A BERT sequence has the following format: + + - single sequence: `[CLS] X [SEP]` + - pair of sequences: `[CLS] A [SEP] B [SEP]` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + cls = [self.cls_token_id] + sep = [self.sep_token_id] + return cls + token_ids_0 + sep + token_ids_1 + sep + + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + if token_ids_1 is not None: + return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1] + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence + pair mask has the following format: + + ``` + 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 + | first sequence | second sequence | + ``` + + If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s). + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s). + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + index = 0 + if os.path.isdir(save_directory): + vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + else: + vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory + with open(vocab_file, "w", encoding="utf-8") as writer: + for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning( + f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive." + " Please check that the vocabulary is not corrupted!" + ) + index = token_index + writer.write(token + "\n") + index += 1 + return (vocab_file,) + + +class BasicTokenizer: + """ + Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.). + + Args: + do_lower_case (`bool`, *optional*, defaults to `True`): + Whether or not to lowercase the input when tokenizing. + never_split (`Iterable`, *optional*): + Collection of tokens which will never be split during tokenization. Only has an effect when + `do_basic_tokenize=True` + tokenize_chinese_chars (`bool`, *optional*, defaults to `True`): + Whether or not to tokenize Chinese characters. + + This should likely be deactivated for Japanese (see this + [issue](https://github.com/huggingface/transformers/issues/328)). + strip_accents (`bool`, *optional*): + Whether or not to strip all accents. If this option is not specified, then it will be determined by the + value for `lowercase` (as in the original BERT). + do_split_on_punc (`bool`, *optional*, defaults to `True`): + In some instances we want to skip the basic punctuation splitting so that later tokenization can capture + the full context of the words, such as contractions. + """ + + def __init__( + self, + do_lower_case=True, + never_split=None, + tokenize_chinese_chars=True, + strip_accents=None, + do_split_on_punc=True, + ): + if never_split is None: + never_split = [] + self.do_lower_case = do_lower_case + self.never_split = set(never_split) + self.tokenize_chinese_chars = tokenize_chinese_chars + self.strip_accents = strip_accents + self.do_split_on_punc = do_split_on_punc + + def tokenize(self, text, never_split=None): + """ + Basic Tokenization of a piece of text. For sub-word tokenization, see WordPieceTokenizer. + + Args: + never_split (`List[str]`, *optional*) + Kept for backward compatibility purposes. Now implemented directly at the base class level (see + [`PreTrainedTokenizer.tokenize`]) List of token not to split. + """ + # union() returns a new set by concatenating the two sets. + never_split = self.never_split.union(set(never_split)) if never_split else self.never_split + text = self._clean_text(text) + + # This was added on November 1st, 2018 for the multilingual and Chinese + # models. This is also applied to the English models now, but it doesn't + # matter since the English models were not trained on any Chinese data + # and generally don't have any Chinese data in them (there are Chinese + # characters in the vocabulary because Wikipedia does have some Chinese + # words in the English Wikipedia.). + if self.tokenize_chinese_chars: + text = self._tokenize_chinese_chars(text) + # prevents treating the same character with different unicode codepoints as different characters + unicode_normalized_text = unicodedata.normalize("NFC", text) + orig_tokens = whitespace_tokenize(unicode_normalized_text) + split_tokens = [] + for token in orig_tokens: + if token not in never_split: + if self.do_lower_case: + token = token.lower() + if self.strip_accents is not False: + token = self._run_strip_accents(token) + elif self.strip_accents: + token = self._run_strip_accents(token) + split_tokens.extend(self._run_split_on_punc(token, never_split)) + + output_tokens = whitespace_tokenize(" ".join(split_tokens)) + return output_tokens + + def _run_strip_accents(self, text): + """Strips accents from a piece of text.""" + text = unicodedata.normalize("NFD", text) + output = [] + for char in text: + cat = unicodedata.category(char) + if cat == "Mn": + continue + output.append(char) + return "".join(output) + + def _run_split_on_punc(self, text, never_split=None): + """Splits punctuation on a piece of text.""" + if not self.do_split_on_punc or (never_split is not None and text in never_split): + return [text] + chars = list(text) + i = 0 + start_new_word = True + output = [] + while i < len(chars): + char = chars[i] + if _is_punctuation(char): + output.append([char]) + start_new_word = True + else: + if start_new_word: + output.append([]) + start_new_word = False + output[-1].append(char) + i += 1 + + return ["".join(x) for x in output] + + def _tokenize_chinese_chars(self, text): + """Adds whitespace around any CJK character.""" + output = [] + for char in text: + cp = ord(char) + if self._is_chinese_char(cp): + output.append(" ") + output.append(char) + output.append(" ") + else: + output.append(char) + return "".join(output) + + def _is_chinese_char(self, cp): + """Checks whether CP is the codepoint of a CJK character.""" + # This defines a "chinese character" as anything in the CJK Unicode block: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + # + # Note that the CJK Unicode block is NOT all Japanese and Korean characters, + # despite its name. The modern Korean Hangul alphabet is a different block, + # as is Japanese Hiragana and Katakana. Those alphabets are used to write + # space-separated words, so they are not treated specially and handled + # like the all of the other languages. + if ( + (cp >= 0x4E00 and cp <= 0x9FFF) + or (cp >= 0x3400 and cp <= 0x4DBF) # + or (cp >= 0x20000 and cp <= 0x2A6DF) # + or (cp >= 0x2A700 and cp <= 0x2B73F) # + or (cp >= 0x2B740 and cp <= 0x2B81F) # + or (cp >= 0x2B820 and cp <= 0x2CEAF) # + or (cp >= 0xF900 and cp <= 0xFAFF) + or (cp >= 0x2F800 and cp <= 0x2FA1F) # + ): # + return True + + return False + + def _clean_text(self, text): + """Performs invalid character removal and whitespace cleanup on text.""" + output = [] + for char in text: + cp = ord(char) + if cp == 0 or cp == 0xFFFD or _is_control(char): + continue + if _is_whitespace(char): + output.append(" ") + else: + output.append(char) + return "".join(output) + + +class WordpieceTokenizer: + """Runs WordPiece tokenization.""" + + def __init__(self, vocab, unk_token, max_input_chars_per_word=100): + self.vocab = vocab + self.unk_token = unk_token + self.max_input_chars_per_word = max_input_chars_per_word + + def tokenize(self, text): + """ + Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform + tokenization using the given vocabulary. + + For example, `input = "unaffable"` wil return as output `["un", "##aff", "##able"]`. + + Args: + text: A single token or whitespace separated tokens. This should have + already been passed through *BasicTokenizer*. + + Returns: + A list of wordpiece tokens. + """ + + output_tokens = [] + for token in whitespace_tokenize(text): + chars = list(token) + if len(chars) > self.max_input_chars_per_word: + output_tokens.append(self.unk_token) + continue + + is_bad = False + start = 0 + sub_tokens = [] + while start < len(chars): + end = len(chars) + cur_substr = None + while start < end: + substr = "".join(chars[start:end]) + if start > 0: + substr = "##" + substr + if substr in self.vocab: + cur_substr = substr + break + end -= 1 + if cur_substr is None: + is_bad = True + break + sub_tokens.append(cur_substr) + start = end + + if is_bad: + output_tokens.append(self.unk_token) + else: + output_tokens.extend(sub_tokens) + return output_tokens + + +__all__ = ["RetriBertTokenizer"] diff --git a/docs/transformers/build/lib/transformers/models/deprecated/retribert/tokenization_retribert_fast.py b/docs/transformers/build/lib/transformers/models/deprecated/retribert/tokenization_retribert_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..cc51d0e2a1de9280e6b8635804eaf46f7859e5b9 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/retribert/tokenization_retribert_fast.py @@ -0,0 +1,179 @@ +# coding=utf-8 +# Copyright 2018 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for RetriBERT.""" + +import json +from typing import List, Optional, Tuple + +from tokenizers import normalizers + +from ....tokenization_utils_fast import PreTrainedTokenizerFast +from ....utils import logging +from .tokenization_retribert import RetriBertTokenizer + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.json"} + + +class RetriBertTokenizerFast(PreTrainedTokenizerFast): + r""" + Construct a "fast" RetriBERT tokenizer (backed by HuggingFace's *tokenizers* library). + + [`RetriBertTokenizerFast`] is identical to [`BertTokenizerFast`] and runs end-to-end tokenization: punctuation + splitting and wordpiece. + + This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should + refer to this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + File containing the vocabulary. + do_lower_case (`bool`, *optional*, defaults to `True`): + Whether or not to lowercase the input when tokenizing. + unk_token (`str`, *optional*, defaults to `"[UNK]"`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + sep_token (`str`, *optional*, defaults to `"[SEP]"`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + pad_token (`str`, *optional*, defaults to `"[PAD]"`): + The token used for padding, for example when batching sequences of different lengths. + cls_token (`str`, *optional*, defaults to `"[CLS]"`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + mask_token (`str`, *optional*, defaults to `"[MASK]"`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + clean_text (`bool`, *optional*, defaults to `True`): + Whether or not to clean the text before tokenization by removing any control characters and replacing all + whitespaces by the classic one. + tokenize_chinese_chars (`bool`, *optional*, defaults to `True`): + Whether or not to tokenize Chinese characters. This should likely be deactivated for Japanese (see [this + issue](https://github.com/huggingface/transformers/issues/328)). + strip_accents (`bool`, *optional*): + Whether or not to strip all accents. If this option is not specified, then it will be determined by the + value for `lowercase` (as in the original BERT). + wordpieces_prefix (`str`, *optional*, defaults to `"##"`): + The prefix for subwords. + """ + + vocab_files_names = VOCAB_FILES_NAMES + slow_tokenizer_class = RetriBertTokenizer + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file=None, + tokenizer_file=None, + do_lower_case=True, + unk_token="[UNK]", + sep_token="[SEP]", + pad_token="[PAD]", + cls_token="[CLS]", + mask_token="[MASK]", + tokenize_chinese_chars=True, + strip_accents=None, + **kwargs, + ): + super().__init__( + vocab_file, + tokenizer_file=tokenizer_file, + do_lower_case=do_lower_case, + unk_token=unk_token, + sep_token=sep_token, + pad_token=pad_token, + cls_token=cls_token, + mask_token=mask_token, + tokenize_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, + **kwargs, + ) + + normalizer_state = json.loads(self.backend_tokenizer.normalizer.__getstate__()) + if ( + normalizer_state.get("lowercase", do_lower_case) != do_lower_case + or normalizer_state.get("strip_accents", strip_accents) != strip_accents + or normalizer_state.get("handle_chinese_chars", tokenize_chinese_chars) != tokenize_chinese_chars + ): + normalizer_class = getattr(normalizers, normalizer_state.pop("type")) + normalizer_state["lowercase"] = do_lower_case + normalizer_state["strip_accents"] = strip_accents + normalizer_state["handle_chinese_chars"] = tokenize_chinese_chars + self.backend_tokenizer.normalizer = normalizer_class(**normalizer_state) + + self.do_lower_case = do_lower_case + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A BERT sequence has the following format: + + - single sequence: `[CLS] X [SEP]` + - pair of sequences: `[CLS] A [SEP] B [SEP]` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + + if token_ids_1 is not None: + output += token_ids_1 + [self.sep_token_id] + + return output + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence + pair mask has the following format: + + ``` + 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 + | first sequence | second sequence | + ``` + + If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s). + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s). + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + files = self._tokenizer.model.save(save_directory, name=filename_prefix) + return tuple(files) + + +__all__ = ["RetriBertTokenizerFast"] diff --git a/docs/transformers/build/lib/transformers/models/deprecated/speech_to_text_2/__init__.py b/docs/transformers/build/lib/transformers/models/deprecated/speech_to_text_2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..78c549b6e294e1ca97a718ef601472d4d400e12a --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/speech_to_text_2/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ....utils import _LazyModule +from ....utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_speech_to_text_2 import * + from .modeling_speech_to_text_2 import * + from .processing_speech_to_text_2 import * + from .tokenization_speech_to_text_2 import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/build/lib/transformers/models/deprecated/speech_to_text_2/configuration_speech_to_text_2.py b/docs/transformers/build/lib/transformers/models/deprecated/speech_to_text_2/configuration_speech_to_text_2.py new file mode 100644 index 0000000000000000000000000000000000000000..2afd79feb28dd99c6ca94d2483b8f88b8cdc0a0e --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/speech_to_text_2/configuration_speech_to_text_2.py @@ -0,0 +1,134 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Speech2Text model configuration""" + +from ....configuration_utils import PretrainedConfig +from ....utils import logging + + +logger = logging.get_logger(__name__) + + +class Speech2Text2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Speech2Text2ForCausalLM`]. It is used to + instantiate an Speech2Text2 model according to the specified arguments, defining the model architecture. + Instantiating a configuration with the defaults will yield a similar configuration to that of the Speech2Text2 + [facebook/s2t-wav2vec2-large-en-de](https://huggingface.co/facebook/s2t-wav2vec2-large-en-de) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50265): + Vocabulary size of the Speech2Text model. Defines the number of different tokens that can be represented by + the `inputs_ids` passed when calling [`Speech2TextModel`] + d_model (`int`, *optional*, defaults to 1024): + Dimensionality of the layers and the pooler layer. + decoder_layers (`int`, *optional*, defaults to 12): + Number of decoder layers. + decoder_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer decoder. + decoder_ffn_dim (`int`, *optional*, defaults to 4096): + Dimensionality of the "intermediate" (often named feed-forward) layer in decoder. + activation_function (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the pooler. If string, `"gelu"`, `"relu"`, + `"silu"` and `"gelu_new"` are supported. + dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, and pooler. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + activation_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for activations inside the fully connected layer. + init_std (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + https://arxiv.org/abs/1909.11556>`__ for more details. + decoder_layerdrop (`float`, *optional*, defaults to 0.0): + The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) + for more details. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + max_target_positions (`int`, *optional*, defaults to 1024): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + + Example: + + ```python + >>> from transformers import Speech2Text2Config, Speech2Text2ForCausalLM + + >>> # Initializing a Speech2Text2 s2t_transformer_s style configuration + >>> configuration = Speech2Text2Config() + + >>> # Initializing a model (with random weights) from the s2t_transformer_s style configuration + >>> model = Speech2Text2ForCausalLM(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "speech_to_text_2" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = {"num_attention_heads": "decoder_attention_heads", "hidden_size": "d_model"} + + def __init__( + self, + vocab_size=10000, + decoder_layers=6, + decoder_ffn_dim=2048, + decoder_attention_heads=4, + decoder_layerdrop=0.0, + use_cache=True, + activation_function="relu", + d_model=256, + dropout=0.1, + attention_dropout=0.0, + activation_dropout=0.0, + init_std=0.02, + decoder_start_token_id=2, + scale_embedding=True, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + max_target_positions=1024, + **kwargs, + ): + self.vocab_size = vocab_size + self.d_model = d_model + self.decoder_ffn_dim = decoder_ffn_dim + self.decoder_layers = decoder_layers + self.decoder_attention_heads = decoder_attention_heads + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.activation_function = activation_function + self.init_std = init_std + self.decoder_layerdrop = decoder_layerdrop + self.use_cache = use_cache + self.num_hidden_layers = decoder_layers + self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True + self.max_target_positions = max_target_positions + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + decoder_start_token_id=decoder_start_token_id, + **kwargs, + ) + + +__all__ = ["Speech2Text2Config"] diff --git a/docs/transformers/build/lib/transformers/models/deprecated/speech_to_text_2/modeling_speech_to_text_2.py b/docs/transformers/build/lib/transformers/models/deprecated/speech_to_text_2/modeling_speech_to_text_2.py new file mode 100644 index 0000000000000000000000000000000000000000..6f1dd18d97ff61015759b99587b846f4825c427c --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/speech_to_text_2/modeling_speech_to_text_2.py @@ -0,0 +1,930 @@ +# coding=utf-8 +# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Speech2Text2 model.""" + +import copy +import math +from typing import Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss + +from ....activations import ACT2FN +from ....modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask +from ....modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions +from ....modeling_utils import PreTrainedModel +from ....utils import add_start_docstrings, logging, replace_return_docstrings +from .configuration_speech_to_text_2 import Speech2Text2Config + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "Speech2Text2Config" +_CHECKPOINT_FOR_DOC = "facebook/s2t-wav2vec2-large-en-de" + + +class Speech2Text2SinusoidalPositionalEmbedding(nn.Module): + """This module produces sinusoidal positional embeddings of any length.""" + + def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None): + super().__init__() + self.offset = 2 + self.embedding_dim = embedding_dim + self.padding_idx = padding_idx + + def make_weights(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None): + emb_weights = self.get_embedding(num_embeddings, embedding_dim, padding_idx) + if hasattr(self, "weights"): + # in forward put the weights on the correct dtype and device of the param + emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device) + + self.weights = nn.Parameter(emb_weights) + self.weights.requires_grad = False + self.weights.detach_() + + @staticmethod + def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None): + """ + Build sinusoidal embeddings. This matches the implementation in tensor2tensor, but differs slightly from the + description in Section 3.5 of "Attention Is All You Need". + """ + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb) + emb = torch.arange(num_embeddings, dtype=torch.int64).float().unsqueeze(1) * emb.unsqueeze(0) + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1) + if embedding_dim % 2 == 1: + # zero pad + emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) + if padding_idx is not None: + emb[padding_idx, :] = 0 + return emb.to(torch.get_default_dtype()) + + @torch.no_grad() + def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0): + bsz, seq_len = input_ids.size() + # Create the position ids from the input token ids. Any padded tokens remain padded. + position_ids = self.create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length).to( + input_ids.device + ) + + # expand embeddings if needed + max_pos = self.padding_idx + 1 + seq_len + if max_pos > self.weights.size(0): + self.make_weights(max_pos + self.offset, self.embedding_dim, self.padding_idx) + + return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, -1).detach() + + def create_position_ids_from_input_ids( + self, input_ids: torch.Tensor, padding_idx: int, past_key_values_length: Optional[int] = 0 + ): + """ + Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding + symbols are ignored. This is modified from fairseq's `utils.make_positions`. + + Args: + x: torch.Tensor x: + Returns: torch.Tensor + """ + # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. + mask = input_ids.ne(padding_idx).int() + incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask + return incremental_indices.long() + padding_idx + + +class Speech2Text2Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + is_causal: bool = False, + config: Optional[Speech2Text2Config] = None, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + self.config = config + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + self.is_causal = is_causal + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.reshape(*proj_shape) + value_states = value_states.reshape(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value + + +class Speech2Text2DecoderLayer(nn.Module): + def __init__(self, config: Speech2Text2Config): + super().__init__() + self.embed_dim = config.d_model + + self.self_attn = Speech2Text2Attention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + ) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + + if config.is_decoder: + self.encoder_attn = Speech2Text2Attention( + self.embed_dim, + config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + ) + self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) + + self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) + self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + cross_attn_layer_head_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = True, + ): + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + encoder_hidden_states (`torch.FloatTensor`): + cross attention input to the layer of shape `(batch, seq_len, embed_dim)` + encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of + size *(decoder_attention_heads,)*. + past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + # Self Attention + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + # add present self-attn cache to positions 1,2 of present_key_value tuple + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=self_attn_past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Cross-Attention Block + cross_attn_present_key_value = None + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + + # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=cross_attn_past_key_value, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + # add cross-attn to positions 3,4 of present_key_value tuple + present_key_value = present_key_value + cross_attn_present_key_value + + # Fully Connected + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class Speech2Text2PreTrainedModel(PreTrainedModel): + config_class = Speech2Text2Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + std = self.config.init_std + if isinstance(module, (nn.Linear, nn.Conv1d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, Speech2Text2SinusoidalPositionalEmbedding): + weight = module.get_embedding(*module.weight.shape, module.padding_idx) + weight = nn.Parameter(weight, requires_grad=False) + weight.detach_() + module.weight = weight + + +SPEECH_TO_TEXT_2_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Speech2Text2Config`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +class Speech2Text2Decoder(Speech2Text2PreTrainedModel): + """ + Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`Speech2Text2DecoderLayer`] + + Args: + config: Speech2Text2Config + embed_tokens (nn.Embedding): output embedding + """ + + def __init__(self, config: Speech2Text2Config): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.decoder_layerdrop + self.padding_idx = config.pad_token_id + self.max_target_positions = config.max_target_positions + self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + + self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) + + self.embed_positions = Speech2Text2SinusoidalPositionalEmbedding( + self.max_target_positions, + config.d_model, + self.padding_idx, + ) + + self.layers = nn.ModuleList([Speech2Text2DecoderLayer(config) for _ in range(config.decoder_layers)]) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + head_mask=None, + cross_attn_head_mask=None, + past_key_values=None, + inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`Speech2Text2Tokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + of the decoder. + encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): + Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values + selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention + on hidden heads. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + # embed positions + positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length) + + hidden_states = inputs_embeds + positions + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + next_decoder_cache = () if use_cache else None + + # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): + if attn_mask is not None: + if attn_mask.size()[0] != (len(self.layers)): + raise ValueError( + f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: + continue + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + head_mask[idx] if head_mask is not None else None, + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, + None, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=( + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None + ), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + +@add_start_docstrings( + "The Speech2Text2 Model with a language modeling head. Can be used for summarization.", + SPEECH_TO_TEXT_2_START_DOCSTRING, +) +class Speech2Text2DecoderWrapper(Speech2Text2PreTrainedModel): + """ + This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is + used in combination with the [`EncoderDecoderModel`] framework. + """ + + def __init__(self, config): + super().__init__(config) + self.decoder = Speech2Text2Decoder(config) + + def forward(self, *args, **kwargs): + return self.decoder(*args, **kwargs) + + +@add_start_docstrings( + "The Speech2Text2 Decoder with a language modeling head. Can be used as the decoder part of" + " [`EncoderDecoderModel`] and [`SpeechEncoderDecoder`].", + SPEECH_TO_TEXT_2_START_DOCSTRING, +) +class Speech2Text2ForCausalLM(Speech2Text2PreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + config = copy.deepcopy(config) + config.is_decoder = True + config.is_encoder_decoder = False + super().__init__(config) + self.model = Speech2Text2DecoderWrapper(config) + + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.model.decoder.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model.decoder = decoder + + def get_decoder(self): + return self.model.decoder + + @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.FloatTensor], CausalLMOutputWithCrossAttentions]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`Speech2Text2Tokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention + if the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used + in the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional + tensors are only required when the model is used as a decoder in a Sequence to Sequence model. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + + Returns: + + Example: + + ```python + >>> from transformers import ( + ... SpeechEncoderDecoderModel, + ... Speech2Text2ForCausalLM, + ... Wav2Vec2Model, + ... Speech2Text2Config, + ... Wav2Vec2Config, + ... Wav2Vec2FeatureExtractor, + ... Speech2Text2Tokenizer, + ... ) + >>> from datasets import load_dataset + + >>> feature_extractor = Wav2Vec2FeatureExtractor() + >>> tokenizer = Speech2Text2Tokenizer.from_pretrained("facebook/s2t-wav2vec2-large-en-de") + + >>> encoder = Wav2Vec2Model(Wav2Vec2Config()) + >>> decoder = Speech2Text2ForCausalLM(Speech2Text2Config()) + >>> # init random speech2text model + + >>> model = SpeechEncoderDecoderModel(encoder=encoder, decoder=decoder) + >>> model.config.pad_token_id = tokenizer.pad_token_id + >>> model.config.decoder_start_token_id = tokenizer.bos_token_id + >>> # pre-process inputs and labels + + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + >>> inputs = feature_extractor( + ... ds[0]["audio"]["array"], sampling_rate=ds[0]["audio"]["sampling_rate"], return_tensors="pt" + ... ) + >>> input_values = inputs.input_values + >>> decoder_input_ids = tokenizer(ds[0]["text"], return_tensors="pt").input_ids + >>> # compute loss + + >>> loss = model(inputs=input_values, labels=decoder_input_ids).loss + >>> # backprop loss + + >>> loss.backward() # doctest: +IGNORE_RESULT + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + logits = self.lm_head(outputs[0]) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs + ): + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_ids.shape) + + if past_key_values: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] + # first step, decoder_cached_states are empty + return { + "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed + "attention_mask": attention_mask, + "past_key_values": past_key_values, + "use_cache": use_cache, + } + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + + +__all__ = ["Speech2Text2ForCausalLM", "Speech2Text2PreTrainedModel"] diff --git a/docs/transformers/build/lib/transformers/models/deprecated/speech_to_text_2/processing_speech_to_text_2.py b/docs/transformers/build/lib/transformers/models/deprecated/speech_to_text_2/processing_speech_to_text_2.py new file mode 100644 index 0000000000000000000000000000000000000000..3b8edaa46d10b8ca04f68d530733fcb7cb390ee8 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/speech_to_text_2/processing_speech_to_text_2.py @@ -0,0 +1,119 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Speech processor class for Speech2Text2 +""" + +import warnings +from contextlib import contextmanager + +from ....processing_utils import ProcessorMixin + + +class Speech2Text2Processor(ProcessorMixin): + r""" + Constructs a Speech2Text2 processor which wraps a Speech2Text2 feature extractor and a Speech2Text2 tokenizer into + a single processor. + + [`Speech2Text2Processor`] offers all the functionalities of [`AutoFeatureExtractor`] and [`Speech2Text2Tokenizer`]. + See the [`~Speech2Text2Processor.__call__`] and [`~Speech2Text2Processor.decode`] for more information. + + Args: + feature_extractor (`AutoFeatureExtractor`): + An instance of [`AutoFeatureExtractor`]. The feature extractor is a required input. + tokenizer (`Speech2Text2Tokenizer`): + An instance of [`Speech2Text2Tokenizer`]. The tokenizer is a required input. + """ + + feature_extractor_class = "AutoFeatureExtractor" + tokenizer_class = "Speech2Text2Tokenizer" + + def __init__(self, feature_extractor, tokenizer): + super().__init__(feature_extractor, tokenizer) + self.current_processor = self.feature_extractor + self._in_target_context_manager = False + + def __call__(self, *args, **kwargs): + """ + When used in normal mode, this method forwards all its arguments to AutoFeatureExtractor's + [`~AutoFeatureExtractor.__call__`] and returns its output. If used in the context + [`~Speech2Text2Processor.as_target_processor`] this method forwards all its arguments to + Speech2Text2Tokenizer's [`~Speech2Text2Tokenizer.__call__`]. Please refer to the docstring of the above two + methods for more information. + """ + # For backward compatibility + if self._in_target_context_manager: + return self.current_processor(*args, **kwargs) + + if "raw_speech" in kwargs: + warnings.warn("Using `raw_speech` as a keyword argument is deprecated. Use `audio` instead.") + audio = kwargs.pop("raw_speech") + else: + audio = kwargs.pop("audio", None) + sampling_rate = kwargs.pop("sampling_rate", None) + text = kwargs.pop("text", None) + if len(args) > 0: + audio = args[0] + args = args[1:] + + if audio is None and text is None: + raise ValueError("You need to specify either an `audio` or `text` input to process.") + + if audio is not None: + inputs = self.feature_extractor(audio, *args, sampling_rate=sampling_rate, **kwargs) + if text is not None: + encodings = self.tokenizer(text, **kwargs) + + if text is None: + return inputs + elif audio is None: + return encodings + else: + inputs["labels"] = encodings["input_ids"] + return inputs + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to Speech2Text2Tokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to Speech2Text2Tokenizer's [`~PreTrainedTokenizer.decode`]. Please refer + to the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + @contextmanager + def as_target_processor(self): + """ + Temporarily sets the tokenizer for processing the input. Useful for encoding the labels when fine-tuning + Speech2Text2. + """ + warnings.warn( + "`as_target_processor` is deprecated and will be removed in v5 of Transformers. You can process your " + "labels by using the argument `text` of the regular `__call__` method (either in the same call as " + "your audio inputs, or in a separate call." + ) + self._in_target_context_manager = True + self.current_processor = self.tokenizer + yield + self.current_processor = self.feature_extractor + self._in_target_context_manager = False + + +__all__ = ["Speech2Text2Processor"] diff --git a/docs/transformers/build/lib/transformers/models/deprecated/speech_to_text_2/tokenization_speech_to_text_2.py b/docs/transformers/build/lib/transformers/models/deprecated/speech_to_text_2/tokenization_speech_to_text_2.py new file mode 100644 index 0000000000000000000000000000000000000000..f5aa7ef8067c57d2fdae3b10c9ad3b01c3e23299 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/speech_to_text_2/tokenization_speech_to_text_2.py @@ -0,0 +1,252 @@ +# coding=utf-8 +# Copyright 2021 The Facebook Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization class for Speech2Text2.""" + +import json +import os +from typing import Dict, List, Optional, Tuple + +from ....tokenization_utils import PreTrainedTokenizer +from ....utils import logging + + +logger = logging.get_logger(__name__) + + +VOCAB_FILES_NAMES = { + "vocab_file": "vocab.json", + "tokenizer_config_file": "tokenizer_config.json", + "merges_file": "merges.txt", +} + + +BPE_TOKEN_MERGES = "" +BPE_TOKEN_VOCAB = "@@ " + + +def get_pairs(word): + """ + Return set of symbol pairs in a word. word is represented as tuple of symbols (symbols being variable-length + strings) + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +# Speech2Text2 has no max input length + + +class Speech2Text2Tokenizer(PreTrainedTokenizer): + """ + Constructs a Speech2Text2Tokenizer. + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains some of the main methods. Users should refer to + the superclass for more information regarding such methods. + + Args: + vocab_file (`str`): + File containing the vocabulary. + bos_token (`str`, *optional*, defaults to `""`): + The beginning of sentence token. + eos_token (`str`, *optional*, defaults to `""`): + The end of sentence token. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + + **kwargs + Additional keyword arguments passed along to [`PreTrainedTokenizer`] + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + bos_token="", + pad_token="", + eos_token="", + unk_token="", + do_lower_case=False, + merges_file=None, + **kwargs, + ): + self.do_lower_case = do_lower_case + + with open(vocab_file, encoding="utf-8") as vocab_handle: + self.encoder = json.load(vocab_handle) + self.decoder = {v: k for k, v in self.encoder.items()} + + if merges_file is None: + logger.info(f"No merges files provided. {self.__class__.__name__} can only be used for decoding.") + + self.bpe_ranks = None + self.cache = None + else: + with open(merges_file, encoding="utf-8") as merges_handle: + merges = merges_handle.read().split("\n")[:-1] + + merges = [tuple(merge.split()[:2]) for merge in merges] + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {} + super().__init__( + unk_token=unk_token, + bos_token=bos_token, + eos_token=eos_token, + pad_token=pad_token, + do_lower_case=do_lower_case, + **kwargs, + ) + + @property + def vocab_size(self) -> int: + return len(self.decoder) + + def get_vocab(self) -> Dict: + return dict(self.encoder, **self.added_tokens_encoder) + + def bpe(self, token): + word = tuple(token[:-1]) + (token[-1] + BPE_TOKEN_MERGES,) + if token in self.cache: + return self.cache[token] + pairs = get_pairs(word) + + if not pairs: + return token + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + except ValueError: + new_word.extend(word[i:]) + break + else: + new_word.extend(word[i:j]) + i = j + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = " ".join(word) + if word == "\n " + BPE_TOKEN_MERGES: + word = "\n" + BPE_TOKEN_MERGES + + if word.endswith(BPE_TOKEN_MERGES): + word = word.replace(BPE_TOKEN_MERGES, "") + + word = word.replace(" ", BPE_TOKEN_VOCAB) + self.cache[token] = word + return word + + def _tokenize(self, text): + """Tokenize a string.""" + + if self.bpe_ranks is None: + raise ValueError( + "This tokenizer was instantiated without a `merges.txt` file, so" + " that it can only be used for decoding, not for encoding. " + "Make sure to provide `merges.txt` file at instantiation to enable " + "encoding." + ) + + if self.do_lower_case: + text = text.lower() + + text = text.split() + + split_tokens = [] + for token in text: + if token: + split_tokens.extend(list(self.bpe(token).split(" "))) + + return split_tokens + + def _convert_token_to_id(self, token: str) -> int: + """Converts a token (str) in an index (integer) using the vocab.""" + return self.encoder.get(token, self.encoder.get(self.unk_token)) + + def _convert_id_to_token(self, index: int) -> str: + """Converts an index (integer) in a token (str) using the vocab.""" + result = self.decoder.get(index, self.unk_token) + return result + + def convert_tokens_to_string(self, tokens: List[str]) -> str: + """ + Converts a list of output tokens into a single string. + """ + # combine tokens + string = " ".join(tokens) + + # make sure @@ tokens are concatenated + string = "".join(string.split(BPE_TOKEN_VOCAB)) + + return string + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + merges_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"] + ) + + with open(vocab_file, "w", encoding="utf-8") as f: + f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n") + + index = 0 + if self.bpe_ranks is None: + return (vocab_file,) + + with open(merges_file, "w", encoding="utf-8") as writer: + for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning( + f"Saving vocabulary to {merges_file}: BPE merge indices are not consecutive." + " Please check that the tokenizer is not corrupted!" + ) + index = token_index + writer.write(" ".join(bpe_tokens) + "\n") + index += 1 + + return (vocab_file, merges_file) + + +__all__ = ["Speech2Text2Tokenizer"] diff --git a/docs/transformers/build/lib/transformers/models/deprecated/tapex/__init__.py b/docs/transformers/build/lib/transformers/models/deprecated/tapex/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b535eb1df2c40a7680bbf8d57fc70b78e23437f4 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/tapex/__init__.py @@ -0,0 +1,26 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ....utils import _LazyModule +from ....utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .tokenization_tapex import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/build/lib/transformers/models/deprecated/tapex/tokenization_tapex.py b/docs/transformers/build/lib/transformers/models/deprecated/tapex/tokenization_tapex.py new file mode 100644 index 0000000000000000000000000000000000000000..3d554872c48df0c4f8ae9f4f63e98affca511688 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/tapex/tokenization_tapex.py @@ -0,0 +1,1470 @@ +# coding=utf-8 +# Copyright 2022 Microsoft Research and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for TAPEX.""" + +import json +import os +import random +from functools import lru_cache +from typing import Dict, List, Optional, Tuple, Union + +import regex as re + +from ....file_utils import ExplicitEnum, PaddingStrategy, TensorType, add_end_docstrings, is_pandas_available +from ....tokenization_utils import AddedToken, PreTrainedTokenizer +from ....tokenization_utils_base import ENCODE_KWARGS_DOCSTRING, BatchEncoding, TextInput, TruncationStrategy +from ....utils import logging + + +if is_pandas_available(): + import pandas as pd + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt"} + + +class TapexTruncationStrategy(ExplicitEnum): + """ + Possible values for the `truncation` argument in [`~TapasTokenizer.__call__`]. Useful for tab-completion in an IDE. + """ + + DROP_ROWS_TO_FIT = "drop_rows_to_fit" + + +TAPEX_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING = r""" + add_special_tokens (`bool`, *optional*, defaults to `True`): + Whether or not to encode the sequences with the special tokens relative to their model. + padding (`bool`, `str` or [`~file_utils.PaddingStrategy`], *optional*, defaults to `False`): + Activates and controls padding. Accepts the following values: + + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + truncation (`bool`, `str`, [`TapexTruncationStrategy`] or [`~tokenization_utils_base.TruncationStrategy`], + *optional*, defaults to `False`): + + Activates and controls truncation. Accepts the following values: + + - `'drop_rows_to_fit'`: Truncate to a maximum length specified with the argument `max_length` or to the + maximum acceptable input length for the model if that argument is not provided. This will truncate + row by row, removing rows from the table. + - `True` or `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or + to the maximum acceptable input length for the model if that argument is not provided. This will + truncate token by token, removing a token from the longest sequence in the pair if a pair of + sequences (or a batch of pairs) is provided. + - `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the + maximum acceptable input length for the model if that argument is not provided. This will only + truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + - `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the + maximum acceptable input length for the model if that argument is not provided. This will only + truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + - `False` or `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths + greater than the model maximum admissible input size). + max_length (`int`, *optional*): + Controls the maximum length to use by one of the truncation/padding parameters. If left unset or set to + `None`, this will use the predefined model maximum length if a maximum length is required by one of the + truncation/padding parameters. If the model has no specific maximum input length (like XLNet) + truncation/padding to a maximum length will be deactivated. + stride (`int`, *optional*, defaults to 0): + If set to a number along with `max_length`, the overflowing tokens returned when + `return_overflowing_tokens=True` will contain some tokens from the end of the truncated sequence + returned to provide some overlap between truncated and overflowing sequences. The value of this + argument defines the number of overlapping tokens. + pad_to_multiple_of (`int`, *optional*): + If set will pad the sequence to a multiple of the provided value. This is especially useful to enable + the use of Tensor Cores on NVIDIA hardware with compute capability `>= 7.5` (Volta). + return_tensors (`str` or [`~file_utils.TensorType`], *optional*): + If set, will return tensors instead of list of python integers. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return Numpy `np.ndarray` objects. +""" + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control + characters the bpe code barfs on. The reversible bpe codes work on unicode strings. This means you need a large # + of unicode characters in your vocab if you want to avoid UNKs. When you're at something like a 10B token dataset + you end up needing around 5K for decent coverage. This is a significant percentage of your normal, say, 32K bpe + vocab. To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + """ + bs = ( + list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """ + Return set of symbol pairs in a word. Word is represented as tuple of symbols (symbols being variable-length + strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +class IndexedRowTableLinearize: + """ + FORMAT: col: col1 | col2 | col 3 row 1 : val1 | val2 | val3 row 2 : ... + """ + + def process_table(self, table_content: Dict): + """ + Given a table, TableLinearize aims at converting it into a flatten sequence with special symbols. + """ + assert "header" in table_content and "rows" in table_content, self.PROMPT_MESSAGE + # process header + table_str = self.process_header(table_content["header"]) + " " + # process rows + for i, row_example in enumerate(table_content["rows"]): + # NOTE: the row should start from row 1 instead of 0 + table_str += self.process_row(row_example, row_index=i + 1) + " " + return table_str.strip() + + def process_header(self, headers: List): + """ + Given a list of headers, TableLinearize aims at converting it into a flatten sequence with special symbols. + """ + return "col : " + " | ".join(headers) + + def process_row(self, row: List, row_index: int): + """ + Given a row, TableLinearize aims at converting it into a flatten sequence with special symbols. + """ + row_str = "" + row_cell_values = [] + for cell_value in row: + if isinstance(cell_value, int): + row_cell_values.append(str(cell_value)) + else: + row_cell_values.append(cell_value) + row_str += " | ".join(row_cell_values) + return "row " + str(row_index) + " : " + row_str + + +class TapexTokenizer(PreTrainedTokenizer): + r""" + Construct a TAPEX tokenizer. Based on byte-level Byte-Pair-Encoding (BPE). + + This tokenizer can be used to flatten one or more table(s) and concatenate them with one or more related sentences + to be used by TAPEX models. The format that the TAPEX tokenizer creates is the following: + + sentence col: col1 | col2 | col 3 row 1 : val1 | val2 | val3 row 2 : ... + + The tokenizer supports a single table + single query, a single table and multiple queries (in which case the table + will be duplicated for every query), a single query and multiple tables (in which case the query will be duplicated + for every table), and multiple tables and queries. In other words, you can provide a batch of tables + questions to + the tokenizer for instance to prepare them for the model. + + Tokenization itself is based on the BPE algorithm. It is identical to the one used by BART, RoBERTa and GPT-2. + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + merges_file (`str`): + Path to the merges file. + do_lower_case (`bool`, *optional*, defaults to `True`): + Whether or not to lowercase the input when tokenizing. + errors (`str`, *optional*, defaults to `"replace"`): + Paradigm to follow when decoding bytes to UTF-8. See + [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information. + bos_token (`str`, *optional*, defaults to `""`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + + + + When building a sequence using special tokens, this is not the token that is used for the beginning of + sequence. The token used is the `cls_token`. + + + + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + sep_token (`str`, *optional*, defaults to `""`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + cls_token (`str`, *optional*, defaults to `""`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + mask_token (`str`, *optional*, defaults to `""`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + add_prefix_space (`bool`, *optional*, defaults to `False`): + Whether or not to add an initial space to the input. This allows to treat the leading word just as any + other word. (BART tokenizer detect beginning of words by the preceding space). + max_cell_length (`int`, *optional*, defaults to 15): + Maximum number of characters per cell when linearizing a table. If this number is exceeded, truncation + takes place. + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + merges_file, + do_lower_case=True, + errors="replace", + bos_token="", + eos_token="", + sep_token="", + cls_token="", + unk_token="", + pad_token="", + mask_token="", + add_prefix_space=False, + max_cell_length=15, + **kwargs, + ): + bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token + eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token + sep_token = AddedToken(sep_token, lstrip=False, rstrip=False) if isinstance(sep_token, str) else sep_token + cls_token = AddedToken(cls_token, lstrip=False, rstrip=False) if isinstance(cls_token, str) else cls_token + unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token + pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token + + # Mask token behave like a normal word, i.e. include the space before it + mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token + + with open(vocab_file, encoding="utf-8") as vocab_handle: + self.encoder = json.load(vocab_handle) + self.decoder = {v: k for k, v in self.encoder.items()} + self.errors = errors # how to handle errors in decoding + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + with open(merges_file, encoding="utf-8") as merges_handle: + bpe_merges = merges_handle.read().split("\n")[1:-1] + bpe_merges = [tuple(merge.split()) for merge in bpe_merges] + self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) + self.cache = {} + self.add_prefix_space = add_prefix_space + self.do_lower_case = do_lower_case + + # Should have added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions + self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") + + # additional properties + + super().__init__( + vocab_file=vocab_file, + merges_file=merges_file, + do_lower_case=do_lower_case, + errors=errors, + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + sep_token=sep_token, + cls_token=cls_token, + pad_token=pad_token, + mask_token=mask_token, + add_prefix_space=add_prefix_space, + max_cell_length=max_cell_length, + **kwargs, + ) + + self.max_cell_length = max_cell_length + self.table_linearize = IndexedRowTableLinearize() + + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A TAPEX sequence has the following format: + - single sequence: ` X ` + - pair of sequences: ` A B ` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + cls = [self.cls_token_id] + sep = [self.sep_token_id] + return cls + token_ids_0 + sep + sep + token_ids_1 + sep + + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Args: + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + if token_ids_1 is None: + return [1] + ([0] * len(token_ids_0)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1] + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Args: + Create a mask from the two sequences passed to be used in a sequence-pair classification task. TAPEX does not: + make use of token type ids, therefore a list of zeros is returned. + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + Returns: + `List[int]`: List of zeros. + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0] + + def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs): + add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space) + if (is_split_into_words or add_prefix_space) and (len(text) > 0 and not text[0].isspace()): + text = " " + text + return (text, kwargs) + + @property + def vocab_size(self): + return len(self.encoder) + + def get_vocab(self): + return dict(self.encoder, **self.added_tokens_encoder) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token) + pairs = get_pairs(word) + + if not pairs: + return token + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + except ValueError: + new_word.extend(word[i:]) + break + else: + new_word.extend(word[i:j]) + i = j + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = " ".join(word) + self.cache[token] = word + return word + + def _tokenize(self, text): + """Tokenize a string.""" + bpe_tokens = [] + for token in re.findall(self.pat, text): + token = "".join( + self.byte_encoder[b] for b in token.encode("utf-8") + ) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case) + bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" ")) + return bpe_tokens + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.encoder.get(token, self.encoder.get(self.unk_token)) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.decoder.get(index) + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + text = "".join(tokens) + text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors) + return text + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + merge_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"] + ) + + with open(vocab_file, "w", encoding="utf-8") as f: + f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n") + + index = 0 + with open(merge_file, "w", encoding="utf-8") as writer: + writer.write("#version: 0.2\n") + for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning( + f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive." + " Please check that the tokenizer is not corrupted!" + ) + index = token_index + writer.write(" ".join(bpe_tokens) + "\n") + index += 1 + + return vocab_file, merge_file + + @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, TAPEX_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) + def __call__( + self, + table: Union["pd.DataFrame", List["pd.DataFrame"]] = None, + query: Optional[Union[TextInput, List[TextInput]]] = None, + answer: Union[str, List[str]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + """ + Main method to tokenize and prepare for the model one or several table-sequence pair(s). + + Args: + table (`pd.DataFrame`, `List[pd.DataFrame]`): + Table(s) containing tabular data. + query (`str` or `List[str]`, *optional*): + Sentence or batch of sentences related to one or more table(s) to be encoded. Note that the number of + sentences must match the number of tables. + answer (`str` or `List[str]`, *optional*): + Optionally, the corresponding answer to the questions as supervision. + """ + + if table is not None: + return self.source_call_func( + table=table, + query=query, + answer=answer, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + elif answer is not None: + return self.target_call_func( + answer=answer, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + else: + raise ValueError("You need to provide either a `table` or an `answer`.") + + def source_call_func( + self, + table: Union["pd.DataFrame", List["pd.DataFrame"]], + query: Optional[Union[TextInput, List[TextInput]]] = None, + answer: Union[str, List[str]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + # Input type checking for clearer error + valid_table = False + valid_query = False + + # Check that table have a valid type + if isinstance(table, pd.DataFrame): + valid_table = True + elif isinstance(table, (list, tuple)) and isinstance(table[0], pd.DataFrame): + valid_table = True + + # Check that query have a valid type + if query is None or isinstance(query, str): + valid_query = True + elif isinstance(query, (list, tuple)): + if len(query) == 0 or isinstance(query[0], str): + valid_query = True + + if not valid_table: + raise ValueError( + "table input must of type `pd.DataFrame` (single example), `List[pd.DataFrame]` (batch of examples). " + ) + if not valid_query: + raise ValueError("query input must of type `str` (single example), `List[str]` (batch of examples). ") + is_batched = isinstance(table, (list, tuple)) or isinstance(query, (list, tuple)) + + if is_batched: + return self.batch_encode_plus( + table=table, + query=query, + answer=answer, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + else: + return self.encode_plus( + table=table, + query=query, + answer=answer, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + + @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, TAPEX_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) + def batch_encode_plus( + self, + table: Union["pd.DataFrame", List["pd.DataFrame"]], + query: Optional[List[TextInput]] = None, + answer: List[str] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str] = None, + max_length: Optional[int] = None, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + """ + + + This method is deprecated, `__call__` should be used instead. + + + """ + # Backward compatibility for 'truncation_strategy', 'pad_to_max_length' + padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies( + padding=padding, + truncation=truncation, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + verbose=verbose, + **kwargs, + ) + + return self._batch_encode_plus( + table=table, + query=query, + answer=answer, + add_special_tokens=add_special_tokens, + padding_strategy=padding_strategy, + truncation_strategy=truncation_strategy, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + + def _batch_encode_plus( + self, + table: Union["pd.DataFrame", List["pd.DataFrame"]], + query: Optional[List[TextInput]] = None, + answer: Optional[List[str]] = None, + add_special_tokens: bool = True, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + if return_offsets_mapping: + raise NotImplementedError( + "return_offset_mapping is not available when using Python tokenizers. " + "To use this feature, change your tokenizer to one deriving from " + "transformers.PreTrainedTokenizerFast." + ) + + if isinstance(table, pd.DataFrame) and isinstance(query, (list, tuple)): + # single table, many queries case + # duplicate table for every query + table = [table] * len(query) + if isinstance(table, (list, tuple)) and isinstance(query, str): + # many tables, single query case + # duplicate query for every table + query = [query] * len(table) + + batch_outputs = self._batch_prepare_for_model( + table=table, + query=query, + answer=answer, + add_special_tokens=add_special_tokens, + padding_strategy=padding_strategy, + truncation_strategy=truncation_strategy, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + return_attention_mask=return_attention_mask, + return_token_type_ids=return_token_type_ids, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_length=return_length, + return_tensors=return_tensors, + verbose=verbose, + ) + + return BatchEncoding(batch_outputs) + + @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, TAPEX_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) + def _batch_prepare_for_model( + self, + table: Union["pd.DataFrame", List["pd.DataFrame"]], + query: Optional[Union[TextInput, List[TextInput]]] = None, + answer: Optional[Union[str, List[str]]] = None, + add_special_tokens: bool = True, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[str] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_length: bool = False, + verbose: bool = True, + ) -> BatchEncoding: + """ + This method adds special tokens, truncates sequences if overflowing while taking into account the special + tokens and manages a moving window (with user defined stride) for overflowing tokens. + """ + batch_outputs = {} + if answer is None: + answer = [None] * len(table) + for _table, _query, _answer in zip(table, query, answer): + text = self.prepare_table_query( + _table, _query, _answer, truncation_strategy=truncation_strategy, max_length=max_length + ) + + if self.do_lower_case: + text = text.lower() + + tokens = self.tokenize(text) + outputs = self.prepare_for_model( + ids=self.convert_tokens_to_ids(tokens), + add_special_tokens=add_special_tokens, + padding=PaddingStrategy.DO_NOT_PAD.value, # we pad in batch afterwards + truncation=truncation_strategy.value, + max_length=max_length, + stride=stride, + pad_to_multiple_of=None, # we pad in batch afterwards + return_attention_mask=False, # we pad in batch afterwards + return_token_type_ids=return_token_type_ids, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_length=return_length, + return_tensors=None, # We convert the whole batch to tensors at the end + prepend_batch_axis=False, + verbose=verbose, + ) + + for key, value in outputs.items(): + if key not in batch_outputs: + batch_outputs[key] = [] + batch_outputs[key].append(value) + + batch_outputs = self.pad( + batch_outputs, + padding=padding_strategy.value, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + return_attention_mask=return_attention_mask, + ) + + batch_outputs = BatchEncoding(batch_outputs, tensor_type=return_tensors) + + return batch_outputs + + @add_end_docstrings(ENCODE_KWARGS_DOCSTRING) + def encode( + self, + table: "pd.DataFrame", + query: Optional[TextInput] = None, + answer: Optional[str] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy, TapexTruncationStrategy] = None, + max_length: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + **kwargs, + ) -> List[int]: + """ + Prepare a table, a string and possible answer for the model. This method does not return token type IDs, + attention masks, etc. which are necessary for the model to work correctly. Use this method if you want to build + your processing on your own, otherwise refer to `__call__`. + """ + encoded_inputs = self.encode_plus( + table, + query=query, + answer=answer, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + return_tensors=return_tensors, + **kwargs, + ) + + return encoded_inputs["input_ids"] + + @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, TAPEX_ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING) + def encode_plus( + self, + table: "pd.DataFrame", + query: Optional[TextInput] = None, + answer: Optional[str] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str] = None, + max_length: Optional[int] = None, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + # Backward compatibility for 'truncation_strategy', 'pad_to_max_length' + padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies( + padding=padding, + truncation=truncation, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + verbose=verbose, + **kwargs, + ) + + return self._encode_plus( + table=table, + query=query, + answer=answer, + add_special_tokens=add_special_tokens, + padding_strategy=padding_strategy, + truncation_strategy=truncation_strategy, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + + def _encode_plus( + self, + table: "pd.DataFrame", + query: Optional[TextInput] = None, + answer: Optional[str] = None, + add_special_tokens: bool = True, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + if return_offsets_mapping: + raise NotImplementedError( + "return_offset_mapping is not available when using Python tokenizers. " + "To use this feature, change your tokenizer to one deriving from " + "transformers.PreTrainedTokenizerFast. " + "More information on available tokenizers at " + "https://github.com/huggingface/transformers/pull/2674" + ) + + text = self.prepare_table_query( + table, query, answer, truncation_strategy=truncation_strategy, max_length=max_length + ) + + # if necessary, perform lower case + if self.do_lower_case: + text = text.lower() + + tokens = self.tokenize(text) + + return self.prepare_for_model( + ids=self.convert_tokens_to_ids(tokens), + add_special_tokens=add_special_tokens, + padding=padding_strategy.value, + truncation=truncation_strategy.value, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors=return_tensors, + prepend_batch_axis=True, + return_attention_mask=return_attention_mask, + return_token_type_ids=return_token_type_ids, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_length=return_length, + verbose=verbose, + ) + + def target_call_func( + self, + answer: Union[str, List[str]], + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + """ + The method tokenizes and prepares the answer label for the model. + + Args: + answer (`str` or `List[str]`): + Corresponding answer supervision to the queries for training the model. + """ + is_batched = isinstance(answer, (list, tuple)) + + if is_batched: + return self.target_batch_encode_plus( + answer=answer, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + else: + return self.target_encode_plus( + answer=answer, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + + def target_batch_encode_plus( + self, + answer: List[str], + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str] = None, + max_length: Optional[int] = None, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + """ + Prepare answer strings for the model. + + Args: + answer `List[str]`: + Corresponding answer supervision to the queries for training the model. + """ + # Backward compatibility for 'truncation_strategy', 'pad_to_max_length' + padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies( + padding=padding, + truncation=truncation, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + verbose=verbose, + **kwargs, + ) + + return self._target_batch_encode_plus( + answer=answer, + add_special_tokens=add_special_tokens, + padding_strategy=padding_strategy, + truncation_strategy=truncation_strategy, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + + def _target_batch_encode_plus( + self, + answer: List[str], + add_special_tokens: bool = True, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + batch_outputs = {} + for text in answer: + if self.do_lower_case: + text = text.lower() + + tokens = self.tokenize(text) + outputs = self.prepare_for_model( + ids=self.convert_tokens_to_ids(tokens), + add_special_tokens=add_special_tokens, + padding=PaddingStrategy.DO_NOT_PAD.value, # we pad in batch afterwards + truncation=truncation_strategy.value, + max_length=max_length, + stride=stride, + pad_to_multiple_of=None, # we pad in batch afterwards + return_attention_mask=False, # we pad in batch afterwards + return_token_type_ids=return_token_type_ids, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_length=return_length, + return_tensors=None, # We convert the whole batch to tensors at the end + prepend_batch_axis=False, + verbose=verbose, + ) + + for key, value in outputs.items(): + if key not in batch_outputs: + batch_outputs[key] = [] + batch_outputs[key].append(value) + + batch_outputs = self.pad( + batch_outputs, + padding=padding_strategy.value, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + return_attention_mask=return_attention_mask, + ) + + batch_outputs = BatchEncoding(batch_outputs, tensor_type=return_tensors) + + return BatchEncoding(batch_outputs) + + def target_encode( + self, + answer: str, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy, TapexTruncationStrategy] = None, + max_length: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + **kwargs, + ) -> List[int]: + """ + Prepare the answer string for the model. This method does not return token type IDs, attention masks, etc. + which are necessary for the model to work correctly. Use this method if you want to build your processing on + your own, otherwise refer to `__call__`. + + Args: + answer `str`: + Corresponding answer supervision to the queries for training the model + """ + encoded_outputs = self.target_encode_plus( + answer=answer, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + return_tensors=return_tensors, + **kwargs, + ) + + return encoded_outputs["input_ids"] + + def target_encode_plus( + self, + answer: str, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str] = None, + max_length: Optional[int] = None, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + """ + Prepare a answer string for the model. + + Args: + answer `str`: + Corresponding answer supervision to the queries for training the model. + """ + # Backward compatibility for 'truncation_strategy', 'pad_to_max_length' + padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies( + padding=padding, + truncation=truncation, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + verbose=verbose, + **kwargs, + ) + + return self._target_encode_plus( + answer=answer, + add_special_tokens=add_special_tokens, + padding_strategy=padding_strategy, + truncation_strategy=truncation_strategy, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + + def _target_encode_plus( + self, + answer: str, + add_special_tokens: bool = True, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + if return_offsets_mapping: + raise NotImplementedError( + "return_offset_mapping is not available when using Python tokenizers. " + "To use this feature, change your tokenizer to one deriving from " + "transformers.PreTrainedTokenizerFast. " + "More information on available tokenizers at " + "https://github.com/huggingface/transformers/pull/2674" + ) + + text = answer + + # if necessary, perform lower case + if self.do_lower_case: + text = text.lower() + + tokens = self.tokenize(text) + + return self.prepare_for_model( + ids=self.convert_tokens_to_ids(tokens), + add_special_tokens=add_special_tokens, + padding=padding_strategy.value, + truncation=truncation_strategy.value, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors=return_tensors, + prepend_batch_axis=True, + return_attention_mask=return_attention_mask, + return_token_type_ids=return_token_type_ids, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_length=return_length, + verbose=verbose, + ) + + def prepare_table_query( + self, + table, + query, + answer=None, + truncation_strategy=Union[str, TruncationStrategy, TapexTruncationStrategy], + max_length=None, + ): + """ + This method can be used to linearize a table and add a corresponding query. + + Optionally, it also handles truncation of the table (cells). + + An answer can be provided for more precise truncation. + """ + if not table.empty: + # step 1: create table dictionary + table_content = {"header": list(table.columns), "rows": [list(row.values) for i, row in table.iterrows()]} + + # step 2: modify table internally + # always truncate table cells based on self.max_cell_length + # optionally truncate rows if truncation_strategy is set to it + self.truncate_table_cells(table_content, query, answer) + if truncation_strategy == TapexTruncationStrategy.DROP_ROWS_TO_FIT: + self.truncate_table_rows(table_content, query, answer, max_length=max_length) + + # step 3: linearize table + linear_table = self.table_linearize.process_table(table_content) + else: + linear_table = "" + + if linear_table == "": + logger.warning( + "You provide an empty table, or all cells contain much tokens (e.g., >= 1024 tokens). " + + f"Please carefully check the corresponding table with the query : {query}." + ) + if query == "": + logger.warning("You provide nothing to query with respect to the table.") + # step 4: concatenate query with linear_table + separator = " " if query and linear_table else "" + joint_input = (query + separator + linear_table) if query else linear_table + + return joint_input + + def truncate_table_cells(self, table_content: Dict, question: str, answer: List): + # TODO (Qian): is it possible to revert the original cell if it is in the final answer? + cell_mapping = {} + for row in table_content["rows"]: + for i, cell in enumerate(row): + truncate_cell = self.truncate_cell(cell) + if truncate_cell is not None: + cell_mapping[cell] = truncate_cell + row[i] = truncate_cell + + # modify the answer list + if answer is not None: + for i, case in enumerate(answer): + if case in cell_mapping.keys(): + answer[i] = cell_mapping[case] + + def truncate_cell(self, cell_value): + # do not process on these cases + if isinstance(cell_value, int) or isinstance(cell_value, float): + return cell_value + if cell_value.strip() != "": + try_tokens = self.tokenize(cell_value) + if len(try_tokens) >= self.max_cell_length: + retain_tokens = try_tokens[: self.max_cell_length] + retain_cell_value = self.convert_tokens_to_string(retain_tokens) + return retain_cell_value + else: + return None + else: + return cell_value + + def truncate_table_rows( + self, table_content: Dict, question: str, answer: Optional[Union[str, List[str]]] = None, max_length=None + ): + """ + Args: + table_content: + {"header": xxx, "rows": xxx, "id" (Optionally): xxx} + + question: + natural language sentence + + answer: + if for training, is the supervision; otherwise will be empty + """ + delete_ratio, remain_token_len = self.estimate_delete_ratio(table_content, question, max_length) + # randomly delete unrelated rows + self.delete_unrelated_rows(table_content, question, answer, delete_ratio) + # guarantee the result < max_length + maximum_keep_rows = 0 + for ind, row_example in enumerate(table_content["rows"]): + value_string = self.table_linearize.process_row(row_example, ind + 1) + value_token_len = len(self.tokenize(value_string)) + # over the size limit, and take action + if value_token_len > remain_token_len: + break + remain_token_len -= value_token_len + maximum_keep_rows += 1 + del table_content["rows"][maximum_keep_rows:] + + def estimate_delete_ratio(self, table_content: Dict, question: str, max_length=None): + if "header" not in table_content or "rows" not in table_content: + raise ValueError("The table content should contain both 'header' and 'rows' keys.") + # calculate the tokens of header, special tokens will only be pre-prepended into question + question_tokens = self.tokenize(question, add_special_tokens=True) + # calculate the tokens of header + header_string = self.table_linearize.process_header(table_content["header"]) + header_tokens = self.tokenize(header_string, add_special_tokens=False) + # split all cell values into tokens and see how many can be accommodated + used_token_len = len(question_tokens) + len(header_tokens) + # remaining token space for rows + remain_token_len = max_length - used_token_len + + value_string = "" + for _, row_example in enumerate(table_content["rows"]): + # use a general index to roughly estimate the overall token len + value_string += self.table_linearize.process_row(row_example, 100) + " " + value_token_len = len(self.tokenize(value_string)) + + if value_token_len < remain_token_len: + # no row will be deleted + return 0.0, remain_token_len + else: + # calc a roughly delete rate + return 1.0 - remain_token_len / value_token_len, remain_token_len + + def delete_unrelated_rows(self, table_content: Dict, question: str, answer: List, delete_ratio: float): + """ + The argument answer is used only during training. + """ + truncated_unrelated_indices = [] + related_indices = [] + if answer is None or len(answer) == 0: + answer_set = set() + else: + answer_set = {ans_ex.lower() for ans_ex in answer} + # add question key words into answer set + if question is not None: + answer_set.update(question.split()) + question_set = set(question.strip("?!.,").split(" ")) + row_max_len = len(table_content["rows"]) + for _row_idx, row in enumerate(table_content["rows"]): + lower_row = {str(cell).lower() for cell in row} + if len(lower_row & answer_set) == 0 and len(lower_row & question_set) == 0: + truncated_unrelated_indices.append(_row_idx) + else: + # add neighbours to preserve information aggressively + related_indices.extend([_row_idx - 2, _row_idx - 1, _row_idx, _row_idx + 1, _row_idx + 2]) + + # remove the neighbours + truncated_unrelated_indices = [ + _row_idx for _row_idx in truncated_unrelated_indices if _row_idx not in related_indices + ] + # select some cases to drop + drop_items = min(len(truncated_unrelated_indices), int(len(table_content["rows"]) * delete_ratio)) + drop_row_indices = random.choices(truncated_unrelated_indices, k=drop_items) + + for _row_idx in reversed(range(row_max_len)): + if _row_idx in drop_row_indices: + del table_content["rows"][_row_idx] + + # only when the drop ratio is too large, logging for warning. + if "id" in table_content and len(drop_row_indices) > 0: + logger.warning("Delete {:.2f} rows in table {}".format(len(drop_row_indices), table_content["id"])) + + +__all__ = ["TapexTokenizer"] diff --git a/docs/transformers/build/lib/transformers/models/deprecated/trajectory_transformer/__init__.py b/docs/transformers/build/lib/transformers/models/deprecated/trajectory_transformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b4bccdd12c9703313951e223d0ac1955f9ca1581 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/trajectory_transformer/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ....utils import _LazyModule +from ....utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_trajectory_transformer import * + from .modeling_trajectory_transformer import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/build/lib/transformers/models/deprecated/trajectory_transformer/configuration_trajectory_transformer.py b/docs/transformers/build/lib/transformers/models/deprecated/trajectory_transformer/configuration_trajectory_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..5a267cc4c4b82c6703eaa3a732d3fca1606a60d7 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/trajectory_transformer/configuration_trajectory_transformer.py @@ -0,0 +1,155 @@ +# coding=utf-8 +# Copyright 2022 The Trajectory Transformers paper authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TrajectoryTransformer model configuration""" + +from ....configuration_utils import PretrainedConfig +from ....utils import logging + + +logger = logging.get_logger(__name__) + + +class TrajectoryTransformerConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`TrajectoryTransformerModel`]. It is used to + instantiate an TrajectoryTransformer model according to the specified arguments, defining the model architecture. + Instantiating a configuration with the defaults will yield a similar configuration to that of the + TrajectoryTransformer + [CarlCochet/trajectory-transformer-halfcheetah-medium-v2](https://huggingface.co/CarlCochet/trajectory-transformer-halfcheetah-medium-v2) + architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 100): + Vocabulary size of the TrajectoryTransformer model. Defines the number of different tokens that can be + represented by the `trajectories` passed when calling [`TrajectoryTransformerModel`] + action_weight (`int`, *optional*, defaults to 5): + Weight of the action in the loss function + reward_weight (`int`, *optional*, defaults to 1): + Weight of the reward in the loss function + value_weight (`int`, *optional*, defaults to 1): + Weight of the value in the loss function + block_size (`int`, *optional*, defaults to 249): + Size of the blocks in the trajectory transformer. + action_dim (`int`, *optional*, defaults to 6): + Dimension of the action space. + observation_dim (`int`, *optional*, defaults to 17): + Dimension of the observation space. + transition_dim (`int`, *optional*, defaults to 25): + Dimension of the transition space. + n_layer (`int`, *optional*, defaults to 4): + Number of hidden layers in the Transformer encoder. + n_head (`int`, *optional*, defaults to 4): + Number of attention heads for each attention layer in the Transformer encoder. + n_embd (`int`, *optional*, defaults to 128): + Dimensionality of the embeddings and hidden states. + resid_pdrop (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + embd_pdrop (`int`, *optional*, defaults to 0.1): + The dropout ratio for the embeddings. + attn_pdrop (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + kaiming_initializer_range (`float, *optional*, defaults to 1): + A coefficient scaling the negative slope of the kaiming initializer rectifier for EinLinear layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + Example: + + ```python + >>> from transformers import TrajectoryTransformerConfig, TrajectoryTransformerModel + + >>> # Initializing a TrajectoryTransformer CarlCochet/trajectory-transformer-halfcheetah-medium-v2 style configuration + >>> configuration = TrajectoryTransformerConfig() + + >>> # Initializing a model (with random weights) from the CarlCochet/trajectory-transformer-halfcheetah-medium-v2 style configuration + >>> model = TrajectoryTransformerModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "trajectory_transformer" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = { + "hidden_size": "n_embd", + "num_attention_heads": "n_head", + "num_hidden_layers": "n_layer", + } + + def __init__( + self, + vocab_size=100, + action_weight=5, + reward_weight=1, + value_weight=1, + block_size=249, + action_dim=6, + observation_dim=17, + transition_dim=25, + n_layer=4, + n_head=4, + n_embd=128, + embd_pdrop=0.1, + attn_pdrop=0.1, + resid_pdrop=0.1, + learning_rate=0.0006, + max_position_embeddings=512, + initializer_range=0.02, + layer_norm_eps=1e-12, + kaiming_initializer_range=1, + use_cache=True, + pad_token_id=1, + bos_token_id=50256, + eos_token_id=50256, + **kwargs, + ): + self.vocab_size = vocab_size + self.action_weight = action_weight + self.reward_weight = reward_weight + self.value_weight = value_weight + self.max_position_embeddings = max_position_embeddings + self.block_size = block_size + self.action_dim = action_dim + self.observation_dim = observation_dim + self.transition_dim = transition_dim + self.learning_rate = learning_rate + self.n_layer = n_layer + self.n_head = n_head + self.n_embd = n_embd + self.embd_pdrop = embd_pdrop + self.attn_pdrop = attn_pdrop + self.resid_pdrop = resid_pdrop + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.kaiming_initializer_range = kaiming_initializer_range + self.use_cache = use_cache + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + +__all__ = ["TrajectoryTransformerConfig"] diff --git a/docs/transformers/build/lib/transformers/models/deprecated/trajectory_transformer/convert_trajectory_transformer_original_pytorch_checkpoint_to_pytorch.py b/docs/transformers/build/lib/transformers/models/deprecated/trajectory_transformer/convert_trajectory_transformer_original_pytorch_checkpoint_to_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..da7f7806671dbace1a10bd60d93e6782e27a5136 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/trajectory_transformer/convert_trajectory_transformer_original_pytorch_checkpoint_to_pytorch.py @@ -0,0 +1,70 @@ +# coding=utf-8 +# Copyright 2022 The Trajectory Transformers paper authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TrajectoryTransformer pytorch checkpoint conversion""" + +import torch +import trajectory.utils as utils + +from transformers import TrajectoryTransformerModel + + +class Parser(utils.Parser): + dataset: str = "halfcheetah-medium-expert-v2" + config: str = "config.offline" + + +def convert_trajectory_transformer_original_pytorch_checkpoint_to_pytorch(logbase, dataset, loadpath, epoch, device): + """Converting Sequential blocks to ModuleList""" + + gpt, gpt_epoch = utils.load_model(logbase, dataset, loadpath, epoch=epoch, device=device) + trajectory_transformer = TrajectoryTransformerModel(gpt.config) + + trajectory_transformer.tok_emb.load_state_dict(gpt.tok_emb.state_dict()) + trajectory_transformer.pos_emb = gpt.pos_emb + trajectory_transformer.drop.load_state_dict(gpt.drop.state_dict()) + trajectory_transformer.ln_f.load_state_dict(gpt.ln_f.state_dict()) + trajectory_transformer.head.load_state_dict(gpt.head.state_dict()) + + for i, block in enumerate(gpt.blocks): + trajectory_transformer.blocks[i].ln1.load_state_dict(gpt.blocks[i].ln1.state_dict()) + trajectory_transformer.blocks[i].ln2.load_state_dict(gpt.blocks[i].ln2.state_dict()) + trajectory_transformer.blocks[i].attn.load_state_dict(gpt.blocks[i].attn.state_dict()) + + trajectory_transformer.blocks[i].l1.load_state_dict(gpt.blocks[i].mlp[0].state_dict()) + trajectory_transformer.blocks[i].act.load_state_dict(gpt.blocks[i].mlp[1].state_dict()) + trajectory_transformer.blocks[i].l2.load_state_dict(gpt.blocks[i].mlp[2].state_dict()) + trajectory_transformer.blocks[i].drop.load_state_dict(gpt.blocks[i].mlp[3].state_dict()) + + torch.save(trajectory_transformer.state_dict(), "pytorch_model.bin") + + +if __name__ == "__main__": + """ + To run this script you will need to install the original repository to run the original model. You can find it + here: https://github.com/jannerm/trajectory-transformer From this repository code you can also download the + original pytorch checkpoints. + + Run with the command: + + ```sh + >>> python convert_trajectory_transformer_original_pytorch_checkpoint_to_pytorch.py --dataset + ... --gpt_loadpath + ``` + """ + + args = Parser().parse_args("plan") + convert_trajectory_transformer_original_pytorch_checkpoint_to_pytorch( + args.logbase, args.dataset, args.gpt_loadpath, args.gpt_epoch, args.device + ) diff --git a/docs/transformers/build/lib/transformers/models/deprecated/trajectory_transformer/modeling_trajectory_transformer.py b/docs/transformers/build/lib/transformers/models/deprecated/trajectory_transformer/modeling_trajectory_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..0a9a0111d22218af6af00d7561107d0c70def2db --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/trajectory_transformer/modeling_trajectory_transformer.py @@ -0,0 +1,610 @@ +# coding=utf-8 +# Copyright 2022 The Trajectory Transformers paper authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch TrajectoryTransformer model.""" + +import math +import os +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import functional as F + +from ....modeling_utils import PreTrainedModel +from ....utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from .configuration_trajectory_transformer import TrajectoryTransformerConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "CarlCochet/trajectory-transformer-halfcheetah-medium-v2" +_CONFIG_FOR_DOC = "TrajectoryTransformerConfig" + + +def load_tf_weights_in_trajectory_transformer(model, config, tf_checkpoint_path): + """Load tf checkpoints in a pytorch model.""" + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array) + + for name, array in zip(names, arrays): + name = name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any( + n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] + for n in name + ): + logger.info(f"Skipping {'/'.join(name)}") + continue + pointer = model + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "kernel" or scope_names[0] == "gamma": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "output_bias" or scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "output_weights": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "squad": + pointer = getattr(pointer, "classifier") + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info(f"Skipping {'/'.join(name)}") + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if m_name[-11:] == "_embeddings": + pointer = getattr(pointer, "weight") + elif m_name == "kernel": + array = np.transpose(array) + try: + if pointer.shape != array.shape: + raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array) + return model + + +@dataclass +class TrajectoryTransformerOutput(ModelOutput): + """ + Base class for model's outputs that also contains a pooling of the last hidden states. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss. + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`Tuple[Tuple[torch.Tensor]]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of length `config.n_layers`, containing tuples of tensors of shape `(batch_size, num_heads, + sequence_length, embed_size_per_head)`). Contains pre-computed hidden-states (key and values in the + attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer + plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. GPT2Attentions weights after the attention softmax, used to compute the weighted average + in the self-attention heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +class TrajectoryTransformerPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = TrajectoryTransformerConfig + load_tf_weights = load_tf_weights_in_trajectory_transformer + base_model_prefix = "trajectory_transformer" + main_input_name = "trajectories" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + if isinstance(module, (nn.Linear, nn.Embedding)): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, EinLinear): + for i in range(module.n_models): + nn.init.kaiming_uniform_(module.weight[i], a=math.sqrt(5) / self.config.kaiming_initializer_range) + if module.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight[i]) + bound = (1 / math.sqrt(fan_in)) * self.config.initializer_range + nn.init.uniform_(module.bias[i], -bound, bound) + + +TRAJECTORY_TRANSFORMER_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use + it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`TrajectoryTransformerConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +TRAJECTORY_TRANSFORMER_INPUTS_DOCSTRING = r""" + Args: + trajectories (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Batch of trajectories, where a trajectory is a sequence of states, actions and rewards. + past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`, *optional*): + Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see + `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have + their past given to this model should not be passed as `input_ids` as they have already been computed. + targets (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Desired targets used to compute the loss. + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class EinLinear(nn.Module): + def __init__(self, n_models, in_features, out_features, bias): + super().__init__() + self.n_models = n_models + self.out_features = out_features + self.in_features = in_features + self.weight = nn.Parameter(torch.Tensor(n_models, out_features, in_features)) + if bias: + self.bias = nn.Parameter(torch.Tensor(n_models, out_features)) + else: + self.register_parameter("bias", None) + + def reset_parameters(self): + for i in range(self.n_models): + nn.init.kaiming_uniform_(self.weight[i], a=math.sqrt(5)) + if self.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[i]) + bound = 1 / math.sqrt(fan_in) + nn.init.uniform_(self.bias[i], -bound, bound) + + def forward(self, input): + """ + Args: + input (`torch.FloatTensor` of shape `(B, n_models, input_dim)`): + The input to the layer. + """ + # [ batch_size x n_models x output_dim ] + output = torch.einsum("eoi,bei->beo", self.weight, input) + if self.bias is not None: + raise RuntimeError() + return output + + +class CausalSelfAttention(nn.Module): + def __init__(self, config): + super().__init__() + + if config.n_embd % config.n_head != 0: + raise ValueError(f"n_head ({config.n_head}) should be a divisor of n_embd ({config.n_embd})") + + # key, query, value projections for all heads + self.key = nn.Linear(config.n_embd, config.n_embd) + self.query = nn.Linear(config.n_embd, config.n_embd) + self.value = nn.Linear(config.n_embd, config.n_embd) + + # regularization + self.attn_drop = nn.Dropout(config.attn_pdrop) + self.resid_drop = nn.Dropout(config.resid_pdrop) + + # output projection + self.proj = nn.Linear(config.n_embd, config.n_embd) + + # causal mask to ensure that attention is only applied to the left in the input sequence + self.register_buffer( + "mask", + torch.tril(torch.ones(config.block_size, config.block_size)).view( + 1, 1, config.block_size, config.block_size + ), + persistent=False, + ) + + # mask previous value estimates + joined_dim = config.observation_dim + config.action_dim + 2 + self.mask.squeeze()[:, joined_dim - 1 :: joined_dim] = 0 + + self.n_head = config.n_head + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + layer_past: Optional[Tuple[torch.Tensor]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ): + batch_size, sequence_length, embedding_dim = hidden_states.size() + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + # [ batch_size x n_heads x sequence_length x head_dim ] + key = ( + self.key(hidden_states) + .view(batch_size, sequence_length, self.n_head, embedding_dim // self.n_head) + .transpose(1, 2) + ) + query = ( + self.query(hidden_states) + .view(batch_size, sequence_length, self.n_head, embedding_dim // self.n_head) + .transpose(1, 2) + ) + value = ( + self.value(hidden_states) + .view(batch_size, sequence_length, self.n_head, embedding_dim // self.n_head) + .transpose(1, 2) + ) + + if layer_past is not None: + past_key, past_value = layer_past + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + + if use_cache is True: + present = (key, value) + else: + present = None + + # causal self-attention + # [ batch_size x n_heads x sequence_length x sequence_length ] + attn_weights = (torch.matmul(query, key.transpose(-2, -1))) * (1.0 / math.sqrt(key.size(-1))) + attn_weights = attn_weights.masked_fill( + self.mask[:, :, :sequence_length, :sequence_length] == 0, torch.finfo(attn_weights.dtype).min + ) + attn_weights = F.softmax(attn_weights, dim=-1) + self._attn_map = attn_weights.clone() + attn_weights = self.attn_drop(attn_weights) + + output = torch.matmul(attn_weights, value) + # [ batch_size x sequence_length x embedding_dim ] + # re-assemble all head outputs side by side + output = output.transpose(1, 2).contiguous().view(batch_size, sequence_length, embedding_dim) + + # output projection + output = self.resid_drop(self.proj(output)) + + outputs = (output, present) + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class Block(nn.Module): + def __init__(self, config): + super().__init__() + self.ln1 = nn.LayerNorm(config.n_embd) + self.ln2 = nn.LayerNorm(config.n_embd) + self.attn = CausalSelfAttention(config) + + # MLP + self.l1 = nn.Linear(config.n_embd, 4 * config.n_embd) + self.act = nn.GELU() + self.l2 = nn.Linear(4 * config.n_embd, config.n_embd) + self.drop = nn.Dropout(config.resid_pdrop) + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + layer_past: Optional[Tuple[torch.Tensor]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ): + residual = hidden_states + hidden_states = self.ln1(hidden_states) + + attn_outputs = self.attn( + hidden_states, layer_past=layer_past, use_cache=use_cache, output_attentions=output_attentions + ) + attn_output = attn_outputs[0] + outputs = attn_outputs[1:] + hidden_states = attn_output + residual + + residual = hidden_states + hidden_states = self.ln2(hidden_states) + hidden_states = self.l1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.l2(hidden_states) + hidden_states = residual + self.drop(hidden_states) + + if use_cache: + outputs = (hidden_states,) + outputs + else: + outputs = (hidden_states,) + outputs[1:] + + return outputs + + +@add_start_docstrings( + "The bare TrajectoryTransformer Model transformer outputting raw hidden-states without any specific head on top.", + TRAJECTORY_TRANSFORMER_START_DOCSTRING, +) +class TrajectoryTransformerModel(TrajectoryTransformerPreTrainedModel): + """the full GPT language model, with a context size of block_size""" + + def __init__(self, config): + super().__init__(config) + + # input embedding stem (+1 for stop token) + self.tok_emb = nn.Embedding(config.vocab_size * config.transition_dim + 1, config.n_embd) + + self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd)) + self.drop = nn.Dropout(config.embd_pdrop) + # transformer + self.blocks = nn.ModuleList([Block(config) for _ in range(config.n_layer)]) + # decoder head + self.ln_f = nn.LayerNorm(config.n_embd) + self.head = EinLinear(config.transition_dim, config.n_embd, config.vocab_size + 1, bias=False) + + self.vocab_size = config.vocab_size + self.stop_token = config.vocab_size * config.transition_dim + self.block_size = config.block_size + + self.observation_dim = config.observation_dim + self.action_dim = config.action_dim + self.transition_dim = config.transition_dim + self.embedding_dim = config.n_embd + + self.action_weight = config.action_weight + self.reward_weight = config.reward_weight + self.value_weight = config.value_weight + + self.gradient_checkpointing = False + + self.post_init() + + def get_block_size(self): + return self.block_size + + def offset_tokens(self, trajectories): + _, sequence_length = trajectories.shape + + n_states = int(np.ceil(sequence_length / self.transition_dim)) + + offsets = torch.arange(self.transition_dim) * self.vocab_size + offsets = offsets.repeat(n_states).to(trajectories.device) + + offset_trajectories = trajectories + offsets[:sequence_length] + offset_trajectories[trajectories == self.vocab_size] = self.stop_token + return offset_trajectories + + def pad_to_full_observation(self, hidden_states): + batch_size, sequence_length, _ = hidden_states.shape + + n_pad = (self.transition_dim - sequence_length % self.transition_dim) % self.transition_dim + padding = torch.zeros(batch_size, n_pad, self.embedding_dim, device=hidden_states.device) + + # [ batch_size x padded_sequence_length' x embedding_dim ] + hidden_states_pad = torch.cat([hidden_states, padding], dim=1) + hidden_states_pad = hidden_states_pad.view(-1, self.transition_dim, self.embedding_dim) + + return hidden_states_pad, n_pad + + @add_start_docstrings_to_model_forward( + TRAJECTORY_TRANSFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length") + ) + @replace_return_docstrings(output_type=TrajectoryTransformerOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + trajectories: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + targets: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], TrajectoryTransformerOutput]: + r""" + Returns: + + Examples: + + ```python + >>> from transformers import TrajectoryTransformerModel + >>> import torch + + >>> model = TrajectoryTransformerModel.from_pretrained( + ... "CarlCochet/trajectory-transformer-halfcheetah-medium-v2" + ... ) + >>> model.to(device) + >>> model.eval() + + >>> observations_dim, action_dim, batch_size = 17, 6, 256 + >>> seq_length = observations_dim + action_dim + 1 + + >>> trajectories = torch.LongTensor([np.random.permutation(self.seq_length) for _ in range(batch_size)]).to( + ... device + ... ) + >>> targets = torch.LongTensor([np.random.permutation(self.seq_length) for _ in range(batch_size)]).to(device) + + >>> outputs = model( + ... trajectories, + ... targets=targets, + ... use_cache=True, + ... output_attentions=True, + ... output_hidden_states=True, + ... return_dict=True, + ... ) + ``` + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + if past_key_values is None: + past_key_values = tuple([None] * len(self.blocks)) + + batch_size, sequence_length = trajectories.size() + + if sequence_length > self.block_size: + raise ValueError("Cannot forward, model block size is exhausted.") + + offset_trajectories = self.offset_tokens(trajectories) + # [ batch_size x sequence_length x embedding_dim ] + # forward the GPT model + token_embeddings = self.tok_emb(offset_trajectories) # each index maps to a (learnable) vector + position_embeddings = self.pos_emb[:, :sequence_length, :] # each position maps to a (learnable) vector + + hidden_states = self.drop(token_embeddings + position_embeddings) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + for i, (block, layer_past) in enumerate(zip(self.blocks, past_key_values)): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + outputs = self._gradient_checkpointing_func( + block.__call__, + hidden_states, + layer_past, + use_cache, + output_attentions, + ) + else: + outputs = block(hidden_states, layer_past, use_cache, output_attentions) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + + # [ batch_size x sequence_length x embedding_dim ] + hidden_state = self.ln_f(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + hidden_states_pad, n_pad = self.pad_to_full_observation(hidden_state) + + logits = self.head(hidden_states_pad) + logits = logits.reshape(batch_size, sequence_length + n_pad, self.vocab_size + 1) + logits = logits[:, :sequence_length] + + # if we are given some desired targets also calculate the loss + if targets is not None: + loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.view(-1), reduction="none") + if self.action_weight != 1 or self.reward_weight != 1 or self.value_weight != 1: + # make weights + n_states = int(np.ceil(sequence_length / self.transition_dim)) + weights = torch.cat( + [ + torch.ones(self.observation_dim, device=trajectories.device), + torch.ones(self.action_dim, device=trajectories.device) * self.action_weight, + torch.ones(1, device=trajectories.device) * self.reward_weight, + torch.ones(1, device=trajectories.device) * self.value_weight, + ] + ) + weights = weights.repeat(n_states) + weights = weights[1:].repeat(batch_size, 1) + loss = loss * weights.view(-1) + loss = (loss * attention_mask.view(-1)).mean() + else: + loss = None + + if not return_dict: + return tuple(v for v in [loss, logits, presents, all_hidden_states, all_self_attentions] if v is not None) + + return TrajectoryTransformerOutput( + loss=loss, + logits=logits, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +__all__ = [ + "TrajectoryTransformerModel", + "TrajectoryTransformerPreTrainedModel", + "load_tf_weights_in_trajectory_transformer", +] diff --git a/docs/transformers/build/lib/transformers/models/deprecated/transfo_xl/__init__.py b/docs/transformers/build/lib/transformers/models/deprecated/transfo_xl/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0ac9a2cbf4766bd187d90fdaf46505db3e92b68f --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/transfo_xl/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ....utils import _LazyModule +from ....utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_transfo_xl import * + from .modeling_tf_transfo_xl import * + from .modeling_transfo_xl import * + from .tokenization_transfo_xl import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/build/lib/transformers/models/deprecated/transfo_xl/configuration_transfo_xl.py b/docs/transformers/build/lib/transformers/models/deprecated/transfo_xl/configuration_transfo_xl.py new file mode 100644 index 0000000000000000000000000000000000000000..23972deae2ca9b31ce5adcf2aaf4f6e6cd4b7587 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/transfo_xl/configuration_transfo_xl.py @@ -0,0 +1,189 @@ +# coding=utf-8 +# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Transformer XL configuration""" + +from ....configuration_utils import PretrainedConfig +from ....utils import logging + + +logger = logging.get_logger(__name__) + + +class TransfoXLConfig(PretrainedConfig): + """ + This is the configuration class to store the configuration of a [`TransfoXLModel`] or a [`TFTransfoXLModel`]. It is + used to instantiate a Transformer-XL model according to the specified arguments, defining the model architecture. + Instantiating a configuration with the defaults will yield a similar configuration to that of the TransfoXL + [transfo-xl/transfo-xl-wt103](https://huggingface.co/transfo-xl/transfo-xl-wt103) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 267735): + Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`TransfoXLModel`] or [`TFTransfoXLModel`]. + cutoffs (`List[int]`, *optional*, defaults to `[20000, 40000, 200000]`): + Cutoffs for the adaptive softmax. + d_model (`int`, *optional*, defaults to 1024): + Dimensionality of the model's hidden states. + d_embed (`int`, *optional*, defaults to 1024): + Dimensionality of the embeddings + n_head (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + d_head (`int`, *optional*, defaults to 64): + Dimensionality of the model's heads. + d_inner (`int`, *optional*, defaults to 4096): + Inner dimension in FF + div_val (`int`, *optional*, defaults to 4): + Divident value for adapative input and softmax + pre_lnorm (`boolean`, *optional*, defaults to `False`): + Whether or not to apply LayerNorm to the input instead of the output in the blocks. + n_layer (`int`, *optional*, defaults to 18): + Number of hidden layers in the Transformer encoder. + mem_len (`int`, *optional*, defaults to 1600): + Length of the retained previous heads. + clamp_len (`int`, *optional*, defaults to 1000): + Use the same pos embeddings after clamp_len. + same_length (`boolean`, *optional*, defaults to `True`): + Whether or not to use the same attn length for all tokens + proj_share_all_but_first (`boolean`, *optional*, defaults to `True`): + True to share all but first projs, False not to share. + attn_type (`int`, *optional*, defaults to 0): + Attention type. 0 for Transformer-XL, 1 for Shaw et al, 2 for Vaswani et al, 3 for Al Rfou et al. + sample_softmax (`int`, *optional*, defaults to -1): + Number of samples in the sampled softmax. + adaptive (`boolean`, *optional*, defaults to `True`): + Whether or not to use adaptive softmax. + dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + dropatt (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + untie_r (`boolean`, *optional*, defaults to `True`): + Whether ot not to untie relative position biases. + init (`str`, *optional*, defaults to `"normal"`): + Parameter initializer to use. + init_range (`float`, *optional*, defaults to 0.01): + Parameters initialized by U(-init_range, init_range). + proj_init_std (`float`, *optional*, defaults to 0.01): + Parameters initialized by N(0, init_std) + init_std (`float`, *optional*, defaults to 0.02): + Parameters initialized by N(0, init_std) + layer_norm_epsilon (`float`, *optional*, defaults to 1e-05): + The epsilon to use in the layer normalization layers + eos_token_id (`int`, *optional*, defaults to 0): + End of stream token id. + + Examples: + + ```python + >>> from transformers import TransfoXLConfig, TransfoXLModel + + >>> # Initializing a Transformer XL configuration + >>> configuration = TransfoXLConfig() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = TransfoXLModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "transfo-xl" + keys_to_ignore_at_inference = ["mems"] + attribute_map = { + "n_token": "vocab_size", + "hidden_size": "d_model", + "num_attention_heads": "n_head", + "num_hidden_layers": "n_layer", + } + + def __init__( + self, + vocab_size=267735, + cutoffs=[20000, 40000, 200000], + d_model=1024, + d_embed=1024, + n_head=16, + d_head=64, + d_inner=4096, + div_val=4, + pre_lnorm=False, + n_layer=18, + mem_len=1600, + clamp_len=1000, + same_length=True, + proj_share_all_but_first=True, + attn_type=0, + sample_softmax=-1, + adaptive=True, + dropout=0.1, + dropatt=0.0, + untie_r=True, + init="normal", + init_range=0.01, + proj_init_std=0.01, + init_std=0.02, + layer_norm_epsilon=1e-5, + eos_token_id=0, + **kwargs, + ): + self.vocab_size = vocab_size + self.cutoffs = [] + self.cutoffs.extend(cutoffs) + if proj_share_all_but_first: + self.tie_projs = [False] + [True] * len(self.cutoffs) + else: + self.tie_projs = [False] + [False] * len(self.cutoffs) + self.d_model = d_model + self.d_embed = d_embed + self.d_head = d_head + self.d_inner = d_inner + self.div_val = div_val + self.pre_lnorm = pre_lnorm + self.n_layer = n_layer + self.n_head = n_head + self.mem_len = mem_len + self.same_length = same_length + self.attn_type = attn_type + self.clamp_len = clamp_len + self.sample_softmax = sample_softmax + self.adaptive = adaptive + self.dropout = dropout + self.dropatt = dropatt + self.untie_r = untie_r + self.init = init + self.init_range = init_range + self.proj_init_std = proj_init_std + self.init_std = init_std + self.layer_norm_epsilon = layer_norm_epsilon + super().__init__(eos_token_id=eos_token_id, **kwargs) + + @property + def max_position_embeddings(self): + # Message copied from Transformer-XL documentation + logger.info(f"The model {self.model_type} is one of the few models that has no sequence length limit.") + return -1 + + @max_position_embeddings.setter + def max_position_embeddings(self, value): + # Message copied from Transformer-XL documentation + raise NotImplementedError( + f"The model {self.model_type} is one of the few models that has no sequence length limit." + ) + + +__all__ = ["TransfoXLConfig"] diff --git a/docs/transformers/build/lib/transformers/models/deprecated/transfo_xl/convert_transfo_xl_original_tf_checkpoint_to_pytorch.py b/docs/transformers/build/lib/transformers/models/deprecated/transfo_xl/convert_transfo_xl_original_tf_checkpoint_to_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..2c7b687c4d98974ebf99790a201f7c3221a5498a --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/transfo_xl/convert_transfo_xl_original_tf_checkpoint_to_pytorch.py @@ -0,0 +1,121 @@ +# coding=utf-8 +# Copyright 2018 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert Transformer XL checkpoint and datasets.""" + +import argparse +import os +import pickle +import sys + +import torch + +from transformers import TransfoXLConfig, TransfoXLLMHeadModel, load_tf_weights_in_transfo_xl +from transformers.models.deprecated.transfo_xl import tokenization_transfo_xl as data_utils +from transformers.models.deprecated.transfo_xl.tokenization_transfo_xl import CORPUS_NAME, VOCAB_FILES_NAMES +from transformers.utils import CONFIG_NAME, WEIGHTS_NAME, logging + + +logging.set_verbosity_info() + +# We do this to be able to load python 2 datasets pickles +# See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918 +data_utils.Vocab = data_utils.TransfoXLTokenizer +data_utils.Corpus = data_utils.TransfoXLCorpus +sys.modules["data_utils"] = data_utils +sys.modules["vocabulary"] = data_utils + + +def convert_transfo_xl_checkpoint_to_pytorch( + tf_checkpoint_path, transfo_xl_config_file, pytorch_dump_folder_path, transfo_xl_dataset_file +): + if transfo_xl_dataset_file: + # Convert a pre-processed corpus (see original TensorFlow repo) + with open(transfo_xl_dataset_file, "rb") as fp: + corpus = pickle.load(fp, encoding="latin1") + # Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term) + pytorch_vocab_dump_path = pytorch_dump_folder_path + "/" + VOCAB_FILES_NAMES["pretrained_vocab_file"] + print(f"Save vocabulary to {pytorch_vocab_dump_path}") + corpus_vocab_dict = corpus.vocab.__dict__ + torch.save(corpus_vocab_dict, pytorch_vocab_dump_path) + + corpus_dict_no_vocab = corpus.__dict__ + corpus_dict_no_vocab.pop("vocab", None) + pytorch_dataset_dump_path = pytorch_dump_folder_path + "/" + CORPUS_NAME + print(f"Save dataset to {pytorch_dataset_dump_path}") + torch.save(corpus_dict_no_vocab, pytorch_dataset_dump_path) + + if tf_checkpoint_path: + # Convert a pre-trained TensorFlow model + config_path = os.path.abspath(transfo_xl_config_file) + tf_path = os.path.abspath(tf_checkpoint_path) + + print(f"Converting Transformer XL checkpoint from {tf_path} with config at {config_path}.") + # Initialise PyTorch model + if transfo_xl_config_file == "": + config = TransfoXLConfig() + else: + config = TransfoXLConfig.from_json_file(transfo_xl_config_file) + print(f"Building PyTorch model from configuration: {config}") + model = TransfoXLLMHeadModel(config) + + model = load_tf_weights_in_transfo_xl(model, config, tf_path) + # Save pytorch-model + pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME) + pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path, CONFIG_NAME) + print(f"Save PyTorch model to {os.path.abspath(pytorch_weights_dump_path)}") + torch.save(model.state_dict(), pytorch_weights_dump_path) + print(f"Save configuration file to {os.path.abspath(pytorch_config_dump_path)}") + with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: + f.write(config.to_json_string()) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--pytorch_dump_folder_path", + default=None, + type=str, + required=True, + help="Path to the folder to store the PyTorch model or dataset/vocab.", + ) + parser.add_argument( + "--tf_checkpoint_path", + default="", + type=str, + help="An optional path to a TensorFlow checkpoint path to be converted.", + ) + parser.add_argument( + "--transfo_xl_config_file", + default="", + type=str, + help=( + "An optional config json file corresponding to the pre-trained BERT model. \n" + "This specifies the model architecture." + ), + ) + parser.add_argument( + "--transfo_xl_dataset_file", + default="", + type=str, + help="An optional dataset file to be converted in a vocabulary.\n" + "Given the files are in the pickle format, please be wary of passing it files you trust.", + ) + args = parser.parse_args() + convert_transfo_xl_checkpoint_to_pytorch( + args.tf_checkpoint_path, + args.transfo_xl_config_file, + args.pytorch_dump_folder_path, + args.transfo_xl_dataset_file, + ) diff --git a/docs/transformers/build/lib/transformers/models/deprecated/transfo_xl/modeling_tf_transfo_xl.py b/docs/transformers/build/lib/transformers/models/deprecated/transfo_xl/modeling_tf_transfo_xl.py new file mode 100644 index 0000000000000000000000000000000000000000..90e2ebc3436db35a002120601ff60b66169318bb --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/transfo_xl/modeling_tf_transfo_xl.py @@ -0,0 +1,1129 @@ +# coding=utf-8 +# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +TF 2.0 Transformer XL model. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import tensorflow as tf + +from ....modeling_tf_utils import ( + TFModelInputType, + TFPreTrainedModel, + TFSequenceClassificationLoss, + get_initializer, + keras, + keras_serializable, + unpack_inputs, +) +from ....tf_utils import shape_list, stable_softmax +from ....utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, +) +from .configuration_transfo_xl import TransfoXLConfig +from .modeling_tf_transfo_xl_utilities import TFAdaptiveSoftmaxMask + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "transfo-xl/transfo-xl-wt103" +_CONFIG_FOR_DOC = "TransfoXLConfig" + + +class TFPositionalEmbedding(keras.layers.Layer): + def __init__(self, demb, **kwargs): + super().__init__(**kwargs) + + self.inv_freq = 1 / (10000 ** (tf.range(0, demb, 2.0) / demb)) + + def call(self, pos_seq, bsz=None): + self.inv_freq = tf.cast(self.inv_freq, dtype=pos_seq.dtype) + sinusoid_inp = tf.einsum("i,j->ij", pos_seq, self.inv_freq) + pos_emb = tf.concat([tf.sin(sinusoid_inp), tf.cos(sinusoid_inp)], -1) + + if bsz is not None: + return tf.tile(pos_emb[:, None, :], [1, bsz, 1]) + else: + return pos_emb[:, None, :] + + +class TFPositionwiseFF(keras.layers.Layer): + def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, layer_norm_epsilon=1e-5, init_std=0.02, **kwargs): + super().__init__(**kwargs) + + self.d_model = d_model + self.d_inner = d_inner + self.dropout = dropout + + self.layer_1 = keras.layers.Dense( + d_inner, kernel_initializer=get_initializer(init_std), activation=tf.nn.relu, name="CoreNet_._0" + ) + self.drop_1 = keras.layers.Dropout(dropout) + self.layer_2 = keras.layers.Dense(d_model, kernel_initializer=get_initializer(init_std), name="CoreNet_._3") + self.drop_2 = keras.layers.Dropout(dropout) + + self.layer_norm = keras.layers.LayerNormalization(epsilon=layer_norm_epsilon, name="layer_norm") + + self.pre_lnorm = pre_lnorm + + def call(self, inp, training=False): + if self.pre_lnorm: + # layer normalization + positionwise feed-forward + core_out = self.layer_norm(inp) + core_out = self.layer_1(core_out) + core_out = self.drop_1(core_out, training=training) + core_out = self.layer_2(core_out) + core_out = self.drop_2(core_out, training=training) + + # residual connection + output = core_out + inp + else: + # positionwise feed-forward + core_out = self.layer_1(inp) + core_out = self.drop_1(core_out, training=training) + core_out = self.layer_2(core_out) + core_out = self.drop_2(core_out, training=training) + + # residual connection + layer normalization + output = self.layer_norm(inp + core_out) + + return output + + +class TFRelPartialLearnableMultiHeadAttn(keras.layers.Layer): + def __init__( + self, + n_head, + d_model, + d_head, + dropout, + dropatt=0.0, + pre_lnorm=False, + r_r_bias=None, + r_w_bias=None, + layer_norm_epsilon=1e-5, + init_std=0.02, + output_attentions=False, + **kwargs, + ): + super().__init__(**kwargs) + + self.n_head = n_head + self.d_model = d_model + self.d_head = d_head + self.dropout = dropout + self.output_attentions = output_attentions + + self.qkv_net = keras.layers.Dense( + 3 * n_head * d_head, kernel_initializer=get_initializer(init_std), use_bias=False, name="qkv_net" + ) + + self.drop = keras.layers.Dropout(dropout) + self.dropatt = keras.layers.Dropout(dropatt) + self.o_net = keras.layers.Dense( + d_model, kernel_initializer=get_initializer(init_std), use_bias=False, name="o_net" + ) + + self.layer_norm = keras.layers.LayerNormalization(epsilon=layer_norm_epsilon, name="layer_norm") + + self.scale = 1 / (d_head**0.5) + + self.pre_lnorm = pre_lnorm + + if r_r_bias is not None and r_w_bias is not None: # Biases are shared + self.r_r_bias = r_r_bias + self.r_w_bias = r_w_bias + else: + self.r_r_bias = None + self.r_w_bias = None + + self.r_net = keras.layers.Dense( + self.n_head * self.d_head, kernel_initializer=get_initializer(init_std), use_bias=False, name="r_net" + ) + + def build(self, input_shape): + if self.r_r_bias is None or self.r_w_bias is None: # Biases are not shared + self.r_r_bias = self.add_weight( + shape=(self.n_head, self.d_head), initializer="zeros", trainable=True, name="r_r_bias" + ) + self.r_w_bias = self.add_weight( + shape=(self.n_head, self.d_head), initializer="zeros", trainable=True, name="r_w_bias" + ) + super().build(input_shape) + + def _rel_shift(self, x): + x_size = shape_list(x) + + x = tf.pad(x, [[0, 0], [1, 0], [0, 0], [0, 0]]) + x = tf.reshape(x, [x_size[1] + 1, x_size[0], x_size[2], x_size[3]]) + x = tf.slice(x, [1, 0, 0, 0], [-1, -1, -1, -1]) + x = tf.reshape(x, x_size) + + return x + + def call(self, w, r, attn_mask, mems, head_mask, output_attentions, training=False): + qlen, rlen, bsz = shape_list(w)[0], shape_list(r)[0], shape_list(w)[1] + + if mems is not None: + mems = tf.cast(mems, dtype=w.dtype) + cat = tf.concat([mems, w], 0) + if self.pre_lnorm: + w_heads = self.qkv_net(self.layer_norm(cat)) + else: + w_heads = self.qkv_net(cat) + r_head_k = self.r_net(r) + + w_head_q, w_head_k, w_head_v = tf.split(w_heads, 3, axis=-1) + w_head_q = w_head_q[-qlen:] + else: + if self.pre_lnorm: + w_heads = self.qkv_net(self.layer_norm(w)) + else: + w_heads = self.qkv_net(w) + r_head_k = self.r_net(r) + + w_head_q, w_head_k, w_head_v = tf.split(w_heads, 3, axis=-1) + + klen = shape_list(w_head_k)[0] + + w_head_q = tf.reshape(w_head_q, (qlen, bsz, self.n_head, self.d_head)) # qlen x bsz x n_head x d_head + w_head_k = tf.reshape(w_head_k, (klen, bsz, self.n_head, self.d_head)) # qlen x bsz x n_head x d_head + w_head_v = tf.reshape(w_head_v, (klen, bsz, self.n_head, self.d_head)) # qlen x bsz x n_head x d_head + + r_head_k = tf.reshape(r_head_k, (rlen, self.n_head, self.d_head)) # qlen x n_head x d_head + + # compute attention score + rw_head_q = w_head_q + self.r_w_bias # qlen x bsz x n_head x d_head + AC = tf.einsum("ibnd,jbnd->ijbn", rw_head_q, w_head_k) # qlen x klen x bsz x n_head + + rr_head_q = w_head_q + self.r_r_bias + BD = tf.einsum("ibnd,jnd->ijbn", rr_head_q, r_head_k) # qlen x klen x bsz x n_head + BD = self._rel_shift(BD) + + # [qlen x klen x bsz x n_head] + attn_score = AC + BD + attn_score = attn_score * self.scale + + # compute attention probability + if attn_mask is not None: + attn_mask_t = attn_mask[:, :, None, None] + attn_mask_t = tf.cast(attn_mask_t, dtype=attn_score.dtype) + attn_score = attn_score * (1.0 - attn_mask_t) - 1e30 * attn_mask_t + + # [qlen x klen x bsz x n_head] + attn_prob = stable_softmax(attn_score, axis=1) + attn_prob = self.dropatt(attn_prob, training=training) + + # Mask heads if we want to + if head_mask is not None: + attn_prob = attn_prob * head_mask + + # compute attention vector + attn_vec = tf.einsum("ijbn,jbnd->ibnd", attn_prob, w_head_v) + + # [qlen x bsz x n_head x d_head] + attn_vec_sizes = shape_list(attn_vec) + attn_vec = tf.reshape(attn_vec, (attn_vec_sizes[0], attn_vec_sizes[1], self.n_head * self.d_head)) + + # linear projection + attn_out = self.o_net(attn_vec) + attn_out = self.drop(attn_out, training=training) + + if self.pre_lnorm: + # residual connection + outputs = [w + attn_out] + else: + # residual connection + layer normalization + outputs = [self.layer_norm(w + attn_out)] + + if output_attentions: + outputs.append(attn_prob) + + return outputs + + +class TFRelPartialLearnableDecoderLayer(keras.layers.Layer): + def __init__( + self, + n_head, + d_model, + d_head, + d_inner, + dropout, + dropatt=0.0, + pre_lnorm=False, + r_w_bias=None, + r_r_bias=None, + layer_norm_epsilon=1e-5, + init_std=0.02, + output_attentions=False, + **kwargs, + ): + super().__init__(**kwargs) + + self.dec_attn = TFRelPartialLearnableMultiHeadAttn( + n_head, + d_model, + d_head, + dropout, + dropatt=dropatt, + pre_lnorm=pre_lnorm, + r_w_bias=r_w_bias, + r_r_bias=r_r_bias, + init_std=init_std, + layer_norm_epsilon=layer_norm_epsilon, + output_attentions=output_attentions, + name="dec_attn", + ) + self.pos_ff = TFPositionwiseFF( + d_model, + d_inner, + dropout, + pre_lnorm=pre_lnorm, + init_std=init_std, + layer_norm_epsilon=layer_norm_epsilon, + name="pos_ff", + ) + + def call(self, dec_inp, r, dec_attn_mask, mems, head_mask, output_attentions, training=False): + attn_outputs = self.dec_attn(dec_inp, r, dec_attn_mask, mems, head_mask, output_attentions, training=training) + ff_output = self.pos_ff(attn_outputs[0], training=training) + + outputs = [ff_output] + attn_outputs[1:] + + return outputs + + +class TFTransfoEmbeddings(keras.layers.Layer): + def __init__(self, vocab_size, emb_size, init_std, **kwargs): + super().__init__(**kwargs) + + self.vocab_size = vocab_size + self.emb_size = emb_size + self.init_std = init_std + + def build(self, input_shape): + self.weight = self.add_weight( + shape=(self.vocab_size, self.emb_size), + initializer=get_initializer(self.init_std), + name="embeddings", + ) + + super().build(input_shape) + + def call(self, inputs): + return tf.gather(self.weight, inputs) + + +class TFAdaptiveEmbedding(keras.layers.Layer): + def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, init_std=0.02, sample_softmax=False, **kwargs): + super().__init__(**kwargs) + + self.n_token = n_token + self.d_embed = d_embed + self.init_std = init_std + + self.cutoffs = cutoffs + [n_token] + self.div_val = div_val + self.d_proj = d_proj + + self.emb_scale = d_proj**0.5 + + self.cutoff_ends = [0] + self.cutoffs + + self.emb_layers = [] + self.emb_projs = [] + + if div_val == 1: + raise NotImplementedError # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint + else: + for i in range(len(self.cutoffs)): + l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] + d_emb_i = d_embed // (div_val**i) + self.emb_layers.append( + TFTransfoEmbeddings( + r_idx - l_idx, + d_emb_i, + init_std, + name=f"emb_layers_._{i}", + ) + ) + + def build(self, input_shape): + for i in range(len(self.cutoffs)): + d_emb_i = self.d_embed // (self.div_val**i) + self.emb_projs.append( + self.add_weight( + shape=(d_emb_i, self.d_proj), + initializer=get_initializer(self.init_std), + trainable=True, + name=f"emb_projs_._{i}", + ) + ) + + super().build(input_shape) + + def call(self, inp): + if self.div_val == 1: + raise NotImplementedError # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint + else: + inp_flat = tf.reshape(inp, (-1,)) + emb_flat = tf.zeros([shape_list(inp_flat)[0], self.d_proj]) + for i in range(len(self.cutoffs)): + l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] + + mask_i = (inp_flat >= l_idx) & (inp_flat < r_idx) + + inp_i = tf.boolean_mask(inp_flat, mask_i) - l_idx + emb_i = self.emb_layers[i](inp_i) + emb_i = tf.einsum("id,de->ie", emb_i, self.emb_projs[i]) + + mask_idx = tf.where(mask_i) + scatter = tf.scatter_nd(mask_idx, emb_i, shape_list(emb_flat)) + emb_flat = tf.cast(emb_flat, dtype=scatter.dtype) + emb_flat += scatter + + embed_shape = shape_list(inp) + [self.d_proj] + embed = tf.reshape(emb_flat, embed_shape) + + embed *= self.emb_scale + + return embed + + +@keras_serializable +class TFTransfoXLMainLayer(keras.layers.Layer): + config_class = TransfoXLConfig + + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + + self.config = config + self.output_hidden_states = config.output_hidden_states + self.output_attentions = config.output_attentions + self.return_dict = config.use_return_dict + + self.n_token = config.vocab_size + + self.d_embed = config.d_embed + self.d_model = config.d_model + self.n_head = config.n_head + self.d_head = config.d_head + self.untie_r = config.untie_r + + self.word_emb = TFAdaptiveEmbedding( + config.vocab_size, + config.d_embed, + config.d_model, + config.cutoffs, + div_val=config.div_val, + init_std=config.init_std, + name="word_emb", + ) + + self.drop = keras.layers.Dropout(config.dropout) + + self.n_layer = config.n_layer + self.mem_len = config.mem_len + self.attn_type = config.attn_type + + self.layers = [] + if config.attn_type == 0: # the default attention + for i in range(config.n_layer): + self.layers.append( + TFRelPartialLearnableDecoderLayer( + config.n_head, + config.d_model, + config.d_head, + config.d_inner, + config.dropout, + dropatt=config.dropatt, + pre_lnorm=config.pre_lnorm, + r_w_bias=None if self.untie_r else self.r_w_bias, + r_r_bias=None if self.untie_r else self.r_r_bias, + layer_norm_epsilon=config.layer_norm_epsilon, + init_std=config.init_std, + output_attentions=self.output_attentions, + name=f"layers_._{i}", + ) + ) + else: # learnable embeddings and absolute embeddings + raise NotImplementedError # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint + + self.same_length = config.same_length + self.clamp_len = config.clamp_len + + if self.attn_type == 0: # default attention + self.pos_emb = TFPositionalEmbedding(self.d_model, name="pos_emb") + else: # learnable embeddings and absolute embeddings + raise NotImplementedError # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint + + def build(self, input_shape): + if not self.untie_r: + self.r_w_bias = self.add_weight( + shape=(self.n_head, self.d_head), initializer="zeros", trainable=True, name="r_w_bias" + ) + self.r_r_bias = self.add_weight( + shape=(self.n_head, self.d_head), initializer="zeros", trainable=True, name="r_r_bias" + ) + super().build(input_shape) + + def get_input_embeddings(self): + return self.word_emb + + def set_input_embeddings(self, value): + raise NotImplementedError + + def backward_compatible(self): + self.sample_softmax = -1 + + def reset_memory_length(self, mem_len): + self.mem_len = mem_len + + def _prune_heads(self, heads): + raise NotImplementedError + + def init_mems(self, bsz): + if self.mem_len > 0: + mems = [] + for i in range(self.n_layer): + empty = tf.zeros([self.mem_len, bsz, self.d_model]) + mems.append(empty) + + return mems + else: + return None + + def _update_mems(self, hids, mems, mlen, qlen): + # does not deal with None + if mems is None: + return None + + # mems is not None + assert len(hids) == len(mems), "len(hids) != len(mems)" + + # There are `mlen + qlen` steps that can be cached into mems + new_mems = [] + end_idx = mlen + tf.math.maximum(0, qlen) + beg_idx = tf.math.maximum(0, end_idx - tf.convert_to_tensor(self.mem_len)) + for i in range(len(hids)): + mems[i] = tf.cast(mems[i], dtype=hids[i].dtype) + cat = tf.concat([mems[i], hids[i]], axis=0) + tf.stop_gradient(cat) + new_mems.append(cat[beg_idx:end_idx]) + + return new_mems + + @unpack_inputs + def call( + self, + input_ids: TFModelInputType | None = None, + mems: List[tf.Tensor] | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: np.ndarray | tf.Tensor | None = None, + training: bool = False, + ): + # the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library + # so we transpose here from shape [bsz, len] to shape [len, bsz] + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_ids = tf.transpose(input_ids, perm=(1, 0)) + qlen, bsz = shape_list(input_ids) + elif inputs_embeds is not None: + inputs_embeds = tf.transpose(inputs_embeds, perm=(1, 0, 2)) + qlen, bsz = shape_list(inputs_embeds)[:2] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if mems is None: + mems = self.init_mems(bsz) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] (a head_mask for each layer) + # and head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head] + if head_mask is not None: + raise NotImplementedError + else: + head_mask = [None] * self.n_layer + + if inputs_embeds is not None: + word_emb = inputs_embeds + else: + word_emb = self.word_emb(input_ids) + + mlen = shape_list(mems[0])[0] if mems is not None else 0 + klen = mlen + qlen + + # Compute decoder attention mask + all_ones = tf.ones([qlen, klen], dtype=tf.int32) + upper_mask = 1 - tf.linalg.band_part(tf.ones([qlen, klen], dtype=tf.int32), -1, mlen) + if self.same_length: + mask_len = klen - self.mem_len + mask_shift_len = qlen - tf.nn.relu(mask_len) # Lazy clamping of negatives to zero + + # Use an indicator variable instead of a conditional to keep the compiler happy + lower_mask = tf.linalg.band_part(all_ones, -1, 0) - ( + tf.linalg.band_part(all_ones, mask_shift_len - 1, 0) * tf.cast(mask_shift_len != 0, tf.int32) + ) + dec_attn_mask = upper_mask + lower_mask + else: + dec_attn_mask = upper_mask + + hids = [] + attentions = [] if output_attentions else None + if self.attn_type == 0: # default + pos_seq = tf.range(klen - 1, -1, -1.0) + if self.clamp_len > 0: + pos_seq = tf.minimum(pos_seq, self.clamp_len) + pos_emb = self.pos_emb(pos_seq) + + core_out = self.drop(word_emb, training=training) + pos_emb = self.drop(pos_emb, training=training) + + for i, layer in enumerate(self.layers): + hids.append(core_out) + mems_i = None if mems is None else mems[i] + layer_outputs = layer( + core_out, + pos_emb, + dec_attn_mask, + mems_i, + head_mask[i], + output_attentions, + training=training, + ) + core_out = layer_outputs[0] + if output_attentions: + attentions.append(layer_outputs[1]) + else: # learnable embeddings and absolute embeddings + raise NotImplementedError # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint + + core_out = self.drop(core_out, training=training) + + new_mems = self._update_mems(hids, mems, mlen, qlen) + + # We transpose back here to shape [bsz, len, hidden_dim] + core_out = tf.transpose(core_out, perm=(1, 0, 2)) + + if output_hidden_states: + # Transpose to library standard shape [bsz, len, hidden_dim] and add last layer + hids = tuple(tf.transpose(t, perm=(1, 0, 2)) for t in hids) + hids = hids + (core_out,) + else: + hids = None + if output_attentions: + # Transpose to library standard shape [bsz, n_heads, query_seq_len, key_seq_len] + attentions = tuple(tf.transpose(t, perm=(2, 3, 0, 1)) for t in attentions) + + if not return_dict: + return tuple(v for v in [core_out, new_mems, hids, attentions] if v is not None) + + return TFTransfoXLModelOutput( + last_hidden_state=core_out, + mems=new_mems, + hidden_states=hids, + attentions=attentions, + ) + + +class TFTransfoXLPreTrainedModel(TFPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = TransfoXLConfig + base_model_prefix = "transformer" + + +@dataclass +class TFTransfoXLModelOutput(ModelOutput): + """ + Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding). + + Args: + last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + mems (`List[tf.Tensor]` of length `config.n_layers`): + Contains pre-computed hidden-states (key and values in the attention blocks). Can be used (see `mems` + input) to speed up sequential decoding. The token ids which have their past given to this model should not + be passed as input ids as they have already been computed. + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + last_hidden_state: Optional[tf.Tensor] = None + mems: List[tf.Tensor] = None + hidden_states: Tuple[tf.Tensor] | None = None + attentions: Tuple[tf.Tensor] | None = None + + +@dataclass +class TFTransfoXLLMHeadModelOutput(ModelOutput): + """ + Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding). + + Args: + losses (`tf.Tensor` of shape *(batch_size, sequence_length-1)*, *optional*, returned when `labels` is provided): + Language modeling losses (not reduced). + prediction_scores (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token after SoftMax). + mems (`List[tf.Tensor]` of length `config.n_layers`): + Contains pre-computed hidden-states (key and values in the attention blocks). Can be used (see `mems` + input) to speed up sequential decoding. The token ids which have their past given to this model should not + be passed as input ids as they have already been computed. + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + prediction_scores: Optional[tf.Tensor] = None + mems: List[tf.Tensor] = None + hidden_states: Tuple[tf.Tensor] | None = None + attentions: Tuple[tf.Tensor] | None = None + + +@dataclass +class TFTransfoXLSequenceClassifierOutputWithPast(ModelOutput): + """ + Base class for outputs of sentence classification models. + + Args: + loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + mems (`List[tf.Tensor]` of length `config.n_layers`): + Contains pre-computed hidden-states (key and values in the attention blocks). Can be used (see `mems` + input) to speed up sequential decoding. The token ids which have their past given to this model should not + be passed as input ids as they have already been computed. + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: tf.Tensor | None = None + logits: Optional[tf.Tensor] = None + mems: List[tf.Tensor] = None + hidden_states: Tuple[tf.Tensor] | None = None + attentions: Tuple[tf.Tensor] | None = None + + +TRANSFO_XL_START_DOCSTRING = r""" + + This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it + as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and + behavior. + + + + TensorFlow models and layers in `transformers` accept two formats as input: + + - having all inputs as keyword arguments (like PyTorch models), or + - having all inputs as a list, tuple or dict in the first positional argument. + + The reason the second format is supported is that Keras methods prefer this format when passing inputs to models + and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just + pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second + format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with + the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first + positional argument: + + - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` + - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: + `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` + - a dictionary with one or several input Tensors associated to the input names given in the docstring: + `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` + + Note that when creating models and layers with + [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry + about any of this, as you can just pass inputs like you would to any other Python function! + + + + Parameters: + config ([`TransfoXLConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +TRANSFO_XL_INPUTS_DOCSTRING = r""" + Args: + input_ids (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and + [`PreTrainedTokenizer.encode`] for details. + + [What are input IDs?](../glossary#input-ids) + mems (`List[tf.Tensor]` of length `config.n_layers`): + Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model (see + `mems` output below). Can be used to speed up sequential decoding. The token ids which have their mems + given to this model should not be passed as `input_ids` as they have already been computed. + head_mask (`tf.Tensor` or `Numpy array` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + inputs_embeds (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the + config will be used instead. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. This argument can be used only in eager mode, in graph mode the value in the config will be + used instead. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in + eager mode, in graph mode the value will always be set to True. + training (`bool`, *optional*, defaults to `False`): + Whether or not to use the model in training mode (some modules like dropout modules have different + behaviors between training and evaluation). +""" + + +@add_start_docstrings( + "The bare Bert Model transformer outputting raw hidden-states without any specific head on top.", + TRANSFO_XL_START_DOCSTRING, +) +class TFTransfoXLModel(TFTransfoXLPreTrainedModel): + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.transformer = TFTransfoXLMainLayer(config, name="transformer") + + @unpack_inputs + @add_start_docstrings_to_model_forward(TRANSFO_XL_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFTransfoXLModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + mems: List[tf.Tensor] | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + training: bool = False, + ) -> TFTransfoXLModelOutput | Tuple[tf.Tensor]: + outputs = self.transformer( + input_ids=input_ids, + mems=mems, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + return outputs + + +@add_start_docstrings( + """ + The Transformer-XL Model with a language modeling head on top (adaptive softmax with weights tied to the adaptive + input embeddings) + """, + TRANSFO_XL_START_DOCSTRING, +) +class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.transformer = TFTransfoXLMainLayer(config, name="transformer") + self.sample_softmax = config.sample_softmax + assert self.sample_softmax <= 0, ( + "Sampling from the softmax is not implemented yet. Please look at issue: #3310:" + " https://github.com/huggingface/transformers/issues/3310" + ) + + self.crit = TFAdaptiveSoftmaxMask( + config.vocab_size, config.d_embed, config.d_model, config.cutoffs, div_val=config.div_val, name="crit" + ) + + def _resize_token_embeddings(self, new_num_tokens): + raise NotImplementedError() + + def get_output_embeddings(self): + """Double-check if you are using adaptive softmax.""" + if len(self.crit.out_layers) > 0: + return self.crit.out_layers[-1] + return None + + def reset_memory_length(self, mem_len): + self.transformer.reset_memory_length(mem_len) + + def init_mems(self, bsz): + return self.transformer.init_mems(bsz) + + @unpack_inputs + @add_start_docstrings_to_model_forward(TRANSFO_XL_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFTransfoXLLMHeadModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + mems: List[tf.Tensor] | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + labels: np.ndarray | tf.Tensor | None = None, + training: bool = False, + ) -> TFTransfoXLLMHeadModelOutput | Tuple[tf.Tensor]: + if input_ids is not None: + bsz, tgt_len = shape_list(input_ids)[:2] + else: + bsz, tgt_len = shape_list(inputs_embeds)[:2] + + transformer_outputs = self.transformer( + input_ids, + mems, + head_mask, + inputs_embeds, + output_attentions, + output_hidden_states, + return_dict, + training=training, + ) + + last_hidden = transformer_outputs[0] + pred_hid = last_hidden[:, -tgt_len:] + + softmax_output = self.crit(pred_hid, labels, training=training) + prediction_scores = softmax_output if labels is None else () + + if not return_dict: + return (prediction_scores,) + transformer_outputs[1:] + + return TFTransfoXLLMHeadModelOutput( + prediction_scores=prediction_scores, + mems=transformer_outputs.mems, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **model_kwargs): + inputs = {} + + # if past is defined in model kwargs then use it for faster decoding + if past_key_values: + input_ids = tf.expand_dims(input_ids[:, -1], axis=-1) + else: + input_ids = input_ids + + return inputs + + # Adapted from the torch tie_weights function + def tf_to_pt_weight_rename(self, tf_weight): + if self.config.tie_word_embeddings and "crit.out_layers" in tf_weight: + return tf_weight, tf_weight.replace("crit.out_layers", "transformer.word_emb.emb_layers") + elif self.config.tie_projs and "crit.out_projs" in tf_weight: + for i, tie_proj in enumerate(self.config.tie_projs): + if tie_proj and self.config.div_val == 1 and self.config.d_model != self.config.d_embed: + # self.crit.out_projs[i] = self.transformer.word_emb.emb_projs[0] + return tf_weight, tf_weight.replace(f"crit.out_projs.{i}", "transformer.word_emb.emb_projs.0") + elif tie_proj and self.config.div_val != 1: + # self.crit.out_projs[i] = self.transformer.word_emb.emb_projs[i] + return tf_weight, tf_weight.replace("crit.out_projs", "transformer.word_emb.emb_projs") + else: + return (tf_weight,) + + +@add_start_docstrings( + """ + The Transfo XL Model transformer with a sequence classification head on top (linear layer). + + [`TFTransfoXLForSequenceClassification`] uses the last token in order to do the classification, as other causal + models (e.g. GPT-1,GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + TRANSFO_XL_START_DOCSTRING, +) +class TFTransfoXLForSequenceClassification(TFTransfoXLPreTrainedModel, TFSequenceClassificationLoss): + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.num_labels = config.num_labels + self.score = keras.layers.Dense( + config.num_labels, + kernel_initializer=get_initializer(config.init_range), + name="score", + use_bias=False, + ) + self.transformer = TFTransfoXLMainLayer(config, name="transformer") + + def get_output_embeddings(self): + # Remove after transformers v4.32. Fix this model's `test_model_common_attributes` test too. + logger.warning( + "Sequence classification models do not have output embeddings. `.get_output_embeddings` will be removed " + "in transformers v4.32." + ) + return self.transformer.word_emb + + @unpack_inputs + @add_start_docstrings_to_model_forward(TRANSFO_XL_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFTransfoXLSequenceClassifierOutputWithPast, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + mems: List[tf.Tensor] | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: np.ndarray | tf.Tensor | None = None, + training: Optional[bool] = False, + ) -> Union[Tuple, TFTransfoXLSequenceClassifierOutputWithPast]: + r""" + labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the cross entropy classification loss. Indices should be in `[0, ..., + config.vocab_size - 1]`. + """ + transformer_outputs = self.transformer( + input_ids=input_ids, + mems=mems, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + in_logits = None + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = ( + tf.argmax(tf.cast(tf.math.equal(input_ids, self.config.pad_token_id), input_ids.dtype), axis=-1) + - 1 + ) + sequence_lengths = tf.where(sequence_lengths >= 0, sequence_lengths, input_ids.shape[-1] - 1) + in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1) + else: + sequence_lengths = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + loss = None + + if labels is not None: + if input_ids is not None: + batch_size, sequence_length = shape_list(input_ids)[:2] + else: + batch_size, sequence_length = shape_list(inputs_embeds)[:2] + assert self.config.pad_token_id is not None or batch_size == 1, ( + "Cannot handle batch sizes > 1 if no padding token is defined." + ) + + if not tf.is_tensor(sequence_lengths): + in_logits = logits[0:batch_size, sequence_lengths] + + loss = self.hf_compute_loss(tf.reshape(labels, [-1, 1]), tf.reshape(in_logits, [-1, self.num_labels])) + + pooled_logits = in_logits if in_logits is not None else logits + + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return TFTransfoXLSequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + mems=transformer_outputs.mems, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +__all__ = [ + "TFAdaptiveEmbedding", + "TFTransfoXLForSequenceClassification", + "TFTransfoXLLMHeadModel", + "TFTransfoXLMainLayer", + "TFTransfoXLModel", + "TFTransfoXLPreTrainedModel", +] diff --git a/docs/transformers/build/lib/transformers/models/deprecated/transfo_xl/modeling_tf_transfo_xl_utilities.py b/docs/transformers/build/lib/transformers/models/deprecated/transfo_xl/modeling_tf_transfo_xl_utilities.py new file mode 100644 index 0000000000000000000000000000000000000000..48205e06fb20a473959544db4971dff0d3e58cbf --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/transfo_xl/modeling_tf_transfo_xl_utilities.py @@ -0,0 +1,178 @@ +# coding=utf-8 +# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +A TF 2.0 Adaptive Softmax for Transformer XL model. +""" + +import tensorflow as tf + +from ....modeling_tf_utils import keras +from ....tf_utils import shape_list + + +class TFAdaptiveSoftmaxMask(keras.layers.Layer): + def __init__(self, vocab_size, d_embed, d_proj, cutoffs, div_val=1, keep_order=False, **kwargs): + super().__init__(**kwargs) + + self.vocab_size = vocab_size + self.d_embed = d_embed + self.d_proj = d_proj + + self.cutoffs = cutoffs + [vocab_size] + self.cutoff_ends = [0] + self.cutoffs + self.div_val = div_val + + self.shortlist_size = self.cutoffs[0] + self.n_clusters = len(self.cutoffs) - 1 + self.head_size = self.shortlist_size + self.n_clusters + self.keep_order = keep_order + + self.out_layers = [] + self.out_projs = [] + + def build(self, input_shape): + if self.n_clusters > 0: + self.cluster_weight = self.add_weight( + shape=(self.n_clusters, self.d_embed), initializer="zeros", trainable=True, name="cluster_weight" + ) + self.cluster_bias = self.add_weight( + shape=(self.n_clusters,), initializer="zeros", trainable=True, name="cluster_bias" + ) + + if self.div_val == 1: + for i in range(len(self.cutoffs)): + if self.d_proj != self.d_embed: + weight = self.add_weight( + shape=(self.d_embed, self.d_proj), + initializer="zeros", + trainable=True, + name=f"out_projs_._{i}", + ) + self.out_projs.append(weight) + else: + self.out_projs.append(None) + weight = self.add_weight( + shape=(self.vocab_size, self.d_embed), + initializer="zeros", + trainable=True, + name=f"out_layers_._{i}_._weight", + ) + bias = self.add_weight( + shape=(self.vocab_size,), + initializer="zeros", + trainable=True, + name=f"out_layers_._{i}_._bias", + ) + self.out_layers.append((weight, bias)) + else: + for i in range(len(self.cutoffs)): + l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] + d_emb_i = self.d_embed // (self.div_val**i) + + weight = self.add_weight( + shape=(d_emb_i, self.d_proj), initializer="zeros", trainable=True, name=f"out_projs_._{i}" + ) + self.out_projs.append(weight) + weight = self.add_weight( + shape=(r_idx - l_idx, d_emb_i), + initializer="zeros", + trainable=True, + name=f"out_layers_._{i}_._weight", + ) + bias = self.add_weight( + shape=(r_idx - l_idx,), + initializer="zeros", + trainable=True, + name=f"out_layers_._{i}_._bias", + ) + self.out_layers.append((weight, bias)) + super().build(input_shape) + + @staticmethod + def _logit(x, W, b, proj=None): + y = x + if proj is not None: + y = tf.einsum("ibd,ed->ibe", y, proj) + return tf.einsum("ibd,nd->ibn", y, W) + b + + @staticmethod + def _gather_logprob(logprob, target): + lp_size = shape_list(logprob) + r = tf.range(lp_size[0], dtype=target.dtype) + idx = tf.stack([r, target], 1) + return tf.gather_nd(logprob, idx) + + def call(self, hidden, target, return_mean=True, training=False): + head_logprob = 0 + if self.n_clusters == 0: + output = self._logit(hidden, self.out_layers[0][0], self.out_layers[0][1], self.out_projs[0]) + if target is not None: + loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=target, logits=output) + out = tf.nn.log_softmax(output, axis=-1) + else: + hidden_sizes = shape_list(hidden) + out = [] + loss = tf.zeros(hidden_sizes[:2]) + for i in range(len(self.cutoffs)): + l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] + if target is not None: + mask = (target >= l_idx) & (target < r_idx) + mask_idx = tf.where(mask) + cur_target = tf.boolean_mask(target, mask) - l_idx + + if self.div_val == 1: + cur_W = self.out_layers[0][0][l_idx:r_idx] + cur_b = self.out_layers[0][1][l_idx:r_idx] + else: + cur_W = self.out_layers[i][0] + cur_b = self.out_layers[i][1] + + if i == 0: + cur_W = tf.concat([cur_W, self.cluster_weight], 0) + cur_b = tf.concat([cur_b, self.cluster_bias], 0) + + head_logit = self._logit(hidden, cur_W, cur_b, self.out_projs[0]) + head_logprob = tf.nn.log_softmax(head_logit) + out.append(head_logprob[..., : self.cutoffs[0]]) + if target is not None: + cur_head_logprob = tf.boolean_mask(head_logprob, mask) + cur_logprob = self._gather_logprob(cur_head_logprob, cur_target) + else: + tail_logit = self._logit(hidden, cur_W, cur_b, self.out_projs[i]) + tail_logprob = tf.nn.log_softmax(tail_logit) + cluster_prob_idx = self.cutoffs[0] + i - 1 # No probability for the head cluster + logprob_i = head_logprob[..., cluster_prob_idx, None] + tail_logprob + out.append(logprob_i) + if target is not None: + cur_head_logprob = tf.boolean_mask(head_logprob, mask) + cur_tail_logprob = tf.boolean_mask(tail_logprob, mask) + cur_logprob = self._gather_logprob(cur_tail_logprob, cur_target) + cur_logprob += cur_head_logprob[:, self.cutoff_ends[1] + i - 1] + if target is not None: + loss += tf.scatter_nd(mask_idx, -cur_logprob, shape_list(loss)) + out = tf.concat(out, axis=-1) + + if target is not None: + if return_mean: + loss = tf.reduce_mean(loss) + # Add the training-time loss value to the layer using `self.add_loss()`. + self.add_loss(loss) + + # Log the loss as a metric (we could log arbitrary metrics, + # including different metrics for training and inference. + self.add_metric(loss, name=self.name, aggregation="mean" if return_mean else "") + + return out diff --git a/docs/transformers/build/lib/transformers/models/deprecated/transfo_xl/modeling_transfo_xl.py b/docs/transformers/build/lib/transformers/models/deprecated/transfo_xl/modeling_transfo_xl.py new file mode 100644 index 0000000000000000000000000000000000000000..cf843850c05c411cec507971aeb8557a3038591b --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/transfo_xl/modeling_transfo_xl.py @@ -0,0 +1,1303 @@ +# coding=utf-8 +# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +PyTorch Transformer XL model. Adapted from https://github.com/kimiyoung/transformer-xl. In particular +https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/mem_transformer.py +""" + +import warnings +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ....modeling_utils import PreTrainedModel +from ....utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, +) +from .configuration_transfo_xl import TransfoXLConfig +from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "transfo-xl/transfo-xl-wt103" +_CONFIG_FOR_DOC = "TransfoXLConfig" + + +def build_tf_to_pytorch_map(model, config): + """ + A map of modules from TF to PyTorch. This time I use a map to keep the PyTorch model as identical to the original + PyTorch model as possible. + """ + tf_to_pt_map = {} + + if hasattr(model, "transformer"): + # We are loading in a TransfoXLLMHeadModel => we will load also the Adaptive Softmax + tf_to_pt_map.update( + { + "transformer/adaptive_softmax/cutoff_0/cluster_W": model.crit.cluster_weight, + "transformer/adaptive_softmax/cutoff_0/cluster_b": model.crit.cluster_bias, + } + ) + for i, (out_l, proj_l, tie_proj) in enumerate( + zip(model.crit.out_layers, model.crit.out_projs, config.tie_projs) + ): + layer_str = f"transformer/adaptive_softmax/cutoff_{i}/" + if config.tie_word_embeddings: + tf_to_pt_map.update({layer_str + "b": out_l.bias}) + else: + raise NotImplementedError + # I don't think this is implemented in the TF code + tf_to_pt_map.update({layer_str + "lookup_table": out_l.weight, layer_str + "b": out_l.bias}) + if not tie_proj: + tf_to_pt_map.update({layer_str + "proj": proj_l}) + # Now load the rest of the transformer + model = model.transformer + + # Embeddings + for i, (embed_l, proj_l) in enumerate(zip(model.word_emb.emb_layers, model.word_emb.emb_projs)): + layer_str = f"transformer/adaptive_embed/cutoff_{i}/" + tf_to_pt_map.update({layer_str + "lookup_table": embed_l.weight, layer_str + "proj_W": proj_l}) + + # Transformer blocks + for i, b in enumerate(model.layers): + layer_str = f"transformer/layer_{i}/" + tf_to_pt_map.update( + { + layer_str + "rel_attn/LayerNorm/gamma": b.dec_attn.layer_norm.weight, + layer_str + "rel_attn/LayerNorm/beta": b.dec_attn.layer_norm.bias, + layer_str + "rel_attn/o/kernel": b.dec_attn.o_net.weight, + layer_str + "rel_attn/qkv/kernel": b.dec_attn.qkv_net.weight, + layer_str + "rel_attn/r/kernel": b.dec_attn.r_net.weight, + layer_str + "ff/LayerNorm/gamma": b.pos_ff.layer_norm.weight, + layer_str + "ff/LayerNorm/beta": b.pos_ff.layer_norm.bias, + layer_str + "ff/layer_1/kernel": b.pos_ff.CoreNet[0].weight, + layer_str + "ff/layer_1/bias": b.pos_ff.CoreNet[0].bias, + layer_str + "ff/layer_2/kernel": b.pos_ff.CoreNet[3].weight, + layer_str + "ff/layer_2/bias": b.pos_ff.CoreNet[3].bias, + } + ) + + # Relative positioning biases + if config.untie_r: + r_r_list = [] + r_w_list = [] + for b in model.layers: + r_r_list.append(b.dec_attn.r_r_bias) + r_w_list.append(b.dec_attn.r_w_bias) + else: + r_r_list = [model.r_r_bias] + r_w_list = [model.r_w_bias] + tf_to_pt_map.update({"transformer/r_r_bias": r_r_list, "transformer/r_w_bias": r_w_list}) + return tf_to_pt_map + + +def load_tf_weights_in_transfo_xl(model, config, tf_path): + """Load tf checkpoints in a pytorch model""" + try: + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + # Build TF to PyTorch weights loading map + tf_to_pt_map = build_tf_to_pytorch_map(model, config) + + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + tf_weights = {} + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + tf_weights[name] = array + + for name, pointer in tf_to_pt_map.items(): + assert name in tf_weights + array = tf_weights[name] + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if "kernel" in name or "proj" in name: + array = np.transpose(array) + if ("r_r_bias" in name or "r_w_bias" in name) and len(pointer) > 1: + # Here we will split the TF weights + assert len(pointer) == array.shape[0] + for i, p_i in enumerate(pointer): + arr_i = array[i, ...] + try: + assert p_i.shape == arr_i.shape + except AssertionError as e: + e.args += (p_i.shape, arr_i.shape) + raise + logger.info(f"Initialize PyTorch weight {name} for layer {i}") + p_i.data = torch.from_numpy(arr_i) + else: + try: + assert pointer.shape == array.shape, ( + f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" + ) + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array) + tf_weights.pop(name, None) + tf_weights.pop(name + "/Adam", None) + tf_weights.pop(name + "/Adam_1", None) + + logger.info(f"Weights not copied to PyTorch model: {', '.join(tf_weights.keys())}") + return model + + +class PositionalEmbedding(nn.Module): + def __init__(self, demb): + super().__init__() + + self.demb = demb + + inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb)) + self.register_buffer("inv_freq", inv_freq) + + def forward(self, pos_seq, bsz=None): + sinusoid_inp = torch.outer(pos_seq, self.inv_freq) + pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1) + + if bsz is not None: + return pos_emb[:, None, :].expand(-1, bsz, -1) + else: + return pos_emb[:, None, :] + + +class PositionwiseFF(nn.Module): + def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, layer_norm_epsilon=1e-5): + super().__init__() + + self.d_model = d_model + self.d_inner = d_inner + self.dropout = dropout + + self.CoreNet = nn.Sequential( + nn.Linear(d_model, d_inner), + nn.ReLU(inplace=True), + nn.Dropout(dropout), + nn.Linear(d_inner, d_model), + nn.Dropout(dropout), + ) + + self.layer_norm = nn.LayerNorm(d_model, eps=layer_norm_epsilon) + + self.pre_lnorm = pre_lnorm + + def forward(self, inp): + if self.pre_lnorm: + # layer normalization + positionwise feed-forward + core_out = self.CoreNet(self.layer_norm(inp)) + + # residual connection + output = core_out + inp + else: + # positionwise feed-forward + core_out = self.CoreNet(inp) + + # residual connection + layer normalization + output = self.layer_norm(inp + core_out) + + return output + + +class RelPartialLearnableMultiHeadAttn(nn.Module): + def __init__( + self, + n_head, + d_model, + d_head, + dropout, + dropatt=0, + pre_lnorm=False, + r_r_bias=None, + r_w_bias=None, + layer_norm_epsilon=1e-5, + ): + super().__init__() + + self.n_head = n_head + self.d_model = d_model + self.d_head = d_head + self.dropout = dropout + + self.qkv_net = nn.Linear(d_model, 3 * n_head * d_head, bias=False) + + self.drop = nn.Dropout(dropout) + self.dropatt = nn.Dropout(dropatt) + self.o_net = nn.Linear(n_head * d_head, d_model, bias=False) + + self.layer_norm = nn.LayerNorm(d_model, eps=layer_norm_epsilon) + + self.scale = 1 / (d_head**0.5) + + self.pre_lnorm = pre_lnorm + + if r_r_bias is None or r_w_bias is None: # Biases are not shared + self.r_r_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head)) + self.r_w_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head)) + else: + self.r_r_bias = r_r_bias + self.r_w_bias = r_w_bias + + self.r_net = nn.Linear(self.d_model, self.n_head * self.d_head, bias=False) + + def _rel_shift(self, x): + zero_pad_shape = (x.size(0), 1) + x.size()[2:] + zero_pad = torch.zeros(zero_pad_shape, device=x.device, dtype=x.dtype) + x_padded = torch.cat([zero_pad, x], dim=1) + + x_padded_shape = (x.size(1) + 1, x.size(0)) + x.size()[2:] + x_padded = x_padded.view(*x_padded_shape) + + x = x_padded[1:].view_as(x) + + return x + + def forward(self, w, r, attn_mask=None, mems=None, head_mask=None, output_attentions=False): + qlen, rlen, bsz = w.size(0), r.size(0), w.size(1) + + if mems is not None: + cat = torch.cat([mems, w], 0) + if self.pre_lnorm: + w_heads = self.qkv_net(self.layer_norm(cat)) + else: + w_heads = self.qkv_net(cat) + r_head_k = self.r_net(r) + + w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1) + w_head_q = w_head_q[-qlen:] + else: + if self.pre_lnorm: + w_heads = self.qkv_net(self.layer_norm(w)) + else: + w_heads = self.qkv_net(w) + r_head_k = self.r_net(r) + + w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1) + + klen = w_head_k.size(0) + + w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head + w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head + w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head + + r_head_k = r_head_k.view(rlen, self.n_head, self.d_head) # qlen x n_head x d_head + + # compute attention score + rw_head_q = w_head_q + self.r_w_bias # qlen x bsz x n_head x d_head + AC = torch.einsum("ibnd,jbnd->ijbn", (rw_head_q, w_head_k)) # qlen x klen x bsz x n_head + + rr_head_q = w_head_q + self.r_r_bias + BD = torch.einsum("ibnd,jnd->ijbn", (rr_head_q, r_head_k)) # qlen x klen x bsz x n_head + BD = self._rel_shift(BD) + + # [qlen x klen x bsz x n_head] + attn_score = AC + BD + attn_score.mul_(self.scale) + + mask_value = torch.finfo(attn_score.dtype).min + + # compute attention probability + if attn_mask is not None and torch.sum(attn_mask).item(): + attn_mask = attn_mask == 1 # Switch to bool + if attn_mask.dim() == 2: + attn_score = ( + attn_score.float().masked_fill(attn_mask[None, :, :, None], mask_value).type_as(attn_score) + ) + elif attn_mask.dim() == 3: + attn_score = attn_score.float().masked_fill(attn_mask[:, :, :, None], mask_value).type_as(attn_score) + + # [qlen x klen x bsz x n_head] + attn_prob = nn.functional.softmax(attn_score, dim=1) + attn_prob = self.dropatt(attn_prob) + + # Mask heads if we want to + if head_mask is not None: + attn_prob = attn_prob * head_mask + + # compute attention vector + attn_vec = torch.einsum("ijbn,jbnd->ibnd", (attn_prob, w_head_v)) + + # [qlen x bsz x n_head x d_head] + attn_vec = attn_vec.contiguous().view(attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head) + + # linear projection + attn_out = self.o_net(attn_vec) + attn_out = self.drop(attn_out) + + if self.pre_lnorm: + # residual connection + outputs = [w + attn_out] + else: + # residual connection + layer normalization + outputs = [self.layer_norm(w + attn_out)] + + if output_attentions: + outputs.append(attn_prob) + + return outputs + + +class RelPartialLearnableDecoderLayer(nn.Module): + def __init__(self, n_head, d_model, d_head, d_inner, dropout, layer_norm_epsilon=1e-5, **kwargs): + super().__init__() + + self.dec_attn = RelPartialLearnableMultiHeadAttn( + n_head, d_model, d_head, dropout, layer_norm_epsilon=layer_norm_epsilon, **kwargs + ) + self.pos_ff = PositionwiseFF( + d_model, d_inner, dropout, pre_lnorm=kwargs.get("pre_lnorm"), layer_norm_epsilon=layer_norm_epsilon + ) + + def forward(self, dec_inp, r, dec_attn_mask=None, mems=None, head_mask=None, output_attentions=False): + attn_outputs = self.dec_attn( + dec_inp, + r, + attn_mask=dec_attn_mask, + mems=mems, + head_mask=head_mask, + output_attentions=output_attentions, + ) + ff_output = self.pos_ff(attn_outputs[0]) + + outputs = [ff_output] + attn_outputs[1:] + + return outputs + + +class AdaptiveEmbedding(nn.Module): + def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, sample_softmax=False): + super().__init__() + + self.n_token = n_token + self.d_embed = d_embed + + self.cutoffs = cutoffs + [n_token] + self.div_val = div_val + self.d_proj = d_proj + + self.emb_scale = d_proj**0.5 + + self.cutoff_ends = [0] + self.cutoffs + + self.emb_layers = nn.ModuleList() + self.emb_projs = nn.ParameterList() + if div_val == 1: + self.emb_layers.append(nn.Embedding(n_token, d_embed, sparse=sample_softmax > 0)) + if d_proj != d_embed: + self.emb_projs.append(nn.Parameter(torch.FloatTensor(d_proj, d_embed))) + else: + for i in range(len(self.cutoffs)): + l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] + d_emb_i = d_embed // (div_val**i) + self.emb_layers.append(nn.Embedding(r_idx - l_idx, d_emb_i)) + self.emb_projs.append(nn.Parameter(torch.FloatTensor(d_proj, d_emb_i))) + + def forward(self, inp): + if self.div_val == 1: + embed = self.emb_layers[0](inp) + if self.d_proj != self.d_embed: + embed = nn.functional.linear(embed, self.emb_projs[0]) + else: + param = next(self.parameters()) + inp_flat = inp.view(-1) + emb_flat = torch.zeros([inp_flat.size(0), self.d_proj], dtype=param.dtype, device=param.device) + for i in range(len(self.cutoffs)): + l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] + + mask_i = (inp_flat >= l_idx) & (inp_flat < r_idx) + indices_i = mask_i.nonzero().squeeze() + + if indices_i.numel() == 0: + continue + + inp_i = inp_flat.index_select(0, indices_i) - l_idx + emb_i = self.emb_layers[i](inp_i) + emb_i = nn.functional.linear(emb_i, self.emb_projs[i]) + + emb_flat.index_copy_(0, indices_i, emb_i) + + embed_shape = inp.size() + (self.d_proj,) + embed = emb_flat.view(embed_shape) + + embed.mul_(self.emb_scale) + + return embed + + +class TransfoXLPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = TransfoXLConfig + load_tf_weights = load_tf_weights_in_transfo_xl + base_model_prefix = "transformer" + + def _init_weight(self, weight): + if self.config.init == "uniform": + nn.init.uniform_(weight, -self.config.init_range, self.config.init_range) + elif self.config.init == "normal": + nn.init.normal_(weight, 0.0, self.config.init_std) + + def _init_bias(self, bias): + nn.init.constant_(bias, 0.0) + + def _init_weights(self, m): + """Initialize the weights.""" + classname = m.__class__.__name__ + if classname.find("Linear") != -1: + if hasattr(m, "weight") and m.weight is not None: + self._init_weight(m.weight) + if hasattr(m, "bias") and m.bias is not None: + self._init_bias(m.bias) + elif classname.find("AdaptiveEmbedding") != -1: + if hasattr(m, "emb_projs"): + for i in range(len(m.emb_projs)): + if m.emb_projs[i] is not None: + nn.init.normal_(m.emb_projs[i], 0.0, self.config.proj_init_std) + elif classname.find("Embedding") != -1: + if hasattr(m, "weight"): + self._init_weight(m.weight) + elif classname.find("ProjectedAdaptiveLogSoftmax") != -1: + if hasattr(m, "cluster_weight") and m.cluster_weight is not None: + self._init_weight(m.cluster_weight) + if hasattr(m, "cluster_bias") and m.cluster_bias is not None: + self._init_bias(m.cluster_bias) + if hasattr(m, "out_projs"): + for i in range(len(m.out_projs)): + if m.out_projs[i] is not None: + nn.init.normal_(m.out_projs[i], 0.0, self.config.proj_init_std) + elif classname.find("LayerNorm") != -1: + if hasattr(m, "weight"): + nn.init.normal_(m.weight, 1.0, self.config.init_std) + if hasattr(m, "bias") and m.bias is not None: + self._init_bias(m.bias) + else: + if hasattr(m, "r_emb"): + self._init_weight(m.r_emb) + if hasattr(m, "r_w_bias"): + self._init_weight(m.r_w_bias) + if hasattr(m, "r_r_bias"): + self._init_weight(m.r_r_bias) + if hasattr(m, "r_bias"): + self._init_bias(m.r_bias) + + def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, layer: Optional[int] = -1): + """ + Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size. Take care of tying + weights embeddings afterwards if the model class has a *tie_weights()* method. + + Arguments: + new_num_tokens: (*optional*) int: + New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at + the end. Reducing the size will remove vectors from the end. If not provided or None: does nothing and + just returns a pointer to the input tokens `torch.nn.Embeddings` Module of the model. + layer: (*optional*) int: + Layer of the *AdaptiveEmbedding* where the resizing should be done. Per default the last layer will be + resized. Be aware that when resizing other than the last layer, you have to ensure that the new + token(s) in the tokenizer are at the corresponding position. + + Return: `torch.nn.Embeddings` Pointer to the input tokens Embeddings Module of the model + """ + base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed + + if new_num_tokens is None: + return self.get_input_embeddings() + + new_num_tokens_layer, layer = self._get_new_num_tokens_layer(new_num_tokens, layer) + assert new_num_tokens_layer > 0, "The size of the new embedding layer cannot be 0 or less" + model_embeds = base_model._resize_token_embeddings(new_num_tokens_layer, layer) + + # Update base model and current model config + self.config.vocab_size = new_num_tokens + base_model.vocab_size = new_num_tokens + base_model.n_token = new_num_tokens + + new_embedding_shapes = self._get_embedding_shapes() + self._resize_cutoffs(new_num_tokens, new_num_tokens_layer, new_embedding_shapes, layer) + + # Tie weights again if needed + self.tie_weights() + + return model_embeds + + def _get_new_num_tokens_layer(self, new_num_tokens, layer): + embeddings = self.get_input_embeddings() + if layer == -1: + layer = len(embeddings.emb_layers) - 1 + assert 0 <= layer <= len(embeddings.emb_layers) - 1 + + new_num_tokens_layer = ( + new_num_tokens + - sum([emb.weight.shape[0] for emb in embeddings.emb_layers[:layer]]) + - sum([emb.weight.shape[0] for emb in embeddings.emb_layers[layer + 1 :]]) + ) + return new_num_tokens_layer, layer + + def _get_embedding_shapes(self): + embeddings = self.get_input_embeddings() + return [emb.weight.shape[0] for emb in embeddings.emb_layers] + + def _resize_token_embeddings(self, new_num_tokens, layer=-1): + embeddings = self.get_input_embeddings() + if new_num_tokens is None: + return embeddings + new_embeddings_layer = self._get_resized_embeddings(embeddings.emb_layers[layer], new_num_tokens) + embeddings.emb_layers[layer] = new_embeddings_layer + + self.set_input_embeddings(embeddings) + + return self.get_input_embeddings() + + def _resize_cutoffs(self, new_num_tokens, new_emb_size, new_embedding_shapes, layer): + embeddings = self.get_input_embeddings() + + for i in range(layer, len(embeddings.cutoffs)): + embeddings.cutoffs[i] = sum(new_embedding_shapes[: i + 1]) + + embeddings.cutoff_ends = [0] + embeddings.cutoffs + embeddings.n_token = new_num_tokens + + self.config.cutoffs = embeddings.cutoffs[:-1] + + return embeddings.cutoffs + + +@dataclass +class TransfoXLModelOutput(ModelOutput): + """ + Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding). + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + mems (`List[torch.FloatTensor]` of length `config.n_layers`): + Contains pre-computed hidden-states (key and values in the attention blocks). Can be used (see `mems` + input) to speed up sequential decoding. The token ids which have their past given to this model should not + be passed as input ids as they have already been computed. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + last_hidden_state: torch.FloatTensor + mems: List[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class TransfoXLSequenceClassifierOutputWithPast(ModelOutput): + """ + Base class for outputs of sentence classification models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + mems (`List[torch.FloatTensor]` of length `config.n_layers`): + Contains pre-computed hidden-states (key and values in the attention blocks). Can be used (see `mems` + input) to speed up sequential decoding. The token ids which have their past given to this model should not + be passed as input ids as they have already been computed. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + mems: List[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class TransfoXLLMHeadModelOutput(ModelOutput): + """ + Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding). + + Args: + losses (`torch.FloatTensor` of shape *(batch_size, sequence_length-1)*, *optional*, returned when `labels` is provided): + Language modeling losses (not reduced). + prediction_scores (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token after SoftMax). + mems (`List[torch.FloatTensor]` of length `config.n_layers`): + Contains pre-computed hidden-states (key and values in the attention blocks). Can be used (see `mems` + input) to speed up sequential decoding. The token ids which have their past given to this model should not + be passed as input ids as they have already been computed. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + loss (`torch.FloatTensor` of shape `()`, *optional*, returned when `labels` is provided) + Reduced language modeling loss. + """ + + losses: Optional[torch.FloatTensor] = None + prediction_scores: Optional[torch.FloatTensor] = None + mems: List[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + loss: Optional[torch.FloatTensor] = None + + @property + def logits(self): + # prediction scores are the output of the adaptive softmax, see + # the file `modeling_transfo_xl_utilities`. Since the adaptive + # softmax returns the log softmax value, `self.prediction_scores` + # are strictly speaking not exactly `logits`, but behave the same + # way logits do. + return self.prediction_scores + + +TRANSFO_XL_START_DOCSTRING = r""" + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`TransfoXLConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +TRANSFO_XL_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + mems (`List[torch.FloatTensor]` of length `config.n_layers`): + Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model (see + `mems` output below). Can be used to speed up sequential decoding. The token ids which have their mems + given to this model should not be passed as `input_ids` as they have already been computed. + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare Bert Model transformer outputting raw hidden-states without any specific head on top.", + TRANSFO_XL_START_DOCSTRING, +) +class TransfoXLModel(TransfoXLPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.n_token = config.vocab_size + + self.d_embed = config.d_embed + self.d_model = config.d_model + self.n_head = config.n_head + self.d_head = config.d_head + + self.word_emb = AdaptiveEmbedding( + config.vocab_size, config.d_embed, config.d_model, config.cutoffs, div_val=config.div_val + ) + + self.drop = nn.Dropout(config.dropout) + + self.n_layer = config.n_layer + self.mem_len = config.mem_len + self.attn_type = config.attn_type + + if not config.untie_r: + self.r_w_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head)) + self.r_r_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head)) + + self.layers = nn.ModuleList() + if config.attn_type == 0: # the default attention + for i in range(config.n_layer): + self.layers.append( + RelPartialLearnableDecoderLayer( + config.n_head, + config.d_model, + config.d_head, + config.d_inner, + config.dropout, + dropatt=config.dropatt, + pre_lnorm=config.pre_lnorm, + r_w_bias=None if config.untie_r else self.r_w_bias, + r_r_bias=None if config.untie_r else self.r_r_bias, + layer_norm_epsilon=config.layer_norm_epsilon, + ) + ) + else: # learnable embeddings and absolute embeddings are not used in our pretrained checkpoints + raise NotImplementedError # Removed them to avoid maintaining dead code + + self.same_length = config.same_length + self.clamp_len = config.clamp_len + + if self.attn_type == 0: # default attention + self.pos_emb = PositionalEmbedding(self.d_model) + else: # learnable embeddings and absolute embeddings + raise NotImplementedError # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.word_emb + + def set_input_embeddings(self, new_embeddings): + self.word_emb = new_embeddings + + def backward_compatible(self): + self.sample_softmax = -1 + + def reset_memory_length(self, mem_len): + self.mem_len = mem_len + + def _prune_heads(self, heads): + logger.info("Head pruning is not implemented for Transformer-XL model") + pass + + def init_mems(self, bsz): + if self.mem_len > 0: + mems = [] + param = next(self.parameters()) + for i in range(self.n_layer): + empty = torch.zeros(self.mem_len, bsz, self.config.d_model, dtype=param.dtype, device=param.device) + mems.append(empty) + + return mems + else: + return None + + def _update_mems(self, hids, mems, mlen, qlen): + # does not deal with None + if mems is None: + return None + + # mems is not None + assert len(hids) == len(mems), "len(hids) != len(mems)" + + # There are `mlen + qlen` steps that can be cached into mems + with torch.no_grad(): + new_mems = [] + end_idx = mlen + max(0, qlen) + beg_idx = max(0, end_idx - self.mem_len) + for i in range(len(hids)): + cat = torch.cat([mems[i], hids[i]], dim=0) + new_mems.append(cat[beg_idx:end_idx].detach()) + + return new_mems + + @add_start_docstrings_to_model_forward(TRANSFO_XL_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TransfoXLModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + mems: Optional[List[torch.FloatTensor]] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TransfoXLModelOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library + # so we transpose here from shape [bsz, len] to shape [len, bsz] + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_ids = input_ids.transpose(0, 1).contiguous() + qlen, bsz = input_ids.size() + elif inputs_embeds is not None: + inputs_embeds = inputs_embeds.transpose(0, 1).contiguous() + qlen, bsz = inputs_embeds.shape[0], inputs_embeds.shape[1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if mems is None: + mems = self.init_mems(bsz) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] (a head_mask for each layer) + # and head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head] + if head_mask is not None: + if head_mask.dim() == 1: + head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0).unsqueeze(0) + head_mask = head_mask.expand(self.n_layer, -1, -1, -1, -1) + elif head_mask.dim() == 2: + head_mask = head_mask.unsqueeze(1).unsqueeze(1).unsqueeze(1) + head_mask = head_mask.to( + dtype=next(self.parameters()).dtype + ) # switch to float if need + fp16 compatibility + else: + head_mask = [None] * self.n_layer + + if inputs_embeds is not None: + word_emb = inputs_embeds + else: + word_emb = self.word_emb(input_ids) + + mlen = mems[0].size(0) if mems is not None else 0 + klen = mlen + qlen + if self.same_length: + all_ones = word_emb.new_ones((qlen, klen), dtype=torch.bool) + mask_len = klen - self.mem_len + if mask_len > 0: + mask_shift_len = qlen - mask_len + else: + mask_shift_len = qlen + dec_attn_mask = (torch.triu(all_ones, 1 + mlen) + torch.tril(all_ones, -mask_shift_len))[:, :, None] # -1 + else: + dec_attn_mask = torch.triu(word_emb.new_ones((qlen, klen), dtype=torch.bool), diagonal=1 + mlen)[ + :, :, None + ] + + hids = [] + attentions = [] if output_attentions else None + if self.attn_type == 0: # default + pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.device, dtype=torch.int64).type_as( + dtype=word_emb.dtype + ) + if self.clamp_len > 0: + pos_seq.clamp_(max=self.clamp_len) + pos_emb = self.pos_emb(pos_seq) + + core_out = self.drop(word_emb) + pos_emb = self.drop(pos_emb) + + for i, layer in enumerate(self.layers): + hids.append(core_out) + mems_i = None if mems is None else mems[i] + layer_outputs = layer( + core_out, + pos_emb, + dec_attn_mask=dec_attn_mask, + mems=mems_i, + head_mask=head_mask[i], + output_attentions=output_attentions, + ) + core_out = layer_outputs[0] + if output_attentions: + attentions.append(layer_outputs[1]) + else: # learnable embeddings and absolute embeddings + raise NotImplementedError # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint + + core_out = self.drop(core_out) + + new_mems = self._update_mems(hids, mems, mlen, qlen) + + if output_hidden_states: + # Add last layer and transpose to library standard shape [bsz, len, hidden_dim] + hids.append(core_out) + hids = tuple(t.transpose(0, 1).contiguous() for t in hids) + else: + hids = None + if output_attentions: + # Transpose to library standard shape [bsz, n_heads, query_seq_len, key_seq_len] + attentions = tuple(t.permute(2, 3, 0, 1).contiguous() for t in attentions) + # We transpose back here to shape [bsz, len, hidden_dim] + core_out = core_out.transpose(0, 1).contiguous() + + if not return_dict: + return tuple(v for v in [core_out, new_mems, hids, attentions] if v is not None) + + return TransfoXLModelOutput( + last_hidden_state=core_out, + mems=new_mems, + hidden_states=hids, + attentions=attentions, + ) + + +@add_start_docstrings( + """ + The Transformer-XL Model with a language modeling head on top (adaptive softmax with weights tied to the adaptive + input embeddings) + """, + TRANSFO_XL_START_DOCSTRING, +) +class TransfoXLLMHeadModel(TransfoXLPreTrainedModel): + _tied_weights_keys = [r"crit\.out_projs\.\d+", r"crit\.out_layers\.\d+\.weight"] + + def __init__(self, config): + super().__init__(config) + self.transformer = TransfoXLModel(config) + self.sample_softmax = config.sample_softmax + self.trainer_compatible = getattr(config, "trainer_compatible", False) + + if not self.trainer_compatible: + warnings.warn( + "The output of TransfoXL will be updated in v5 to support a single loss as first argument. In order " + "to use that updated output, please specify `trainer_compatible=True` as your configuration" + " attribute.", + DeprecationWarning, + ) + + assert self.sample_softmax <= 0, ( + "Sampling from the softmax is not implemented yet. Please look at issue: #3310:" + " https://github.com/huggingface/transformers/issues/3310" + ) + + self.crit = ProjectedAdaptiveLogSoftmax( + config.vocab_size, config.d_embed, config.d_model, config.cutoffs, div_val=config.div_val + ) + + # Initialize weights and apply final processing + self.post_init() + + def tie_weights(self): + """ + Run this to be sure output and input (adaptive) softmax weights are tied + """ + + if self.config.tie_word_embeddings: + for i in range(len(self.crit.out_layers)): + self._tie_or_clone_weights(self.crit.out_layers[i], self.transformer.word_emb.emb_layers[i]) + if self.config.tie_projs: + for i, tie_proj in enumerate(self.config.tie_projs): + if tie_proj and self.config.div_val == 1 and self.config.d_model != self.config.d_embed: + if self.config.torchscript: + self.crit.out_projs[i] = nn.Parameter(self.transformer.word_emb.emb_projs[0].clone()) + else: + self.crit.out_projs[i] = self.transformer.word_emb.emb_projs[0] + elif tie_proj and self.config.div_val != 1: + if self.config.torchscript: + self.crit.out_projs[i] = nn.Parameter(self.transformer.word_emb.emb_projs[i].clone()) + else: + self.crit.out_projs[i] = self.transformer.word_emb.emb_projs[i] + + def reset_memory_length(self, mem_len): + self.transformer.reset_memory_length(mem_len) + + def init_mems(self, bsz): + return self.transformer.init_mems(bsz) + + @add_start_docstrings_to_model_forward(TRANSFO_XL_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TransfoXLLMHeadModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + mems: Optional[List[torch.FloatTensor]] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TransfoXLLMHeadModelOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if input_ids is not None: + bsz, tgt_len = input_ids.size(0), input_ids.size(1) + elif inputs_embeds is not None: + bsz, tgt_len = inputs_embeds.size(0), inputs_embeds.size(1) + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + transformer_outputs = self.transformer( + input_ids, + mems=mems, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden = transformer_outputs[0] + pred_hid = last_hidden[:, -tgt_len:] + + if labels is not None: + # Prevents all labels being -100 and throwing an error + # when backwarding the loss + miss_valid_label = labels[0, 1:].sum() == (labels.size(1) - 1) * -100 + if miss_valid_label: + # Sets an token, just to prevent loss from being NaN + labels[0, 1] = self.config.eos_token_id + + softmax_output = self.crit(pred_hid, labels) + prediction_scores = softmax_output.view(bsz, tgt_len, -1) if labels is None else () + + if labels is not None: + losses = softmax_output.view(bsz, tgt_len - 1) + # Avoids from incorporating padding (-100) tokens into loss value + loss = losses[losses != 0].mean() + else: + losses, loss = None, None + + if not return_dict: + if self.trainer_compatible: + output = (prediction_scores, losses) if losses is not None else (prediction_scores,) + output += transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + else: + output = (prediction_scores, *transformer_outputs[1:]) + output = ((losses,) + output) if losses is not None else output + return (output + (loss,)) if loss is not None else output + + return TransfoXLLMHeadModelOutput( + loss=loss, + prediction_scores=prediction_scores, + losses=losses, + mems=transformer_outputs.mems, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + def get_output_embeddings(self): + """Double-check if you are using adaptive softmax.""" + if self.sample_softmax > 0: + return self.out_layer + else: + return self.crit.out_layers[-1] + + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **model_kwargs): + inputs = {} + + # if past is defined in model kwargs then use it for faster decoding + if past_key_values: + inputs["mems"] = past_key_values + inputs["input_ids"] = input_ids[:, -1].unsqueeze(-1) + else: + inputs["input_ids"] = input_ids + + return inputs + + def _resize_cutoffs(self, new_num_tokens, new_emb_size, new_embedding_shapes, layer): + new_cutoffs = super()._resize_cutoffs(new_num_tokens, new_emb_size, new_embedding_shapes, layer) + + self.crit.cutoffs = new_cutoffs + self.crit.cutoff_ends = [0] + new_cutoffs + self.crit.n_token = new_num_tokens + + @staticmethod + def _reorder_cache(mems: List[torch.Tensor], beam_idx: torch.Tensor) -> List[torch.Tensor]: + """ + This function is used to re-order the `mems` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called. This is required to match `mems` with the correct beam_idx at every + generation step. + """ + return [layer_past.index_select(1, beam_idx.to(layer_past.device)) for layer_past in mems] + + +@add_start_docstrings( + """ + The Transformer-XL Model transformer with a sequence classification head on top (linear layer). + + [`TransfoXLForSequenceClassification`] uses the last token in order to do the classification, as other causal + models (e.g. GPT-1) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + TRANSFO_XL_START_DOCSTRING, +) +class TransfoXLForSequenceClassification(TransfoXLPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.transformer = TransfoXLModel(config) + self.score = nn.Linear(config.d_embed, self.num_labels, bias=False) + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(TRANSFO_XL_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TransfoXLSequenceClassifierOutputWithPast, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + mems: Optional[List[torch.FloatTensor]] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TransfoXLSequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + mems=mems, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size, sequence_length = input_ids.shape[:2] + else: + batch_size, sequence_length = inputs_embeds.shape[:2] + + assert self.config.pad_token_id is not None or batch_size == 1, ( + "Cannot handle batch sizes > 1 if no padding token is defined." + ) + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[range(batch_size), sequence_lengths] + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return TransfoXLSequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + mems=transformer_outputs.mems, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +__all__ = [ + "AdaptiveEmbedding", + "TransfoXLForSequenceClassification", + "TransfoXLLMHeadModel", + "TransfoXLModel", + "TransfoXLPreTrainedModel", + "load_tf_weights_in_transfo_xl", +] diff --git a/docs/transformers/build/lib/transformers/models/deprecated/transfo_xl/modeling_transfo_xl_utilities.py b/docs/transformers/build/lib/transformers/models/deprecated/transfo_xl/modeling_transfo_xl_utilities.py new file mode 100644 index 0000000000000000000000000000000000000000..f76f3ccc6259fcb033b44eb43dd98be23482221c --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/transfo_xl/modeling_transfo_xl_utilities.py @@ -0,0 +1,251 @@ +# coding=utf-8 +# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Utilities for PyTorch Transformer XL model. Directly adapted from https://github.com/kimiyoung/transformer-xl. +""" + +import torch +from torch import nn + + +# CUDA_MAJOR = int(torch.version.cuda.split('.')[0]) +# CUDA_MINOR = int(torch.version.cuda.split('.')[1]) + + +class ProjectedAdaptiveLogSoftmax(nn.Module): + def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, keep_order=False): + super().__init__() + + self.n_token = n_token + self.d_embed = d_embed + self.d_proj = d_proj + + self.cutoffs = cutoffs + [n_token] + self.cutoff_ends = [0] + self.cutoffs + self.div_val = div_val + + self.shortlist_size = self.cutoffs[0] + self.n_clusters = len(self.cutoffs) - 1 + self.head_size = self.shortlist_size + self.n_clusters + + if self.n_clusters > 0: + self.cluster_weight = nn.Parameter(torch.zeros(self.n_clusters, self.d_embed)) + self.cluster_bias = nn.Parameter(torch.zeros(self.n_clusters)) + + self.out_layers = nn.ModuleList() + self.out_projs = nn.ParameterList() + + if div_val == 1: + for i in range(len(self.cutoffs)): + if d_proj != d_embed: + self.out_projs.append(nn.Parameter(torch.FloatTensor(d_proj, d_embed))) + else: + self.out_projs.append(None) + + self.out_layers.append(nn.Linear(d_embed, n_token)) + else: + for i in range(len(self.cutoffs)): + l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] + d_emb_i = d_embed // (div_val**i) + + self.out_projs.append(nn.Parameter(torch.FloatTensor(d_proj, d_emb_i))) + + self.out_layers.append(nn.Linear(d_emb_i, r_idx - l_idx)) + + self.keep_order = keep_order + + def _compute_logit(self, hidden, weight, bias, proj): + if proj is None: + logit = nn.functional.linear(hidden, weight, bias=bias) + else: + # if CUDA_MAJOR <= 9 and CUDA_MINOR <= 1: + proj_hid = nn.functional.linear(hidden, proj.t().contiguous()) + logit = nn.functional.linear(proj_hid, weight, bias=bias) + # else: + # logit = torch.einsum('bd,de,ev->bv', (hidden, proj, weight.t())) + # if bias is not None: + # logit = logit + bias + + return logit + + def forward(self, hidden, labels=None, keep_order=False): + """ + Params: + hidden :: [len*bsz x d_proj] + labels :: [len*bsz] + + Return: + if labels is None: out :: [len*bsz x n_tokens] log probabilities of tokens over the vocabulary else: out :: + [(len-1)*bsz] Negative log likelihood. We could replace this implementation by the native PyTorch one if + theirs had an option to set bias on all clusters in the native one. here: + https://github.com/pytorch/pytorch/blob/dbe6a7a9ff1a364a8706bf5df58a1ca96d2fd9da/torch/nn/modules/adaptive.py#L138 + """ + + if labels is not None: + # Shift so that tokens < n predict n + hidden = hidden[..., :-1, :].contiguous() + labels = labels[..., 1:].contiguous() + hidden = hidden.view(-1, hidden.size(-1)) + labels = labels.view(-1) + if hidden.size(0) != labels.size(0): + raise RuntimeError("Input and labels should have the same size in the batch dimension.") + else: + hidden = hidden.view(-1, hidden.size(-1)) + + if self.n_clusters == 0: + logit = self._compute_logit(hidden, self.out_layers[0].weight, self.out_layers[0].bias, self.out_projs[0]) + if labels is not None: + mask = labels != -100 + out = torch.zeros_like(labels, dtype=hidden.dtype, device=hidden.device) + out[mask] = ( + -nn.functional.log_softmax(logit, dim=-1)[mask].gather(1, labels[mask].unsqueeze(1)).squeeze(1) + ) + else: + out = nn.functional.log_softmax(logit, dim=-1) + else: + # construct weights and biases + weights, biases = [], [] + for i in range(len(self.cutoffs)): + if self.div_val == 1: + l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] + weight_i = self.out_layers[0].weight[l_idx:r_idx] + bias_i = self.out_layers[0].bias[l_idx:r_idx] + else: + weight_i = self.out_layers[i].weight + bias_i = self.out_layers[i].bias + + if i == 0: + weight_i = torch.cat([weight_i, self.cluster_weight], dim=0) + bias_i = torch.cat([bias_i, self.cluster_bias], dim=0) + + weights.append(weight_i) + biases.append(bias_i) + + head_weight, head_bias, head_proj = weights[0], biases[0], self.out_projs[0] + + head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj) + head_logprob = nn.functional.log_softmax(head_logit, dim=1) + + if labels is None: + out = hidden.new_empty((head_logit.size(0), self.n_token)) + else: + out = torch.zeros_like(labels, dtype=hidden.dtype, device=hidden.device) + + offset = 0 + cutoff_values = [0] + self.cutoffs + for i in range(len(cutoff_values) - 1): + l_idx, r_idx = cutoff_values[i], cutoff_values[i + 1] + + if labels is not None: + mask_i = (labels >= l_idx) & (labels < r_idx) + indices_i = mask_i.nonzero().squeeze() + + if indices_i.numel() == 0: + continue + + target_i = labels.index_select(0, indices_i) - l_idx + head_logprob_i = head_logprob.index_select(0, indices_i) + hidden_i = hidden.index_select(0, indices_i) + else: + hidden_i = hidden + + if i == 0: + if labels is not None: + logprob_i = head_logprob_i.gather(1, target_i[:, None]).squeeze(1) + else: + out[:, : self.cutoffs[0]] = head_logprob[:, : self.cutoffs[0]] + else: + weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i] + + tail_logit_i = self._compute_logit(hidden_i, weight_i, bias_i, proj_i) + tail_logprob_i = nn.functional.log_softmax(tail_logit_i, dim=1) + cluster_prob_idx = self.cutoffs[0] + i - 1 # No probability for the head cluster + if labels is not None: + logprob_i = head_logprob_i[:, cluster_prob_idx] + tail_logprob_i.gather( + 1, target_i[:, None] + ).squeeze(1) + else: + logprob_i = head_logprob[:, cluster_prob_idx, None] + tail_logprob_i + out[:, l_idx:r_idx] = logprob_i + + if labels is not None: + if (hasattr(self, "keep_order") and self.keep_order) or keep_order: + out.index_copy_(0, indices_i, -logprob_i) + else: + out[offset : offset + logprob_i.size(0)].copy_(-logprob_i) + offset += logprob_i.size(0) + + return out + + def log_prob(self, hidden): + r""" + Computes log probabilities for all \\(n\_classes\\) From: + https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/adaptive.p + + Args: + hidden (Tensor): a minibatch of example + + Returns: + log-probabilities of for each class \\(c\\) in range \\(0 <= c <= n\_classes\\), where \\(n\_classes\\) is + a parameter passed to `AdaptiveLogSoftmaxWithLoss` constructor. Shape: + + - Input: \\((N, in\_features)\\) + - Output: \\((N, n\_classes)\\) + """ + if self.n_clusters == 0: + logit = self._compute_logit(hidden, self.out_layers[0].weight, self.out_layers[0].bias, self.out_projs[0]) + return nn.functional.log_softmax(logit, dim=-1) + else: + # construct weights and biases + weights, biases = [], [] + for i in range(len(self.cutoffs)): + if self.div_val == 1: + l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] + weight_i = self.out_layers[0].weight[l_idx:r_idx] + bias_i = self.out_layers[0].bias[l_idx:r_idx] + else: + weight_i = self.out_layers[i].weight + bias_i = self.out_layers[i].bias + + if i == 0: + weight_i = torch.cat([weight_i, self.cluster_weight], dim=0) + bias_i = torch.cat([bias_i, self.cluster_bias], dim=0) + + weights.append(weight_i) + biases.append(bias_i) + + head_weight, head_bias, head_proj = weights[0], biases[0], self.out_projs[0] + head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj) + + out = hidden.new_empty((head_logit.size(0), self.n_token)) + head_logprob = nn.functional.log_softmax(head_logit, dim=1) + + cutoff_values = [0] + self.cutoffs + for i in range(len(cutoff_values) - 1): + start_idx, stop_idx = cutoff_values[i], cutoff_values[i + 1] + + if i == 0: + out[:, : self.cutoffs[0]] = head_logprob[:, : self.cutoffs[0]] + else: + weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i] + + tail_logit_i = self._compute_logit(hidden, weight_i, bias_i, proj_i) + tail_logprob_i = nn.functional.log_softmax(tail_logit_i, dim=1) + + logprob_i = head_logprob[:, -i] + tail_logprob_i + out[:, start_idx, stop_idx] = logprob_i + + return out diff --git a/docs/transformers/build/lib/transformers/models/deprecated/transfo_xl/tokenization_transfo_xl.py b/docs/transformers/build/lib/transformers/models/deprecated/transfo_xl/tokenization_transfo_xl.py new file mode 100644 index 0000000000000000000000000000000000000000..f9a5fb7b3466b0091717dfe2a264c16b0a1aeeeb --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/transfo_xl/tokenization_transfo_xl.py @@ -0,0 +1,821 @@ +# coding=utf-8 +# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tokenization classes for Transformer XL model. Adapted from https://github.com/kimiyoung/transformer-xl. +""" + +import glob +import os +import pickle +import re +from collections import Counter, OrderedDict +from typing import List, Optional, Tuple + +import numpy as np + +from ....tokenization_utils import PreTrainedTokenizer +from ....utils import ( + cached_file, + is_sacremoses_available, + is_torch_available, + logging, + requires_backends, + strtobool, + torch_only_method, +) + + +if is_sacremoses_available(): + import sacremoses as sm + + +if is_torch_available(): + import torch + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = { + "pretrained_vocab_file": "vocab.pkl", + "pretrained_vocab_file_torch": "vocab.bin", + "vocab_file": "vocab.txt", +} + + +PRETRAINED_CORPUS_ARCHIVE_MAP = { + "transfo-xl/transfo-xl-wt103": "https://huggingface.co/transfo-xl/transfo-xl-wt103/resolve/main/corpus.bin", +} +CORPUS_NAME = "corpus.bin" + +MATCH_NUMBERS = r"(?<=\d)[,.](?=\d)", r" @\g<0>@ " +DETOKENIZE_NUMBERS = [(r" @\,@ ", r","), (r" @\.@ ", r".")] + + +def tokenize_numbers(text_array: List[str]) -> List[str]: + """ + Splits large comma-separated numbers and floating point values. This is done by replacing commas with ' @,@ ' and + dots with ' @.@ '. + + Args: + text_array: An already tokenized text as list. + + Returns: + A list of strings with tokenized numbers. + + Example: + + ```python + >>> tokenize_numbers(["$", "5,000", "1.73", "m"]) + ['$', '5', '@,@', '000', '1', '@.@', '73', 'm'] + ```""" + tokenized = [] + for i in range(len(text_array)): + reg, sub = MATCH_NUMBERS + replaced = re.sub(reg, sub, text_array[i]).split() + tokenized.extend(replaced) + + return tokenized + + +def detokenize_numbers(text: str) -> str: + """ + Inverts the operation of *tokenize_numbers*. This is replacing ' @,@ ' and ' @.@' by ',' and '.'. + + Args: + text: A string where the number should be detokenized. + + Returns: + A detokenized string. + + Example: + + ```python + >>> detokenize_numbers("$ 5 @,@ 000 1 @.@ 73 m") + '$ 5,000 1.73 m' + ```""" + for reg, sub in DETOKENIZE_NUMBERS: + text = re.sub(reg, sub, text) + return text + + +class TransfoXLTokenizer(PreTrainedTokenizer): + """ + Construct a Transformer-XL tokenizer adapted from Vocab class in [the original + code](https://github.com/kimiyoung/transformer-xl). The Transformer-XL tokenizer is a word-level tokenizer (no + sub-word tokenization). + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + special (`List[str]`, *optional*): + A list of special tokens (to be treated by the original implementation of this tokenizer). + min_freq (`int`, *optional*, defaults to 0): + The minimum number of times a token has to be present in order to be kept in the vocabulary (otherwise it + will be mapped to `unk_token`). + max_size (`int`, *optional*): + The maximum size of the vocabulary. If left unset, it will default to the size of the vocabulary found + after excluding the tokens according to the `min_freq` rule. + lower_case (`bool`, *optional*, defaults to `False`): + Whether or not to lowercase the input when tokenizing. + delimiter (`str`, *optional*): + The delimiter used between tokens. + vocab_file (`str`, *optional*): + File containing the vocabulary (from the original implementation). + pretrained_vocab_file (`str`, *optional*): + File containing the vocabulary as saved with the `save_pretrained()` method. + never_split (`List[str]`, *optional*): + List of tokens that should never be split. If no list is specified, will simply use the existing special + tokens. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + additional_special_tokens (`List[str]`, *optional*, defaults to `['']`): + A list of additional special tokens (for the HuggingFace functionality). + language (`str`, *optional*, defaults to `"en"`): + The language of this tokenizer (used for mose preprocessing). + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids"] + + def __init__( + self, + special=None, + min_freq=0, + max_size=None, + lower_case=False, + delimiter=None, + vocab_file=None, + pretrained_vocab_file: Optional[str] = None, + never_split=None, + unk_token="", + eos_token="", + additional_special_tokens=[""], + language="en", + **kwargs, + ): + logger.error( + "`TransfoXL` was deprecated due to security issues linked to `pickle.load` in `TransfoXLTokenizer`. " + "See more details on this model's documentation page: " + "`https://github.com/huggingface/transformers/blob/main/docs/source/en/model_doc/transfo-xl.md`." + ) + + requires_backends(self, "sacremoses") + if special is None: + special = [] + self.counter = Counter() + self.special = special + self.min_freq = min_freq + self.max_size = max_size + self.lower_case = lower_case + self.delimiter = delimiter + self.vocab_file = vocab_file + self.punctuation_symbols = '!"#$%&()*+,-./\\:;<=>?@[\\]^_`{|}~' + self.punction_without_space_before_pattern = re.compile(rf"[^\s][{self.punctuation_symbols}]") + self.punctuation_with_space_around_pattern = self._compile_space_around_punctuation_pattern() + self.language = language + self.moses_punct_normalizer = sm.MosesPunctNormalizer(language) + self.moses_tokenizer = sm.MosesTokenizer(language) + self.moses_detokenizer = sm.MosesDetokenizer(language) + self.idx2sym = [] + self.sym2idx = OrderedDict() + # This try... catch... is not beautiful but honestly this tokenizer was not made to be used + # in a library like ours, at all. + try: + vocab_dict = None + if pretrained_vocab_file is not None: + # Priority on pickle files (support PyTorch and TF) + if not strtobool(os.environ.get("TRUST_REMOTE_CODE", "False")): + raise ValueError( + "This part uses `pickle.load` which is insecure and will execute arbitrary code that is " + "potentially malicious. It's recommended to never unpickle data that could have come from an " + "untrusted source, or that could have been tampered with. If you already verified the pickle " + "data and decided to use it, you can set the environment variable " + "`TRUST_REMOTE_CODE` to `True` to allow it." + ) + with open(pretrained_vocab_file, "rb") as f: + vocab_dict = pickle.load(f) + + # Loading a torch-saved transfo-xl vocab dict with pickle results in an integer + # Entering this if statement means that we tried to load a torch-saved file with pickle, and we failed. + # We therefore load it with torch, if it's available. + if isinstance(vocab_dict, int): + if not is_torch_available(): + raise ImportError( + "Not trying to load dict with PyTorch as you need to install pytorch to load " + "from a PyTorch pretrained vocabulary, " + "or activate it with environment variables USE_TORCH=1 and USE_TF=0." + ) + vocab_dict = torch.load(pretrained_vocab_file, weights_only=True) + + if vocab_dict is not None: + for key, value in vocab_dict.items(): + if key not in self.__dict__ or key in ["sym2idx", "idx2sym"]: + self.__dict__[key] = value + elif vocab_file is not None: + self.build_vocab() + + except Exception as e: + raise ValueError( + f"Unable to parse file {pretrained_vocab_file}. Unknown format. " + "If you tried to load a model saved through TransfoXLTokenizerFast, " + "please note they are not compatible." + ) from e + + if vocab_file is not None: + self.build_vocab() + + super().__init__( + special=special, + min_freq=min_freq, + max_size=max_size, + lower_case=lower_case, + delimiter=delimiter, + vocab_file=vocab_file, + pretrained_vocab_file=pretrained_vocab_file, + never_split=never_split, + unk_token=unk_token, + eos_token=eos_token, + additional_special_tokens=additional_special_tokens, + language=language, + **kwargs, + ) + + # these are not required to initialize the parent class as only used when tokenizing. + if never_split is None: + never_split = self.all_special_tokens + self.never_split = never_split + + @property + def do_lower_case(self): + return self.lower_case + + def _compile_space_around_punctuation_pattern(self): + look_ahead_for_special_token = f"(?=[{self.punctuation_symbols}])" + look_ahead_to_match_all_except_space = r"(?=[^\s])" + return re.compile(r"" + look_ahead_for_special_token + look_ahead_to_match_all_except_space) + + def count_file(self, path, verbose=False, add_eos=False): + if verbose: + logger.info(f"counting file {path} ...") + assert os.path.exists(path), f"Input file {path} not found" + + sents = [] + with open(path, "r", encoding="utf-8") as f: + for idx, line in enumerate(f): + if verbose and idx > 0 and idx % 500000 == 0: + logger.info(f" line {idx}") + symbols = self.tokenize(line, add_eos=add_eos) + self.counter.update(symbols) + sents.append(symbols) + + return sents + + def count_sents(self, sents, verbose=False): + """ + sents : a list of sentences, each a list of tokenized symbols + """ + if verbose: + logger.info(f"counting {len(sents)} sents ...") + for idx, symbols in enumerate(sents): + if verbose and idx > 0 and idx % 500000 == 0: + logger.info(f" line {idx}") + self.counter.update(symbols) + + def _build_from_file(self, vocab_file): + self.idx2sym = [] + self.sym2idx = OrderedDict() + + with open(vocab_file, "r", encoding="utf-8") as f: + for line in f: + symb = line.strip().split()[0] + self.add_symbol(symb) + if "" in self.sym2idx: + self.unk_idx = self.sym2idx[""] + elif "" in self.sym2idx: + self.unk_idx = self.sym2idx[""] + else: + raise ValueError("Token not in vocabulary and no token in vocabulary for replacement.") + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + if os.path.isdir(save_directory): + vocab_file = os.path.join( + save_directory, + (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["pretrained_vocab_file"], + ) + else: + vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory + with open(vocab_file, "wb") as f: + pickle.dump(self.__dict__, f) + return (vocab_file,) + + def build_vocab(self): + if self.vocab_file: + logger.info(f"building vocab from {self.vocab_file}") + self._build_from_file(self.vocab_file) + logger.info(f"Final vocab size {len(self.sym2idx)}") + else: + logger.info(f"building vocab with min_freq={self.min_freq}, max_size={self.max_size}") + self.idx2sym = [] + self.sym2idx = OrderedDict() + + for sym in self.special: + self.add_special(sym) + + for sym, cnt in self.counter.most_common(self.max_size): + if cnt < self.min_freq: + break + self.add_symbol(sym) + + logger.info(f"Final vocab size {len(self.sym2idx)} from {len(self.counter)} unique tokens") + + @torch_only_method + def encode_file(self, path, ordered=False, verbose=False, add_eos=True, add_double_eos=False): + if verbose: + logger.info(f"encoding file {path} ...") + assert os.path.exists(path), f"Output file {path} not found" + encoded = [] + with open(path, "r", encoding="utf-8") as f: + for idx, line in enumerate(f): + if verbose and idx > 0 and idx % 500000 == 0: + logger.info(f" line {idx}") + symbols = self.tokenize(line, add_eos=add_eos, add_double_eos=add_double_eos) + encoded.append(self.convert_to_tensor(symbols)) + + if ordered: + encoded = torch.cat(encoded) + + return encoded + + @torch_only_method + def encode_sents(self, sents, ordered=False, verbose=False): + if verbose: + logger.info(f"encoding {len(sents)} sents ...") + encoded = [] + for idx, symbols in enumerate(sents): + if verbose and idx > 0 and idx % 500000 == 0: + logger.info(f" line {idx}") + encoded.append(self.convert_to_tensor(symbols)) + + if ordered: + encoded = torch.cat(encoded) + + return encoded + + def add_special(self, sym): + if sym not in self.sym2idx: + self.idx2sym.append(sym) + self.sym2idx[sym] = len(self.idx2sym) - 1 + setattr(self, f"{sym.strip('<>')}_idx", self.sym2idx[sym]) + + def add_symbol(self, sym): + if sym not in self.sym2idx: + self.idx2sym.append(sym) + self.sym2idx[sym] = len(self.idx2sym) - 1 + + def move_added_token(self, token: str, target_idx: int): + """ + Moves an added token to a specific position in the vocab. This method should be used when resizing an embedding + layer other than the last one in the `AdaptiveEmbedding` in order to move the token in the tokenizer from the + default position (at the very end) to the desired one. + + Args: + token: The token to move to a specific position in the vocab. + target_idx: The position where the token should be moved to. + """ + assert token in self.added_tokens_encoder, "Token which should be moved has to be an added token" + assert token not in self.idx2sym, "Token which should be moved is already in vocab" + + # Insert sym into vocab + self.idx2sym.insert(target_idx, token) + self.sym2idx[token] = target_idx + + # Shift following indices in sym2idx + for idx in range(target_idx + 1, len(self.idx2sym)): + current_sym = self.idx2sym[idx] + self.sym2idx[current_sym] = idx + + # Delete token from added_tokens + old_index = self._added_tokens_encoder.pop(token) + self._added_tokens_decoder.pop(old_index) + + def moses_punct_norm(self, text): + return self.moses_punct_normalizer.normalize(text) + + def moses_tokenize(self, text): + return self.moses_tokenizer.tokenize( + text, aggressive_dash_splits=True, return_str=False, escape=False, protected_patterns=self.never_split + ) + + def moses_pipeline(self, text: str) -> List[str]: + """ + Does basic tokenization using [`sacremoses.MosesPunctNormalizer`] and [`sacremoses.MosesTokenizer`] with + *aggressive_dash_splits=True* (see [`sacremoses.tokenize.MosesTokenizer.tokenize`]). Additionally, large + comma-separated numbers and floating point values are split. E.g. "23,000 people are 1.80m tall" -> "23 @,@ 000 + people are 1 @.@ 80m tall" + + Args: + text: Text to be tokenize + + Returns: + A list of tokenized string + + Example: + + ```python + >>> tokenizer = TransfoXLTokenizer.from_pretrained("transfo-xl/transfo-xl-wt103") + >>> tokenizer.moses_pipeline("23,000 people are 1.80 m tall") + ['23', '@,@', '000', 'people', 'are', '1', '@.@', '80', 'm', 'tall'] + ```""" + text = self.moses_punct_norm(text) + text = self.moses_tokenize(text) + text = tokenize_numbers(text) + return text + + def _convert_id_to_token(self, idx): + """Converts an id in a token (BPE) using the vocab.""" + assert 0 <= idx < len(self), f"Index {idx} out of vocabulary range" + return self.idx2sym[idx] + + def _convert_token_to_id(self, sym): + """Converts a token (str) in an id using the vocab.""" + if sym in self.sym2idx: + return self.sym2idx[sym] + else: + # logger.info(f'encounter unk {sym}') + # assert '' not in sym + if hasattr(self, "unk_idx"): + return self.sym2idx.get(sym, self.unk_idx) + # Backward compatibility with pre-trained models + elif "" in self.sym2idx: + return self.sym2idx[""] + elif "" in self.sym2idx: + return self.sym2idx[""] + else: + raise ValueError("Token not in vocabulary and no token in vocabulary for replacement.") + + def convert_tokens_to_string(self, tokens): + """ + Converts a sequence of tokens (string) in a single string. Additionally, the split numbers are converted back + into it's original form. + """ + out_string = self.moses_detokenizer.detokenize(tokens) + return detokenize_numbers(out_string).strip() + + @torch_only_method + def convert_to_tensor(self, symbols): + return torch.LongTensor(self.convert_tokens_to_ids(symbols)) + + @property + def vocab_size(self): + return len(self.idx2sym) + + def get_vocab(self): + vocab = self.sym2idx.copy() + vocab.update(self.added_tokens_encoder) + return vocab + + def _tokenize(self, line, add_eos=False, add_double_eos=False): + line = line.strip() + # convert to lower case + if self.lower_case: + line = line.lower() + + # empty delimiter '' will evaluate False + if self.delimiter == "": + symbols = line + else: + symbols = self.moses_pipeline(line) + + if add_double_eos: # lm1b + return [""] + symbols + [""] + elif add_eos: + return symbols + [""] + else: + return symbols + + +class LMOrderedIterator: + def __init__(self, data, bsz, bptt, device="cpu", ext_len=None): + """ + data -- LongTensor -- the LongTensor is strictly ordered + """ + self.bsz = bsz + self.bptt = bptt + self.ext_len = ext_len if ext_len is not None else 0 + + self.device = device + + # Work out how cleanly we can divide the dataset into bsz parts. + self.n_step = data.size(0) // bsz + + # Trim off any extra elements that wouldn't cleanly fit (remainders). + data = data.narrow(0, 0, self.n_step * bsz) + + # Evenly divide the data across the bsz batches. + self.data = data.view(bsz, -1).t().contiguous().to(device) + + # Number of mini-batches + self.n_batch = (self.n_step + self.bptt - 1) // self.bptt + + def get_batch(self, i, bptt=None): + if bptt is None: + bptt = self.bptt + seq_len = min(bptt, self.data.size(0) - 1 - i) + + end_idx = i + seq_len + beg_idx = max(0, i - self.ext_len) + + data = self.data[beg_idx:end_idx] + target = self.data[i + 1 : i + 1 + seq_len] + + data_out = data.transpose(0, 1).contiguous().to(self.device) + target_out = target.transpose(0, 1).contiguous().to(self.device) + + return data_out, target_out, seq_len + + def get_fixlen_iter(self, start=0): + for i in range(start, self.data.size(0) - 1, self.bptt): + yield self.get_batch(i) + + def get_varlen_iter(self, start=0, std=5, min_len=5, max_deviation=3): + max_len = self.bptt + max_deviation * std + i = start + while True: + bptt = self.bptt if np.random.random() < 0.95 else self.bptt / 2.0 + bptt = min(max_len, max(min_len, int(np.random.normal(bptt, std)))) + data, target, seq_len = self.get_batch(i, bptt) + i += seq_len + yield data, target, seq_len + if i >= self.data.size(0) - 2: + break + + def __iter__(self): + return self.get_fixlen_iter() + + +class LMShuffledIterator: + def __init__(self, data, bsz, bptt, device="cpu", ext_len=None, shuffle=False): + """ + data -- list[LongTensor] -- there is no order among the LongTensors + """ + self.data = data + + self.bsz = bsz + self.bptt = bptt + self.ext_len = ext_len if ext_len is not None else 0 + + self.device = device + self.shuffle = shuffle + + def get_sent_stream(self): + # index iterator + epoch_indices = np.random.permutation(len(self.data)) if self.shuffle else np.array(range(len(self.data))) + + # sentence iterator + for idx in epoch_indices: + yield self.data[idx] + + @torch_only_method + def stream_iterator(self, sent_stream): + # streams for each data in the batch + streams = [None] * self.bsz + + data = torch.LongTensor(self.bptt, self.bsz) + target = torch.LongTensor(self.bptt, self.bsz) + + n_retain = 0 + + while True: + # data : [n_retain+bptt x bsz] + # target : [bptt x bsz] + data[n_retain:].fill_(-1) + target.fill_(-1) + + valid_batch = True + + for i in range(self.bsz): + n_filled = 0 + try: + while n_filled < self.bptt: + if streams[i] is None or len(streams[i]) <= 1: + streams[i] = next(sent_stream) + # number of new tokens to fill in + n_new = min(len(streams[i]) - 1, self.bptt - n_filled) + # first n_retain tokens are retained from last batch + data[n_retain + n_filled : n_retain + n_filled + n_new, i] = streams[i][:n_new] + target[n_filled : n_filled + n_new, i] = streams[i][1 : n_new + 1] + streams[i] = streams[i][n_new:] + n_filled += n_new + except StopIteration: + valid_batch = False + break + + if not valid_batch: + return + + data_out = data.transpose(0, 1).contiguous().to(self.device) + target_out = target.transpose(0, 1).contiguous().to(self.device) + + yield data_out, target_out, self.bptt + + n_retain = min(data.size(0), self.ext_len) + if n_retain > 0: + data[:n_retain] = data[-n_retain:] + data.resize_(n_retain + self.bptt, data.size(1)) + + def __iter__(self): + # sent_stream is an iterator + sent_stream = self.get_sent_stream() + + for batch in self.stream_iterator(sent_stream): + yield batch + + +class LMMultiFileIterator(LMShuffledIterator): + def __init__(self, paths, vocab, bsz, bptt, device="cpu", ext_len=None, shuffle=False): + self.paths = paths + self.vocab = vocab + + self.bsz = bsz + self.bptt = bptt + self.ext_len = ext_len if ext_len is not None else 0 + + self.device = device + self.shuffle = shuffle + + def get_sent_stream(self, path): + sents = self.vocab.encode_file(path, add_double_eos=True) + if self.shuffle: + np.random.shuffle(sents) + sent_stream = iter(sents) + + return sent_stream + + def __iter__(self): + if self.shuffle: + np.random.shuffle(self.paths) + + for path in self.paths: + # sent_stream is an iterator + sent_stream = self.get_sent_stream(path) + for batch in self.stream_iterator(sent_stream): + yield batch + + +class TransfoXLCorpus: + @classmethod + @torch_only_method + def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): + """ + Instantiate a pre-processed corpus. + """ + vocab = TransfoXLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) + is_local = os.path.isdir(pretrained_model_name_or_path) + # redirect to the cache, if necessary + try: + resolved_corpus_file = cached_file(pretrained_model_name_or_path, CORPUS_NAME, cache_dir=cache_dir) + except EnvironmentError: + logger.error( + f"Corpus '{pretrained_model_name_or_path}' was not found in corpus list" + f" ({', '.join(PRETRAINED_CORPUS_ARCHIVE_MAP.keys())}. We assumed '{pretrained_model_name_or_path}'" + f" was a path or url but couldn't find files {CORPUS_NAME} at this path or url." + ) + return None + if is_local: + logger.info(f"loading corpus file {resolved_corpus_file}") + else: + logger.info(f"loading corpus file {CORPUS_NAME} from cache at {resolved_corpus_file}") + + # Instantiate tokenizer. + corpus = cls(*inputs, **kwargs) + corpus_dict = torch.load(resolved_corpus_file, weights_only=True) + for key, value in corpus_dict.items(): + corpus.__dict__[key] = value + corpus.vocab = vocab + if corpus.train is not None: + corpus.train = torch.tensor(corpus.train, dtype=torch.long) + if corpus.valid is not None: + corpus.valid = torch.tensor(corpus.valid, dtype=torch.long) + if corpus.test is not None: + corpus.test = torch.tensor(corpus.test, dtype=torch.long) + return corpus + + def __init__(self, *args, **kwargs): + self.vocab = TransfoXLTokenizer(*args, **kwargs) + self.dataset = None + self.train = None + self.valid = None + self.test = None + + def build_corpus(self, path, dataset): + self.dataset = dataset + + if self.dataset in ["ptb", "wt2", "enwik8", "text8"]: + self.vocab.count_file(os.path.join(path, "train.txt")) + self.vocab.count_file(os.path.join(path, "valid.txt")) + self.vocab.count_file(os.path.join(path, "test.txt")) + elif self.dataset == "wt103": + self.vocab.count_file(os.path.join(path, "train.txt")) + elif self.dataset == "lm1b": + train_path_pattern = os.path.join( + path, + "1-billion-word-language-modeling-benchmark-r13output", + "training-monolingual.tokenized.shuffled", + "news.en-*", + ) + train_paths = glob.glob(train_path_pattern) + # the vocab will load from file when build_vocab() is called + + self.vocab.build_vocab() + + if self.dataset in ["ptb", "wt2", "wt103"]: + self.train = self.vocab.encode_file(os.path.join(path, "train.txt"), ordered=True) + self.valid = self.vocab.encode_file(os.path.join(path, "valid.txt"), ordered=True) + self.test = self.vocab.encode_file(os.path.join(path, "test.txt"), ordered=True) + elif self.dataset in ["enwik8", "text8"]: + self.train = self.vocab.encode_file(os.path.join(path, "train.txt"), ordered=True, add_eos=False) + self.valid = self.vocab.encode_file(os.path.join(path, "valid.txt"), ordered=True, add_eos=False) + self.test = self.vocab.encode_file(os.path.join(path, "test.txt"), ordered=True, add_eos=False) + elif self.dataset == "lm1b": + self.train = train_paths + self.valid = self.vocab.encode_file(os.path.join(path, "valid.txt"), ordered=False, add_double_eos=True) + self.test = self.vocab.encode_file(os.path.join(path, "test.txt"), ordered=False, add_double_eos=True) + + def get_iterator(self, split, *args, **kwargs): + if split == "train": + if self.dataset in ["ptb", "wt2", "wt103", "enwik8", "text8"]: + data_iter = LMOrderedIterator(self.train, *args, **kwargs) + elif self.dataset == "lm1b": + kwargs["shuffle"] = True + data_iter = LMMultiFileIterator(self.train, self.vocab, *args, **kwargs) + elif split in ["valid", "test"]: + data = self.valid if split == "valid" else self.test + if self.dataset in ["ptb", "wt2", "wt103", "enwik8", "text8"]: + data_iter = LMOrderedIterator(data, *args, **kwargs) + elif self.dataset == "lm1b": + data_iter = LMShuffledIterator(data, *args, **kwargs) + else: + data_iter = None + raise ValueError(f"Split not recognized: {split}") + + return data_iter + + +@torch_only_method +def get_lm_corpus(datadir, dataset): + fn = os.path.join(datadir, "cache.pt") + fn_pickle = os.path.join(datadir, "cache.pkl") + if os.path.exists(fn): + logger.info("Loading cached dataset...") + corpus = torch.load(fn_pickle, weights_only=True) + elif os.path.exists(fn): + logger.info("Loading cached dataset from pickle...") + if not strtobool(os.environ.get("TRUST_REMOTE_CODE", "False")): + raise ValueError( + "This part uses `pickle.load` which is insecure and will execute arbitrary code that is potentially " + "malicious. It's recommended to never unpickle data that could have come from an untrusted source, or " + "that could have been tampered with. If you already verified the pickle data and decided to use it, " + "you can set the environment variable `TRUST_REMOTE_CODE` to `True` to allow it." + ) + with open(fn, "rb") as fp: + corpus = pickle.load(fp) + else: + logger.info(f"Producing dataset {dataset}...") + kwargs = {} + if dataset in ["wt103", "wt2"]: + kwargs["special"] = [""] + kwargs["lower_case"] = False + elif dataset == "ptb": + kwargs["special"] = [""] + kwargs["lower_case"] = True + elif dataset == "lm1b": + kwargs["special"] = [] + kwargs["lower_case"] = False + kwargs["vocab_file"] = os.path.join(datadir, "1b_word_vocab.txt") + elif dataset in ["enwik8", "text8"]: + pass + + corpus = TransfoXLCorpus(datadir, dataset, **kwargs) + torch.save(corpus, fn) + + return corpus + + +__all__ = ["TransfoXLCorpus", "TransfoXLTokenizer"] diff --git a/docs/transformers/build/lib/transformers/models/deprecated/tvlt/__init__.py b/docs/transformers/build/lib/transformers/models/deprecated/tvlt/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..941db2f6ac5fefe66e06d598e1adbe0db1cf1769 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/tvlt/__init__.py @@ -0,0 +1,20 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. +from typing import TYPE_CHECKING + +from ....utils import _LazyModule +from ....utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_tvlt import * + from .feature_extraction_tvlt import * + from .processing_tvlt import * + from .modeling_tvlt import * + from .image_processing_tvlt import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/build/lib/transformers/models/deprecated/tvlt/configuration_tvlt.py b/docs/transformers/build/lib/transformers/models/deprecated/tvlt/configuration_tvlt.py new file mode 100644 index 0000000000000000000000000000000000000000..bf159fa7e0b7754778ebda57b46dea4479fdc3c9 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/tvlt/configuration_tvlt.py @@ -0,0 +1,187 @@ +# coding=utf-8 +# Copyright 2023 MURGe-Lab and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TVLT model configuration""" + +from ....configuration_utils import PretrainedConfig +from ....utils import logging + + +logger = logging.get_logger(__name__) + + +class TvltConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`TvltModel`]. It is used to instantiate a TVLT + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the TVLT + [ZinengTang/tvlt-base](https://huggingface.co/ZinengTang/tvlt-base) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + spectrogram_length (`int`, *optional*, defaults to 2048): + The time length of each audio spectrogram. + frequency_length (`int`, *optional*, defaults to 128): + The frequency length of audio spectrogram. + image_patch_size (`List[int]`, *optional*, defaults to `[16, 16]`): + The size (resolution) of each image patch. + audio_patch_size (`List[int]`, *optional*, defaults to `[16, 16]`): + The size (resolution) of each audio patch. + num_image_channels (`int`, *optional*, defaults to 3): + The number of input image channels. + num_audio_channels (`int`, *optional*, defaults to 1): + The number of input audio channels. + num_frames (`int`, *optional*, defaults to 8): + The maximum number of frames for an input video. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + qkv_bias (`bool`, *optional*, defaults to `True`): + Whether to add a bias to the queries, keys and values. + use_mean_pooling (`bool`, *optional*, defaults to `False`): + Whether to mean pool the final hidden states instead of using the final hidden state of the [CLS] token. + decoder_num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the decoder. + decoder_hidden_size (`int`, *optional*, defaults to 512): + Dimensionality of the decoder. + decoder_num_hidden_layers (`int`, *optional*, defaults to 8): + Number of hidden layers in the decoder. + decoder_intermediate_size (`int`, *optional*, defaults to 2048): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the decoder. + pixel_mask_ratio (`float`, *optional*, defaults to 0.75): + Image patch masking ratio. + audio_mask_ratio (`float`, *optional*, defaults to 0.15): + Audio patch masking ratio. + audio_mask_type (`str`, *optional*, defaults to `"frame-level"`): + Audio patch masking type, choose between "frame-level" and "patch-level". + task_matching (`bool`, *optional*, defaults to `True`): + Whether to use vision audio matching task in pretraining. + task_mae (`bool`, *optional*, defaults to `True`): + Whether to use the masked auto-encoder (MAE) in pretraining. + loss_type (`str`, *optional*, defaults to `"classification"`): + Loss types including regression and classification. + + Example: + + ```python + >>> from transformers import TvltConfig, TvltModel + + >>> # # Initializing a TVLT ZinengTang/tvlt-base style configuration + >>> configuration = TvltConfig() + + >>> # # Initializing a model (with random weights) from the ZinengTang/tvlt-base style configuration + >>> model = TvltModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "tvlt" + + def __init__( + self, + image_size=224, + spectrogram_length=2048, + frequency_length=128, + image_patch_size=[16, 16], + audio_patch_size=[16, 16], + num_image_channels=3, + num_audio_channels=1, + num_frames=8, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + initializer_range=0.02, + layer_norm_eps=1e-6, + qkv_bias=True, + use_mean_pooling=False, + decoder_num_attention_heads=16, + decoder_hidden_size=512, + decoder_num_hidden_layers=8, + decoder_intermediate_size=2048, + pixel_mask_ratio=0.75, + audio_mask_ratio=0.15, + audio_mask_type="frame-level", + task_matching=True, + task_mae=True, + loss_type="classification", + **kwargs, + ): + super().__init__(**kwargs) + + if audio_mask_type not in ("frame-level", "patch_level"): + raise ValueError( + "audio_mask_type must be one of two acceptable strategies - {'frame_level', 'patch-level') " + f"got {audio_mask_type}" + ) + + self.image_size = image_size + self.spectrogram_length = spectrogram_length + self.frequency_length = frequency_length + self.image_patch_size = image_patch_size + self.audio_patch_size = audio_patch_size + self.num_image_channels = num_image_channels + self.num_audio_channels = num_audio_channels + self.num_frames = num_frames + + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.qkv_bias = qkv_bias + self.use_mean_pooling = use_mean_pooling + + self.decoder_num_attention_heads = decoder_num_attention_heads + self.decoder_hidden_size = decoder_hidden_size + self.decoder_num_hidden_layers = decoder_num_hidden_layers + self.decoder_intermediate_size = decoder_intermediate_size + self.pixel_mask_ratio = pixel_mask_ratio + self.audio_mask_ratio = audio_mask_ratio + self.audio_mask_type = audio_mask_type + + self.task_matching = task_matching + self.task_mae = task_mae + self.loss_type = loss_type + + +__all__ = ["TvltConfig"] diff --git a/docs/transformers/build/lib/transformers/models/deprecated/tvlt/feature_extraction_tvlt.py b/docs/transformers/build/lib/transformers/models/deprecated/tvlt/feature_extraction_tvlt.py new file mode 100644 index 0000000000000000000000000000000000000000..bbbfac9031b9be499bd168c74681f32e32357c5b --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/tvlt/feature_extraction_tvlt.py @@ -0,0 +1,233 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Feature extractor class for TVLT.""" + +from math import ceil +from typing import List, Optional, Union + +import numpy as np + +from ....audio_utils import mel_filter_bank, spectrogram, window_function +from ....feature_extraction_sequence_utils import BatchFeature, SequenceFeatureExtractor +from ....utils import TensorType, logging + + +logger = logging.get_logger(__name__) + + +class TvltFeatureExtractor(SequenceFeatureExtractor): + r""" + Constructs a TVLT audio feature extractor. This feature extractor can be used to prepare audios for the model. + + This feature extractor inherits from [`FeatureExtractionMixin`] which contains most of the main methods. Users + should refer to this superclass for more information regarding those methods. + + Args: + spectrogram_length (`Dict[str, int]` *optional*, defaults to 2048): + The time length of each audio spectrogram. + num_channels (`int` *optional*, defaults to 1): + Number of audio channels. + patch_size (`List[int]` *optional*, defaults to `[16, 16]`): + The patch size of audio patch embedding. + feature_size (`int`, *optional*, defaults to 128): + The frequency length of audio spectrogram. + sampling_rate (`int`, *optional*, defaults to 44100): + The sampling rate at which the audio files should be digitalized expressed in Hertz (Hz). + hop_length_to_sampling_rate (`int`, *optional*, defaults to 86): + Hop length is length of the overlaping windows for the STFT used to obtain the Mel Frequency coefficients. + For example, with sampling rate 44100, the hop length is 512, with 44100 / 512 = 86 + n_fft (`int`, *optional*, defaults to 2048): + Size of the Fourier transform. + padding_value (`float`, *optional*, defaults to 0.0): + Padding value used to pad the audio. Should correspond to silences. + """ + + model_input_names = ["audio_values", "audio_mask"] + + def __init__( + self, + spectrogram_length=2048, + num_channels=1, + patch_size=[16, 16], + feature_size=128, + sampling_rate=44100, + hop_length_to_sampling_rate=86, + n_fft=2048, + padding_value=0.0, + **kwargs, + ): + super().__init__( + feature_size=feature_size, + sampling_rate=sampling_rate, + padding_value=padding_value, + **kwargs, + ) + + self.spectrogram_length = spectrogram_length + self.num_channels = num_channels + self.patch_size = patch_size + self.freq_len = feature_size // self.patch_size[1] + self.n_fft = n_fft + self.hop_length = sampling_rate // hop_length_to_sampling_rate + self.sampling_rate = sampling_rate + self.padding_value = padding_value + self.mel_filters = mel_filter_bank( + num_frequency_bins=1 + n_fft // 2, + num_mel_filters=feature_size, + min_frequency=0.0, + max_frequency=22050.0, + sampling_rate=sampling_rate, + norm="slaney", + mel_scale="slaney", + ).T + + def _np_extract_fbank_features(self, waveform: np.array) -> np.ndarray: + """ + Compute the log-mel spectrogram of the provided audio, gives similar results to Whisper's original torch + implementation with 1e-5 tolerance. + """ + log_spec = spectrogram( + waveform, + window_function(self.n_fft, "hann"), + frame_length=self.n_fft, + hop_length=self.hop_length, + power=2.0, + mel_filters=self.mel_filters.T, + log_mel="dB", + db_range=80.0, + ) + log_spec = log_spec[:, :-1] + log_spec = log_spec - 20.0 + log_spec = np.clip(log_spec / 40.0, -2.0, 0.0) + 1.0 + return log_spec + + def __call__( + self, + raw_speech: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]], + return_tensors: Optional[Union[str, TensorType]] = None, + return_attention_mask: Optional[bool] = True, + sampling_rate: Optional[int] = None, + resample: bool = False, + mask_audio: bool = False, + **kwargs, + ) -> BatchFeature: + """ + Main method to prepare one or several audio(s) for the model. + + Args: + raw_speech (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`): + The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float + values, a list of numpy arrays or a list of list of float values. Must be mono channel audio, not + stereo, i.e. single float per timestep. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors instead of list of python integers. Acceptable values are: + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return Numpy `np.ndarray` objects. + return_attention_mask (`bool`, *optional*, default to `True`): + Whether to return the attention mask. If left to the default, will return the attention mask according + to the specific feature_extractor's default. [What are attention masks?](../glossary#attention-mask) + + + + For TvltTransformer models, `attention_mask` should alwys be passed for batched inference, to avoid + subtle bugs. + + + + sampling_rate (`int`, *optional*): + The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass + `sampling_rate` at the forward call to prevent silent errors and allow automatic speech recognition + pipeline. Current model supports sampling rate 16000 and 44100. + resample (`bool`, *optional*, defaults to `False`): + If the sampling rate is not matched, resample the input audio to match. + mask_audio (`bool`, *optional*, defaults to `False`): + Whether or not to mask input audio for MAE task. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **audio_values** -- Audio values to be fed to a model, of shape (batch_size, num_channels, height, + width). + + - **audio_mask** -- Audio masks to be fed to a model, of shape (batch_size, num_audio_patches). + """ + + if sampling_rate is not None: + if sampling_rate != self.sampling_rate: + raise ValueError( + "This feature extractor is set to support sampling rate" + f" of {self.sampling_rate}. Please make sure that the provided `raw_speech` input was sampled" + f" with {self.sampling_rate} and not {sampling_rate}." + ) + else: + logger.warning( + "It is strongly recommended to pass the `sampling_rate` argument to this function. " + "Failing to do so can result in silent errors that might be hard to debug." + ) + + is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1 + if is_batched_numpy and len(raw_speech.shape) > 2: + raise ValueError(f"Only mono-channel audio is supported for input to {self}") + is_batched = is_batched_numpy or ( + isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list))) + ) + if is_batched: + raw_speech = [np.asarray([speech], dtype=np.float32).T for speech in raw_speech] + elif not is_batched and not isinstance(raw_speech, np.ndarray): + raw_speech = np.asarray(raw_speech, dtype=np.float32) + elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(np.float64): + raw_speech = raw_speech.astype(np.float32) + # always return batch + if not is_batched: + raw_speech = [np.asarray([raw_speech]).T] + + # Convert audio signals to log mel spectrograms, truncate by time axis + audio_features = [ + self._np_extract_fbank_features(waveform.squeeze()).T[: self.spectrogram_length] for waveform in raw_speech + ] + if isinstance(audio_features[0], List): + audio_features = [np.asarray(feature, dtype=np.float32) for feature in audio_features] + + # Create audio attention mask + max_patch_len = max( + [ceil(feature.shape[0] / self.patch_size[0]) * self.freq_len for feature in audio_features] + ) # The maximum number of audio patches in a batch + if return_attention_mask: + audio_mask = [ + (ceil(feature.shape[0] / self.patch_size[0]) * self.freq_len) * [1] + + (max_patch_len - ceil(feature.shape[0] / self.patch_size[0]) * self.freq_len) * [0] + for feature in audio_features + ] + audio_mask = np.array(audio_mask).astype(np.float32) + + # convert into correct format for padding + max_time_len = max_patch_len // self.freq_len * self.patch_size[0] # The maximum audio size in a batch + padded_audio_features = np.ones([len(audio_features), 1, max_time_len, self.feature_size]).astype(np.float32) + padded_audio_features = padded_audio_features * self.padding_value + for i in range(len(audio_features)): + feature = audio_features[i] + padded_audio_features[i, :, : feature.shape[0], :] = feature + + # return as BatchFeature + if return_attention_mask: + data = {"audio_values": padded_audio_features, "audio_mask": audio_mask} + else: + data = {"audio_values": padded_audio_features} + + encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors) + return encoded_inputs + + +__all__ = ["TvltFeatureExtractor"] diff --git a/docs/transformers/build/lib/transformers/models/deprecated/tvlt/image_processing_tvlt.py b/docs/transformers/build/lib/transformers/models/deprecated/tvlt/image_processing_tvlt.py new file mode 100644 index 0000000000000000000000000000000000000000..db87d6e8d5686e9d5f5b04ab9020d0e4bed7aa19 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/tvlt/image_processing_tvlt.py @@ -0,0 +1,438 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for TVLT.""" + +from typing import Dict, List, Optional, Union + +import numpy as np + +from ....image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ....image_transforms import ( + get_resize_output_image_size, + resize, + to_channel_dimension_format, +) +from ....image_utils import ( + IMAGENET_STANDARD_MEAN, + IMAGENET_STANDARD_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + infer_channel_dimension_format, + is_scaled_image, + is_valid_image, + to_numpy_array, + valid_images, + validate_kwargs, + validate_preprocess_arguments, +) +from ....utils import TensorType, logging + + +logger = logging.get_logger(__name__) + + +def make_batched(videos) -> List[List[ImageInput]]: + if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)): + return videos + + elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]): + videos_dim = np.array(videos[0]).ndim + if videos_dim == 3: + return [videos] + elif videos_dim == 4: + return videos + + elif is_valid_image(videos): + videos_dim = np.array(videos).ndim + if videos_dim == 3: + return [[videos]] + elif videos_dim == 4: + return [videos] + elif videos_dim == 5: + return videos + + raise ValueError(f"Could not make batched video from {videos}") + + +class TvltImageProcessor(BaseImageProcessor): + r""" + Constructs a TVLT image processor. + + This processor can be used to prepare either videos or images for the model by converting images to 1-frame videos. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the + `do_resize` parameter in the `preprocess` method. + size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 224}`): + Size of the output image after resizing. The shortest edge of the image will be resized to + `size["shortest_edge"]` while maintaining the aspect ratio of the original image. Can be overriden by + `size` in the `preprocess` method. + patch_size (`List[int]` *optional*, defaults to [16,16]): + The patch size of image patch embedding. + num_frames (`int` *optional*, defaults to 8): + The maximum number of video frames. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`): + Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the + `preprocess` method. + do_center_crop (`bool`, *optional*, defaults to `True`): + Whether to center crop the image to the specified `crop_size`. Can be overridden by the `do_center_crop` + parameter in the `preprocess` method. + crop_size (`Dict[str, int]`, *optional*, defaults to `{"height": 224, "width": 224}`): + Size of the image after applying the center crop. Can be overridden by the `crop_size` parameter in the + `preprocess` method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale` + parameter in the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to 1/255): + Defines the scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter + in the `preprocess` method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess` + method. + image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + """ + + model_input_names = [ + "pixel_values", + "pixel_mask", + "pixel_values_mixed", + "pixel_mask_mixed", + ] + + def __init__( + self, + do_resize: bool = True, + size: Dict[str, int] = None, + patch_size: List[int] = [16, 16], + num_frames: int = 8, + resample: PILImageResampling = PILImageResampling.BILINEAR, + do_center_crop: bool = True, + crop_size: Dict[str, int] = None, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, List[float]]] = IMAGENET_STANDARD_MEAN, + image_std: Optional[Union[float, List[float]]] = IMAGENET_STANDARD_STD, + init_mask_generator=False, + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"shortest_edge": 224} + size = get_size_dict(size, default_to_square=False) + crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224} + crop_size = get_size_dict(crop_size, param_name="crop_size") + + self.do_resize = do_resize + self.size = size + self.patch_size = patch_size + self.num_frames = num_frames + self.do_center_crop = do_center_crop + self.crop_size = crop_size + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean + self.image_std = image_std + self._valid_processor_keys = [ + "videos", + "do_resize", + "size", + "patch_size", + "num_frames", + "resample", + "do_center_crop", + "crop_size", + "do_rescale", + "rescale_factor", + "do_normalize", + "image_mean", + "image_std", + "is_mixed", + "return_tensors", + "data_format", + "input_data_format", + ] + + def resize( + self, + image: np.ndarray, + size: Dict[str, int], + resample: PILImageResampling = PILImageResampling.BILINEAR, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize an image. + + Args: + image (`np.ndarray`): + Image to resize. + size (`Dict[str, int]`): + Size of the output image. If `size` is of the form `{"height": h, "width": w}`, the output image will + have the size `(h, w)`. If `size` is of the form `{"shortest_edge": s}`, the output image will have its + shortest edge of length `s` while keeping the aspect ratio of the original image. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`): + Resampling filter to use when resiizing the image. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + input_data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + size = get_size_dict(size, default_to_square=False) + if "shortest_edge" in size: + output_size = get_resize_output_image_size( + image, size["shortest_edge"], default_to_square=False, input_data_format=input_data_format + ) + elif "height" in size and "width" in size: + output_size = (size["height"], size["width"]) + else: + raise ValueError(f"Size must have 'height' and 'width' or 'shortest_edge' as keys. Got {size.keys()}") + return resize( + image, + size=output_size, + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + def _preprocess_image( + self, + image: ImageInput, + do_resize: Optional[bool] = None, + size: Dict[str, int] = None, + resample: PILImageResampling = None, + do_center_crop: Optional[bool] = None, + crop_size: Dict[str, int] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> np.ndarray: + """Preprocesses a single image.""" + + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_resize=do_resize, + size=size, + resample=resample, + ) + + # All transformations expect numpy arrays. + image = to_numpy_array(image) + + if do_rescale and is_scaled_image(image): + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + input_data_format = infer_channel_dimension_format(image) + + if do_resize: + image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format) + + if do_center_crop: + image = self.center_crop(image, size=crop_size, input_data_format=input_data_format) + + if do_rescale: + image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + + if do_normalize: + image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) + image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + return image + + def preprocess( + self, + videos: ImageInput, + do_resize: Optional[bool] = None, + size: Dict[str, int] = None, + patch_size: List[int] = None, + num_frames: Optional[int] = None, + resample: PILImageResampling = None, + do_center_crop: Optional[bool] = None, + crop_size: Dict[str, int] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + is_mixed: bool = False, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: ChannelDimension = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> BatchFeature: + """ + Preprocess an videos or image or batch of videos or images. + + Args: + videos (`ImageInput`): + Images or videos to preprocess. Expects a single or batch of frames with pixel values ranging from 0 to + 255. If passing in frames with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image after applying resize. + patch_size (`List[int]` *optional*, defaults to self.patch_size): + The patch size of image patch embedding. + num_frames (`int` *optional*, defaults to self.num_frames): + The maximum number of video frames. + resample (`PILImageResampling`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only + has an effect if `do_resize` is set to `True`. + do_center_crop (`bool`, *optional*, defaults to `self.do_centre_crop`): + Whether to centre crop the image. + crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`): + Size of the image after applying the centre crop. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image values between [0 - 1]. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation. + is_mixed (`bool`, *optional*): + If the input video has negative samples. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the inferred channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height, + width). + + - **pixel_mask** -- Pixel masks to be fed to a model, of shape (batch_size, num_pixel_patches). + + - **pixel_values_mixed** -- Pixel values with both postive or negative to be fed to a model, of shape + (batch_size, num_channels, height, width). + + - **pixel_mask_mixed** -- Pixel masks with both postive or negative to be fed to a model, of shape + (batch_size, num_pixel_patches). + """ + do_resize = do_resize if do_resize is not None else self.do_resize + resample = resample if resample is not None else self.resample + do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + + size = size if size is not None else self.size + size = get_size_dict(size, default_to_square=False) + crop_size = crop_size if crop_size is not None else self.crop_size + crop_size = get_size_dict(crop_size, param_name="crop_size") + patch_size = patch_size if patch_size is not None else self.patch_size + num_frames = num_frames if patch_size is not None else self.num_frames + + validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys) + + if not valid_images(videos): + raise ValueError( + "Invalid image or video type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + videos = make_batched(videos) + + # Check number of frames is fewer than maximum frames + for video in videos: + if len(video) > self.num_frames: + raise ValueError( + f"number of frames must not be greater than the maximum frames of the model {self.num_frames}." + ) + + max_num_frames = max([len(video) for video in videos]) + num_patches_per_image = (size["shortest_edge"] // patch_size[0]) ** 2 + video_masks = np.array( + [ + len(video) * num_patches_per_image * [1] + (max_num_frames - len(video)) * num_patches_per_image * [0] + for video in videos + ] + ) + + videos = [ + [ + self._preprocess_image( + image=img, + do_resize=do_resize, + size=size, + resample=resample, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + data_format=data_format, + input_data_format=input_data_format, + ) + for img in video + ] + for video in videos + ] + + # If videos contain both positive/negative, use mixed key for video-audio matching task + if is_mixed: + data = {"pixel_values_mixed": videos, "pixel_mask_mixed": video_masks} + else: + data = {"pixel_values": videos, "pixel_mask": video_masks} + + return BatchFeature(data=data, tensor_type=return_tensors) + + +__all__ = ["TvltImageProcessor"] diff --git a/docs/transformers/build/lib/transformers/models/deprecated/tvlt/modeling_tvlt.py b/docs/transformers/build/lib/transformers/models/deprecated/tvlt/modeling_tvlt.py new file mode 100644 index 0000000000000000000000000000000000000000..279224ac4d24a2c4b3a7694dbd20a77390a0a793 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/tvlt/modeling_tvlt.py @@ -0,0 +1,1291 @@ +# coding=utf-8 +# Copyright 2023 MURGe-Lab and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch TVLT model.""" + +import collections.abc +import math +from copy import deepcopy +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ....activations import ACT2FN +from ....modeling_outputs import BaseModelOutput, SequenceClassifierOutput +from ....modeling_utils import PreTrainedModel +from ....pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ....utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from .configuration_tvlt import TvltConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "TvltConfig" +_CHECKPOINT_FOR_DOC = "ZinengTang/tvlt-base" + + +@dataclass +class TvltModelOutput(ModelOutput): + """ + Class for TvltModel's outputs, with potential hidden states and attentions. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + last_pixel_hidden_state (`torch.FloatTensor` of shape `(batch_size, pixel_sequence_length, hidden_size)`): + Pixel sequence of hidden-states at the output of the last layer of the model. + last_audio_hidden_state (`torch.FloatTensor` of shape `(batch_size, audio_sequence_length, hidden_size)`): + Audio sequence of hidden-states at the output of the last layer of the model. + pixel_label_masks (`torch.FloatTensor` of shape `(batch_size, pixel_patch_length)`): + Tensor indicating which pixel patches are masked (1) and which are not (0). + audio_label_masks (`torch.FloatTensor` of shape `(batch_size, audio_patch_length)`): + Tensor indicating which audio patches are masked (1) and which are not (0). + pixel_ids_restore (`torch.LongTensor` of shape `(batch_size, pixel_patch_length)`): + Tensor containing the ids permutation of pixel masking. + audio_ids_restore (`torch.LongTensor` of shape `(batch_size, audio_patch_length)`): + Tensor containing the ids permutation of audio masking. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings and one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer + plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in + the self-attention heads. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + last_pixel_hidden_state: Optional[torch.FloatTensor] = None + last_audio_hidden_state: Optional[torch.FloatTensor] = None + pixel_label_masks: Optional[torch.LongTensor] = None + audio_label_masks: Optional[torch.LongTensor] = None + pixel_ids_restore: Optional[torch.LongTensor] = None + audio_ids_restore: Optional[torch.LongTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class TvltDecoderOutput(ModelOutput): + """ + Class for TvltDecoder's outputs, with potential hidden states and attentions. + + Args: + logits (`torch.FloatTensor` of shape `(batch_size, patch_size ** 2 * num_channels)`): + Pixel reconstruction logits. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings and one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer + plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in + the self-attention heads. + """ + + logits: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class TvltForPreTrainingOutput(ModelOutput): + """ + Class for TvltForPreTraining's outputs, with potential hidden states and attentions. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`): + Pixel reconstruction loss. + matching_logits (`torch.FloatTensor` of shape `(batch_size, 1)`): + Matching objective logits. + pixel_logits (`torch.FloatTensor` of shape + `(batch_size, pixel_patch_length, image_patch_size ** 3 * pixel_num_channels)`): Pixel reconstruction + logits. + audio_logits (`torch.FloatTensor` of shape + `(batch_size, audio_patch_length, image_patch_size[0] * image_patch_size[1])`): Audio reconstruction + logits. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings and one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer + plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in + the self-attention heads. + """ + + loss: Optional[torch.FloatTensor] = None + matching_logits: Optional[torch.FloatTensor] = None + pixel_logits: Optional[torch.FloatTensor] = None + audio_logits: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +def generate_pixel_mask_noise(pixel_values, pixel_mask=None, mask_ratio=0.75): + """Generate noise for audio masking.""" + + batch_size, seq_len = pixel_values.shape[:2] + noise = torch.rand((batch_size, seq_len), device=pixel_values.device) # noise in [0, 1] + len_keep = int(seq_len * (1 - mask_ratio)) + return noise, len_keep + + +def generate_audio_mask_noise(audio_values, audio_mask=None, mask_ratio=0.75, mask_type="patch-level", freq_len=8): + """Generate noise for audio masking.""" + + batch_size, seq_len = audio_values.shape[:2] + if mask_type == "frame-level": + num_time_patches = seq_len // freq_len + noise = ( + torch.rand(batch_size, num_time_patches, device=audio_values.device) + .unsqueeze(-1) + .repeat(1, 1, freq_len) + .view(batch_size, seq_len) + ) # noise in [0, 1] + elif mask_type == "patch-level": + noise = torch.rand(batch_size, seq_len, device=audio_values.device) # noise in [0, 1] + len_keep = int(seq_len * (1 - mask_ratio)) + return noise, len_keep + + +def random_masking(sequence, noise, len_keep, attention_masks=None): + """ + Perform random masking by per-sample shuffling on frame-level. Per-sample shuffling is done by argsort random + noise. sequence: [batch_size, seq_len, hidden_dim], sequence + """ + + batch_size, seq_len, hidden_dim = sequence.shape + + # sort noise for each sample + ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove + ids_restore = torch.argsort(ids_shuffle, dim=1) + + # keep the first subset + ids_keep = ids_shuffle[:, :len_keep] + sequence_masked = torch.gather(sequence, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, hidden_dim)) + + # generate the binary mask: 0 is keep, 1 is remove + label_masks = torch.ones([batch_size, seq_len], device=sequence.device) + label_masks[:, :len_keep] = 0 + # unshuffle to get the binary mask + label_masks = torch.gather(label_masks, dim=1, index=ids_restore) + + if attention_masks is not None: + label_masks *= attention_masks + attention_masks = torch.gather(attention_masks, dim=1, index=ids_keep) + + return sequence_masked, attention_masks, label_masks, ids_restore + + +class TvltPixelEmbeddings(nn.Module): + """Construct the patch and position embeddings.""" + + def __init__(self, config): + super().__init__() + + self.patch_embeddings = TvltPixelPatchEmbeddings(config) + self.num_patches_per_image = self.patch_embeddings.num_patches_per_image + + self.type_embed_v = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) + self.temporal_embed = nn.Parameter(torch.zeros(1, config.num_frames, config.hidden_size)) + self.pos_embed_v = nn.Parameter(torch.zeros(1, self.num_patches_per_image, config.hidden_size)) + + self.config = config + + def forward(self, pixel_values, attention_masks=None): + # create patch embeddings + batch_size, num_frames, num_channels, height, width = pixel_values.shape + + embeddings = self.patch_embeddings(pixel_values) + embeddings += self.pos_embed_v.repeat(1, num_frames, 1) + embeddings += torch.repeat_interleave(self.temporal_embed[:, :num_frames], self.num_patches_per_image, dim=1) + embeddings += self.type_embed_v + + return embeddings, attention_masks + + +class TvltAudioEmbeddings(nn.Module): + """Construct the patch and position embeddings.""" + + def __init__(self, config): + super().__init__() + + self.patch_embeddings = TvltAudioPatchEmbeddings(config) + self.num_patches = self.patch_embeddings.num_patches + + self.type_embed_a = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) + self.num_freq_patches = config.frequency_length // config.audio_patch_size[1] + self.pos_embed_a = nn.Parameter(torch.zeros(1, self.num_patches // self.num_freq_patches, config.hidden_size)) + self.freq_embed = nn.Parameter(torch.zeros(1, self.num_freq_patches, config.hidden_size)) + + self.num_freq_patches = config.frequency_length // config.audio_patch_size[1] + self.config = config + + def forward(self, audio_values, attention_masks=None): + # create patch embeddings + embeddings = self.patch_embeddings(audio_values) + + num_time_patches = embeddings.size(1) // self.num_freq_patches + embeddings += self.freq_embed.repeat(1, num_time_patches, 1) + embeddings += torch.repeat_interleave(self.pos_embed_a[:, :num_time_patches], self.num_freq_patches, dim=1) + embeddings += self.type_embed_a + + return embeddings, attention_masks + + +class TvltPixelPatchEmbeddings(nn.Module): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__(self, config): + super().__init__() + image_size, patch_size = config.image_size, config.image_patch_size + num_channels, hidden_size = config.num_image_channels, config.hidden_size + + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + num_patches_per_image = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches_per_image = num_patches_per_image + self.hidden_size = hidden_size + + self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + batch_size, num_frames, num_channels, height, width = pixel_values.shape + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + if height != self.image_size[0] or width != self.image_size[1]: + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})." + ) + + pixel_values = pixel_values.reshape(batch_size * num_frames, num_channels, height, width) + embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2) + embeddings = embeddings.reshape(batch_size, num_frames * self.num_patches_per_image, self.hidden_size) + + return embeddings + + +class TvltAudioPatchEmbeddings(nn.Module): + """ + This class turns `audio_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__(self, config): + super().__init__() + spectrogram_length, frequency_length, patch_size = ( + config.spectrogram_length, + config.frequency_length, + config.audio_patch_size, + ) + num_channels, hidden_size = config.num_audio_channels, config.hidden_size + + spectrogram_size = (spectrogram_length, frequency_length) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + num_patches = (spectrogram_size[1] // patch_size[1]) * (spectrogram_size[0] // patch_size[0]) + patch_shape = (spectrogram_size[0] // patch_size[0], spectrogram_size[1] // patch_size[1]) + self.spectrogram_size = spectrogram_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + self.patch_shape = patch_shape + + self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) + + def forward(self, audio_values: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, height, width = audio_values.shape + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + if height > self.spectrogram_size[0] or width != self.spectrogram_size[1]: + raise ValueError( + f"Input audio size ({height}*{width}) doesn't match model" + f" ({self.spectrogram_size[0]}*{self.spectrogram_size[1]})." + ) + embeddings = self.projection(audio_values).flatten(2).transpose(1, 2) + + return embeddings + + +class TvltSelfAttention(nn.Module): + def __init__(self, config): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size {config.hidden_size} is not a multiple of the number of attention " + f"heads {config.num_attention_heads}." + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False): + mixed_query_layer = self.query(hidden_states) + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +class TvltSelfOutput(nn.Module): + """ + The residual connection is defined in TvltLayer instead of here (as is the case with other models), due to the + layernorm applied before each block. + """ + + def __init__(self, config: TvltConfig) -> None: + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + return hidden_states + + +class TvltAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.attention = TvltSelfAttention(config) + self.output = TvltSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.attention.query = prune_linear_layer(self.attention.query, index) + self.attention.key = prune_linear_layer(self.attention.key, index) + self.attention.value = prune_linear_layer(self.attention.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads) + self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False): + self_outputs = self.attention(hidden_states, attention_mask, head_mask, output_attentions) + + attention_output = self.output(self_outputs[0], hidden_states) + + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class TvltIntermediate(nn.Module): + def __init__(self, config: TvltConfig) -> None: + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + + return hidden_states + + +class TvltOutput(nn.Module): + def __init__(self, config: TvltConfig) -> None: + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + hidden_states = hidden_states + input_tensor + + return hidden_states + + +class TvltLayer(nn.Module): + """This corresponds to the Block class in the timm implementation.""" + + def __init__(self, config): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = TvltAttention(config) + self.intermediate = TvltIntermediate(config) + self.output = TvltOutput(config) + self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False): + self_attention_outputs = self.attention( + self.layernorm_before(hidden_states), # in ViLT, layernorm is applied before self-attention + attention_mask, + head_mask, + output_attentions=output_attentions, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + # first residual connection + hidden_states = attention_output + hidden_states.to(attention_output.device) + + # in ViLT, layernorm is also applied after self-attention + layer_output = self.layernorm_after(hidden_states) + layer_output = self.intermediate(layer_output) + + # second residual connection is done here + layer_output = self.output(layer_output, hidden_states) + + outputs = (layer_output,) + outputs + + return outputs + + +class TvltEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([TvltLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + attention_mask, + layer_head_mask, + output_attentions, + ) + else: + layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class TvltPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = TvltConfig + base_model_prefix = "tvlt" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +TVLT_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it + as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`TvltConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +TVLT_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`TvltProcessor`]. See [`TvltProcessor.__call__`] for + details. + + audio_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Audio values. Audio values can be obtained using [`TvltProcessor`]. See [`TvltProcessor.__call__`] for + details. + + pixel_mask (`torch.FloatTensor` of shape `(batch_size, num_pixel_patches)`): + Pixel masks. Pixel masks can be obtained using [`TvltProcessor`]. See [`TvltProcessor.__call__`] for + details. + + audio_mask (`torch.FloatTensor` of shape `(batch_size, num_audio_patches)`): + Audio masks. Audio masks can be obtained using [`TvltProcessor`]. See [`TvltProcessor.__call__`] for + details. + + pixel_values_mixed (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`): + Pixel values that mix positive and negative samples in Tvlt vision-audio matching. Pixel values mixed can + be obtained using [`TvltProcessor`]. See [`TvltProcessor.__call__`] for details. + + pixel_mask_mixed (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel masks of pixel_values_mixed. Pixel masks mixed can be obtained using [`TvltProcessor`]. See + [`TvltProcessor.__call__`] for details. + + mask_pixel (`bool`, *optional*): + Whether to mask pixel for MAE tasks. Only set to True in TvltForPreTraining. + + mask_audio (`bool`, *optional*): + Whether to mask audio for MAE tasks. Only set to True in TvltForPreTraining. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare TVLT Model transformer outputting raw hidden-states without any specific head on top.", + TVLT_START_DOCSTRING, +) +class TvltModel(TvltPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.config = config + + self.pixel_embeddings = TvltPixelEmbeddings(config) + self.audio_embeddings = TvltAudioEmbeddings(config) + self.encoder = TvltEncoder(config) + + self.cls_embedding = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) + + if config.use_mean_pooling: + self.layernorm = None + else: + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.pixel_embeddings.patch_embeddings, self.audio_embeddings.patch_embeddings + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(TVLT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TvltModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: torch.FloatTensor, + audio_values: torch.FloatTensor, + pixel_mask: Optional[torch.FloatTensor] = None, + audio_mask: Optional[torch.FloatTensor] = None, + mask_pixel: bool = False, + mask_audio: bool = False, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.FloatTensor], TvltModelOutput]: + r""" + Returns: + + Examples: + + ```python + >>> from transformers import TvltProcessor, TvltModel + >>> import numpy as np + >>> import torch + + >>> num_frames = 8 + >>> images = list(np.random.randn(num_frames, 3, 224, 224)) + >>> audio = list(np.random.randn(10000)) + + >>> processor = TvltProcessor.from_pretrained("ZinengTang/tvlt-base") + >>> model = TvltModel.from_pretrained("ZinengTang/tvlt-base") + + >>> input_dict = processor(images, audio, sampling_rate=44100, return_tensors="pt") + + >>> outputs = model(**input_dict) + >>> loss = outputs.loss + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + pixel_embedding_output, pixel_mask = self.pixel_embeddings(pixel_values, pixel_mask) + + audio_embedding_output, audio_mask = self.audio_embeddings(audio_values, audio_mask) + + # Mask pixel if mask_pixel is True + pixel_label_masks = None + pixel_ids_restore = None + if mask_pixel: + pixel_mask_noise, pixel_len_keep = generate_pixel_mask_noise( + pixel_embedding_output, pixel_mask=pixel_mask, mask_ratio=self.config.pixel_mask_ratio + ) + pixel_embedding_output, pixel_mask, pixel_label_masks, pixel_ids_restore = random_masking( + pixel_embedding_output, + pixel_mask_noise, + pixel_len_keep, + attention_masks=pixel_mask, + ) + + # Mask audio if mask_audio is True + audio_label_masks = None + audio_ids_restore = None + if mask_audio: + num_freq_patches = self.config.frequency_length // self.config.audio_patch_size[1] + audio_mask_noise, audio_len_keep = generate_audio_mask_noise( + audio_embedding_output, + audio_mask=audio_mask, + mask_ratio=self.config.audio_mask_ratio, + mask_type=self.config.audio_mask_type, + freq_len=num_freq_patches, + ) + audio_embedding_output, audio_mask, audio_label_masks, audio_ids_restore = random_masking( + audio_embedding_output, + audio_mask_noise, + audio_len_keep, + attention_masks=audio_mask, + ) + + # Prepare for encoder inputs and attention masks + batch_size = pixel_values.size(0) + embedding_output = torch.cat( + [self.cls_embedding.repeat(batch_size, 1, 1), pixel_embedding_output, audio_embedding_output], 1 + ) + masked_pixel_len = pixel_embedding_output.size(1) + + attention_mask = None + if pixel_mask is not None and audio_mask is not None: + attention_mask = torch.cat([pixel_mask[:, :1], pixel_mask, audio_mask], 1) + + input_shape = embedding_output.size() + extended_attention_mask = None + if attention_mask is not None: + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + if self.layernorm is not None: + sequence_output = self.layernorm(sequence_output) + + pixel_sequence_output = sequence_output[:, 1 : 1 + masked_pixel_len] + audio_sequence_output = sequence_output[:, 1 + masked_pixel_len :] + if not return_dict: + return ( + sequence_output, + pixel_sequence_output, + audio_sequence_output, + pixel_label_masks, + audio_label_masks, + pixel_ids_restore, + audio_ids_restore, + ) + encoder_outputs[1:] + + return TvltModelOutput( + last_hidden_state=sequence_output, + last_pixel_hidden_state=pixel_sequence_output, + last_audio_hidden_state=audio_sequence_output, + pixel_label_masks=pixel_label_masks, + audio_label_masks=audio_label_masks, + pixel_ids_restore=pixel_ids_restore, + audio_ids_restore=audio_ids_restore, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class TvltDecoder(nn.Module): + def __init__(self, config): + super().__init__() + + decoder_config = deepcopy(config) + decoder_config.hidden_size = config.decoder_hidden_size + decoder_config.num_hidden_layers = config.decoder_num_hidden_layers + decoder_config.num_attention_heads = config.decoder_num_attention_heads + decoder_config.intermediate_size = config.decoder_intermediate_size + self.decoder_layers = nn.ModuleList( + [TvltLayer(decoder_config) for _ in range(config.decoder_num_hidden_layers)] + ) + + self.layernorm = nn.LayerNorm(config.decoder_hidden_size, eps=config.layer_norm_eps) + + self.gradient_checkpointing = False + self.config = config + + def forward( + self, + hidden_states, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + ): + # apply Transformer layers (blocks) + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + for i, layer_module in enumerate(self.decoder_layers): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + None, + output_attentions, + ) + else: + layer_outputs = layer_module(hidden_states, output_attentions=output_attentions) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # predictor projection + logits = self.layernorm(hidden_states) + + if not return_dict: + return tuple(v for v in [logits, all_hidden_states, all_self_attentions] if v is not None) + return TvltDecoderOutput(logits=logits, hidden_states=all_hidden_states, attentions=all_self_attentions) + + +@add_start_docstrings( + "The TVLT Model transformer with the decoder on top for self-supervised pre-training.", + TVLT_START_DOCSTRING, +) +class TvltForPreTraining(TvltPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.config = config + + self.task_matching = config.task_matching + self.task_mae = config.task_mae + if not (self.task_matching or self.task_mae): + raise ValueError("Must set at least one of matching task and MAE task to true") + + self.tvlt = TvltModel(config) + + if self.task_matching: + self.matching_head = TvltMatchingHead(config) + + if self.task_mae: + self.encoder_to_decoder = nn.Linear(config.hidden_size, config.decoder_hidden_size, bias=True) + + self.pixel_mask_token = nn.Parameter(torch.zeros(1, 1, config.decoder_hidden_size)) + self.audio_mask_token = nn.Parameter(torch.zeros(1, 1, config.decoder_hidden_size)) + + self.decoder = TvltDecoder(config) + + decoder_hidden_size = config.decoder_hidden_size + + num_frames = config.num_frames + num_patches_per_image = self.tvlt.pixel_embeddings.num_patches_per_image + self.decoder_pixel_pos_embed = nn.Parameter(torch.zeros(1, num_patches_per_image, decoder_hidden_size)) + self.decoder_temporal_embed = nn.Parameter(torch.zeros(1, config.num_frames, decoder_hidden_size)) + self.decoder_pixel_type_embed = nn.Parameter(torch.zeros(1, 1, decoder_hidden_size)) + + num_audio_patches = self.tvlt.audio_embeddings.num_patches + num_freq_patches = config.frequency_length // config.audio_patch_size[1] + self.decoder_audio_pos_embed = nn.Parameter( + torch.zeros(1, num_audio_patches // num_freq_patches, decoder_hidden_size) + ) + self.decoder_freq_embed = nn.Parameter(torch.zeros(1, num_freq_patches, decoder_hidden_size)) + self.decoder_audio_type_embed = nn.Parameter(torch.zeros(1, 1, decoder_hidden_size)) + + pixel_mae_output_dim = self.config.image_patch_size[0] ** 2 * self.config.num_image_channels + self.pixel_mae_head = TvltMAEHead(config, pixel_mae_output_dim) + audio_mae_output_dim = ( + self.config.audio_patch_size[0] * self.config.audio_patch_size[1] * self.config.num_audio_channels + ) + self.audio_mae_head = TvltMAEHead(config, audio_mae_output_dim) + + self.num_frames = num_frames + self.num_patches_per_image = num_patches_per_image + self.num_freq_patches = num_freq_patches + self.image_patch_size = config.image_patch_size + self.audio_patch_size = config.audio_patch_size + + # Initialize weights and apply final processing + self.post_init() + + def patchify_pixel(self, pixel_values): + """ + pixel_values: [batch_size, num_frames, 3, height, width] + """ + batch_size, num_frames, num_channels, height, width = pixel_values.shape + num_patches_height = pixel_values.shape[3] // self.image_patch_size[0] + num_patches_width = pixel_values.shape[4] // self.image_patch_size[1] + patchified_pixel_values = pixel_values.reshape( + shape=( + batch_size, + num_frames, + num_channels, + num_patches_height, + self.image_patch_size[0], + num_patches_width, + self.image_patch_size[1], + ) + ) + patchified_pixel_values = torch.einsum("ntchpwq->nthwpqc", patchified_pixel_values) + patchified_pixel_values = patchified_pixel_values.reshape( + shape=( + batch_size, + num_patches_height * num_patches_width * num_frames, + self.image_patch_size[0] * self.image_patch_size[1] * num_channels, + ) + ) + return patchified_pixel_values + + def patchify_audio(self, audio_values): + """ + audio_values: [batch_size, 1, height, width] + """ + batch_size, num_channels, height, width = audio_values.shape + num_patches_height = height // self.audio_patch_size[0] + num_patches_width = width // self.audio_patch_size[1] + patchified_audio_values = audio_values.reshape( + shape=( + batch_size, + num_channels, + num_patches_height, + self.audio_patch_size[0], + num_patches_width, + self.audio_patch_size[1], + ) + ) + patchified_audio_values = torch.einsum("nchpwq->nhwpqc", patchified_audio_values) + patchified_audio_values = patchified_audio_values.reshape( + shape=( + batch_size, + num_patches_height * num_patches_width, + self.audio_patch_size[0] * self.audio_patch_size[1] * num_channels, + ) + ) + return patchified_audio_values + + def pixel_mae_loss(self, pixel_values, pixel_predictions, mask): + patchified_pixel_values = self.patchify_pixel(pixel_values) + loss = (pixel_predictions - patchified_pixel_values) ** 2 + loss = loss.mean(dim=-1) # [batch_size, pixel_pixel_length], mean loss per patch + loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches + return loss + + def audio_mae_loss(self, audio_values, audio_predictions, mask): + patchified_audio_values = self.patchify_audio(audio_values) + loss = (audio_predictions - patchified_audio_values) ** 2 + loss = loss.mean(dim=-1) # [batch_size, audio_pixel_length], mean loss per patch + loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches + return loss + + def concatenate_mask(self, mask_token, sequence, ids_restore): + batch_size, seq_length, dim = sequence.shape + mask_tokens = mask_token.repeat(batch_size, ids_restore.shape[1] - seq_length, 1) + padded_sequence = torch.cat([sequence, mask_tokens], dim=1) + padded_sequence = torch.gather( + padded_sequence, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, dim) + ) # unshuffle + return padded_sequence + + @add_start_docstrings_to_model_forward(TVLT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TvltForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: torch.FloatTensor, + audio_values: torch.FloatTensor, + pixel_mask: Optional[torch.FloatTensor] = None, + audio_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + pixel_values_mixed: Optional[torch.FloatTensor] = None, + pixel_mask_mixed: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.FloatTensor], TvltForPreTrainingOutput]: + r""" + pixel_values_mixed (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`): + Pixel values that mix positive and negative samples in Tvlt vision-audio matching. Audio values can be + obtained using [`TvltProcessor`]. See [`TvltProcessor.__call__`] for details. + + pixel_mask_mixed (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel masks of pixel_values_mixed. Pixel values mixed can be obtained using [`TvltProcessor`]. See + [`TvltProcessor.__call__`] for details. + + labels (`torch.LongTensor` of shape `(batch_size, num_labels)`, *optional*): + Labels for computing the vision audio matching loss. Indices should be in `[0, 1]`. num_labels has to be 1. + + Return: + + Examples: + + ```python + >>> from transformers import TvltProcessor, TvltForPreTraining + >>> import numpy as np + >>> import torch + + >>> num_frames = 8 + >>> images = list(np.random.randn(num_frames, 3, 224, 224)) + >>> images_mixed = list(np.random.randn(num_frames, 3, 224, 224)) + >>> audio = list(np.random.randn(10000)) + >>> processor = TvltProcessor.from_pretrained("ZinengTang/tvlt-base") + >>> model = TvltForPreTraining.from_pretrained("ZinengTang/tvlt-base") + >>> input_dict = processor( + ... images, audio, images_mixed, sampling_rate=44100, mask_pixel=True, mask_audio=True, return_tensors="pt" + ... ) + + >>> outputs = model(**input_dict) + >>> loss = outputs.loss + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + total_loss = 0.0 + + if self.task_matching: + if labels is None: + raise ValueError("Matching task requires labels") + if pixel_values_mixed is None: + raise ValueError("Matching task requires pixel_values_mixed") + + outputs = self.tvlt( + pixel_values_mixed, + audio_values, + pixel_mask=pixel_mask_mixed, + audio_mask=audio_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + matching_logits = self.matching_head(sequence_output) + + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(matching_logits.view(-1), labels.view(-1)) + total_loss += loss + + pixel_logits = None + audio_logits = None + if self.task_mae and self.training: + outputs = self.tvlt( + pixel_values, + audio_values, + pixel_mask=pixel_mask, + audio_mask=audio_mask, + mask_pixel=True, + mask_audio=True, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + pixel_sequence_output = outputs.last_pixel_hidden_state if return_dict else outputs[1] + audio_sequence_output = outputs.last_audio_hidden_state if return_dict else outputs[2] + pixel_label_masks = outputs.pixel_label_masks if return_dict else outputs[3] + audio_label_masks = outputs.audio_label_masks if return_dict else outputs[4] + pixel_ids_restore = outputs.pixel_ids_restore if return_dict else outputs[5] + audio_ids_restore = outputs.audio_ids_restore if return_dict else outputs[6] + + pixel_decoder_input = self.encoder_to_decoder( + pixel_sequence_output + ) # [batch_size, num_masked_pixel_patches, decoder_hidden_size] + audio_decoder_input = self.encoder_to_decoder( + audio_sequence_output + ) # [batch_size, num_masked_audio_patches, decoder_hidden_size] + num_frames = pixel_values.size(1) + pixel_decoder_input = self.concatenate_mask(self.pixel_mask_token, pixel_decoder_input, pixel_ids_restore) + pixel_decoder_input = pixel_decoder_input + self.decoder_pixel_pos_embed.repeat(1, num_frames, 1) + pixel_decoder_input = pixel_decoder_input + torch.repeat_interleave( + self.decoder_temporal_embed[:, :num_frames], self.num_patches_per_image, dim=1 + ) + pixel_decoder_input = pixel_decoder_input + self.decoder_pixel_type_embed + pixel_decoder_outputs = self.decoder(pixel_decoder_input) + pixel_logits = self.pixel_mae_head(pixel_decoder_outputs.logits) + + audio_decoder_input = self.concatenate_mask(self.audio_mask_token, audio_decoder_input, audio_ids_restore) + num_time_patches = audio_decoder_input.size(1) // self.num_freq_patches + audio_decoder_input = audio_decoder_input + self.decoder_freq_embed.repeat(1, num_time_patches, 1) + audio_decoder_input = audio_decoder_input + torch.repeat_interleave( + self.decoder_audio_pos_embed[:, :num_time_patches], self.num_freq_patches, dim=1 + ) + audio_decoder_input = audio_decoder_input + self.decoder_audio_type_embed + audio_decoder_outputs = self.decoder(audio_decoder_input) + audio_logits = self.audio_mae_head(audio_decoder_outputs.logits) + + loss = self.pixel_mae_loss(pixel_values, pixel_logits, pixel_label_masks) + self.audio_mae_loss( + audio_values, audio_logits, audio_label_masks + ) + total_loss += loss + + if not return_dict: + output = (matching_logits, pixel_logits, audio_logits) + outputs[7:] + return ((total_loss,) + output) if loss is not None else output + + return TvltForPreTrainingOutput( + loss=total_loss, + matching_logits=matching_logits, + pixel_logits=pixel_logits, + audio_logits=audio_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class TvltPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class TvltMatchingHead(nn.Module): + def __init__(self, config): + super().__init__() + self.pooler = TvltPooler(config) + self.fc = nn.Linear(config.hidden_size, 1) + + def forward(self, hidden_states): + hidden_states = self.fc(self.pooler(hidden_states)) + return hidden_states + + +class TvltMAEHead(nn.Module): + def __init__(self, config, output_dim=None): + super().__init__() + self.config = config + self.decoder = nn.Linear(config.decoder_hidden_size, output_dim) + + def forward(self, hidden_states): + hidden_states = self.decoder(hidden_states) + return hidden_states + + +@add_start_docstrings( + """ + Tvlt Model transformer with a classifier head on top (an MLP on top of the final hidden state of the [CLS] token) + for audiovisual classification tasks, e.g. CMU-MOSEI Sentiment Analysis and Audio to Video Retrieval. + """, + TVLT_START_DOCSTRING, +) +class TvltForAudioVisualClassification(TvltPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.tvlt = TvltModel(config) + + # Classifier head + self.classifier = nn.Sequential( + nn.Linear(config.hidden_size, config.hidden_size * 2), + nn.LayerNorm(config.hidden_size * 2, eps=config.layer_norm_eps), + nn.GELU(), + nn.Linear(config.hidden_size * 2, config.num_labels), + ) + self.config = config + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(TVLT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: torch.FloatTensor, + audio_values: torch.FloatTensor, + pixel_mask: Optional[torch.FloatTensor] = None, + audio_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: Optional[torch.LongTensor] = None, + ) -> Union[Tuple[torch.FloatTensor], SequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, num_labels)`, *optional*): + Labels for computing the audiovisual loss. Indices should be in `[0, ..., num_classes-1]` where num_classes + refers to the number of classes in audiovisual tasks. + + Return: + + Examples: + ```python + >>> from transformers import TvltProcessor, TvltForAudioVisualClassification + >>> import numpy as np + >>> import torch + + >>> num_frames = 8 + >>> images = list(np.random.randn(num_frames, 3, 224, 224)) + >>> audio = list(np.random.randn(10000)) + >>> processor = TvltProcessor.from_pretrained("ZinengTang/tvlt-base") + >>> model = TvltForAudioVisualClassification.from_pretrained("ZinengTang/tvlt-base") + >>> input_dict = processor(images, audio, sampling_rate=44100, return_tensors="pt") + + >>> outputs = model(**input_dict) + >>> loss = outputs.loss + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.tvlt( + pixel_values, + audio_values, + pixel_mask=pixel_mask, + audio_mask=audio_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0][:, 0] + logits = self.classifier(sequence_output) # rank value + + loss = None + if labels is not None: + if self.config.loss_type == "regression": + loss_fct = MSELoss() + loss = loss_fct(logits, labels) + elif self.config.loss_type == "classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[4:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = ["TvltModel", "TvltForPreTraining", "TvltForAudioVisualClassification", "TvltPreTrainedModel"] diff --git a/docs/transformers/build/lib/transformers/models/deprecated/tvlt/processing_tvlt.py b/docs/transformers/build/lib/transformers/models/deprecated/tvlt/processing_tvlt.py new file mode 100644 index 0000000000000000000000000000000000000000..d9f8e0978d8a2c4158094ebdf3bc2279700e7ba5 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/tvlt/processing_tvlt.py @@ -0,0 +1,92 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Processor class for TVLT. +""" + +from ....processing_utils import ProcessorMixin + + +class TvltProcessor(ProcessorMixin): + r""" + Constructs a TVLT processor which wraps a TVLT image processor and TVLT feature extractor into a single processor. + + [`TvltProcessor`] offers all the functionalities of [`TvltImageProcessor`] and [`TvltFeatureExtractor`]. See the + docstring of [`~TvltProcessor.__call__`] for more information. + + Args: + image_processor (`TvltImageProcessor`): + An instance of [`TvltImageProcessor`]. The image processor is a required input. + feature_extractor (`TvltFeatureExtractor`): + An instance of [`TvltFeatureExtractor`]. The feature extractor is a required input. + """ + + attributes = ["image_processor", "feature_extractor"] + image_processor_class = "TvltImageProcessor" + feature_extractor_class = "TvltFeatureExtractor" + + def __init__(self, image_processor, feature_extractor): + super().__init__(image_processor=image_processor, feature_extractor=feature_extractor) + + self.image_processor = image_processor + self.feature_extractor = feature_extractor + + def __call__( + self, + images=None, + audio=None, + images_mixed=None, + sampling_rate=None, + mask_audio=False, + mask_pixel=False, + *args, + **kwargs, + ): + """ + Forwards the `images` argument to TvltImageProcessor's [`~TvltImageProcessor.preprocess`] and the `audio` + argument to TvltFeatureExtractor's [`~TvltFeatureExtractor.__call__`]. Please refer to the docstring of the + above two methods for more information. + """ + + if images is None and audio is None: + raise ValueError("You need to specify either an `images` or `audio` input to process.") + + images_mixed_dict = None + if images is not None: + images_dict = self.image_processor(images, mask_pixel=mask_pixel, *args, **kwargs) + if images_mixed is not None: + images_mixed_dict = self.image_processor(images_mixed, is_mixed=True, *args, **kwargs) + if audio is not None: + audio_dict = self.feature_extractor( + audio, *args, sampling_rate=sampling_rate, mask_audio=mask_audio, **kwargs + ) + + output_dict = {} + if audio is not None: + output_dict.update(audio_dict) + if images is not None: + output_dict.update(images_dict) + if images_mixed_dict is not None: + output_dict.update(images_mixed_dict) + return output_dict + + @property + def model_input_names(self): + image_processor_input_names = self.image_processor.model_input_names + feature_extractor_input_names = self.feature_extractor.model_input_names + return list(dict.fromkeys(image_processor_input_names + feature_extractor_input_names)) + + +__all__ = ["TvltProcessor"] diff --git a/docs/transformers/build/lib/transformers/models/deprecated/van/__init__.py b/docs/transformers/build/lib/transformers/models/deprecated/van/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9552c827365d3913539be3eafaf822146ada5829 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/van/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ....utils import _LazyModule +from ....utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_van import * + from .modeling_van import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/build/lib/transformers/models/deprecated/van/configuration_van.py b/docs/transformers/build/lib/transformers/models/deprecated/van/configuration_van.py new file mode 100644 index 0000000000000000000000000000000000000000..08f1db7a4b48e7fb32b986cdc557b7aa3d328ed0 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/van/configuration_van.py @@ -0,0 +1,110 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""VAN model configuration""" + +from ....configuration_utils import PretrainedConfig +from ....utils import logging + + +logger = logging.get_logger(__name__) + + +class VanConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`VanModel`]. It is used to instantiate a VAN model + according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the VAN + [Visual-Attention-Network/van-base](https://huggingface.co/Visual-Attention-Network/van-base) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + patch_sizes (`List[int]`, *optional*, defaults to `[7, 3, 3, 3]`): + Patch size to use in each stage's embedding layer. + strides (`List[int]`, *optional*, defaults to `[4, 2, 2, 2]`): + Stride size to use in each stage's embedding layer to downsample the input. + hidden_sizes (`List[int]`, *optional*, defaults to `[64, 128, 320, 512]`): + Dimensionality (hidden size) at each stage. + depths (`List[int]`, *optional*, defaults to `[3, 3, 12, 3]`): + Depth (number of layers) for each stage. + mlp_ratios (`List[int]`, *optional*, defaults to `[8, 8, 4, 4]`): + The expansion ratio for mlp layer at each stage. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in each layer. If string, `"gelu"`, `"relu"`, + `"selu"` and `"gelu_new"` are supported. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + layer_scale_init_value (`float`, *optional*, defaults to 0.01): + The initial value for layer scaling. + drop_path_rate (`float`, *optional*, defaults to 0.0): + The dropout probability for stochastic depth. + dropout_rate (`float`, *optional*, defaults to 0.0): + The dropout probability for dropout. + + Example: + ```python + >>> from transformers import VanModel, VanConfig + + >>> # Initializing a VAN van-base style configuration + >>> configuration = VanConfig() + >>> # Initializing a model from the van-base style configuration + >>> model = VanModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "van" + + def __init__( + self, + image_size=224, + num_channels=3, + patch_sizes=[7, 3, 3, 3], + strides=[4, 2, 2, 2], + hidden_sizes=[64, 128, 320, 512], + depths=[3, 3, 12, 3], + mlp_ratios=[8, 8, 4, 4], + hidden_act="gelu", + initializer_range=0.02, + layer_norm_eps=1e-6, + layer_scale_init_value=1e-2, + drop_path_rate=0.0, + dropout_rate=0.0, + **kwargs, + ): + super().__init__(**kwargs) + self.image_size = image_size + self.num_channels = num_channels + self.patch_sizes = patch_sizes + self.strides = strides + self.hidden_sizes = hidden_sizes + self.depths = depths + self.mlp_ratios = mlp_ratios + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.layer_scale_init_value = layer_scale_init_value + self.drop_path_rate = drop_path_rate + self.dropout_rate = dropout_rate + + +__all__ = ["VanConfig"] diff --git a/docs/transformers/build/lib/transformers/models/deprecated/van/convert_van_to_pytorch.py b/docs/transformers/build/lib/transformers/models/deprecated/van/convert_van_to_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..cd87217f051a04a6a339e93e6d77cfc85c701832 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/van/convert_van_to_pytorch.py @@ -0,0 +1,290 @@ +# coding=utf-8 +# Copyright 2022 BNRist (Tsinghua University), TKLNDST (Nankai University) and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert VAN checkpoints from the original repository. + +URL: https://github.com/Visual-Attention-Network/VAN-Classification""" + +import argparse +import json +import sys +from dataclasses import dataclass, field +from functools import partial +from pathlib import Path +from typing import List, Optional + +import torch +import torch.nn as nn +from huggingface_hub import cached_download, hf_hub_download +from torch import Tensor + +from transformers import AutoImageProcessor, VanConfig, VanForImageClassification +from transformers.models.deprecated.van.modeling_van import VanLayerScaling +from transformers.utils import logging + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + +@dataclass +class Tracker: + module: nn.Module + traced: List[nn.Module] = field(default_factory=list) + handles: list = field(default_factory=list) + + def _forward_hook(self, m, inputs: Tensor, outputs: Tensor): + has_not_submodules = len(list(m.modules())) == 1 or isinstance(m, nn.Conv2d) or isinstance(m, nn.BatchNorm2d) + if has_not_submodules: + if not isinstance(m, VanLayerScaling): + self.traced.append(m) + + def __call__(self, x: Tensor): + for m in self.module.modules(): + self.handles.append(m.register_forward_hook(self._forward_hook)) + self.module(x) + [x.remove() for x in self.handles] + return self + + @property + def parametrized(self): + # check the len of the state_dict keys to see if we have learnable params + return list(filter(lambda x: len(list(x.state_dict().keys())) > 0, self.traced)) + + +@dataclass +class ModuleTransfer: + src: nn.Module + dest: nn.Module + verbose: int = 0 + src_skip: List = field(default_factory=list) + dest_skip: List = field(default_factory=list) + + def __call__(self, x: Tensor): + """ + Transfer the weights of `self.src` to `self.dest` by performing a forward pass using `x` as input. Under the + hood we tracked all the operations in both modules. + """ + dest_traced = Tracker(self.dest)(x).parametrized + src_traced = Tracker(self.src)(x).parametrized + + src_traced = list(filter(lambda x: type(x) not in self.src_skip, src_traced)) + dest_traced = list(filter(lambda x: type(x) not in self.dest_skip, dest_traced)) + + if len(dest_traced) != len(src_traced): + raise Exception( + f"Numbers of operations are different. Source module has {len(src_traced)} operations while" + f" destination module has {len(dest_traced)}." + ) + + for dest_m, src_m in zip(dest_traced, src_traced): + dest_m.load_state_dict(src_m.state_dict()) + if self.verbose == 1: + print(f"Transfered from={src_m} to={dest_m}") + + +def copy_parameters(from_model: nn.Module, our_model: nn.Module) -> nn.Module: + # nn.Parameter cannot be tracked by the Tracker, thus we need to manually convert them + from_state_dict = from_model.state_dict() + our_state_dict = our_model.state_dict() + config = our_model.config + all_keys = [] + for stage_idx in range(len(config.hidden_sizes)): + for block_id in range(config.depths[stage_idx]): + from_key = f"block{stage_idx + 1}.{block_id}.layer_scale_1" + to_key = f"van.encoder.stages.{stage_idx}.layers.{block_id}.attention_scaling.weight" + + all_keys.append((from_key, to_key)) + from_key = f"block{stage_idx + 1}.{block_id}.layer_scale_2" + to_key = f"van.encoder.stages.{stage_idx}.layers.{block_id}.mlp_scaling.weight" + + all_keys.append((from_key, to_key)) + + for from_key, to_key in all_keys: + our_state_dict[to_key] = from_state_dict.pop(from_key) + + our_model.load_state_dict(our_state_dict) + return our_model + + +def convert_weight_and_push( + name: str, + config: VanConfig, + checkpoint: str, + from_model: nn.Module, + save_directory: Path, + push_to_hub: bool = True, +): + print(f"Downloading weights for {name}...") + checkpoint_path = cached_download(checkpoint) + print(f"Converting {name}...") + from_state_dict = torch.load(checkpoint_path, weights_only=True)["state_dict"] + from_model.load_state_dict(from_state_dict) + from_model.eval() + with torch.no_grad(): + our_model = VanForImageClassification(config).eval() + module_transfer = ModuleTransfer(src=from_model, dest=our_model) + x = torch.randn((1, 3, 224, 224)) + module_transfer(x) + our_model = copy_parameters(from_model, our_model) + + if not torch.allclose(from_model(x), our_model(x).logits): + raise ValueError("The model logits don't match the original one.") + + checkpoint_name = name + print(checkpoint_name) + + if push_to_hub: + our_model.push_to_hub( + repo_path_or_name=save_directory / checkpoint_name, + commit_message="Add model", + use_temp_dir=True, + ) + + # we can use the convnext one + image_processor = AutoImageProcessor.from_pretrained("facebook/convnext-base-224-22k-1k") + image_processor.push_to_hub( + repo_path_or_name=save_directory / checkpoint_name, + commit_message="Add image processor", + use_temp_dir=True, + ) + + print(f"Pushed {checkpoint_name}") + + +def convert_weights_and_push(save_directory: Path, model_name: Optional[str] = None, push_to_hub: bool = True): + filename = "imagenet-1k-id2label.json" + num_labels = 1000 + + repo_id = "huggingface/label-files" + num_labels = num_labels + id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r")) + id2label = {int(k): v for k, v in id2label.items()} + + id2label = id2label + label2id = {v: k for k, v in id2label.items()} + + ImageNetPreTrainedConfig = partial(VanConfig, num_labels=num_labels, id2label=id2label, label2id=label2id) + + names_to_config = { + "van-tiny": ImageNetPreTrainedConfig( + hidden_sizes=[32, 64, 160, 256], + depths=[3, 3, 5, 2], + mlp_ratios=[8, 8, 4, 4], + ), + "van-small": ImageNetPreTrainedConfig( + hidden_sizes=[64, 128, 320, 512], + depths=[2, 2, 4, 2], + mlp_ratios=[8, 8, 4, 4], + ), + "van-base": ImageNetPreTrainedConfig( + hidden_sizes=[64, 128, 320, 512], + depths=[3, 3, 12, 3], + mlp_ratios=[8, 8, 4, 4], + ), + "van-large": ImageNetPreTrainedConfig( + hidden_sizes=[64, 128, 320, 512], + depths=[3, 5, 27, 3], + mlp_ratios=[8, 8, 4, 4], + ), + } + + names_to_original_models = { + "van-tiny": van_tiny, + "van-small": van_small, + "van-base": van_base, + "van-large": van_large, + } + + names_to_original_checkpoints = { + "van-tiny": ( + "https://huggingface.co/Visual-Attention-Network/VAN-Tiny-original/resolve/main/van_tiny_754.pth.tar" + ), + "van-small": ( + "https://huggingface.co/Visual-Attention-Network/VAN-Small-original/resolve/main/van_small_811.pth.tar" + ), + "van-base": ( + "https://huggingface.co/Visual-Attention-Network/VAN-Base-original/resolve/main/van_base_828.pth.tar" + ), + "van-large": ( + "https://huggingface.co/Visual-Attention-Network/VAN-Large-original/resolve/main/van_large_839.pth.tar" + ), + } + + if model_name: + convert_weight_and_push( + model_name, + names_to_config[model_name], + checkpoint=names_to_original_checkpoints[model_name], + from_model=names_to_original_models[model_name](), + save_directory=save_directory, + push_to_hub=push_to_hub, + ) + else: + for model_name, config in names_to_config.items(): + convert_weight_and_push( + model_name, + config, + checkpoint=names_to_original_checkpoints[model_name], + from_model=names_to_original_models[model_name](), + save_directory=save_directory, + push_to_hub=push_to_hub, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--model-name", + default=None, + type=str, + help=( + "The name of the model you wish to convert, it must be one of the supported resnet* architecture," + " currently: van-tiny/small/base/large. If `None`, all of them will the converted." + ), + ) + parser.add_argument( + "--pytorch_dump_folder_path", + default=None, + type=Path, + required=True, + help="Path to the output PyTorch model directory.", + ) + parser.add_argument( + "--van_dir", + required=True, + type=Path, + help=( + "A path to VAN's original implementation directory. You can download from here:" + " https://github.com/Visual-Attention-Network/VAN-Classification" + ), + ) + parser.add_argument( + "--push_to_hub", + default=True, + type=bool, + required=False, + help="If True, push model and image processor to the hub.", + ) + + args = parser.parse_args() + pytorch_dump_folder_path: Path = args.pytorch_dump_folder_path + pytorch_dump_folder_path.mkdir(exist_ok=True, parents=True) + van_dir = args.van_dir + # append the path to the parents to maskformer dir + sys.path.append(str(van_dir.parent)) + from van.models.van import van_base, van_large, van_small, van_tiny + + convert_weights_and_push(pytorch_dump_folder_path, args.model_name, args.push_to_hub) diff --git a/docs/transformers/build/lib/transformers/models/deprecated/van/modeling_van.py b/docs/transformers/build/lib/transformers/models/deprecated/van/modeling_van.py new file mode 100644 index 0000000000000000000000000000000000000000..1da03cb544d467d9fdbe9b5258fabaccdedd7eff --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/van/modeling_van.py @@ -0,0 +1,541 @@ +# coding=utf-8 +# Copyright 2022 BNRist (Tsinghua University), TKLNDST (Nankai University) and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Visual Attention Network (VAN) model.""" + +import math +from collections import OrderedDict +from typing import Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ....activations import ACT2FN +from ....modeling_outputs import ( + BaseModelOutputWithNoAttention, + BaseModelOutputWithPoolingAndNoAttention, + ImageClassifierOutputWithNoAttention, +) +from ....modeling_utils import PreTrainedModel +from ....utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging +from .configuration_van import VanConfig + + +logger = logging.get_logger(__name__) + +# General docstring +_CONFIG_FOR_DOC = "VanConfig" + +# Base docstring +_CHECKPOINT_FOR_DOC = "Visual-Attention-Network/van-base" +_EXPECTED_OUTPUT_SHAPE = [1, 512, 7, 7] + +# Image classification docstring +_IMAGE_CLASS_CHECKPOINT = "Visual-Attention-Network/van-base" +_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat" + + +def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, + however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the + layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the + argument. + """ + if drop_prob == 0.0 or not training: + return input + keep_prob = 1 - drop_prob + shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) + random_tensor.floor_() # binarize + output = input.div(keep_prob) * random_tensor + return output + + +class VanDropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: Optional[float] = None) -> None: + super().__init__() + self.drop_prob = drop_prob + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return drop_path(hidden_states, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return "p={}".format(self.drop_prob) + + +class VanOverlappingPatchEmbedder(nn.Module): + """ + Downsamples the input using a patchify operation with a `stride` of 4 by default making adjacent windows overlap by + half of the area. From [PVTv2: Improved Baselines with Pyramid Vision + Transformer](https://arxiv.org/abs/2106.13797). + """ + + def __init__(self, in_channels: int, hidden_size: int, patch_size: int = 7, stride: int = 4): + super().__init__() + self.convolution = nn.Conv2d( + in_channels, hidden_size, kernel_size=patch_size, stride=stride, padding=patch_size // 2 + ) + self.normalization = nn.BatchNorm2d(hidden_size) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + hidden_state = self.convolution(input) + hidden_state = self.normalization(hidden_state) + return hidden_state + + +class VanMlpLayer(nn.Module): + """ + MLP with depth-wise convolution, from [PVTv2: Improved Baselines with Pyramid Vision + Transformer](https://arxiv.org/abs/2106.13797). + """ + + def __init__( + self, + in_channels: int, + hidden_size: int, + out_channels: int, + hidden_act: str = "gelu", + dropout_rate: float = 0.5, + ): + super().__init__() + self.in_dense = nn.Conv2d(in_channels, hidden_size, kernel_size=1) + self.depth_wise = nn.Conv2d(hidden_size, hidden_size, kernel_size=3, padding=1, groups=hidden_size) + self.activation = ACT2FN[hidden_act] + self.dropout1 = nn.Dropout(dropout_rate) + self.out_dense = nn.Conv2d(hidden_size, out_channels, kernel_size=1) + self.dropout2 = nn.Dropout(dropout_rate) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + hidden_state = self.in_dense(hidden_state) + hidden_state = self.depth_wise(hidden_state) + hidden_state = self.activation(hidden_state) + hidden_state = self.dropout1(hidden_state) + hidden_state = self.out_dense(hidden_state) + hidden_state = self.dropout2(hidden_state) + return hidden_state + + +class VanLargeKernelAttention(nn.Module): + """ + Basic Large Kernel Attention (LKA). + """ + + def __init__(self, hidden_size: int): + super().__init__() + self.depth_wise = nn.Conv2d(hidden_size, hidden_size, kernel_size=5, padding=2, groups=hidden_size) + self.depth_wise_dilated = nn.Conv2d( + hidden_size, hidden_size, kernel_size=7, dilation=3, padding=9, groups=hidden_size + ) + self.point_wise = nn.Conv2d(hidden_size, hidden_size, kernel_size=1) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + hidden_state = self.depth_wise(hidden_state) + hidden_state = self.depth_wise_dilated(hidden_state) + hidden_state = self.point_wise(hidden_state) + return hidden_state + + +class VanLargeKernelAttentionLayer(nn.Module): + """ + Computes attention using Large Kernel Attention (LKA) and attends the input. + """ + + def __init__(self, hidden_size: int): + super().__init__() + self.attention = VanLargeKernelAttention(hidden_size) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + attention = self.attention(hidden_state) + attended = hidden_state * attention + return attended + + +class VanSpatialAttentionLayer(nn.Module): + """ + Van spatial attention layer composed by projection (via conv) -> act -> Large Kernel Attention (LKA) attention -> + projection (via conv) + residual connection. + """ + + def __init__(self, hidden_size: int, hidden_act: str = "gelu"): + super().__init__() + self.pre_projection = nn.Sequential( + OrderedDict( + [ + ("conv", nn.Conv2d(hidden_size, hidden_size, kernel_size=1)), + ("act", ACT2FN[hidden_act]), + ] + ) + ) + self.attention_layer = VanLargeKernelAttentionLayer(hidden_size) + self.post_projection = nn.Conv2d(hidden_size, hidden_size, kernel_size=1) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + residual = hidden_state + hidden_state = self.pre_projection(hidden_state) + hidden_state = self.attention_layer(hidden_state) + hidden_state = self.post_projection(hidden_state) + hidden_state = hidden_state + residual + return hidden_state + + +class VanLayerScaling(nn.Module): + """ + Scales the inputs by a learnable parameter initialized by `initial_value`. + """ + + def __init__(self, hidden_size: int, initial_value: float = 1e-2): + super().__init__() + self.weight = nn.Parameter(initial_value * torch.ones((hidden_size)), requires_grad=True) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + # unsqueezing for broadcasting + hidden_state = self.weight.unsqueeze(-1).unsqueeze(-1) * hidden_state + return hidden_state + + +class VanLayer(nn.Module): + """ + Van layer composed by normalization layers, large kernel attention (LKA) and a multi layer perceptron (MLP). + """ + + def __init__( + self, + config: VanConfig, + hidden_size: int, + mlp_ratio: int = 4, + drop_path_rate: float = 0.5, + ): + super().__init__() + self.drop_path = VanDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() + self.pre_normomalization = nn.BatchNorm2d(hidden_size) + self.attention = VanSpatialAttentionLayer(hidden_size, config.hidden_act) + self.attention_scaling = VanLayerScaling(hidden_size, config.layer_scale_init_value) + self.post_normalization = nn.BatchNorm2d(hidden_size) + self.mlp = VanMlpLayer( + hidden_size, hidden_size * mlp_ratio, hidden_size, config.hidden_act, config.dropout_rate + ) + self.mlp_scaling = VanLayerScaling(hidden_size, config.layer_scale_init_value) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + residual = hidden_state + # attention + hidden_state = self.pre_normomalization(hidden_state) + hidden_state = self.attention(hidden_state) + hidden_state = self.attention_scaling(hidden_state) + hidden_state = self.drop_path(hidden_state) + # residual connection + hidden_state = residual + hidden_state + residual = hidden_state + # mlp + hidden_state = self.post_normalization(hidden_state) + hidden_state = self.mlp(hidden_state) + hidden_state = self.mlp_scaling(hidden_state) + hidden_state = self.drop_path(hidden_state) + # residual connection + hidden_state = residual + hidden_state + return hidden_state + + +class VanStage(nn.Module): + """ + VanStage, consisting of multiple layers. + """ + + def __init__( + self, + config: VanConfig, + in_channels: int, + hidden_size: int, + patch_size: int, + stride: int, + depth: int, + mlp_ratio: int = 4, + drop_path_rate: float = 0.0, + ): + super().__init__() + self.embeddings = VanOverlappingPatchEmbedder(in_channels, hidden_size, patch_size, stride) + self.layers = nn.Sequential( + *[ + VanLayer( + config, + hidden_size, + mlp_ratio=mlp_ratio, + drop_path_rate=drop_path_rate, + ) + for _ in range(depth) + ] + ) + self.normalization = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + hidden_state = self.embeddings(hidden_state) + hidden_state = self.layers(hidden_state) + # rearrange b c h w -> b (h w) c + batch_size, hidden_size, height, width = hidden_state.shape + hidden_state = hidden_state.flatten(2).transpose(1, 2) + hidden_state = self.normalization(hidden_state) + # rearrange b (h w) c- > b c h w + hidden_state = hidden_state.view(batch_size, height, width, hidden_size).permute(0, 3, 1, 2) + return hidden_state + + +class VanEncoder(nn.Module): + """ + VanEncoder, consisting of multiple stages. + """ + + def __init__(self, config: VanConfig): + super().__init__() + self.stages = nn.ModuleList([]) + patch_sizes = config.patch_sizes + strides = config.strides + hidden_sizes = config.hidden_sizes + depths = config.depths + mlp_ratios = config.mlp_ratios + drop_path_rates = [ + x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu") + ] + + for num_stage, (patch_size, stride, hidden_size, depth, mlp_expantion, drop_path_rate) in enumerate( + zip(patch_sizes, strides, hidden_sizes, depths, mlp_ratios, drop_path_rates) + ): + is_first_stage = num_stage == 0 + in_channels = hidden_sizes[num_stage - 1] + if is_first_stage: + in_channels = config.num_channels + self.stages.append( + VanStage( + config, + in_channels, + hidden_size, + patch_size=patch_size, + stride=stride, + depth=depth, + mlp_ratio=mlp_expantion, + drop_path_rate=drop_path_rate, + ) + ) + + def forward( + self, + hidden_state: torch.Tensor, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + ) -> Union[Tuple, BaseModelOutputWithNoAttention]: + all_hidden_states = () if output_hidden_states else None + + for _, stage_module in enumerate(self.stages): + hidden_state = stage_module(hidden_state) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_state,) + + if not return_dict: + return tuple(v for v in [hidden_state, all_hidden_states] if v is not None) + + return BaseModelOutputWithNoAttention(last_hidden_state=hidden_state, hidden_states=all_hidden_states) + + +class VanPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = VanConfig + base_model_prefix = "van" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + nn.init.trunc_normal_(module.weight, std=self.config.initializer_range) + if isinstance(module, nn.Linear) and module.bias is not None: + nn.init.constant_(module.bias, 0) + elif isinstance(module, nn.LayerNorm): + nn.init.constant_(module.bias, 0) + nn.init.constant_(module.weight, 1.0) + elif isinstance(module, nn.Conv2d): + fan_out = module.kernel_size[0] * module.kernel_size[1] * module.out_channels + fan_out //= module.groups + module.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if module.bias is not None: + module.bias.data.zero_() + + +VAN_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it + as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`VanConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +VAN_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See + [`ConvNextImageProcessor.__call__`] for details. + + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all stages. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare VAN model outputting raw features without any specific head on top. Note, VAN does not have an embedding" + " layer.", + VAN_START_DOCSTRING, +) +class VanModel(VanPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.config = config + self.encoder = VanEncoder(config) + # final layernorm layer + self.layernorm = nn.LayerNorm(config.hidden_sizes[-1], eps=config.layer_norm_eps) + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(VAN_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPoolingAndNoAttention, + config_class=_CONFIG_FOR_DOC, + modality="vision", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward( + self, + pixel_values: Optional[torch.FloatTensor], + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPoolingAndNoAttention]: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_outputs = self.encoder( + pixel_values, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + last_hidden_state = encoder_outputs[0] + # global average pooling, n c w h -> n c + pooled_output = last_hidden_state.mean(dim=[-2, -1]) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndNoAttention( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + ) + + +@add_start_docstrings( + """ + VAN Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for + ImageNet. + """, + VAN_START_DOCSTRING, +) +class VanForImageClassification(VanPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.van = VanModel(config) + # Classifier head + self.classifier = ( + nn.Linear(config.hidden_sizes[-1], config.num_labels) if config.num_labels > 0 else nn.Identity() + ) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(VAN_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_IMAGE_CLASS_CHECKPOINT, + output_type=ImageClassifierOutputWithNoAttention, + config_class=_CONFIG_FOR_DOC, + expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, + ) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, ImageClassifierOutputWithNoAttention]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.van(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict) + + pooled_output = outputs.pooler_output if return_dict else outputs[1] + + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.config.num_labels == 1: + self.config.problem_type = "regression" + elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.config.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states) + + +__all__ = ["VanForImageClassification", "VanModel", "VanPreTrainedModel"] diff --git a/docs/transformers/build/lib/transformers/models/deprecated/vit_hybrid/__init__.py b/docs/transformers/build/lib/transformers/models/deprecated/vit_hybrid/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f5bd93aa4dabea9b2adbc54eeeb3664de589f43a --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/vit_hybrid/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ....utils import _LazyModule +from ....utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_vit_hybrid import * + from .image_processing_vit_hybrid import * + from .modeling_vit_hybrid import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/build/lib/transformers/models/deprecated/vit_hybrid/configuration_vit_hybrid.py b/docs/transformers/build/lib/transformers/models/deprecated/vit_hybrid/configuration_vit_hybrid.py new file mode 100644 index 0000000000000000000000000000000000000000..65b6a3e5ef515218afb72e8b871c4ffb2465666f --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/vit_hybrid/configuration_vit_hybrid.py @@ -0,0 +1,172 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""ViT Hybrid model configuration""" + +from ....configuration_utils import PretrainedConfig +from ....utils import logging +from ...auto.configuration_auto import CONFIG_MAPPING +from ...bit import BitConfig + + +logger = logging.get_logger(__name__) + + +class ViTHybridConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`ViTHybridModel`]. It is used to instantiate a ViT + Hybrid model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the ViT Hybrid + [google/vit-hybrid-base-bit-384](https://huggingface.co/google/vit-hybrid-base-bit-384) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + backbone_config (`Union[Dict[str, Any], PretrainedConfig]`, *optional*): + The configuration of the backbone in a dictionary or the config object of the backbone. + backbone (`str`, *optional*): + Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this + will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone` + is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights. + use_pretrained_backbone (`bool`, *optional*, defaults to `False`): + Whether to use pretrained weights for the backbone. + use_timm_backbone (`bool`, *optional*, defaults to `False`): + Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers + library. + backbone_kwargs (`dict`, *optional*): + Keyword arguments to be passed to AutoBackbone when loading from a checkpoint + e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 1): + The size (resolution) of each patch. + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + backbone_featmap_shape (`List[int]`, *optional*, defaults to `[1, 1024, 24, 24]`): + Used only for the `hybrid` embedding type. The shape of the feature maps of the backbone. + qkv_bias (`bool`, *optional*, defaults to `True`): + Whether to add a bias to the queries, keys and values. + + Example: + + ```python + >>> from transformers import ViTHybridConfig, ViTHybridModel + + >>> # Initializing a ViT Hybrid vit-hybrid-base-bit-384 style configuration + >>> configuration = ViTHybridConfig() + + >>> # Initializing a model (with random weights) from the vit-hybrid-base-bit-384 style configuration + >>> model = ViTHybridModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "vit-hybrid" + + def __init__( + self, + backbone_config=None, + backbone=None, + use_pretrained_backbone=False, + use_timm_backbone=False, + backbone_kwargs=None, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + initializer_range=0.02, + layer_norm_eps=1e-12, + image_size=224, + patch_size=1, + num_channels=3, + backbone_featmap_shape=[1, 1024, 24, 24], + qkv_bias=True, + **kwargs, + ): + super().__init__(**kwargs) + if use_pretrained_backbone: + raise ValueError("Pretrained backbones are not supported yet.") + + if backbone_config is not None and backbone is not None: + raise ValueError("You can't specify both `backbone` and `backbone_config`.") + + if backbone_config is None and backbone is None: + logger.info("`backbone_config` is `None`. Initializing the config with a `BiT` backbone.") + backbone_config = { + "global_padding": "same", + "layer_type": "bottleneck", + "depths": [3, 4, 9], + "out_features": ["stage3"], + "embedding_dynamic_padding": True, + } + + if backbone_kwargs is not None and backbone_kwargs and backbone_config is not None: + raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.") + + if isinstance(backbone_config, dict): + if "model_type" in backbone_config: + backbone_config_class = CONFIG_MAPPING[backbone_config["model_type"]] + else: + logger.info( + "`model_type` is not found in `backbone_config`. Use `Bit` as the backbone configuration class." + ) + backbone_config_class = BitConfig + backbone_config = backbone_config_class(**backbone_config) + + self.backbone_featmap_shape = backbone_featmap_shape + self.backbone_config = backbone_config + self.backbone = backbone + self.use_pretrained_backbone = use_pretrained_backbone + self.use_timm_backbone = use_timm_backbone + self.backbone_kwargs = backbone_kwargs + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.qkv_bias = qkv_bias + + +__all__ = ["ViTHybridConfig"] diff --git a/docs/transformers/build/lib/transformers/models/deprecated/vit_hybrid/convert_vit_hybrid_timm_to_pytorch.py b/docs/transformers/build/lib/transformers/models/deprecated/vit_hybrid/convert_vit_hybrid_timm_to_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..1d717d74c961e509697adab7623d2bc3fe64a1cf --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/vit_hybrid/convert_vit_hybrid_timm_to_pytorch.py @@ -0,0 +1,282 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert ViT hybrid checkpoints from the timm library.""" + +import argparse +import json +from pathlib import Path + +import requests +import timm +import torch +from huggingface_hub import hf_hub_download +from PIL import Image +from timm.data import resolve_data_config +from timm.data.transforms_factory import create_transform + +from transformers import ( + BitConfig, + ViTHybridConfig, + ViTHybridForImageClassification, + ViTHybridImageProcessor, + ViTHybridModel, +) +from transformers.image_utils import PILImageResampling +from transformers.utils import logging + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + +# here we list all keys to be renamed (original name on the left, our name on the right) +def create_rename_keys(config, base_model=False): + rename_keys = [] + + # fmt: off + # stem: + rename_keys.append(("cls_token", "vit.embeddings.cls_token")) + rename_keys.append(("pos_embed", "vit.embeddings.position_embeddings")) + + rename_keys.append(("patch_embed.proj.weight", "vit.embeddings.patch_embeddings.projection.weight")) + rename_keys.append(("patch_embed.proj.bias", "vit.embeddings.patch_embeddings.projection.bias")) + + # backbone + rename_keys.append(("patch_embed.backbone.stem.conv.weight", "vit.embeddings.patch_embeddings.backbone.bit.embedder.convolution.weight")) + rename_keys.append(("patch_embed.backbone.stem.norm.weight", "vit.embeddings.patch_embeddings.backbone.bit.embedder.norm.weight")) + rename_keys.append(("patch_embed.backbone.stem.norm.bias", "vit.embeddings.patch_embeddings.backbone.bit.embedder.norm.bias")) + + for stage_idx in range(len(config.backbone_config.depths)): + for layer_idx in range(config.backbone_config.depths[stage_idx]): + rename_keys.append((f"patch_embed.backbone.stages.{stage_idx}.blocks.{layer_idx}.conv1.weight", f"vit.embeddings.patch_embeddings.backbone.bit.encoder.stages.{stage_idx}.layers.{layer_idx}.conv1.weight")) + rename_keys.append((f"patch_embed.backbone.stages.{stage_idx}.blocks.{layer_idx}.norm1.weight", f"vit.embeddings.patch_embeddings.backbone.bit.encoder.stages.{stage_idx}.layers.{layer_idx}.norm1.weight")) + rename_keys.append((f"patch_embed.backbone.stages.{stage_idx}.blocks.{layer_idx}.norm1.bias", f"vit.embeddings.patch_embeddings.backbone.bit.encoder.stages.{stage_idx}.layers.{layer_idx}.norm1.bias")) + rename_keys.append((f"patch_embed.backbone.stages.{stage_idx}.blocks.{layer_idx}.conv2.weight", f"vit.embeddings.patch_embeddings.backbone.bit.encoder.stages.{stage_idx}.layers.{layer_idx}.conv2.weight")) + rename_keys.append((f"patch_embed.backbone.stages.{stage_idx}.blocks.{layer_idx}.norm2.weight", f"vit.embeddings.patch_embeddings.backbone.bit.encoder.stages.{stage_idx}.layers.{layer_idx}.norm2.weight")) + rename_keys.append((f"patch_embed.backbone.stages.{stage_idx}.blocks.{layer_idx}.norm2.bias", f"vit.embeddings.patch_embeddings.backbone.bit.encoder.stages.{stage_idx}.layers.{layer_idx}.norm2.bias")) + rename_keys.append((f"patch_embed.backbone.stages.{stage_idx}.blocks.{layer_idx}.conv3.weight", f"vit.embeddings.patch_embeddings.backbone.bit.encoder.stages.{stage_idx}.layers.{layer_idx}.conv3.weight")) + rename_keys.append((f"patch_embed.backbone.stages.{stage_idx}.blocks.{layer_idx}.norm3.weight", f"vit.embeddings.patch_embeddings.backbone.bit.encoder.stages.{stage_idx}.layers.{layer_idx}.norm3.weight")) + rename_keys.append((f"patch_embed.backbone.stages.{stage_idx}.blocks.{layer_idx}.norm3.bias", f"vit.embeddings.patch_embeddings.backbone.bit.encoder.stages.{stage_idx}.layers.{layer_idx}.norm3.bias")) + + rename_keys.append((f"patch_embed.backbone.stages.{stage_idx}.blocks.0.downsample.conv.weight", f"vit.embeddings.patch_embeddings.backbone.bit.encoder.stages.{stage_idx}.layers.0.downsample.conv.weight")) + rename_keys.append((f"patch_embed.backbone.stages.{stage_idx}.blocks.0.downsample.norm.weight", f"vit.embeddings.patch_embeddings.backbone.bit.encoder.stages.{stage_idx}.layers.0.downsample.norm.weight")) + rename_keys.append((f"patch_embed.backbone.stages.{stage_idx}.blocks.0.downsample.norm.bias", f"vit.embeddings.patch_embeddings.backbone.bit.encoder.stages.{stage_idx}.layers.0.downsample.norm.bias")) + + # transformer encoder + for i in range(config.num_hidden_layers): + # encoder layers: output projection, 2 feedforward neural networks and 2 layernorms + rename_keys.append((f"blocks.{i}.norm1.weight", f"vit.encoder.layer.{i}.layernorm_before.weight")) + rename_keys.append((f"blocks.{i}.norm1.bias", f"vit.encoder.layer.{i}.layernorm_before.bias")) + rename_keys.append((f"blocks.{i}.attn.proj.weight", f"vit.encoder.layer.{i}.attention.output.dense.weight")) + rename_keys.append((f"blocks.{i}.attn.proj.bias", f"vit.encoder.layer.{i}.attention.output.dense.bias")) + rename_keys.append((f"blocks.{i}.norm2.weight", f"vit.encoder.layer.{i}.layernorm_after.weight")) + rename_keys.append((f"blocks.{i}.norm2.bias", f"vit.encoder.layer.{i}.layernorm_after.bias")) + rename_keys.append((f"blocks.{i}.mlp.fc1.weight", f"vit.encoder.layer.{i}.intermediate.dense.weight")) + rename_keys.append((f"blocks.{i}.mlp.fc1.bias", f"vit.encoder.layer.{i}.intermediate.dense.bias")) + rename_keys.append((f"blocks.{i}.mlp.fc2.weight", f"vit.encoder.layer.{i}.output.dense.weight")) + rename_keys.append((f"blocks.{i}.mlp.fc2.bias", f"vit.encoder.layer.{i}.output.dense.bias")) + + if base_model: + # layernorm + pooler + rename_keys.extend( + [ + ("norm.weight", "layernorm.weight"), + ("norm.bias", "layernorm.bias"), + ("pre_logits.fc.weight", "pooler.dense.weight"), + ("pre_logits.fc.bias", "pooler.dense.bias"), + ] + ) + + # if just the base model, we should remove "vit" from all keys that start with "vit" + rename_keys = [(pair[0], pair[1][4:]) if pair[1].startswith("vit") else pair for pair in rename_keys] + else: + # layernorm + classification head + rename_keys.extend( + [ + ("norm.weight", "vit.layernorm.weight"), + ("norm.bias", "vit.layernorm.bias"), + ("head.weight", "classifier.weight"), + ("head.bias", "classifier.bias"), + ] + ) + # fmt: on + + return rename_keys + + +# we split up the matrix of each encoder layer into queries, keys and values +def read_in_q_k_v(state_dict, config, base_model=False): + for i in range(config.num_hidden_layers): + if base_model: + prefix = "" + else: + prefix = "vit." + # read in weights + bias of input projection layer (in timm, this is a single matrix + bias) + in_proj_weight = state_dict.pop(f"blocks.{i}.attn.qkv.weight") + in_proj_bias = state_dict.pop(f"blocks.{i}.attn.qkv.bias") + # next, add query, keys and values (in that order) to the state dict + state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[ + : config.hidden_size, : + ] + state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.bias"] = in_proj_bias[: config.hidden_size] + state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[ + config.hidden_size : config.hidden_size * 2, : + ] + state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.bias"] = in_proj_bias[ + config.hidden_size : config.hidden_size * 2 + ] + state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[ + -config.hidden_size :, : + ] + state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.bias"] = in_proj_bias[-config.hidden_size :] + + +def remove_classification_head_(state_dict): + ignore_keys = ["head.weight", "head.bias"] + for k in ignore_keys: + state_dict.pop(k, None) + + +def rename_key(dct, old, new): + val = dct.pop(old) + dct[new] = val + + +# We will verify our results on an image of cute cats +def prepare_img(): + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + im = Image.open(requests.get(url, stream=True).raw) + return im + + +@torch.no_grad() +def convert_vit_checkpoint(vit_name, pytorch_dump_folder_path, push_to_hub=False): + """ + Copy/paste/tweak model's weights to our ViT structure. + """ + + # define default ViT hybrid configuration + backbone_config = BitConfig( + global_padding="same", + layer_type="bottleneck", + depths=(3, 4, 9), + out_features=["stage3"], + embedding_dynamic_padding=True, + ) + config = ViTHybridConfig(backbone_config=backbone_config, image_size=384, num_labels=1000) + base_model = False + + # load original model from timm + timm_model = timm.create_model(vit_name, pretrained=True) + timm_model.eval() + + # load state_dict of original model, remove and rename some keys + state_dict = timm_model.state_dict() + if base_model: + remove_classification_head_(state_dict) + rename_keys = create_rename_keys(config, base_model) + for src, dest in rename_keys: + rename_key(state_dict, src, dest) + read_in_q_k_v(state_dict, config, base_model) + + repo_id = "huggingface/label-files" + filename = "imagenet-1k-id2label.json" + id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r")) + id2label = {int(k): v for k, v in id2label.items()} + config.id2label = id2label + config.label2id = {v: k for k, v in id2label.items()} + + # load HuggingFace model + if vit_name[-5:] == "in21k": + model = ViTHybridModel(config).eval() + else: + model = ViTHybridForImageClassification(config).eval() + model.load_state_dict(state_dict) + + # create image processor + transform = create_transform(**resolve_data_config({}, model=timm_model)) + timm_transforms = transform.transforms + + pillow_resamplings = { + "bilinear": PILImageResampling.BILINEAR, + "bicubic": PILImageResampling.BICUBIC, + "nearest": PILImageResampling.NEAREST, + } + + processor = ViTHybridImageProcessor( + do_resize=True, + size={"shortest_edge": timm_transforms[0].size}, + resample=pillow_resamplings[timm_transforms[0].interpolation.value], + do_center_crop=True, + crop_size={"height": timm_transforms[1].size[0], "width": timm_transforms[1].size[1]}, + do_normalize=True, + image_mean=timm_transforms[-1].mean.tolist(), + image_std=timm_transforms[-1].std.tolist(), + ) + + image = prepare_img() + timm_pixel_values = transform(image).unsqueeze(0) + pixel_values = processor(image, return_tensors="pt").pixel_values + + # verify pixel values + assert torch.allclose(timm_pixel_values, pixel_values) + + # verify logits + with torch.no_grad(): + outputs = model(pixel_values) + logits = outputs.logits + + print("Predicted class:", logits.argmax(-1).item()) + if base_model: + timm_pooled_output = timm_model.forward_features(pixel_values) + assert timm_pooled_output.shape == outputs.pooler_output.shape + assert torch.allclose(timm_pooled_output, outputs.pooler_output, atol=1e-3) + else: + timm_logits = timm_model(pixel_values) + assert timm_logits.shape == outputs.logits.shape + assert torch.allclose(timm_logits, outputs.logits, atol=1e-3) + print("Looks ok!") + + if pytorch_dump_folder_path is not None: + Path(pytorch_dump_folder_path).mkdir(exist_ok=True) + print(f"Saving model {vit_name} to {pytorch_dump_folder_path}") + model.save_pretrained(pytorch_dump_folder_path) + print(f"Saving processor to {pytorch_dump_folder_path}") + processor.save_pretrained(pytorch_dump_folder_path) + + if push_to_hub: + print(f"Pushing model and processor to the hub {vit_name}") + model.push_to_hub(f"ybelkada/{vit_name}") + processor.push_to_hub(f"ybelkada/{vit_name}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--vit_name", + default="vit_base_r50_s16_384", + type=str, + help="Name of the hybrid ViT timm model you'd like to convert.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory." + ) + parser.add_argument( + "--push_to_hub", action="store_true", help="Whether to upload the model to the HuggingFace hub." + ) + + args = parser.parse_args() + convert_vit_checkpoint(args.vit_name, args.pytorch_dump_folder_path, args.push_to_hub) diff --git a/docs/transformers/build/lib/transformers/models/deprecated/vit_hybrid/image_processing_vit_hybrid.py b/docs/transformers/build/lib/transformers/models/deprecated/vit_hybrid/image_processing_vit_hybrid.py new file mode 100644 index 0000000000000000000000000000000000000000..72410878933d7bd9726794588dc241c08b405a25 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/vit_hybrid/image_processing_vit_hybrid.py @@ -0,0 +1,341 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for ViT hybrid.""" + +from typing import Dict, List, Optional, Union + +import numpy as np + +from ....image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ....image_transforms import ( + convert_to_rgb, + get_resize_output_image_size, + resize, + to_channel_dimension_format, +) +from ....image_utils import ( + OPENAI_CLIP_MEAN, + OPENAI_CLIP_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + infer_channel_dimension_format, + is_scaled_image, + make_list_of_images, + to_numpy_array, + valid_images, + validate_kwargs, + validate_preprocess_arguments, +) +from ....utils import TensorType, is_vision_available, logging + + +logger = logging.get_logger(__name__) + + +if is_vision_available(): + import PIL + + +class ViTHybridImageProcessor(BaseImageProcessor): + r""" + Constructs a ViT Hybrid image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by + `do_resize` in the `preprocess` method. + size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 224}`): + Size of the image after resizing. The shortest edge of the image is resized to size["shortest_edge"], with + the longest edge resized to keep the input aspect ratio. Can be overridden by `size` in the `preprocess` + method. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method. + do_center_crop (`bool`, *optional*, defaults to `True`): + Whether to center crop the image to the specified `crop_size`. Can be overridden by `do_center_crop` in the + `preprocess` method. + crop_size (`Dict[str, int]` *optional*, defaults to 224): + Size of the output image after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess` + method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in + the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess` + method. + do_normalize: + Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method. + image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + Can be overridden by the `image_std` parameter in the `preprocess` method. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize: bool = True, + size: Dict[str, int] = None, + resample: PILImageResampling = PILImageResampling.BICUBIC, + do_center_crop: bool = True, + crop_size: Dict[str, int] = None, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_convert_rgb: bool = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"shortest_edge": 224} + size = get_size_dict(size, default_to_square=False) + crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224} + crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size") + + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_center_crop = do_center_crop + self.crop_size = crop_size + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN + self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD + self.do_convert_rgb = do_convert_rgb + self._valid_processor_keys = [ + "images", + "do_resize", + "size", + "resample", + "do_center_crop", + "crop_size", + "do_rescale", + "rescale_factor", + "do_normalize", + "image_mean", + "image_std", + "do_convert_rgb", + "return_tensors", + "data_format", + "input_data_format", + ] + + def resize( + self, + image: np.ndarray, + size: Dict[str, int], + resample: PILImageResampling = PILImageResampling.BICUBIC, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge + resized to keep the input aspect ratio. + + Args: + image (`np.ndarray`): + Image to resize. + size (`Dict[str, int]`): + Size of the output image. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + Resampling filter to use when resiizing the image. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + default_to_square = True + if "shortest_edge" in size: + size = size["shortest_edge"] + default_to_square = False + elif "height" in size and "width" in size: + size = (size["height"], size["width"]) + else: + raise ValueError("Size must contain either 'shortest_edge' or 'height' and 'width'.") + + output_size = get_resize_output_image_size( + image, + size=size, + default_to_square=default_to_square, + input_data_format=input_data_format, + ) + return resize( + image, + size=output_size, + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + def preprocess( + self, + images: ImageInput, + do_resize: Optional[bool] = None, + size: Dict[str, int] = None, + resample: PILImageResampling = None, + do_center_crop: Optional[bool] = None, + crop_size: Optional[int] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_convert_rgb: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> PIL.Image.Image: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with + the longest edge resized to keep the input aspect ratio. + resample (`int`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only + has an effect if `do_resize` is set to `True`. + do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`): + Whether to center crop the image. + crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`): + Size of the center crop. Only has an effect if `do_center_crop` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to + `True`. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: defaults to the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + do_resize = do_resize if do_resize is not None else self.do_resize + size = size if size is not None else self.size + size = get_size_dict(size, param_name="size", default_to_square=False) + resample = resample if resample is not None else self.resample + do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop + crop_size = crop_size if crop_size is not None else self.crop_size + crop_size = get_size_dict(crop_size, param_name="crop_size", default_to_square=True) + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + + images = make_list_of_images(images) + + validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_resize=do_resize, + size=size, + resample=resample, + ) + + # PIL RGBA images are converted to RGB + if do_convert_rgb: + images = [convert_to_rgb(image) for image in images] + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if do_rescale and is_scaled_image(images[0]): + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + all_images = [] + for image in images: + if do_resize: + image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format) + + if do_center_crop: + image = self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) + + if do_rescale: + image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + + if do_normalize: + image = self.normalize( + image=image, mean=image_mean, std=image_std, input_data_format=input_data_format + ) + + all_images.append(image) + images = [ + to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + for image in all_images + ] + + data = {"pixel_values": images} + return BatchFeature(data=data, tensor_type=return_tensors) + + +__all__ = ["ViTHybridImageProcessor"] diff --git a/docs/transformers/build/lib/transformers/models/deprecated/vit_hybrid/modeling_vit_hybrid.py b/docs/transformers/build/lib/transformers/models/deprecated/vit_hybrid/modeling_vit_hybrid.py new file mode 100644 index 0000000000000000000000000000000000000000..c4d2511019395f97ad790f9ab548667a4cac3b5b --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/vit_hybrid/modeling_vit_hybrid.py @@ -0,0 +1,770 @@ +# coding=utf-8 +# Copyright 2022 Google AI, Ross Wightman, The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch ViT Hybrid model.""" + +import collections.abc +import math +from typing import Dict, List, Optional, Set, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ....activations import ACT2FN +from ....modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput +from ....modeling_utils import PreTrainedModel +from ....pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ....utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + torch_int, +) +from ....utils.backbone_utils import load_backbone +from .configuration_vit_hybrid import ViTHybridConfig + + +logger = logging.get_logger(__name__) + +# General docstring +_CONFIG_FOR_DOC = "ViTHybridConfig" + +# Base docstring +_CHECKPOINT_FOR_DOC = "google/vit-hybrid-base-bit-384" +_EXPECTED_OUTPUT_SHAPE = [1, 197, 768] + +# Image classification docstring +_IMAGE_CLASS_CHECKPOINT = "google/vit-hybrid-base-bit-384" +_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat" + + +class ViTHybridEmbeddings(nn.Module): + """ + Construct the CLS token, position and patch embeddings. Optionally, also the mask token. + """ + + def __init__(self, config: ViTHybridConfig, use_mask_token: bool = False) -> None: + super().__init__() + + self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size)) + self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None + self.patch_embeddings = ViTHybridPatchEmbeddings(config) + num_patches = self.patch_embeddings.num_patches + self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size)) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.patch_size = config.patch_size + self.config = config + + # Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing. + + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 + """ + + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embeddings.shape[1] - 1 + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: + return self.position_embeddings + + class_pos_embed = self.position_embeddings[:, :1] + patch_pos_embed = self.position_embeddings[:, 1:] + + dim = embeddings.shape[-1] + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(new_height, new_width), + mode="bicubic", + align_corners=False, + ) + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) + + def forward( + self, + pixel_values: torch.Tensor, + bool_masked_pos: Optional[torch.BoolTensor] = None, + interpolate_pos_encoding: bool = False, + ) -> torch.Tensor: + batch_size, num_channels, height, width = pixel_values.shape + embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) + + if bool_masked_pos is not None: + seq_length = embeddings.shape[1] + mask_tokens = self.mask_token.expand(batch_size, seq_length, -1) + # replace the masked visual tokens by mask_tokens + mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens) + embeddings = embeddings * (1.0 - mask) + mask_tokens * mask + + # add the [CLS] token to the embedded patch tokens + cls_tokens = self.cls_token.expand(batch_size, -1, -1) + embeddings = torch.cat((cls_tokens, embeddings), dim=1) + + # add positional encoding to each token + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embeddings + + embeddings = self.dropout(embeddings) + + return embeddings + + +class ViTHybridPatchEmbeddings(nn.Module): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__(self, config, feature_size=None): + super().__init__() + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.hidden_size + + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + + self.backbone = load_backbone(config) + if self.backbone.config.model_type != "bit": + raise ValueError(f"Backbone model type {self.backbone.model_type} is not supported.") + feature_dim = self.backbone.channels[-1] + + if feature_size is None: + feature_map = config.backbone_featmap_shape + + feature_size = feature_map[-2:] + feature_dim = feature_map[1] + else: + feature_size = ( + feature_size if isinstance(feature_size, collections.abc.Iterable) else (feature_size, feature_size) + ) + feature_dim = self.backbone.channels[-1] + + self.grid_size = (feature_size[0] // patch_size[0], feature_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + + self.projection = nn.Conv2d(feature_dim, hidden_size, kernel_size=patch_size, stride=patch_size) + + def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor: + _, num_channels, height, width = pixel_values.shape + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + if not interpolate_pos_encoding: + if height != self.image_size[0] or width != self.image_size[1]: + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model" + f" ({self.image_size[0]}*{self.image_size[1]})." + ) + + features = self.backbone(pixel_values).feature_maps[-1] + embeddings = self.projection(features).flatten(2).transpose(1, 2) + + return embeddings + + +class ViTHybridSelfAttention(nn.Module): + def __init__(self, config: ViTHybridConfig) -> None: + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size {config.hidden_size} is not a multiple of the number of attention " + f"heads {config.num_attention_heads}." + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + mixed_query_layer = self.query(hidden_states) + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +class ViTHybridSdpaSelfAttention(ViTHybridSelfAttention): + def __init__(self, config: ViTHybridConfig) -> None: + super().__init__(config) + self.attention_probs_dropout_prob = config.attention_probs_dropout_prob + + def forward( + self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + mixed_query_layer = self.query(hidden_states) + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + head_mask, + self.attention_probs_dropout_prob if self.training else 0.0, + is_causal=False, + scale=None, + ) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + return context_layer, None + + +class ViTHybridSelfOutput(nn.Module): + """ + The residual connection is defined in ViTHybridLayer instead of here (as is the case with other models), due to the + layernorm applied before each block. + """ + + def __init__(self, config: ViTHybridConfig) -> None: + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + return hidden_states + + +class ViTHybridAttention(nn.Module): + def __init__(self, config: ViTHybridConfig) -> None: + super().__init__() + self.attention = ViTHybridSelfAttention(config) + self.output = ViTHybridSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads: Set[int]) -> None: + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.attention.query = prune_linear_layer(self.attention.query, index) + self.attention.key = prune_linear_layer(self.attention.key, index) + self.attention.value = prune_linear_layer(self.attention.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads) + self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + self_outputs = self.attention(hidden_states, head_mask, output_attentions) + + attention_output = self.output(self_outputs[0], hidden_states) + + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class ViTHybridSdpaAttention(ViTHybridAttention): + def __init__(self, config: ViTHybridConfig) -> None: + super().__init__(config) + self.attention = ViTHybridSdpaSelfAttention(config) + + +class ViTHybridIntermediate(nn.Module): + def __init__(self, config: ViTHybridConfig) -> None: + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + + return hidden_states + + +class ViTHybridOutput(nn.Module): + def __init__(self, config: ViTHybridConfig) -> None: + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + hidden_states = hidden_states + input_tensor + + return hidden_states + + +VIT_HYBRID_ATTENTION_CLASSES = { + "eager": ViTHybridAttention, + "sdpa": ViTHybridSdpaAttention, +} + + +class ViTHybridLayer(nn.Module): + """This corresponds to the Block class in the timm implementation.""" + + def __init__(self, config: ViTHybridConfig) -> None: + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = VIT_HYBRID_ATTENTION_CLASSES[config._attn_implementation](config) + self.intermediate = ViTHybridIntermediate(config) + self.output = ViTHybridOutput(config) + self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + self_attention_outputs = self.attention( + self.layernorm_before(hidden_states), # in ViTHybrid, layernorm is applied before self-attention + head_mask, + output_attentions=output_attentions, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + # first residual connection + # We assign to correct device for `accelerate`, check: https://github.com/huggingface/transformers/pull/20705/ + hidden_states = attention_output + hidden_states.to(attention_output.device) + + # in ViTHybrid, layernorm is also applied after self-attention + layer_output = self.layernorm_after(hidden_states) + layer_output = self.intermediate(layer_output) + + # second residual connection is done here + layer_output = self.output(layer_output, hidden_states) + + outputs = (layer_output,) + outputs + + return outputs + + +class ViTHybridEncoder(nn.Module): + def __init__(self, config: ViTHybridConfig) -> None: + super().__init__() + self.config = config + self.layer = nn.ModuleList([ViTHybridLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ) -> Union[tuple, BaseModelOutput]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + layer_head_mask, + output_attentions, + ) + else: + layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class ViTHybridPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = ViTHybridConfig + base_model_prefix = "vit" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + _no_split_modules = ["ViTHybridEmbeddings", "ViTHybridLayer"] + _supports_sdpa = True + + def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid + # `trunc_normal_cpu` not implemented in `half` issues + module.weight.data = nn.init.trunc_normal_( + module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range + ).to(module.weight.dtype) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, ViTHybridEmbeddings): + module.position_embeddings.data = nn.init.trunc_normal_( + module.position_embeddings.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.position_embeddings.dtype) + module.cls_token.data = nn.init.trunc_normal_( + module.cls_token.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.cls_token.dtype) + module.mask_token.data.zero_() + + +VIT_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it + as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`ViTHybridConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +VIT_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See + [`ViTHybridImageProcessor.__call__`] for details. + + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare ViT Hybrid Model transformer outputting raw hidden-states without any specific head on top.", + VIT_START_DOCSTRING, +) +class ViTHybridModel(ViTHybridPreTrainedModel): + def __init__(self, config: ViTHybridConfig, add_pooling_layer: bool = True, use_mask_token: bool = False): + super().__init__(config) + self.config = config + + self.embeddings = ViTHybridEmbeddings(config, use_mask_token=use_mask_token) + self.encoder = ViTHybridEncoder(config) + + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.pooler = ViTHybridPooler(config) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> ViTHybridPatchEmbeddings: + return self.embeddings.patch_embeddings + + def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None: + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPooling, + config_class=_CONFIG_FOR_DOC, + modality="vision", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*): + Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + # TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?) + expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype + if pixel_values.dtype != expected_dtype: + pixel_values = pixel_values.to(expected_dtype) + + embedding_output = self.embeddings( + pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding + ) + + encoder_outputs = self.encoder( + embedding_output, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,) + return head_outputs + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class ViTHybridPooler(nn.Module): + def __init__(self, config: ViTHybridConfig): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +@add_start_docstrings( + """ + ViT Hybrid Model transformer with an image classification head on top (a linear layer on top of the final hidden + state of the [CLS] token) e.g. for ImageNet. + """, + VIT_START_DOCSTRING, +) +class ViTHybridForImageClassification(ViTHybridPreTrainedModel): + def __init__(self, config: ViTHybridConfig) -> None: + super().__init__(config) + + self.num_labels = config.num_labels + self.vit = ViTHybridModel(config, add_pooling_layer=False) + + # Classifier head + self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity() + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_IMAGE_CLASS_CHECKPOINT, + output_type=ImageClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, + ) + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, ImageClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.vit( + pixel_values, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.classifier(sequence_output[:, 0, :]) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = ["ViTHybridForImageClassification", "ViTHybridModel", "ViTHybridPreTrainedModel"] diff --git a/docs/transformers/build/lib/transformers/models/deprecated/xlm_prophetnet/__init__.py b/docs/transformers/build/lib/transformers/models/deprecated/xlm_prophetnet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c13c67012fa157a94bae21734a1f59b26c78588d --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/xlm_prophetnet/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ....utils import _LazyModule +from ....utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_xlm_prophetnet import * + from .modeling_xlm_prophetnet import * + from .tokenization_xlm_prophetnet import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/build/lib/transformers/models/deprecated/xlm_prophetnet/configuration_xlm_prophetnet.py b/docs/transformers/build/lib/transformers/models/deprecated/xlm_prophetnet/configuration_xlm_prophetnet.py new file mode 100644 index 0000000000000000000000000000000000000000..2d7751d9541ec1a3b39799915d742955d863cd24 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/xlm_prophetnet/configuration_xlm_prophetnet.py @@ -0,0 +1,181 @@ +# coding=utf-8 +# Copyright 2020 The Microsoft Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""XLM-ProphetNet model configuration""" + +from typing import Callable, Optional, Union + +from ....configuration_utils import PretrainedConfig +from ....utils import logging + + +logger = logging.get_logger(__name__) + + +class XLMProphetNetConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`XLMProphetNetModel`]. It is used to instantiate a + XLMProphetNet model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the XLMProphetNet + [microsoft/xprophetnet-large-wiki100-cased](https://huggingface.co/microsoft/xprophetnet-large-wiki100-cased) + architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + activation_dropout (`float`, *optional*, defaults to 0.1): + The dropout ratio for activations inside the fully connected layer. + activation_function (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + vocab_size (`int`, *optional*, defaults to 30522): + Vocabulary size of the ProphetNET model. Defines the number of different tokens that can be represented by + the `inputs_ids` passed when calling [`XLMProphetNetModel`]. + hidden_size (`int`, *optional*, defaults to 1024): + Dimensionality of the layers and the pooler layer. + encoder_ffn_dim (`int`, *optional*, defaults to 4096): + Dimensionality of the "intermediate" (often named feed-forward) layer in decoder. + num_encoder_layers (`int`, *optional*, defaults to 12): + Number of encoder layers. + num_encoder_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + decoder_ffn_dim (`int`, *optional*, defaults to 4096): + Dimensionality of the `intermediate` (often named feed-forward) layer in decoder. + num_decoder_layers (`int`, *optional*, defaults to 12): + Number of decoder layers. + num_decoder_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer decoder. + attention_dropout (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + init_std (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + add_cross_attention (`bool`, *optional*, defaults to `True`): + Whether cross-attention layers should be added to the model. + is_encoder_decoder (`bool`, *optional*, defaults to `True`): + Whether this is an encoder/decoder model. + pad_token_id (`int`, *optional*, defaults to 1) + Padding token id. + bos_token_id (`int`, *optional*, defaults to 0) + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2) + End of stream token id. + ngram (`int`, *optional*, defaults to 2) + Number of future tokens to predict. Set to 1 to be same as traditional Language model to predict next first + token. + num_buckets (`int`, *optional*, defaults to 32) + The number of buckets to use for each attention layer. This is for relative position calculation. See the + [T5 paper](see https://arxiv.org/abs/1910.10683) for more details. + relative_max_distance (`int`, *optional*, defaults to 128) + Relative distances greater than this number will be put into the last same bucket. This is for relative + position calculation. See the [T5 paper](see https://arxiv.org/abs/1910.10683) for more details. + disable_ngram_loss (`bool`, *optional*, defaults to `False`): + Whether be trained predicting only the next first token. + eps (`float`, *optional*, defaults to 0.0): + Controls the `epsilon` parameter value for label smoothing in the loss calculation. If set to 0, no label + smoothing is performed. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + """ + + model_type = "xlm-prophetnet" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = { + "num_attention_heads": "num_encoder_attention_heads", + } + + def __init__( + self, + activation_dropout: Optional[float] = 0.1, + activation_function: Optional[Union[str, Callable]] = "gelu", + vocab_size: Optional[int] = 30522, + hidden_size: Optional[int] = 1024, + encoder_ffn_dim: Optional[int] = 4096, + num_encoder_layers: Optional[int] = 12, + num_encoder_attention_heads: Optional[int] = 16, + decoder_ffn_dim: Optional[int] = 4096, + num_decoder_layers: Optional[int] = 12, + num_decoder_attention_heads: Optional[int] = 16, + attention_dropout: Optional[float] = 0.1, + dropout: Optional[float] = 0.1, + max_position_embeddings: Optional[int] = 512, + init_std: Optional[float] = 0.02, + is_encoder_decoder: Optional[bool] = True, + add_cross_attention: Optional[bool] = True, + decoder_start_token_id: Optional[int] = 0, + ngram: Optional[int] = 2, + num_buckets: Optional[int] = 32, + relative_max_distance: Optional[int] = 128, + disable_ngram_loss: Optional[bool] = False, + eps: Optional[float] = 0.0, + use_cache: Optional[bool] = True, + pad_token_id: Optional[int] = 0, + bos_token_id: Optional[int] = 1, + eos_token_id: Optional[int] = 2, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.encoder_ffn_dim = encoder_ffn_dim + self.num_encoder_layers = num_encoder_layers + self.num_encoder_attention_heads = num_encoder_attention_heads + self.decoder_ffn_dim = decoder_ffn_dim + self.num_decoder_layers = num_decoder_layers + self.num_decoder_attention_heads = num_decoder_attention_heads + self.max_position_embeddings = max_position_embeddings + self.init_std = init_std # Normal(0, this parameter) + self.activation_function = activation_function + + # parameters for xlmprophetnet + self.ngram = ngram + self.num_buckets = num_buckets + self.relative_max_distance = relative_max_distance + self.disable_ngram_loss = disable_ngram_loss + self.eps = eps + + # 3 Types of Dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.dropout = dropout + + self.use_cache = use_cache + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + is_encoder_decoder=is_encoder_decoder, + add_cross_attention=add_cross_attention, + decoder_start_token_id=decoder_start_token_id, + **kwargs, + ) + + @property + def num_hidden_layers(self) -> int: + return self.num_encoder_layers + self.num_decoder_layers + + @num_hidden_layers.setter + def num_hidden_layers(self, value): + raise NotImplementedError( + "This model does not support the setting of `num_hidden_layers`. Please set `num_encoder_layers` and" + " `num_decoder_layers`." + ) + + +__all__ = ["XLMProphetNetConfig"] diff --git a/docs/transformers/build/lib/transformers/models/deprecated/xlm_prophetnet/modeling_xlm_prophetnet.py b/docs/transformers/build/lib/transformers/models/deprecated/xlm_prophetnet/modeling_xlm_prophetnet.py new file mode 100644 index 0000000000000000000000000000000000000000..17bc9ffada60028e6531c90b3939cf22f846cdf3 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/xlm_prophetnet/modeling_xlm_prophetnet.py @@ -0,0 +1,2346 @@ +# coding=utf-8 +# Copyright 2020 The Microsoft Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch XLM-ProphetNet model.""" + +import copy +import math +import warnings +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import Tensor, nn +from torch.nn import LayerNorm + +from ....activations import ACT2FN +from ....modeling_outputs import BaseModelOutput +from ....modeling_utils import PreTrainedModel +from ....utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from .configuration_xlm_prophetnet import XLMProphetNetConfig + + +logger = logging.get_logger(__name__) + + +_CONFIG_FOR_DOC = "XLMProphetNetConfig" + + +XLM_PROPHETNET_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + Original ProphetNet code can be found [here](https://github.com/microsoft/ProphetNet). Checkpoints were converted + from original Fairseq checkpoints. For more information on the checkpoint conversion, please take a look at the + file `convert_prophetnet_original_pytorch_checkpoint_to_pytorch.py`. + + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use + it as a regular PyTorch Module and refer to the PyTorch documentation for all matters related to general usage and + behavior. + + Parameters: + config ([`XLMProphetNetConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +XLM_PROPHETNET_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + XLMProphetNet uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If + `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): + Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of + hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +XLM_PROPHETNET_STANDALONE_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +def softmax(hidden_state, dim, onnx_trace=False): + if onnx_trace: + return nn.functional.softmax(hidden_state.float(), dim=dim) + else: + return nn.functional.softmax(hidden_state, dim=dim, dtype=torch.float32) + + +def ngram_attention_bias(sequence_length, ngram, device, dtype): + """ + This function computes the bias for the predict stream + """ + left_block = ( + torch.ones((ngram, sequence_length, sequence_length), device=device, dtype=dtype) * torch.finfo(dtype).min + ) + right_block = left_block.detach().clone() + # create bias + for stream_idx in range(ngram): + right_block[stream_idx].fill_diagonal_(0, wrap=False) + left_block[stream_idx].triu_(-stream_idx + 1) + + left_block[:, :, 0] = 0 + return torch.cat([left_block, right_block], dim=2) + + +def compute_relative_buckets(num_buckets, max_distance, relative_positions, is_bidirectional=False): + """ + This function computes individual parts of the relative position buckets. For more detail, see paper. + """ + inv_relative_positions = -relative_positions + rel_positions_bucket = 0 + + if is_bidirectional: + num_buckets = num_buckets // 2 + rel_positions_bucket = ( + rel_positions_bucket + + torch.lt(inv_relative_positions, torch.zeros_like(inv_relative_positions)).int() * num_buckets + ) + inv_relative_positions = torch.abs(inv_relative_positions) + else: + inv_relative_positions = torch.max(inv_relative_positions, torch.zeros_like(inv_relative_positions)) + + max_exact = num_buckets // 2 + is_small = torch.lt(inv_relative_positions, max_exact) + val_if_large = max_exact + torch.log(inv_relative_positions.float() / max_exact) / math.log( + max_distance / max_exact + ) * (num_buckets - max_exact) + val_if_large = torch.min(val_if_large, torch.ones_like(val_if_large) * (num_buckets - 1)).int() + rel_positions_bucket = rel_positions_bucket + torch.where(is_small, inv_relative_positions.int(), val_if_large) + return rel_positions_bucket + + +def compute_all_stream_relative_buckets(num_buckets, max_distance, position_ids): + """ + This function computes both main and predict relative position buckets. For more detail, see paper. + """ + # main stream + main_stream_relative_positions = position_ids.unsqueeze(1).repeat(1, position_ids.size(-1), 1) + main_stream_relative_positions = main_stream_relative_positions - position_ids.unsqueeze(-1) + + # predicting stream + predicting_stream_relative_positions = torch.cat((position_ids - 1, position_ids), dim=-1).unsqueeze(1) + predicting_stream_relative_positions = predicting_stream_relative_positions.repeat(1, position_ids.size(-1), 1) + predicting_stream_relative_positions = predicting_stream_relative_positions - position_ids.unsqueeze(-1) + + # get both position buckets + main_relative_position_buckets = compute_relative_buckets( + num_buckets, max_distance, main_stream_relative_positions, is_bidirectional=False + ) + predict_relative_position_buckets = compute_relative_buckets( + num_buckets, max_distance, predicting_stream_relative_positions, is_bidirectional=False + ) + return main_relative_position_buckets, predict_relative_position_buckets + + +@dataclass +class XLMProphetNetSeq2SeqLMOutput(ModelOutput): + """ + Base class for sequence-to-sequence language models outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss. + logits (`torch.FloatTensor` of shape `(batch_size, decoder_sequence_length, config.vocab_size)`): + Prediction scores of the main stream language modeling head (scores for each vocabulary token before + SoftMax). + logits_ngram (`torch.FloatTensor` of shape `(batch_size, ngram * decoder_sequence_length, config.vocab_size)`): + Prediction scores of the predict stream language modeling head (scores for each vocabulary token before + SoftMax). + past_key_values (`List[torch.FloatTensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + List of `torch.FloatTensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, + num_attn_heads, decoder_sequence_length, embed_size_per_head)`). + + Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be + used (see `past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, decoder_sequence_length, hidden_size)`. + + Hidden-states of main stream of the decoder at the output of each layer plus the initial embedding outputs. + decoder_ngram_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, ngram * decoder_sequence_length, hidden_size)`. + + Hidden-states of the predict stream of the decoder at the output of each layer plus the initial embedding + outputs. + decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads, + decoder_sequence_length, decoder_sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + decoder_ngram_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads, + decoder_sequence_length, decoder_sequence_length)`. + + Attentions weights of the predict stream of the decoder, after the attention softmax, used to compute the + weighted average in the self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads, + encoder_sequence_length, decoder_sequence_length)`. + + Attentions weights of the cross-attention layer of the decoder, after the attention softmax, used to + compute the weighted average in the + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, encoder_sequence_length, hidden_size)`. + + Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads, + encoder_sequence_length, encoder_sequence_length)`. Attentions weights of the encoder, after the attention + softmax, used to compute the weighted average in the self-attention heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + logits_ngram: Optional[torch.FloatTensor] = None + past_key_values: Optional[Tuple[torch.FloatTensor]] = None + decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + decoder_ngram_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + decoder_ngram_attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + + @property + def decoder_cross_attentions(self): + warnings.warn( + "`decoder_cross_attentions` is deprecated and will be removed soon. Please use `cross_attentions`" + " instead.", + FutureWarning, + ) + return self.cross_attentions + + +@dataclass +class XLMProphetNetSeq2SeqModelOutput(ModelOutput): + """ + Base class for model encoder's outputs that also contains : pre-computed hidden states that can speed up sequential + decoding. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, decoder_sequence_length, hidden_size)`): + Sequence of main stream hidden-states at the output of the last layer of the decoder of the model. + + If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, + hidden_size)` is output. + last_hidden_state_ngram (`torch.FloatTensor` of shape `(batch_size,ngram * decoder_sequence_length, config.vocab_size)`, *optional*): + Sequence of predict stream hidden-states at the output of the last layer of the decoder of the model. + past_key_values (`List[torch.FloatTensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + List of `torch.FloatTensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, + num_attn_heads, decoder_sequence_length, embed_size_per_head)`). + + Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be + used (see `past_key_values` input) to speed up sequential decoding. + decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, decoder_sequence_length, hidden_size)`. + + Hidden-states of main stream of the decoder at the output of each layer plus the initial embedding outputs. + decoder_ngram_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, ngram * decoder_sequence_length, hidden_size)`. + + Hidden-states of the predict stream of the decoder at the output of each layer plus the initial embedding + outputs. + decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads, + decoder_sequence_length, decoder_sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + decoder_ngram_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads, + decoder_sequence_length, decoder_sequence_length)`. + + Attentions weights of the predict stream of the decoder, after the attention softmax, used to compute the + weighted average in the + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads, + encoder_sequence_length, decoder_sequence_length)`. + + Attentions weights of the cross-attention layer of the decoder, after the attention softmax, used to + compute the weighted average in the + encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder of the model. + encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, encoder_sequence_length, hidden_size)`. + + Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. + encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads, + encoder_sequence_length, encoder_sequence_length)`. + + Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + """ + + last_hidden_state: torch.FloatTensor + last_hidden_state_ngram: Optional[torch.FloatTensor] = None + past_key_values: Optional[Tuple[torch.FloatTensor]] = None + decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + decoder_ngram_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + decoder_ngram_attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + encoder_last_hidden_state: Optional[torch.FloatTensor] = None + encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + + @property + def decoder_cross_attentions(self): + warnings.warn( + "`decoder_cross_attentions` is deprecated and will be removed soon. Please use `cross_attentions`" + " instead.", + FutureWarning, + ) + return self.cross_attentions + + +@dataclass +class XLMProphetNetDecoderModelOutput(ModelOutput): + """ + Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding). + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, decoder_sequence_length, hidden_size)`): + Sequence of main stream hidden-states at the output of the last layer of the decoder of the model. + + If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, + hidden_size)` is output. + last_hidden_state_ngram (`torch.FloatTensor` of shape `(batch_size, ngram * decoder_sequence_length, config.vocab_size)`): + Sequence of predict stream hidden-states at the output of the last layer of the decoder of the model. + past_key_values (`List[torch.FloatTensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + List of `torch.FloatTensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, + num_attn_heads, decoder_sequence_length, embed_size_per_head)`). + + Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be + used (see `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, decoder_sequence_length, hidden_size)`. + + Hidden-states of main stream of the decoder at the output of each layer plus the initial embedding outputs. + ngram_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, ngram * decoder_sequence_length, hidden_size)`. + + Hidden-states of the predict stream of the decoder at the output of each layer plus the initial embedding + outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads, + decoder_sequence_length, decoder_sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + ngram_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads, + decoder_sequence_length, decoder_sequence_length)`. + + Attentions weights of the predict stream of the decoder, after the attention softmax, used to compute the + weighted average in the + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads, + encoder_sequence_length, decoder_sequence_length)`. + + Attentions weights of the cross-attention layer of the decoder, after the attention softmax, used to + compute the weighted average in the + """ + + last_hidden_state: torch.FloatTensor + last_hidden_state_ngram: Optional[torch.FloatTensor] = None + past_key_values: Optional[Tuple[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + hidden_states_ngram: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + ngram_attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class XLMProphetNetDecoderLMOutput(ModelOutput): + """ + Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding). + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss. + logits (`torch.FloatTensor` of shape `(batch_size, decoder_sequence_length, config.vocab_size)`): + Prediction scores of the main stream language modeling head (scores for each vocabulary token before + SoftMax). + logits_ngram (`torch.FloatTensor` of shape `(batch_size, ngram * decoder_sequence_length, config.vocab_size)`): + Prediction scores of the predict stream language modeling head (scores for each vocabulary token before + SoftMax). + past_key_values (`List[torch.FloatTensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + List of `torch.FloatTensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, + num_attn_heads, decoder_sequence_length, embed_size_per_head)`). + + Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be + used (see `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, decoder_sequence_length, hidden_size)`. + + Hidden-states of main stream of the decoder at the output of each layer plus the initial embedding outputs. + ngram_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, ngram * decoder_sequence_length, hidden_size)`. + + Hidden-states of the predict stream of the decoder at the output of each layer plus the initial embedding + outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads, + decoder_sequence_length, decoder_sequence_length)`. + + Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the + self-attention heads. + ngram_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads, + decoder_sequence_length, decoder_sequence_length)`. + + Attentions weights of the predict stream of the decoder, after the attention softmax, used to compute the + weighted average in the + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_attn_heads, + encoder_sequence_length, decoder_sequence_length)`. + + Attentions weights of the cross-attention layer of the decoder, after the attention softmax, used to + compute the weighted average in the + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + logits_ngram: Optional[torch.FloatTensor] = None + past_key_values: Optional[Tuple[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + hidden_states_ngram: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + ngram_attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + + +class XLMProphetNetPreTrainedModel(PreTrainedModel): + config_class = XLMProphetNetConfig + base_model_prefix = "prophetnet" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=self.config.init_std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.init_std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def _shift_right(self, input_ids): + decoder_start_token_id = self.config.decoder_start_token_id + pad_token_id = self.config.pad_token_id + + assert decoder_start_token_id is not None, ( + "self.model.config.decoder_start_token_id has to be defined. In XLMProphetNet it is usually set to the" + " pad_token_id. See XLMProphetNet docs for more information" + ) + + # shift inputs to the right + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() + shifted_input_ids[..., 0] = decoder_start_token_id + + assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + assert torch.all(shifted_input_ids >= 0).item(), "Verify that `shifted_input_ids` has only positive values" + + return shifted_input_ids + + +class XLMProphetNetPositionalEmbeddings(nn.Embedding): + """ + This module learns positional embeddings up to a fixed maximum size. Padding ids are ignored by either offsetting + based on padding_idx or by setting padding_idx to None and ensuring that the appropriate position ids are passed to + the forward function. + """ + + def __init__(self, config: XLMProphetNetConfig) -> None: + self.max_length = config.max_position_embeddings + super().__init__(config.max_position_embeddings, config.hidden_size, config.pad_token_id) + + def forward(self, inputs_shape, device, attention_mask=None, past_key_values=None, position_ids=None): + assert (position_ids is None) or (self.padding_idx is None), ( + "If position_ids is pre-computed then padding_idx should not be set." + ) + + if position_ids is None: + if past_key_values is not None: + # position_ids is the same for every token when decoding a single step + # Without the int() cast, it doesn't work in some cases when exporting to ONNX + prev_num_input_ids = past_key_values[0][0].shape[2] + num_input_ids = inputs_shape[1] + prev_num_input_ids + position_ids = torch.ones((1, 1), dtype=torch.long, device=device) * ( + int(self.padding_idx + num_input_ids) + ) + else: + if attention_mask is None: + attention_mask = torch.ones(inputs_shape, dtype=torch.long, device=device) + + # retrieve position_ids from input_ids / attention_mask + position_ids = ( + torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask + ).long() + self.padding_idx + + # make sure position_ids are not bigger then max_length + position_ids = position_ids.clamp(0, self.max_length - 1) + + return super().forward(position_ids), position_ids + + def _forward(self, position_ids): + return super().forward(position_ids) + + +class XLMProphetNetAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + config: XLMProphetNetConfig, + num_attn_heads: int, + ): + super().__init__() + hidden_size = config.hidden_size + + self.attention_dropout = config.attention_dropout + self.dropout = config.dropout + self.num_attn_heads = num_attn_heads + self.head_dim = hidden_size // num_attn_heads + + assert self.head_dim * num_attn_heads == hidden_size, ( + "`config.hidden_size` must be divisible by `config.num_encoder_attention_heads` and" + " `config.num_decoder_attention_heads`" + ) + + self.key_proj = nn.Linear(hidden_size, hidden_size) + self.value_proj = nn.Linear(hidden_size, hidden_size) + self.query_proj = nn.Linear(hidden_size, hidden_size) + + self.out_proj = nn.Linear(hidden_size, hidden_size) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_attn_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states, + key_value_states: Optional[Tensor] = None, + attention_mask: Optional[Tensor] = None, + layer_head_mask: Optional[Tensor] = None, + past_key_value: Optional[Tuple[Tensor]] = None, + output_attentions: bool = False, + ) -> Tuple[Tensor, Optional[Tensor]]: + batch_size, tgt_len, hidden_size = hidden_states.size() + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + assert list(hidden_states.size()) == [ + batch_size, + tgt_len, + hidden_size, + ], f"Size of hidden states should be {batch_size, tgt_len, hidden_size}, but is {hidden_states.size()}" + + # previous time steps are cached - no need to recompute key and value if they are static + query_states = self.query_proj(hidden_states) / (self.head_dim**0.5) + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.key_proj(key_value_states), -1, batch_size) + value_states = self._shape(self.value_proj(key_value_states), -1, batch_size) + else: + # self_attention + key_states = self._shape(self.key_proj(hidden_states), -1, batch_size) + value_states = self._shape(self.value_proj(hidden_states), -1, batch_size) + + if is_cross_attention: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + # project states into the correct shape + proj_shape = (batch_size, self.num_attn_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, batch_size).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + src_len = key_states.size(2) + attn_weights = torch.einsum("bsij,bsjk->bsik", query_states, key_states.transpose(2, 3)) + expected_shape = (batch_size, self.num_attn_heads, tgt_len, src_len) + if attn_weights.size() != expected_shape: + raise ValueError(f"Attention weights should have size {expected_shape}, but is {attn_weights.size()}") + + # This is part of a workaround to get around fork/join parallelism not supporting Optional types. + if attention_mask is not None and attention_mask.dim() == 0: + attention_mask = None + + expected_shape = (batch_size, self.num_attn_heads, 1, src_len) + if attention_mask is not None and attention_mask.size() != expected_shape: + raise ValueError(f"Attention mask should have size {expected_shape}, but is {attention_mask.size()}") + if attention_mask is not None: # don't attend to padding symbols + attn_weights = attn_weights + attention_mask + if output_attentions: + attn_weights_reshaped = attn_weights + else: + attn_weights_reshaped = None + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + assert layer_head_mask.size() == (self.num_attn_heads,), ( + f"Head mask for a single layer should be of size {(self.num_attn_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view( + batch_size, self.num_attn_heads, tgt_len, src_len + ) + + # apply head_mask also on attn_weights_reshaped which is used for n-gram attention inside the model + attn_weights_reshaped = layer_head_mask.view(1, -1, 1, 1) * attn_weights_reshaped + + attn_probs = nn.functional.dropout( + attn_weights, + p=self.attention_dropout, + training=self.training, + ) + attn_output = torch.einsum("bsij,bsjk->bsik", attn_probs, value_states) + expected_shape = (batch_size, self.num_attn_heads, tgt_len, self.head_dim) + if attn_output.size() != expected_shape: + raise ValueError(f"`attn_output` should have shape {expected_shape}, but is of shape {attn_output.size()}") + + attn_output = attn_output.transpose(1, 2).reshape(batch_size, tgt_len, hidden_size) + attn_output = self.out_proj(attn_output) + + attn_output = nn.functional.dropout(attn_output, p=self.dropout, training=self.training) + return attn_output, attn_weights_reshaped, past_key_value + + +class XLMProphetNetFeedForward(nn.Module): + """ + This is the residual two feed-forward layer block based on the original Transformer implementation. + """ + + def __init__(self, config: XLMProphetNetConfig, ffn_dim: int): + super().__init__() + self.activation_fn = ACT2FN[config.activation_function] + self.intermediate = nn.Linear(config.hidden_size, ffn_dim) + self.output = nn.Linear(ffn_dim, config.hidden_size) + self.activation_dropout = config.activation_dropout + self.dropout = config.dropout + + def forward(self, hidden_states): + hidden_states = self.intermediate(hidden_states) + hidden_states = self.activation_fn(hidden_states) + + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.output(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + return hidden_states + + +class XLMProphetNetNgramSelfAttention(nn.Module): + def __init__(self, config: XLMProphetNetConfig): + super().__init__() + self.hidden_size = config.hidden_size + + self.num_buckets = config.num_buckets + self.relative_max_distance = config.relative_max_distance + self.num_attn_heads = config.num_decoder_attention_heads + self.dropout = config.dropout + self.attention_dropout = config.attention_dropout + self.head_dim = config.hidden_size // self.num_attn_heads + self.ngram = config.ngram + + assert self.head_dim * self.num_attn_heads == config.hidden_size, ( + "config.hidden_size must be divisible by num_attn_heads" + ) + # key, value, query projection + self.key_proj = nn.Linear(config.hidden_size, config.hidden_size) + self.value_proj = nn.Linear(config.hidden_size, config.hidden_size) + self.query_proj = nn.Linear(config.hidden_size, config.hidden_size) + + # out projection + self.out_proj = nn.Linear(config.hidden_size, config.hidden_size) + + # rel position embeddings + self.relative_pos_embeddings = nn.Linear(config.hidden_size, self.num_buckets * self.num_attn_heads) + + # for onnx runtime + self.onnx_trace = False + + def _shape(self, tensor, seq_len, batch_size): + return tensor.view(batch_size, seq_len, self.num_attn_heads, self.head_dim).transpose(1, 2).contiguous() + + def prepare_for_onnx_export_(self): + self.onnx_trace = True + + def forward( + self, + hidden_states, + past_key_value: Optional[Tuple[Tensor]] = None, + attention_mask=None, + layer_head_mask=None, + extended_predict_attention_mask=None, + main_relative_position_buckets=None, + predict_relative_position_buckets=None, + position_ids=None, + ): + batch_size, ngram_sequence_length, hidden_size = hidden_states.size() + assert list(hidden_states.size()) == [batch_size, ngram_sequence_length, hidden_size], ( + f"`hidden_states` should be of shape {batch_size, ngram_sequence_length, hidden_size}, but is of shape" + f" {hidden_states.shape}" + ) + + # project + query_states = self.query_proj(hidden_states) + key_states = self.key_proj(hidden_states) + value_states = self.value_proj(hidden_states) + + # normalize + query_states = query_states / (self.head_dim**0.5) + + # reshape + query_states = self._shape(query_states, ngram_sequence_length, batch_size) + key_states = self._shape(key_states, -1, batch_size) + value_states = self._shape(value_states, -1, batch_size) + proj_shape = (batch_size, self.num_attn_heads, -1, self.head_dim) + + query_states = query_states.view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + # chunk into main stream and predict stream + hidden_states_list = hidden_states.chunk(1 + self.ngram, dim=1) + query_states_list = query_states.chunk(1 + self.ngram, dim=2) + key_states_list = key_states.chunk(1 + self.ngram, dim=2) + value_states_list = value_states.chunk(1 + self.ngram, dim=2) + + main_hidden_states, hidden_states_predict_list = hidden_states_list[0], hidden_states_list[1:] + main_query_states, predict_query_states_list = query_states_list[0], query_states_list[1:] + main_key_states, predict_key_states_list = key_states_list[0], key_states_list[1:] + main_value_states, predict_value_states_list = value_states_list[0], value_states_list[1:] + + # saved states are stored with shape (batch_size, num_attn_heads, seq_len, head_dim) + if past_key_value is not None: + prev_main_key_states = past_key_value[0] + main_key_states = torch.cat((prev_main_key_states, main_key_states), dim=2) + prev_main_value_states = past_key_value[1] + main_value_states = torch.cat((prev_main_value_states, main_value_states), dim=2) + + # Update cache + past_key_value = (main_key_states, main_value_states) + + # get seq_length of main stream only + sequence_length = ngram_sequence_length // (1 + self.ngram) + + # MAIN-STREAM + # main attn weights + # [batch_size, number_heads, sequence_length, head_dimesion] + # x [batch_size, number_heads, head_dimesion, sequence_length] + # -> [batch_size, number_heads, sequence_length, sequence_length] + main_attn_weights = torch.einsum("bntc,bncs->bnts", main_query_states, main_key_states.transpose(2, 3)) + + # retrieve relative position embeddings for each layer -> see paper for more details + main_relative_pos_embeddings = self.get_main_relative_pos_embeddings( + main_hidden_states, main_attn_weights, position_ids, main_relative_position_buckets + ) + + main_attn_weights = main_attn_weights + main_relative_pos_embeddings + + if attention_mask is not None: + main_attn_weights = main_attn_weights + attention_mask + + main_attn_probs = softmax( + main_attn_weights, + dim=-1, + onnx_trace=self.onnx_trace, + ).type_as(main_attn_weights) + + if layer_head_mask is not None: + assert layer_head_mask.size() == (self.num_attn_heads,), ( + f"Head mask for a single layer should be of size {(self.num_attn_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + main_attn_probs = layer_head_mask.view(1, -1, 1, 1) * main_attn_probs.view( + batch_size, self.num_attn_heads, -1, sequence_length + ) + + main_attn_probs = nn.functional.dropout(main_attn_probs, p=self.attention_dropout, training=self.training) + # project to attn_output + # [batch_size, number_heads, sequence_length, sequence_length] + # x [batch_size, number_heads, sequence_length, head_dimesion] + # -> [batch_size, number_heads, sequence_length, head_dimesion] + main_attn_output = torch.einsum("bntc,bncs->bnts", main_attn_probs, main_value_states) + # reshape so that num_heads dim is merged into last `head_dim` axis + main_attn_output = main_attn_output.transpose(1, 2).reshape(batch_size, 1, sequence_length, hidden_size) + main_attn_output = self.out_proj(main_attn_output) + + # PREDICT-STREAM + # [batch_size, ngram, number_heads, sequence_length, head_dimesion] + predict_query_states = torch.stack(predict_query_states_list, 1).view( + batch_size, self.ngram, self.num_attn_heads, sequence_length, self.head_dim + ) + + # [batch_size, ngram, number_heads, 2*sequence_length, head_dimesion] + predict_key_states = torch.stack([torch.cat([main_key_states, key], 2) for key in predict_key_states_list], 1) + + # [batch_size, sequence_length, ngram, hidden_size] + predict_hidden_states = torch.stack(hidden_states_predict_list, dim=2) + + # [batch_size, number_heads, ngram, 2*sequence_length, head_dimesion] + predict_value_states = torch.cat( + [torch.cat([main_value_states, v_p], 2).unsqueeze(2) for v_p in predict_value_states_list], 2 + ) + + # [batch_size, ngram, number_heads, sequence_length, head_dimesion] + # x [batch_size, ngram, number_heads, 2*sequence_length, head_dimesion] + # -> [batch_size, ngram, number_heads, sequence_length, 2*sequence_length] + predict_attn_weights = torch.einsum("bnhtc,bnhsc->bnhts", (predict_query_states, predict_key_states)) + + # retrieve relative position embeddings for each layer -> see paper for more details + # [batch_size, ngram, number_heads, sequence_length, predict_relative_pos_embeddings] + predict_relative_pos_embeddings = self.get_predict_relative_pos_embeddings( + predict_hidden_states, predict_attn_weights, position_ids, predict_relative_position_buckets + ) + + # [batch_size, ngram, number_heads, sequence_length, 2*sequence_length] + predict_attn_weights = predict_attn_weights + predict_relative_pos_embeddings + + if extended_predict_attention_mask is not None: + # Permuting Predict attention mask to [batch_size, ngram, number_heads, sequence_length, 2*sequence_length] + extended_predict_attention_mask = extended_predict_attention_mask.permute(0, 2, 1, 3, 4) + extended_predict_attention_mask = extended_predict_attention_mask.to(predict_attn_weights.dtype) + predict_attn_weights = predict_attn_weights + extended_predict_attention_mask + + predict_attn_probs = softmax( + predict_attn_weights, + dim=-1, + onnx_trace=self.onnx_trace, + ).type_as(predict_attn_weights) + + if layer_head_mask is not None: + assert layer_head_mask.size() == (self.num_attn_heads,), ( + f"Head mask for a single layer should be of size {(self.num_attn_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + predict_attn_probs = layer_head_mask.view(1, 1, -1, 1, 1) * predict_attn_probs + + predict_attn_probs = nn.functional.dropout( + predict_attn_probs, p=self.attention_dropout, training=self.training + ) + # project to attention output + # [batch_size, ngram, number_heads, sequence_length, 2*sequence_length] + # x [batch_size, ngram, number_heads, 2*sequence_length, head_dimesion] + # -> [batch_size, ngram, number_heads, sequence_length, head_dimesion] + predict_attn_output = torch.einsum( + "bnhts,bnhsc->bnhtc", (predict_attn_probs, predict_value_states.transpose(1, 2)) + ) + + # reshape so that num_heads dim is merged into last `head_dim` axis + # [batch_size, ngram, number_heads, sequence_length, head_dimesion] -> [batch_size, ngram, sequence_length, hidden_size] + predict_attn_output = predict_attn_output.transpose(2, 3) + predict_attn_output = predict_attn_output.reshape(batch_size, self.ngram, sequence_length, hidden_size) + predict_attn_output = self.out_proj(predict_attn_output) + + # concat to single attn output + # [batch_size, (1+ngram)*sequence_length, hidden_size] + attn_output = torch.cat([main_attn_output, predict_attn_output], 1).view(batch_size, -1, hidden_size) + # reshape into better form for `config.output_attentions` + main_attn_probs = main_attn_probs.view(batch_size, self.num_attn_heads, sequence_length, -1) + + attn_output = nn.functional.dropout(attn_output, p=self.dropout, training=self.training) + + return attn_output, main_attn_probs, predict_attn_probs, past_key_value + + def get_main_relative_pos_embeddings( + self, hidden_states, attn_weights, position_ids, main_relative_position_buckets + ): + # input hidden_states [batch_size, sequence_length, hidden_size] + # input attn_weights [batch_size, num_heads, sequence_length, sequence_length] + # input position_ids [batch_size, sequence_length] or [1,1] + batch_size, num_attn_heads, tgt_len, src_len = attn_weights.shape + attn_weights = attn_weights.view(batch_size, num_attn_heads, tgt_len, src_len) + if main_relative_position_buckets is None: + batch_size, sequence_length = hidden_states.shape[:2] + relative_positions = ( + torch.arange(1, attn_weights.shape[-1] + 1) + .unsqueeze(0) + .unsqueeze(0) + .repeat(batch_size, sequence_length, 1) + .to(position_ids.device) + ) + # [batch_size, sequence_length, sequence_length+1] + relative_positions = relative_positions - position_ids.unsqueeze(0).repeat(batch_size, sequence_length, 1) + main_relative_position_buckets = compute_relative_buckets( + self.num_buckets, self.relative_max_distance, relative_positions, False + ) + + # [batch_size, sequence_length, num_buckets * num_heads] + rel_pos_embeddings = self.relative_pos_embeddings(hidden_states) + rel_pos_embeddings = rel_pos_embeddings.view( + rel_pos_embeddings.shape[:2] + (self.num_buckets, self.num_attn_heads) + ) + rel_pos_embeddings = rel_pos_embeddings.permute(0, 3, 1, 2) + # [batch_size, num_heads, sequence_length, num_buckets] + rel_pos_embeddings = rel_pos_embeddings.reshape(attn_weights.shape[:3] + (-1,)) + + main_relative_position_buckets = main_relative_position_buckets.repeat(1, self.num_attn_heads, 1) + # [batch_size * num_heads * sequence_length, sequence_length] + main_relative_position_buckets = main_relative_position_buckets.view( + -1, main_relative_position_buckets.shape[-1] + ) + main_relative_position_buckets = main_relative_position_buckets.long() + # [batch_size * num_heads * sequence_length, sequence_length] + rel_pos_embeddings = rel_pos_embeddings.reshape(-1, rel_pos_embeddings.size(-1)) + + main_relative_pos_embeddings = torch.gather(rel_pos_embeddings, dim=1, index=main_relative_position_buckets) + main_relative_pos_embeddings = main_relative_pos_embeddings.view(batch_size, num_attn_heads, tgt_len, -1) + return main_relative_pos_embeddings + + def get_predict_relative_pos_embeddings( + self, hidden_states, attn_weights, position_ids, predict_relative_position_buckets + ): + # input hidden_states [batch_size, sequence_length, ngram, hidden_size] + # input attn_weights [batch_size, ngram, num_heads, sequence_length, 2*sequence_length] + # input position_ids [batch_size, sequence_length] or [1,1] + # input predict_relative_position_buckets [batch_size, sequence_length, 2*sequence_length] or None + batch_size, sequence_length = hidden_states.shape[0:2] + + if predict_relative_position_buckets is None: + key_sequence_length = attn_weights.shape[-1] + assert position_ids[0][0] == key_sequence_length - 1, ( + "`position_ids` are incorrect. They should be of the format 1 2 3 4 5 ... (key_sequence_length - 1)" + ) + relative_positions = ( + torch.arange(0, key_sequence_length) + .unsqueeze(0) + .unsqueeze(0) + .repeat(batch_size, sequence_length, 1) + .to(position_ids.device) + ) + + relative_positions = relative_positions - position_ids.unsqueeze(0).repeat(batch_size, sequence_length, 1) + predict_relative_position_buckets = compute_relative_buckets( + self.num_buckets, self.relative_max_distance, relative_positions, False + ) + + # [batch_size, ngram, sequence_length, hidden_size] + hidden_states = hidden_states.transpose(1, 2) + rel_pos_embeddings = self.relative_pos_embeddings(hidden_states) + + # [batch_size, ngram, sequence_length, num_buckets, num_heads] + rel_pos_embeddings = rel_pos_embeddings.view( + hidden_states.shape[:-1] + (self.num_buckets, self.num_attn_heads) + ) + rel_pos_embeddings = rel_pos_embeddings.permute(0, 2, 1, 4, 3) + # [batch_size * ngram * sequence_length * num_heads, num_buckets] + rel_pos_embeddings = rel_pos_embeddings.reshape(-1, self.num_buckets) + # [ngram, batch_size, num_heads * sequence_length, -1] + predict_relative_position_buckets = predict_relative_position_buckets.unsqueeze(0) + predict_relative_position_buckets = predict_relative_position_buckets.repeat( + self.ngram, 1, self.num_attn_heads, 1 + ) + # [ngram * batch_size * num_heads * sequence_length, -1] + predict_relative_position_buckets = predict_relative_position_buckets.view( + -1, predict_relative_position_buckets.size(-1) + ).long() + + predict_relative_pos_embeddings = torch.gather( + rel_pos_embeddings, dim=1, index=predict_relative_position_buckets + ) + + # [batch_size, gram, num_heads, sequence_length, -1] + predict_relative_pos_embeddings = predict_relative_pos_embeddings.view( + batch_size, self.ngram, self.num_attn_heads, sequence_length, -1 + ) + + return predict_relative_pos_embeddings + + +class XLMProphetNetEncoderLayer(nn.Module): + """ + Encoder block for XLMProphetnet + """ + + def __init__(self, config: XLMProphetNetConfig): + super().__init__() + # 1st residual block + self.self_attn = XLMProphetNetAttention(config, config.num_encoder_attention_heads) + self.self_attn_layer_norm = LayerNorm(config.hidden_size) + + # 2nd residual block + self.feed_forward = XLMProphetNetFeedForward(config, config.encoder_ffn_dim) + self.feed_forward_layer_norm = LayerNorm(config.hidden_size) + + def forward( + self, + hidden_states, + attention_mask, + layer_head_mask, + output_attentions: bool = False, + ): + # 1st residual block + attention_output, attn_weights, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = self.self_attn_layer_norm(attention_output + hidden_states) + + # 2nd residual block + feed_forward_output = self.feed_forward(hidden_states) + hidden_states = self.feed_forward_layer_norm(feed_forward_output + hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class XLMProphetNetDecoderLayer(nn.Module): + """ + Decoder block for XLMProphetnet + """ + + def __init__(self, config: XLMProphetNetConfig): + super().__init__() + # 1st residual block + self.self_attn = XLMProphetNetNgramSelfAttention(config) + self.self_attn_layer_norm = LayerNorm(config.hidden_size) + + # 2nd residual block + if config.add_cross_attention: + self.cross_attn = XLMProphetNetAttention(config, config.num_decoder_attention_heads) + self.cross_attn_layer_norm = LayerNorm(config.hidden_size) + + # 3rd residual block + self.feed_forward = XLMProphetNetFeedForward(config, config.decoder_ffn_dim) + self.feed_forward_layer_norm = LayerNorm(config.hidden_size) + + def forward( + self, + hidden_states, + attention_mask=None, + encoder_hidden_states=None, + encoder_attn_mask=None, + layer_head_mask=None, + cross_attn_layer_head_mask=None, + extended_predict_attention_mask=None, + main_relative_position_buckets=None, + predict_relative_position_buckets=None, + position_ids=None, + past_key_value=None, + use_cache: bool = True, + output_attentions: bool = False, + ): + # 1st residual block + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + ngram_attention_output, self_attn_weights, self_attn_weights_ngram, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=self_attn_past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + extended_predict_attention_mask=extended_predict_attention_mask, + main_relative_position_buckets=main_relative_position_buckets, + predict_relative_position_buckets=predict_relative_position_buckets, + position_ids=position_ids, + ) + hidden_states = self.self_attn_layer_norm(hidden_states + ngram_attention_output) + + # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + cross_attn_weights = None + if encoder_hidden_states is not None: + # 2nd residual block + attention_output, cross_attn_weights, cross_attn_present_key_value = self.cross_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attn_mask, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=cross_attn_past_key_value, + output_attentions=output_attentions, + ) + hidden_states = self.cross_attn_layer_norm(attention_output + hidden_states) + + # add cross-attn to positions 3,4 of present_key_value tuple + present_key_value = present_key_value + cross_attn_present_key_value + + # 3rd residual block + feed_forward_output = self.feed_forward(hidden_states) + hidden_states = self.feed_forward_layer_norm(feed_forward_output + hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, self_attn_weights_ngram, cross_attn_weights) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +@add_start_docstrings( + "The standalone encoder part of the XLMProphetNetModel.", + XLM_PROPHETNET_START_DOCSTRING, +) +class XLMProphetNetEncoder(XLMProphetNetPreTrainedModel): + r""" + word_embeddings (`torch.nn.Embeddings` of shape `(config.vocab_size, config.hidden_size)`, *optional*): + The word embedding parameters. This can be used to initialize [`XLMProphetNetEncoder`] with pre-defined word + embeddings instead of randomly initialized word embeddings. + """ + + def __init__(self, config: XLMProphetNetConfig, word_embeddings: nn.Embedding = None): + super().__init__(config) + + self.word_embeddings = ( + word_embeddings + if word_embeddings is not None + else nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + ) + self.position_embeddings = XLMProphetNetPositionalEmbeddings(config) + self.embeddings_layer_norm = LayerNorm(config.hidden_size) + + self.layers = nn.ModuleList([XLMProphetNetEncoderLayer(config) for _ in range(config.num_encoder_layers)]) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.word_embeddings + + def set_input_embeddings(self, value): + self.word_embeddings = value + + @add_start_docstrings_to_model_forward(XLM_PROPHETNET_STANDALONE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, XLMProphetNetEncoder + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("patrickvonplaten/xprophetnet-large-uncased-standalone") + >>> model = XLMProphetNetEncoder.from_pretrained("patrickvonplaten/prophetnet-large-uncased-standalone") + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> last_hidden_states = outputs.last_hidden_state + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is None and inputs_embeds is None: + raise ValueError("Either input_ids or inputs_embeds has to be passed.") + elif input_ids is not None and inputs_embeds is not None: + raise ValueError("Make sure to only pass input_ids or inputs_embeds.") + elif input_ids is not None and inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + # prepare attention mask + if attention_mask is not None: + extended_attention_mask = ( + 1.0 - attention_mask[:, None, None, :].repeat(1, self.config.num_encoder_attention_heads, 1, 1) + ) * torch.finfo(self.dtype).min + extended_attention_mask = extended_attention_mask.to(inputs_embeds.dtype) + else: + extended_attention_mask = None + + position_embeddings, position_ids = self.position_embeddings(inputs_embeds.shape[:2], inputs_embeds.device) + + hidden_states = inputs_embeds + position_embeddings + hidden_states = self.embeddings_layer_norm(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.config.dropout, training=self.training) + + encoder_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + assert head_mask.size()[0] == (len(self.layers)), ( + f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." + ) + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_hidden_states = encoder_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + extended_attention_mask, + (head_mask[idx] if head_mask is not None else None), + output_attentions, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask=extended_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_hidden_states = encoder_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_hidden_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_hidden_states, attentions=all_attentions + ) + + +@add_start_docstrings( + "The standalone decoder part of the XLMProphetNetModel.", + XLM_PROPHETNET_START_DOCSTRING, +) +class XLMProphetNetDecoder(XLMProphetNetPreTrainedModel): + r""" + word_embeddings (`torch.nn.Embeddings` of shape `(config.vocab_size, config.hidden_size)`, *optional*): + The word embedding parameters. This can be used to initialize [`XLMProphetNetEncoder`] with pre-defined word + embeddings instead of randomly initialized word embeddings. + """ + + def __init__(self, config: XLMProphetNetConfig, word_embeddings: Optional[nn.Embedding] = None): + super().__init__(config) + + self.ngram = config.ngram + self.num_buckets = config.num_buckets + self.relative_max_distance = config.relative_max_distance + self.dropout = config.dropout + self.max_target_positions = config.max_position_embeddings + + self.word_embeddings = ( + word_embeddings + if word_embeddings is not None + else nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + ) + self.position_embeddings = XLMProphetNetPositionalEmbeddings(config) + + self.ngram_embeddings = nn.Embedding(self.ngram, config.hidden_size, None) + self.layers = nn.ModuleList([XLMProphetNetDecoderLayer(config) for _ in range(config.num_decoder_layers)]) + self.embeddings_layer_norm = LayerNorm(config.hidden_size) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.word_embeddings + + def set_input_embeddings(self, value): + self.word_embeddings = value + + @add_start_docstrings_to_model_forward(XLM_PROPHETNET_STANDALONE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=XLMProphetNetDecoderModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, XLMProphetNetDecoderModelOutput]: + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, XLMProphetNetDecoder + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("patrickvonplaten/xprophetnet-large-uncased-standalone") + >>> model = XLMProphetNetDecoder.from_pretrained("patrickvonplaten/xprophetnet-large-uncased-standalone", add_cross_attention=False) + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> last_hidden_states = outputs.last_hidden_state + ```""" + use_cache = use_cache if use_cache is not None else self.config.use_cache + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is None and inputs_embeds is None: + raise ValueError("Either `decoder_input_ids` or `decoder_inputs_embeds` has to be passed.") + elif input_ids is not None and inputs_embeds is not None: + raise ValueError("Make sure to only pass `decoder_input_ids` or `decoder_inputs_embeds`.") + elif input_ids is not None and inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + batch_size, sequence_length = inputs_embeds.shape[:2] + + main_stream_pos_embed, position_ids = self.position_embeddings( + (batch_size, sequence_length), + device=inputs_embeds.device, + past_key_values=past_key_values, + ) + + if past_key_values is not None: + main_relative_position_buckets, predict_relative_position_buckets = None, None + else: + ( + main_relative_position_buckets, + predict_relative_position_buckets, + ) = self.compute_buffered_relative_buckets(position_ids) + predicting_stream_pos_embed = self.position_embeddings._forward(position_ids + 1) + + # add position embeddings + hidden_states = inputs_embeds + main_stream_pos_embed + + ngram_embeddings = self.ngram_embeddings.weight + + # prepare attention mask + if past_key_values is not None: + assert hidden_states.size(1) == 1, ( + "At the moment `use_cache` is only supported for `decoder_input_ids` of length 1" + ) + + ngram_hidden_states = [ + (ngram_embeddings[ngram - 1] + predicting_stream_pos_embed).repeat(batch_size, 1, 1) + for ngram in range(self.ngram) + ] + extended_attention_mask = None + extended_predict_attention_mask = None + else: + ngram_hidden_states = [ + (ngram_embeddings[ngram - 1] + predicting_stream_pos_embed) for ngram in range(self.ngram) + ] + extended_attention_mask = self.prepare_attention_mask(hidden_states, attention_mask) + extended_predict_attention_mask = self.prepare_predict_attention_mask(hidden_states, attention_mask) + + # prepare encoder attention mask + if encoder_attention_mask is not None: + extended_encoder_attention_mask = ( + 1.0 - encoder_attention_mask[:, None, None, :].repeat(1, self.config.num_decoder_attention_heads, 1, 1) + ) * torch.finfo(self.dtype).min + extended_encoder_attention_mask = extended_encoder_attention_mask.to(inputs_embeds.dtype) + else: + extended_encoder_attention_mask = None + + hidden_states = torch.cat([hidden_states] + ngram_hidden_states, 1) + + if self.embeddings_layer_norm: + hidden_states = self.embeddings_layer_norm(hidden_states) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + # init attentions, hidden_states and cache with empty tuples + all_main_stream_hidden_states = () if output_hidden_states else None + all_ngram_stream_hidden_states = () if output_hidden_states and self.config.ngram > 0 else None + + all_main_stream_attns = () if output_attentions else None + all_ngram_stream_attns = () if output_attentions else None + all_cross_attns = () if output_attentions and self.config.add_cross_attention else None + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + present_key_values = () if use_cache else None + + # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): + if attn_mask is not None: + assert attn_mask.size()[0] == (len(self.layers)), ( + f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + # grad cannot be kept because tensor is sliced + all_main_stream_hidden_states += (hidden_states[:, :sequence_length],) + if self.config.ngram > 0: + all_ngram_stream_hidden_states += (hidden_states[:, sequence_length:],) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + extended_attention_mask, + encoder_hidden_states, + extended_encoder_attention_mask, + (head_mask[idx] if head_mask is not None else None), + (cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), + extended_predict_attention_mask, + main_relative_position_buckets, + predict_relative_position_buckets, + position_ids, + None, + use_cache, + output_attentions, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=extended_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attn_mask=extended_encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=( + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None + ), + extended_predict_attention_mask=extended_predict_attention_mask, + main_relative_position_buckets=main_relative_position_buckets, + predict_relative_position_buckets=predict_relative_position_buckets, + position_ids=position_ids, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + present_key_values += (layer_outputs[4 if output_attentions else 1],) + + if output_attentions: + all_main_stream_attns += (layer_outputs[1],) + all_ngram_stream_attns += (layer_outputs[2],) + + if self.config.add_cross_attention: + all_cross_attns += (layer_outputs[3],) + + if output_hidden_states: + all_main_stream_hidden_states += (hidden_states[:, :sequence_length],) + if self.config.ngram > 0: + all_ngram_stream_hidden_states += (hidden_states[:, sequence_length:],) + + # split last_hidden_state for return + last_hidden_state = hidden_states[:, :sequence_length] + last_hidden_state_ngram = hidden_states[:, sequence_length:] if self.config.ngram > 0 else None + + if not return_dict: + return tuple( + v + for v in [ + last_hidden_state, + last_hidden_state_ngram, + present_key_values, + all_main_stream_hidden_states, + all_ngram_stream_hidden_states, + all_main_stream_attns, + all_ngram_stream_attns, + all_cross_attns, + ] + if v is not None + ) + return XLMProphetNetDecoderModelOutput( + last_hidden_state=last_hidden_state, + last_hidden_state_ngram=last_hidden_state_ngram, + past_key_values=present_key_values, + hidden_states=all_main_stream_hidden_states, + hidden_states_ngram=all_ngram_stream_hidden_states, + attentions=all_main_stream_attns, + ngram_attentions=all_ngram_stream_attns, + cross_attentions=all_cross_attns, + ) + + def compute_buffered_relative_buckets(self, position_ids): + batch_size, sequence_length = position_ids.shape + + position_ids = torch.arange(1, self.max_target_positions).to(position_ids.device).repeat(1, 1) + main_relative_buckets, predict_relative_buckets = compute_all_stream_relative_buckets( + self.num_buckets, self.relative_max_distance, position_ids + ) + + # buffer relative buckets + main_relative_buckets = main_relative_buckets[:, :sequence_length, :sequence_length].repeat(batch_size, 1, 1) + predict_relative_buckets = torch.cat( + [ + predict_relative_buckets[:, :sequence_length, :sequence_length], + predict_relative_buckets[ + :, :sequence_length, self.max_target_positions : self.max_target_positions + sequence_length + ], + ], + 2, + ).repeat(batch_size, 1, 1) + + return main_relative_buckets, predict_relative_buckets + + def prepare_attention_mask(self, hidden_states, attention_mask): + batch_size, seq_length = hidden_states.shape[:2] + + # get causal mask + causal_mask = torch.full( + (seq_length, seq_length), + torch.finfo(hidden_states.dtype).min, + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + causal_mask = torch.triu(causal_mask, 1) + + extended_causal_mask = causal_mask[:seq_length, :seq_length][None, None, :, :].expand( + (batch_size, self.config.num_decoder_attention_heads) + causal_mask.shape + ) + + # add usual attention mask + if attention_mask is not None: + extended_attention_mask = (1.0 - attention_mask[:, None, None, :]) * torch.finfo(self.dtype).min + extended_attention_mask = extended_causal_mask + extended_attention_mask + else: + extended_attention_mask = extended_causal_mask + return extended_attention_mask.to(hidden_states.dtype) + + def prepare_predict_attention_mask(self, hidden_states, attention_mask): + batch_size, seq_length = hidden_states.shape[:2] + + # get causal mask + predict_causal_mask = ngram_attention_bias( + self.max_target_positions, self.ngram, hidden_states.device, hidden_states.dtype + ) + predict_causal_mask = torch.cat( + [ + predict_causal_mask[:, :seq_length, :seq_length], + predict_causal_mask[ + :, :seq_length, self.max_target_positions : self.max_target_positions + seq_length + ], + ], + dim=-1, + ) + extended_predict_causal_mask = predict_causal_mask[None, None, :, :, :].expand( + (batch_size, self.config.num_decoder_attention_heads) + predict_causal_mask.shape + ) + + # add usual attention mask + if attention_mask is not None: + extended_attention_mask = (1.0 - attention_mask[:, None, None, None, :]) * torch.finfo(self.dtype).min + extended_attention_mask = extended_attention_mask.expand( + (batch_size, self.config.num_decoder_attention_heads, self.ngram, seq_length, seq_length) + ) + # predicted stream attention_mask should always be 0 + extended_attention_mask = torch.cat( + [extended_attention_mask, torch.zeros_like(extended_attention_mask)], dim=-1 + ) + extended_predict_attention_mask = extended_predict_causal_mask + extended_attention_mask + else: + extended_predict_attention_mask = extended_predict_causal_mask + return extended_predict_attention_mask.to(hidden_states.dtype) + + +@add_start_docstrings( + "The bare XLMProphetNet Model outputting raw hidden-states without any specific head on top.", + XLM_PROPHETNET_START_DOCSTRING, +) +class XLMProphetNetModel(XLMProphetNetPreTrainedModel): + _tied_weights_keys = ["encoder.word_embeddings.weight", "decoder.word_embeddings.weight"] + + def __init__(self, config: XLMProphetNetConfig): + super().__init__(config) + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + + encoder_config = copy.deepcopy(config) + encoder_config.is_encoder_decoder = False + encoder_config.use_cache = False + self.encoder = XLMProphetNetEncoder(encoder_config, self.word_embeddings) + + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.is_encoder_decoder = False + self.decoder = XLMProphetNetDecoder(decoder_config, self.word_embeddings) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.word_embeddings + + def set_input_embeddings(self, value): + self.word_embeddings = value + self.encoder.word_embeddings = self.word_embeddings + self.decoder.word_embeddings = self.word_embeddings + + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.word_embeddings, self.word_embeddings) + self._tie_or_clone_weights(self.decoder.word_embeddings, self.word_embeddings) + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + @add_start_docstrings_to_model_forward(XLM_PROPHETNET_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=XLMProphetNetSeq2SeqModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.Tensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + decoder_inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, XLMProphetNetSeq2SeqModelOutput]: + r""" + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, XLMProphetNetModel + + >>> tokenizer = AutoTokenizer.from_pretrained("patrickvonplaten/xprophetnet-large-uncased-standalone") + >>> model = XLMProphetNetModel.from_pretrained("patrickvonplaten/xprophetnet-large-uncased-standalone") + + >>> input_ids = tokenizer( + ... "Studies have been shown that owning a dog is good for you", return_tensors="pt" + ... ).input_ids # Batch size 1 + >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1 + >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) + + >>> last_hidden_states = outputs.last_hidden_state # main stream hidden states + >>> last_hidden_states_ngram = outputs.last_hidden_state_ngram # predict hidden states + ```""" + use_cache = use_cache if use_cache is not None else self.config.use_cache + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + use_cache=use_cache, + return_dict=return_dict, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + return XLMProphetNetSeq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + last_hidden_state_ngram=decoder_outputs.last_hidden_state_ngram, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_ngram_hidden_states=decoder_outputs.hidden_states_ngram, + decoder_attentions=decoder_outputs.attentions, + decoder_ngram_attentions=decoder_outputs.ngram_attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + "The XLMProphetNet Model with a language modeling head. Can be used for sequence generation tasks.", + XLM_PROPHETNET_START_DOCSTRING, +) +class XLMProphetNetForConditionalGeneration(XLMProphetNetPreTrainedModel): + _tied_weights_keys = ["encoder.word_embeddings.weight", "decoder.word_embeddings.weight", "lm_head.weight"] + + def __init__(self, config: XLMProphetNetConfig): + super().__init__(config) + self.prophetnet = XLMProphetNetModel(config) + self.padding_idx = config.pad_token_id + self.disable_ngram_loss = config.disable_ngram_loss + + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.prophetnet.word_embeddings, self.lm_head) + + def get_input_embeddings(self): + return self.prophetnet.word_embeddings + + @add_start_docstrings_to_model_forward(XLM_PROPHETNET_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=XLMProphetNetSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.Tensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + decoder_inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, XLMProphetNetSeq2SeqLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ..., + config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for + labels in `[0, ..., config.vocab_size]` + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, XLMProphetNetForConditionalGeneration + + >>> tokenizer = AutoTokenizer.from_pretrained("patrickvonplaten/xprophetnet-large-uncased-standalone") + >>> model = XLMProphetNetForConditionalGeneration.from_pretrained("patrickvonplaten/xprophetnet-large-uncased-standalone") + + >>> input_ids = tokenizer( + ... "Studies have been shown that owning a dog is good for you", return_tensors="pt" + ... ).input_ids # Batch size 1 + >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1 + >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) + + >>> logits_next_token = outputs.logits # logits to predict next token as usual + >>> logits_ngram_next_tokens = outputs.logits_ngram # logits to predict 2nd, 3rd, ... next tokens + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: + # get decoder inputs from shifting lm labels to the right + decoder_input_ids = self._shift_right(labels) + + outputs = self.prophetnet( + input_ids=input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + encoder_outputs=encoder_outputs, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + batch_size, sequence_length = ( + decoder_input_ids.shape if decoder_input_ids is not None else decoder_inputs_embeds.shape[:2] + ) + + predicting_streams = outputs[1].view(batch_size, self.config.ngram, sequence_length, -1) + predict_logits = self.lm_head(predicting_streams) + + logits = predict_logits[:, 0] + logits_ngram = predict_logits[:, 1:] if self.config.ngram > 1 else None + + # To use .view in loss computation, make sure that logits is contiguous. + if not logits.is_contiguous(): + logits = logits.contiguous() + + loss = None + if labels is not None: + loss = self._compute_loss(predict_logits, labels) + + if not return_dict: + all_logits = tuple(v for v in [logits, logits_ngram] if v is not None) + return (loss,) + all_logits + outputs[2:] if loss is not None else all_logits + outputs[2:] + else: + return XLMProphetNetSeq2SeqLMOutput( + loss=loss, + logits=logits, + logits_ngram=logits_ngram, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_ngram_hidden_states=outputs.decoder_ngram_hidden_states, + decoder_attentions=outputs.decoder_attentions, + decoder_ngram_attentions=outputs.decoder_ngram_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + def _compute_loss(self, logits, labels, ignore_index=-100): + expend_targets = labels.new_zeros(self.config.ngram, labels.size(0), labels.size(1)).fill_(ignore_index) + + for i in range(self.config.ngram): + if i > 0 and self.disable_ngram_loss: + break + expend_targets[i, :, :] = labels + + logits = logits.transpose(0, 1).contiguous() + lprobs = nn.functional.log_softmax( + logits.view(-1, logits.size(-1)), + dim=-1, + dtype=torch.float32, + ) + + loss = nn.functional.nll_loss(lprobs, expend_targets.view(-1), reduction="mean") + + if self.config.eps > 0.0: + smooth_loss = -lprobs.sum(dim=-1, keepdim=True) + non_masked_tokens = expend_targets.ne(ignore_index).view(-1) + smooth_loss = smooth_loss[non_masked_tokens] + smooth_loss = smooth_loss.mean() + + eps_i = self.config.eps / lprobs.size(-1) + loss = (1.0 - self.config.eps) * loss + eps_i * smooth_loss + + return loss + + def prepare_inputs_for_generation( + self, + decoder_input_ids, + past_key_values=None, + attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs, + ): + assert encoder_outputs is not None, "`encoder_outputs` have to be passed for generation." + + if past_key_values: + decoder_input_ids = decoder_input_ids[:, -1:] + # first step, decoder_cached_states are empty + return { + "input_ids": None, # encoder_outputs is defined. input_ids not needed + "encoder_outputs": encoder_outputs, + "past_key_values": past_key_values, + "decoder_input_ids": decoder_input_ids, + "attention_mask": attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "use_cache": use_cache, + } + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return self._shift_right(labels) + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + # cached cross_attention states don't have to be reordered -> they are always the same + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) + + layer_past[2:], + ) + return reordered_past + + def get_encoder(self): + return self.prophetnet.encoder + + def get_decoder(self): + return self.prophetnet.decoder + + +@add_start_docstrings( + "The standalone decoder part of the XLMProphetNetModel with a lm head on top. The model can be used for causal" + " language modeling.", + XLM_PROPHETNET_START_DOCSTRING, +) +class XLMProphetNetForCausalLM(XLMProphetNetPreTrainedModel): + _tied_weights_keys = [ + "prophetnet.word_embeddings.weight", + "prophetnet.decoder.word_embeddings.weight", + "lm_head.weight", + ] + + def __init__(self, config: XLMProphetNetConfig): + # set config for CLM + config = copy.deepcopy(config) + config.is_decoder = True + config.is_encoder_decoder = False + super().__init__(config) + self.prophetnet = XLMProphetNetDecoderWrapper(config) + + self.padding_idx = config.pad_token_id + self.disable_ngram_loss = config.disable_ngram_loss + + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.prophetnet.decoder.word_embeddings + + def set_input_embeddings(self, value): + self.prophetnet.decoder.word_embeddings = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.prophetnet.decoder.word_embeddings, self.lm_head) + + def set_decoder(self, decoder): + self.prophetnet.decoder = decoder + + def get_decoder(self): + return self.prophetnet.decoder + + @add_start_docstrings_to_model_forward(XLM_PROPHETNET_STANDALONE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=XLMProphetNetDecoderLMOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, XLMProphetNetDecoderLMOutput]: + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are + ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]` + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, XLMProphetNetForCausalLM + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("patrickvonplaten/xprophetnet-large-uncased-standalone") + >>> model = XLMProphetNetForCausalLM.from_pretrained("patrickvonplaten/xprophetnet-large-uncased-standalone") + >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder." + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> logits = outputs.logits + + >>> # Model can also be used with EncoderDecoder framework + >>> from transformers import BertTokenizer, EncoderDecoderModel, AutoTokenizer + >>> import torch + + >>> tokenizer_enc = BertTokenizer.from_pretrained("google-bert/bert-large-uncased") + >>> tokenizer_dec = AutoTokenizer.from_pretrained("patrickvonplaten/xprophetnet-large-uncased-standalone") + >>> model = EncoderDecoderModel.from_encoder_decoder_pretrained( + ... "google-bert/bert-large-uncased", "patrickvonplaten/xprophetnet-large-uncased-standalone" + ... ) + + >>> ARTICLE = ( + ... "the us state department said wednesday it had received no " + ... "formal word from bolivia that it was expelling the us ambassador there " + ... "but said the charges made against him are `` baseless ." + ... ) + >>> input_ids = tokenizer_enc(ARTICLE, return_tensors="pt").input_ids + >>> labels = tokenizer_dec( + ... "us rejects charges against its ambassador in bolivia", return_tensors="pt" + ... ).input_ids + >>> outputs = model(input_ids=input_ids, decoder_input_ids=labels[:, :-1], labels=labels[:, 1:]) + + >>> loss = outputs.loss + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn) + outputs = self.prophetnet.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + batch_size, sequence_length = input_ids.shape if input_ids is not None else inputs_embeds.shape[:2] + + predicting_streams = outputs[1].view(batch_size, self.config.ngram, sequence_length, -1) + predict_logits = self.lm_head(predicting_streams) + + logits = predict_logits[:, 0] + logits_ngram = predict_logits[:, 1:] if self.config.ngram > 1 else None + + loss = None + if labels is not None: + loss = self._compute_loss(predict_logits, labels) + + if not return_dict: + all_logits = tuple(v for v in [logits, logits_ngram] if v is not None) + return (loss,) + all_logits + outputs[2:] if loss is not None else all_logits + outputs[2:] + else: + return XLMProphetNetDecoderLMOutput( + loss=loss, + logits=logits, + logits_ngram=logits_ngram, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + hidden_states_ngram=outputs.hidden_states_ngram, + attentions=outputs.attentions, + ngram_attentions=outputs.ngram_attentions, + cross_attentions=outputs.cross_attentions, + ) + + def _compute_loss(self, logits, labels, ignore_index=-100): + expend_targets = labels.new_zeros(self.config.ngram, labels.size(0), labels.size(1)).fill_(ignore_index) + + for i in range(self.config.ngram): + if i > 0 and self.disable_ngram_loss: + break + expend_targets[i, :, :] = labels + + logits = logits.transpose(0, 1).contiguous() + lprobs = nn.functional.log_softmax( + logits.view(-1, logits.size(-1)), + dim=-1, + dtype=torch.float32, + ) + + loss = nn.functional.nll_loss(lprobs, expend_targets.view(-1), reduction="mean") + + if self.config.eps > 0.0: + smooth_loss = -lprobs.sum(dim=-1, keepdim=True) + non_masked_tokens = expend_targets.ne(ignore_index).view(-1) + smooth_loss = smooth_loss[non_masked_tokens] + smooth_loss = smooth_loss.mean() + + eps_i = self.config.eps / lprobs.size(-1) + loss = (1.0 - self.config.eps) * loss + eps_i * smooth_loss + + return loss + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + head_mask=None, + use_cache=None, + **kwargs, + ): + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_ids.shape) + + if past_key_values: + input_ids = input_ids[:, -1:] + # first step, decoder_cached_states are empty + return { + "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed + "attention_mask": attention_mask, + "head_mask": head_mask, + "past_key_values": past_key_values, + "use_cache": use_cache, + } + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + + +class XLMProphetNetDecoderWrapper(XLMProphetNetPreTrainedModel): + """ + This is a wrapper class, so that [`XLMProphetNetForCausalLM`] can correctly be loaded from pretrained XLMProphetNet + classes. + """ + + def __init__(self, config: XLMProphetNetConfig): + super().__init__(config) + + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.decoder = XLMProphetNetDecoder(config, word_embeddings=self.word_embeddings) + + # Initialize weights and apply final processing + self.post_init() + + def _tie_weights(self): + self._tie_or_clone_weights(self.word_embeddings, self.decoder.get_input_embeddings()) + + def forward(self, *args, **kwargs): + return self.decoder(*args, **kwargs) + + +__all__ = [ + "XLMProphetNetDecoder", + "XLMProphetNetEncoder", + "XLMProphetNetForCausalLM", + "XLMProphetNetForConditionalGeneration", + "XLMProphetNetModel", + "XLMProphetNetPreTrainedModel", +] diff --git a/docs/transformers/build/lib/transformers/models/deprecated/xlm_prophetnet/tokenization_xlm_prophetnet.py b/docs/transformers/build/lib/transformers/models/deprecated/xlm_prophetnet/tokenization_xlm_prophetnet.py new file mode 100644 index 0000000000000000000000000000000000000000..1a5da12859f29e9e3cabe354c63d25c584edbd49 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/deprecated/xlm_prophetnet/tokenization_xlm_prophetnet.py @@ -0,0 +1,326 @@ +# coding=utf-8 +# Copyright 2020 The Microsoft Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +import os +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple + +from ....tokenization_utils import PreTrainedTokenizer +from ....utils import logging + + +logger = logging.get_logger(__name__) + +SPIECE_UNDERLINE = "▁" + +VOCAB_FILES_NAMES = {"vocab_file": "prophetnet.tokenizer"} + + +def load_vocab(vocab_file): + """Loads a vocabulary file into a dictionary.""" + vocab = collections.OrderedDict() + with open(vocab_file, "r", encoding="utf-8") as reader: + tokens = reader.readlines() + for index, token in enumerate(tokens): + token = token.rstrip("\n") + vocab[token] = index + return vocab + + +class XLMProphetNetTokenizer(PreTrainedTokenizer): + """ + Adapted from [`RobertaTokenizer`] and [`XLNetTokenizer`]. Based on + [SentencePiece](https://github.com/google/sentencepiece). + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + bos_token (`str`, *optional*, defaults to `"[SEP]"`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + + + + When building a sequence using special tokens, this is not the token that is used for the beginning of + sequence. The token used is the `cls_token`. + + + + eos_token (`str`, *optional*, defaults to `"[SEP]"`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + sep_token (`str`, *optional*, defaults to `"[SEP]"`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + unk_token (`str`, *optional*, defaults to `"[UNK]"`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `"[PAD]"`): + The token used for padding, for example when batching sequences of different lengths. + cls_token (`str`, *optional*, defaults to `"[CLS]"`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + mask_token (`str`, *optional*, defaults to `"[MASK]"`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + sp_model_kwargs (`dict`, *optional*): + Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for + SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things, + to set: + + - `enable_sampling`: Enable subword regularization. + - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout. + + - `nbest_size = {0,1}`: No sampling is performed. + - `nbest_size > 1`: samples from the nbest_size results. + - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice) + using forward-filtering-and-backward-sampling algorithm. + + - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for + BPE-dropout. + + Attributes: + sp_model (`SentencePieceProcessor`): + The *SentencePiece* processor that is used for every conversion (string, tokens and IDs). + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + bos_token="[SEP]", + eos_token="[SEP]", + sep_token="[SEP]", + unk_token="[UNK]", + pad_token="[PAD]", + cls_token="[CLS]", + mask_token="[MASK]", + sp_model_kwargs: Optional[Dict[str, Any]] = None, + **kwargs, + ) -> None: + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + + try: + import sentencepiece as spm + except ImportError: + logger.warning( + "You need to install SentencePiece to use XLMRobertaTokenizer: https://github.com/google/sentencepiece" + " pip install sentencepiece" + ) + raise + + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(str(vocab_file)) + self.vocab_file = vocab_file + + # Original fairseq vocab and spm vocab must be "aligned": + # Vocab | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 + # -------- | ------- | ------- | ------ | ------- | --- | --- | --- | ----- | ----- | ---- + # fairseq | '' | '' | '' | '' | ',' | '.' | '▁' | 's' | '▁de' | '-' + # spm | '' | '' | '' | ',' | '.' | '▁' | 's' | '▁de' | '-' | '▁a' + + # put special tokens and [unused] tokens into the vocab + self.fairseq_tokens_to_ids = {"[PAD]": 0, "[CLS]": 1, "[SEP]": 2, "[UNK]": 3, "[MASK]": 4} + + for i in range(10): + tok = f"[unused{i}]" + self.fairseq_tokens_to_ids[tok] = 5 + i + + # The first "real" token "," has position 15 in the embedding vocab and position 3 in the spm vocab + self.fairseq_offset = 12 + self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()} + + # TODO ArthurZ fairseq_ids_to_tokens should be removed + + super().__init__( + bos_token=bos_token, + eos_token=eos_token, + sep_token=sep_token, + unk_token=unk_token, + pad_token=pad_token, + cls_token=cls_token, + mask_token=mask_token, + sp_model_kwargs=self.sp_model_kwargs, + **kwargs, + ) + + @property + def can_save_slow_tokenizer(self) -> bool: + return os.path.isfile(self.vocab_file) if self.vocab_file else False + + def __getstate__(self): + state = self.__dict__.copy() + state["sp_model"] = None + return state + + def __setstate__(self, d): + self.__dict__ = d + try: + import sentencepiece as spm + except ImportError: + logger.warning( + "You need to install SentencePiece to use XLMRobertaTokenizer: https://github.com/google/sentencepiece" + " pip install sentencepiece" + ) + raise + + # for backward compatibility + if not hasattr(self, "sp_model_kwargs"): + self.sp_model_kwargs = {} + + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(self.vocab_file) + + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + if token_ids_1 is None: + return ([0] * len(token_ids_0)) + [1] + return ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. XLMProphetNet + does not make use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of zeros. + + """ + + sep = [self.sep_token_id] + + if token_ids_1 is None: + return len(token_ids_0 + sep) * [0] + return len(token_ids_0 + sep + sep + token_ids_1 + sep) * [0] + + @property + def vocab_size(self): + return len(self.sp_model) + self.fairseq_offset + + def get_vocab(self): + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + def _tokenize(self, text: str) -> str: + return self.sp_model.encode(text, out_type=str) + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + if token in self.fairseq_tokens_to_ids: + return self.fairseq_tokens_to_ids[token] + spm_id = self.sp_model.PieceToId(token) + + # Need to return unknown token if the SP model returned 0 + return spm_id + self.fairseq_offset if spm_id else self.unk_token_id + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + if index in self.fairseq_ids_to_tokens: + return self.fairseq_ids_to_tokens[index] + return self.sp_model.IdToPiece(index - self.fairseq_offset) + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (strings for sub-words) in a single string.""" + out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip() + return out_string + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): + copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + + return (out_vocab_file,) + + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A XLMProphetNet sequence has the following format: + + - single sequence: `X [SEP]` + - pair of sequences: `A [SEP] B [SEP]` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: list of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + + if token_ids_1 is None: + return token_ids_0 + [self.sep_token_id] + sep = [self.sep_token_id] + return token_ids_0 + sep + token_ids_1 + sep + + +__all__ = ["XLMProphetNetTokenizer"] diff --git a/docs/transformers/build/lib/transformers/models/depth_anything/__init__.py b/docs/transformers/build/lib/transformers/models/depth_anything/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7425e37e0399c792155a88488045176fb3b5e7a5 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/depth_anything/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_depth_anything import * + from .modeling_depth_anything import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/build/lib/transformers/models/depth_anything/configuration_depth_anything.py b/docs/transformers/build/lib/transformers/models/depth_anything/configuration_depth_anything.py new file mode 100644 index 0000000000000000000000000000000000000000..3bbe621a44310c2b217fe00fd2fba653de3489d9 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/depth_anything/configuration_depth_anything.py @@ -0,0 +1,168 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""DepthAnything model configuration""" + +import copy + +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ...utils.backbone_utils import verify_backbone_config_arguments +from ..auto.configuration_auto import CONFIG_MAPPING + + +logger = logging.get_logger(__name__) + + +class DepthAnythingConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`DepthAnythingModel`]. It is used to instantiate a DepthAnything + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the DepthAnything + [LiheYoung/depth-anything-small-hf](https://huggingface.co/LiheYoung/depth-anything-small-hf) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + backbone_config (`Union[Dict[str, Any], PretrainedConfig]`, *optional*): + The configuration of the backbone model. Only used in case `is_hybrid` is `True` or in case you want to + leverage the [`AutoBackbone`] API. + backbone (`str`, *optional*): + Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this + will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone` + is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights. + use_pretrained_backbone (`bool`, *optional*, defaults to `False`): + Whether to use pretrained weights for the backbone. + use_timm_backbone (`bool`, *optional*, defaults to `False`): + Whether or not to use the `timm` library for the backbone. If set to `False`, will use the [`AutoBackbone`] + API. + backbone_kwargs (`dict`, *optional*): + Keyword arguments to be passed to AutoBackbone when loading from a checkpoint + e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set. + patch_size (`int`, *optional*, defaults to 14): + The size of the patches to extract from the backbone features. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + reassemble_hidden_size (`int`, *optional*, defaults to 384): + The number of input channels of the reassemble layers. + reassemble_factors (`List[int]`, *optional*, defaults to `[4, 2, 1, 0.5]`): + The up/downsampling factors of the reassemble layers. + neck_hidden_sizes (`List[str]`, *optional*, defaults to `[48, 96, 192, 384]`): + The hidden sizes to project to for the feature maps of the backbone. + fusion_hidden_size (`int`, *optional*, defaults to 64): + The number of channels before fusion. + head_in_index (`int`, *optional*, defaults to -1): + The index of the features to use in the depth estimation head. + head_hidden_size (`int`, *optional*, defaults to 32): + The number of output channels in the second convolution of the depth estimation head. + depth_estimation_type (`str`, *optional*, defaults to `"relative"`): + The type of depth estimation to use. Can be one of `["relative", "metric"]`. + max_depth (`float`, *optional*): + The maximum depth to use for the "metric" depth estimation head. 20 should be used for indoor models + and 80 for outdoor models. For "relative" depth estimation, this value is ignored. + + Example: + + ```python + >>> from transformers import DepthAnythingConfig, DepthAnythingForDepthEstimation + + >>> # Initializing a DepthAnything small style configuration + >>> configuration = DepthAnythingConfig() + + >>> # Initializing a model from the DepthAnything small style configuration + >>> model = DepthAnythingForDepthEstimation(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "depth_anything" + + def __init__( + self, + backbone_config=None, + backbone=None, + use_pretrained_backbone=False, + use_timm_backbone=False, + backbone_kwargs=None, + patch_size=14, + initializer_range=0.02, + reassemble_hidden_size=384, + reassemble_factors=[4, 2, 1, 0.5], + neck_hidden_sizes=[48, 96, 192, 384], + fusion_hidden_size=64, + head_in_index=-1, + head_hidden_size=32, + depth_estimation_type="relative", + max_depth=None, + **kwargs, + ): + super().__init__(**kwargs) + if backbone_config is None and backbone is None: + logger.info("`backbone_config` is `None`. Initializing the config with the default `Dinov2` backbone.") + backbone_config = CONFIG_MAPPING["dinov2"]( + image_size=518, + hidden_size=384, + num_attention_heads=6, + out_indices=[9, 10, 11, 12], + apply_layernorm=True, + reshape_hidden_states=False, + ) + elif isinstance(backbone_config, dict): + backbone_model_type = backbone_config.get("model_type") + config_class = CONFIG_MAPPING[backbone_model_type] + backbone_config = config_class.from_dict(backbone_config) + + verify_backbone_config_arguments( + use_timm_backbone=use_timm_backbone, + use_pretrained_backbone=use_pretrained_backbone, + backbone=backbone, + backbone_config=backbone_config, + backbone_kwargs=backbone_kwargs, + ) + + self.backbone_config = backbone_config + self.backbone = backbone + self.use_pretrained_backbone = use_pretrained_backbone + self.use_timm_backbone = use_timm_backbone + self.backbone_kwargs = backbone_kwargs + self.reassemble_hidden_size = reassemble_hidden_size + self.patch_size = patch_size + self.initializer_range = initializer_range + self.reassemble_factors = reassemble_factors + self.neck_hidden_sizes = neck_hidden_sizes + self.fusion_hidden_size = fusion_hidden_size + self.head_in_index = head_in_index + self.head_hidden_size = head_hidden_size + if depth_estimation_type not in ["relative", "metric"]: + raise ValueError("depth_estimation_type must be one of ['relative', 'metric']") + self.depth_estimation_type = depth_estimation_type + self.max_depth = max_depth if max_depth else 1 + + def to_dict(self): + """ + Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. Returns: + `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, + """ + output = copy.deepcopy(self.__dict__) + + if output["backbone_config"] is not None: + output["backbone_config"] = self.backbone_config.to_dict() + + output["model_type"] = self.__class__.model_type + return output + + +__all__ = ["DepthAnythingConfig"] diff --git a/docs/transformers/build/lib/transformers/models/detr/image_processing_detr.py b/docs/transformers/build/lib/transformers/models/detr/image_processing_detr.py new file mode 100644 index 0000000000000000000000000000000000000000..75d7e74adde07d32393fe51673c5577d221bf148 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/detr/image_processing_detr.py @@ -0,0 +1,2048 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for DETR.""" + +import io +import pathlib +from collections import defaultdict +from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import ( + PaddingMode, + center_to_corners_format, + corners_to_center_format, + id_to_rgb, + pad, + rescale, + resize, + rgb_to_id, + to_channel_dimension_format, +) +from ...image_utils import ( + IMAGENET_DEFAULT_MEAN, + IMAGENET_DEFAULT_STD, + AnnotationFormat, + AnnotationType, + ChannelDimension, + ImageInput, + PILImageResampling, + get_image_size, + infer_channel_dimension_format, + is_scaled_image, + make_list_of_images, + to_numpy_array, + valid_images, + validate_annotations, + validate_kwargs, + validate_preprocess_arguments, +) +from ...utils import ( + TensorType, + is_flax_available, + is_jax_tensor, + is_scipy_available, + is_tf_available, + is_tf_tensor, + is_torch_available, + is_torch_tensor, + is_vision_available, + logging, +) +from ...utils.import_utils import requires + + +if is_torch_available(): + import torch + from torch import nn + + +if is_vision_available(): + import PIL + + +if is_scipy_available(): + import scipy.special + import scipy.stats + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +SUPPORTED_ANNOTATION_FORMATS = (AnnotationFormat.COCO_DETECTION, AnnotationFormat.COCO_PANOPTIC) + + +# From the original repo: https://github.com/facebookresearch/detr/blob/3af9fa878e73b6894ce3596450a8d9b89d918ca9/datasets/transforms.py#L76 +def get_size_with_aspect_ratio(image_size, size, max_size=None) -> Tuple[int, int]: + """ + Computes the output image size given the input image size and the desired output size. + + Args: + image_size (`Tuple[int, int]`): + The input image size. + size (`int`): + The desired output size. + max_size (`int`, *optional*): + The maximum allowed output size. + """ + height, width = image_size + raw_size = None + if max_size is not None: + min_original_size = float(min((height, width))) + max_original_size = float(max((height, width))) + if max_original_size / min_original_size * size > max_size: + raw_size = max_size * min_original_size / max_original_size + size = int(round(raw_size)) + + if (height <= width and height == size) or (width <= height and width == size): + oh, ow = height, width + elif width < height: + ow = size + if max_size is not None and raw_size is not None: + oh = int(raw_size * height / width) + else: + oh = int(size * height / width) + else: + oh = size + if max_size is not None and raw_size is not None: + ow = int(raw_size * width / height) + else: + ow = int(size * width / height) + + return (oh, ow) + + +def get_image_size_for_max_height_width( + input_image: np.ndarray, + max_height: int, + max_width: int, + input_data_format: Optional[Union[str, ChannelDimension]] = None, +) -> Tuple[int, int]: + """ + Computes the output image size given the input image and the maximum allowed height and width. Keep aspect ratio. + Important, even if image_height < max_height and image_width < max_width, the image will be resized + to at least one of the edges be equal to max_height or max_width. + + For example: + - input_size: (100, 200), max_height: 50, max_width: 50 -> output_size: (25, 50) + - input_size: (100, 200), max_height: 200, max_width: 500 -> output_size: (200, 400) + + Args: + input_image (`np.ndarray`): + The image to resize. + max_height (`int`): + The maximum allowed height. + max_width (`int`): + The maximum allowed width. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred from the input image. + """ + image_size = get_image_size(input_image, input_data_format) + height, width = image_size + height_scale = max_height / height + width_scale = max_width / width + min_scale = min(height_scale, width_scale) + new_height = int(height * min_scale) + new_width = int(width * min_scale) + return new_height, new_width + + +def get_resize_output_image_size( + input_image: np.ndarray, + size: Union[int, Tuple[int, int], List[int]], + max_size: Optional[int] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, +) -> Tuple[int, int]: + """ + Computes the output image size given the input image size and the desired output size. If the desired output size + is a tuple or list, the output image size is returned as is. If the desired output size is an integer, the output + image size is computed by keeping the aspect ratio of the input image size. + + Args: + input_image (`np.ndarray`): + The image to resize. + size (`int` or `Tuple[int, int]` or `List[int]`): + The desired output size. + max_size (`int`, *optional*): + The maximum allowed output size. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred from the input image. + """ + image_size = get_image_size(input_image, input_data_format) + if isinstance(size, (list, tuple)): + return size + + return get_size_with_aspect_ratio(image_size, size, max_size) + + +def get_numpy_to_framework_fn(arr) -> Callable: + """ + Returns a function that converts a numpy array to the framework of the input array. + + Args: + arr (`np.ndarray`): The array to convert. + """ + if isinstance(arr, np.ndarray): + return np.array + if is_tf_available() and is_tf_tensor(arr): + import tensorflow as tf + + return tf.convert_to_tensor + if is_torch_available() and is_torch_tensor(arr): + import torch + + return torch.tensor + if is_flax_available() and is_jax_tensor(arr): + import jax.numpy as jnp + + return jnp.array + raise ValueError(f"Cannot convert arrays of type {type(arr)}") + + +def safe_squeeze(arr: np.ndarray, axis: Optional[int] = None) -> np.ndarray: + """ + Squeezes an array, but only if the axis specified has dim 1. + """ + if axis is None: + return arr.squeeze() + + try: + return arr.squeeze(axis=axis) + except ValueError: + return arr + + +def normalize_annotation(annotation: Dict, image_size: Tuple[int, int]) -> Dict: + image_height, image_width = image_size + norm_annotation = {} + for key, value in annotation.items(): + if key == "boxes": + boxes = value + boxes = corners_to_center_format(boxes) + boxes /= np.asarray([image_width, image_height, image_width, image_height], dtype=np.float32) + norm_annotation[key] = boxes + else: + norm_annotation[key] = value + return norm_annotation + + +# Copied from transformers.models.vilt.image_processing_vilt.max_across_indices +def max_across_indices(values: Iterable[Any]) -> List[Any]: + """ + Return the maximum value across all indices of an iterable of values. + """ + return [max(values_i) for values_i in zip(*values)] + + +# Copied from transformers.models.vilt.image_processing_vilt.get_max_height_width +def get_max_height_width( + images: List[np.ndarray], input_data_format: Optional[Union[str, ChannelDimension]] = None +) -> List[int]: + """ + Get the maximum height and width across all images in a batch. + """ + if input_data_format is None: + input_data_format = infer_channel_dimension_format(images[0]) + + if input_data_format == ChannelDimension.FIRST: + _, max_height, max_width = max_across_indices([img.shape for img in images]) + elif input_data_format == ChannelDimension.LAST: + max_height, max_width, _ = max_across_indices([img.shape for img in images]) + else: + raise ValueError(f"Invalid channel dimension format: {input_data_format}") + return (max_height, max_width) + + +# Copied from transformers.models.vilt.image_processing_vilt.make_pixel_mask +def make_pixel_mask( + image: np.ndarray, output_size: Tuple[int, int], input_data_format: Optional[Union[str, ChannelDimension]] = None +) -> np.ndarray: + """ + Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding. + + Args: + image (`np.ndarray`): + Image to make the pixel mask for. + output_size (`Tuple[int, int]`): + Output size of the mask. + """ + input_height, input_width = get_image_size(image, channel_dim=input_data_format) + mask = np.zeros(output_size, dtype=np.int64) + mask[:input_height, :input_width] = 1 + return mask + + +# inspired by https://github.com/facebookresearch/detr/blob/master/datasets/coco.py#L33 +def convert_coco_poly_to_mask(segmentations, height: int, width: int) -> np.ndarray: + """ + Convert a COCO polygon annotation to a mask. + + Args: + segmentations (`List[List[float]]`): + List of polygons, each polygon represented by a list of x-y coordinates. + height (`int`): + Height of the mask. + width (`int`): + Width of the mask. + """ + try: + from pycocotools import mask as coco_mask + except ImportError: + raise ImportError("Pycocotools is not installed in your environment.") + + masks = [] + for polygons in segmentations: + rles = coco_mask.frPyObjects(polygons, height, width) + mask = coco_mask.decode(rles) + if len(mask.shape) < 3: + mask = mask[..., None] + mask = np.asarray(mask, dtype=np.uint8) + mask = np.any(mask, axis=2) + masks.append(mask) + if masks: + masks = np.stack(masks, axis=0) + else: + masks = np.zeros((0, height, width), dtype=np.uint8) + + return masks + + +# inspired by https://github.com/facebookresearch/detr/blob/master/datasets/coco.py#L50 +def prepare_coco_detection_annotation( + image, + target, + return_segmentation_masks: bool = False, + input_data_format: Optional[Union[ChannelDimension, str]] = None, +): + """ + Convert the target in COCO format into the format expected by DETR. + """ + image_height, image_width = get_image_size(image, channel_dim=input_data_format) + + image_id = target["image_id"] + image_id = np.asarray([image_id], dtype=np.int64) + + # Get all COCO annotations for the given image. + annotations = target["annotations"] + annotations = [obj for obj in annotations if "iscrowd" not in obj or obj["iscrowd"] == 0] + + classes = [obj["category_id"] for obj in annotations] + classes = np.asarray(classes, dtype=np.int64) + + # for conversion to coco api + area = np.asarray([obj["area"] for obj in annotations], dtype=np.float32) + iscrowd = np.asarray([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in annotations], dtype=np.int64) + + boxes = [obj["bbox"] for obj in annotations] + # guard against no boxes via resizing + boxes = np.asarray(boxes, dtype=np.float32).reshape(-1, 4) + boxes[:, 2:] += boxes[:, :2] + boxes[:, 0::2] = boxes[:, 0::2].clip(min=0, max=image_width) + boxes[:, 1::2] = boxes[:, 1::2].clip(min=0, max=image_height) + + keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) + + new_target = {} + new_target["image_id"] = image_id + new_target["class_labels"] = classes[keep] + new_target["boxes"] = boxes[keep] + new_target["area"] = area[keep] + new_target["iscrowd"] = iscrowd[keep] + new_target["orig_size"] = np.asarray([int(image_height), int(image_width)], dtype=np.int64) + + if annotations and "keypoints" in annotations[0]: + keypoints = [obj["keypoints"] for obj in annotations] + # Converting the filtered keypoints list to a numpy array + keypoints = np.asarray(keypoints, dtype=np.float32) + # Apply the keep mask here to filter the relevant annotations + keypoints = keypoints[keep] + num_keypoints = keypoints.shape[0] + keypoints = keypoints.reshape((-1, 3)) if num_keypoints else keypoints + new_target["keypoints"] = keypoints + + if return_segmentation_masks: + segmentation_masks = [obj["segmentation"] for obj in annotations] + masks = convert_coco_poly_to_mask(segmentation_masks, image_height, image_width) + new_target["masks"] = masks[keep] + + return new_target + + +def masks_to_boxes(masks: np.ndarray) -> np.ndarray: + """ + Compute the bounding boxes around the provided panoptic segmentation masks. + + Args: + masks: masks in format `[number_masks, height, width]` where N is the number of masks + + Returns: + boxes: bounding boxes in format `[number_masks, 4]` in xyxy format + """ + if masks.size == 0: + return np.zeros((0, 4)) + + h, w = masks.shape[-2:] + y = np.arange(0, h, dtype=np.float32) + x = np.arange(0, w, dtype=np.float32) + # see https://github.com/pytorch/pytorch/issues/50276 + y, x = np.meshgrid(y, x, indexing="ij") + + x_mask = masks * np.expand_dims(x, axis=0) + x_max = x_mask.reshape(x_mask.shape[0], -1).max(-1) + x = np.ma.array(x_mask, mask=~(np.array(masks, dtype=bool))) + x_min = x.filled(fill_value=1e8) + x_min = x_min.reshape(x_min.shape[0], -1).min(-1) + + y_mask = masks * np.expand_dims(y, axis=0) + y_max = y_mask.reshape(x_mask.shape[0], -1).max(-1) + y = np.ma.array(y_mask, mask=~(np.array(masks, dtype=bool))) + y_min = y.filled(fill_value=1e8) + y_min = y_min.reshape(y_min.shape[0], -1).min(-1) + + return np.stack([x_min, y_min, x_max, y_max], 1) + + +def prepare_coco_panoptic_annotation( + image: np.ndarray, + target: Dict, + masks_path: Union[str, pathlib.Path], + return_masks: bool = True, + input_data_format: Union[ChannelDimension, str] = None, +) -> Dict: + """ + Prepare a coco panoptic annotation for DETR. + """ + image_height, image_width = get_image_size(image, channel_dim=input_data_format) + annotation_path = pathlib.Path(masks_path) / target["file_name"] + + new_target = {} + new_target["image_id"] = np.asarray([target["image_id"] if "image_id" in target else target["id"]], dtype=np.int64) + new_target["size"] = np.asarray([image_height, image_width], dtype=np.int64) + new_target["orig_size"] = np.asarray([image_height, image_width], dtype=np.int64) + + if "segments_info" in target: + masks = np.asarray(PIL.Image.open(annotation_path), dtype=np.uint32) + masks = rgb_to_id(masks) + + ids = np.array([segment_info["id"] for segment_info in target["segments_info"]]) + masks = masks == ids[:, None, None] + masks = masks.astype(np.uint8) + if return_masks: + new_target["masks"] = masks + new_target["boxes"] = masks_to_boxes(masks) + new_target["class_labels"] = np.array( + [segment_info["category_id"] for segment_info in target["segments_info"]], dtype=np.int64 + ) + new_target["iscrowd"] = np.asarray( + [segment_info["iscrowd"] for segment_info in target["segments_info"]], dtype=np.int64 + ) + new_target["area"] = np.asarray( + [segment_info["area"] for segment_info in target["segments_info"]], dtype=np.float32 + ) + + return new_target + + +def get_segmentation_image( + masks: np.ndarray, input_size: Tuple, target_size: Tuple, stuff_equiv_classes, deduplicate=False +): + h, w = input_size + final_h, final_w = target_size + + m_id = scipy.special.softmax(masks.transpose(0, 1), -1) + + if m_id.shape[-1] == 0: + # We didn't detect any mask :( + m_id = np.zeros((h, w), dtype=np.int64) + else: + m_id = m_id.argmax(-1).reshape(h, w) + + if deduplicate: + # Merge the masks corresponding to the same stuff class + for equiv in stuff_equiv_classes.values(): + for eq_id in equiv: + m_id[m_id == eq_id] = equiv[0] + + seg_img = id_to_rgb(m_id) + seg_img = resize(seg_img, (final_w, final_h), resample=PILImageResampling.NEAREST) + return seg_img + + +def get_mask_area(seg_img: np.ndarray, target_size: Tuple[int, int], n_classes: int) -> np.ndarray: + final_h, final_w = target_size + np_seg_img = seg_img.astype(np.uint8) + np_seg_img = np_seg_img.reshape(final_h, final_w, 3) + m_id = rgb_to_id(np_seg_img) + area = [(m_id == i).sum() for i in range(n_classes)] + return area + + +def score_labels_from_class_probabilities(logits: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + probs = scipy.special.softmax(logits, axis=-1) + labels = probs.argmax(-1, keepdims=True) + scores = np.take_along_axis(probs, labels, axis=-1) + scores, labels = scores.squeeze(-1), labels.squeeze(-1) + return scores, labels + + +def post_process_panoptic_sample( + out_logits: np.ndarray, + masks: np.ndarray, + boxes: np.ndarray, + processed_size: Tuple[int, int], + target_size: Tuple[int, int], + is_thing_map: Dict, + threshold=0.85, +) -> Dict: + """ + Converts the output of [`DetrForSegmentation`] into panoptic segmentation predictions for a single sample. + + Args: + out_logits (`torch.Tensor`): + The logits for this sample. + masks (`torch.Tensor`): + The predicted segmentation masks for this sample. + boxes (`torch.Tensor`): + The prediced bounding boxes for this sample. The boxes are in the normalized format `(center_x, center_y, + width, height)` and values between `[0, 1]`, relative to the size the image (disregarding padding). + processed_size (`Tuple[int, int]`): + The processed size of the image `(height, width)`, as returned by the preprocessing step i.e. the size + after data augmentation but before batching. + target_size (`Tuple[int, int]`): + The target size of the image, `(height, width)` corresponding to the requested final size of the + prediction. + is_thing_map (`Dict`): + A dictionary mapping class indices to a boolean value indicating whether the class is a thing or not. + threshold (`float`, *optional*, defaults to 0.85): + The threshold used to binarize the segmentation masks. + """ + # we filter empty queries and detection below threshold + scores, labels = score_labels_from_class_probabilities(out_logits) + keep = (labels != out_logits.shape[-1] - 1) & (scores > threshold) + + cur_scores = scores[keep] + cur_classes = labels[keep] + cur_boxes = center_to_corners_format(boxes[keep]) + + if len(cur_boxes) != len(cur_classes): + raise ValueError("Not as many boxes as there are classes") + + cur_masks = masks[keep] + cur_masks = resize(cur_masks[:, None], processed_size, resample=PILImageResampling.BILINEAR) + cur_masks = safe_squeeze(cur_masks, 1) + b, h, w = cur_masks.shape + + # It may be that we have several predicted masks for the same stuff class. + # In the following, we track the list of masks ids for each stuff class (they are merged later on) + cur_masks = cur_masks.reshape(b, -1) + stuff_equiv_classes = defaultdict(list) + for k, label in enumerate(cur_classes): + if not is_thing_map[label]: + stuff_equiv_classes[label].append(k) + + seg_img = get_segmentation_image(cur_masks, processed_size, target_size, stuff_equiv_classes, deduplicate=True) + area = get_mask_area(cur_masks, processed_size, n_classes=len(cur_scores)) + + # We filter out any mask that is too small + if cur_classes.size() > 0: + # We know filter empty masks as long as we find some + filtered_small = np.array([a <= 4 for a in area], dtype=bool) + while filtered_small.any(): + cur_masks = cur_masks[~filtered_small] + cur_scores = cur_scores[~filtered_small] + cur_classes = cur_classes[~filtered_small] + seg_img = get_segmentation_image(cur_masks, (h, w), target_size, stuff_equiv_classes, deduplicate=True) + area = get_mask_area(seg_img, target_size, n_classes=len(cur_scores)) + filtered_small = np.array([a <= 4 for a in area], dtype=bool) + else: + cur_classes = np.ones((1, 1), dtype=np.int64) + + segments_info = [ + {"id": i, "isthing": is_thing_map[cat], "category_id": int(cat), "area": a} + for i, (cat, a) in enumerate(zip(cur_classes, area)) + ] + del cur_classes + + with io.BytesIO() as out: + PIL.Image.fromarray(seg_img).save(out, format="PNG") + predictions = {"png_string": out.getvalue(), "segments_info": segments_info} + + return predictions + + +def resize_annotation( + annotation: Dict[str, Any], + orig_size: Tuple[int, int], + target_size: Tuple[int, int], + threshold: float = 0.5, + resample: PILImageResampling = PILImageResampling.NEAREST, +): + """ + Resizes an annotation to a target size. + + Args: + annotation (`Dict[str, Any]`): + The annotation dictionary. + orig_size (`Tuple[int, int]`): + The original size of the input image. + target_size (`Tuple[int, int]`): + The target size of the image, as returned by the preprocessing `resize` step. + threshold (`float`, *optional*, defaults to 0.5): + The threshold used to binarize the segmentation masks. + resample (`PILImageResampling`, defaults to `PILImageResampling.NEAREST`): + The resampling filter to use when resizing the masks. + """ + ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(target_size, orig_size)) + ratio_height, ratio_width = ratios + + new_annotation = {} + new_annotation["size"] = target_size + + for key, value in annotation.items(): + if key == "boxes": + boxes = value + scaled_boxes = boxes * np.asarray([ratio_width, ratio_height, ratio_width, ratio_height], dtype=np.float32) + new_annotation["boxes"] = scaled_boxes + elif key == "area": + area = value + scaled_area = area * (ratio_width * ratio_height) + new_annotation["area"] = scaled_area + elif key == "masks": + masks = value[:, None] + masks = np.array([resize(mask, target_size, resample=resample) for mask in masks]) + masks = masks.astype(np.float32) + masks = masks[:, 0] > threshold + new_annotation["masks"] = masks + elif key == "size": + new_annotation["size"] = target_size + else: + new_annotation[key] = value + + return new_annotation + + +# TODO - (Amy) make compatible with other frameworks +def binary_mask_to_rle(mask): + """ + Converts given binary mask of shape `(height, width)` to the run-length encoding (RLE) format. + + Args: + mask (`torch.Tensor` or `numpy.array`): + A binary mask tensor of shape `(height, width)` where 0 denotes background and 1 denotes the target + segment_id or class_id. + Returns: + `List`: Run-length encoded list of the binary mask. Refer to COCO API for more information about the RLE + format. + """ + if is_torch_tensor(mask): + mask = mask.numpy() + + pixels = mask.flatten() + pixels = np.concatenate([[0], pixels, [0]]) + runs = np.where(pixels[1:] != pixels[:-1])[0] + 1 + runs[1::2] -= runs[::2] + return list(runs) + + +# TODO - (Amy) make compatible with other frameworks +def convert_segmentation_to_rle(segmentation): + """ + Converts given segmentation map of shape `(height, width)` to the run-length encoding (RLE) format. + + Args: + segmentation (`torch.Tensor` or `numpy.array`): + A segmentation map of shape `(height, width)` where each value denotes a segment or class id. + Returns: + `List[List]`: A list of lists, where each list is the run-length encoding of a segment / class id. + """ + segment_ids = torch.unique(segmentation) + + run_length_encodings = [] + for idx in segment_ids: + mask = torch.where(segmentation == idx, 1, 0) + rle = binary_mask_to_rle(mask) + run_length_encodings.append(rle) + + return run_length_encodings + + +def remove_low_and_no_objects(masks, scores, labels, object_mask_threshold, num_labels): + """ + Binarize the given masks using `object_mask_threshold`, it returns the associated values of `masks`, `scores` and + `labels`. + + Args: + masks (`torch.Tensor`): + A tensor of shape `(num_queries, height, width)`. + scores (`torch.Tensor`): + A tensor of shape `(num_queries)`. + labels (`torch.Tensor`): + A tensor of shape `(num_queries)`. + object_mask_threshold (`float`): + A number between 0 and 1 used to binarize the masks. + Raises: + `ValueError`: Raised when the first dimension doesn't match in all input tensors. + Returns: + `Tuple[`torch.Tensor`, `torch.Tensor`, `torch.Tensor`]`: The `masks`, `scores` and `labels` without the region + < `object_mask_threshold`. + """ + if not (masks.shape[0] == scores.shape[0] == labels.shape[0]): + raise ValueError("mask, scores and labels must have the same shape!") + + to_keep = labels.ne(num_labels) & (scores > object_mask_threshold) + + return masks[to_keep], scores[to_keep], labels[to_keep] + + +def check_segment_validity(mask_labels, mask_probs, k, mask_threshold=0.5, overlap_mask_area_threshold=0.8): + # Get the mask associated with the k class + mask_k = mask_labels == k + mask_k_area = mask_k.sum() + + # Compute the area of all the stuff in query k + original_area = (mask_probs[k] >= mask_threshold).sum() + mask_exists = mask_k_area > 0 and original_area > 0 + + # Eliminate disconnected tiny segments + if mask_exists: + area_ratio = mask_k_area / original_area + if not area_ratio.item() > overlap_mask_area_threshold: + mask_exists = False + + return mask_exists, mask_k + + +def compute_segments( + mask_probs, + pred_scores, + pred_labels, + mask_threshold: float = 0.5, + overlap_mask_area_threshold: float = 0.8, + label_ids_to_fuse: Optional[Set[int]] = None, + target_size: Tuple[int, int] = None, +): + height = mask_probs.shape[1] if target_size is None else target_size[0] + width = mask_probs.shape[2] if target_size is None else target_size[1] + + segmentation = torch.zeros((height, width), dtype=torch.int32, device=mask_probs.device) + segments: List[Dict] = [] + + if target_size is not None: + mask_probs = nn.functional.interpolate( + mask_probs.unsqueeze(0), size=target_size, mode="bilinear", align_corners=False + )[0] + + current_segment_id = 0 + + # Weigh each mask by its prediction score + mask_probs *= pred_scores.view(-1, 1, 1) + mask_labels = mask_probs.argmax(0) # [height, width] + + # Keep track of instances of each class + stuff_memory_list: Dict[str, int] = {} + for k in range(pred_labels.shape[0]): + pred_class = pred_labels[k].item() + should_fuse = pred_class in label_ids_to_fuse + + # Check if mask exists and large enough to be a segment + mask_exists, mask_k = check_segment_validity( + mask_labels, mask_probs, k, mask_threshold, overlap_mask_area_threshold + ) + + if mask_exists: + if pred_class in stuff_memory_list: + current_segment_id = stuff_memory_list[pred_class] + else: + current_segment_id += 1 + + # Add current object segment to final segmentation map + segmentation[mask_k] = current_segment_id + segment_score = round(pred_scores[k].item(), 6) + segments.append( + { + "id": current_segment_id, + "label_id": pred_class, + "was_fused": should_fuse, + "score": segment_score, + } + ) + if should_fuse: + stuff_memory_list[pred_class] = current_segment_id + + return segmentation, segments + + +@requires(backends=("vision",)) +class DetrImageProcessor(BaseImageProcessor): + r""" + Constructs a Detr image processor. + + Args: + format (`str`, *optional*, defaults to `"coco_detection"`): + Data format of the annotations. One of "coco_detection" or "coco_panoptic". + do_resize (`bool`, *optional*, defaults to `True`): + Controls whether to resize the image's `(height, width)` dimensions to the specified `size`. Can be + overridden by the `do_resize` parameter in the `preprocess` method. + size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 800, "longest_edge": 1333}`): + Size of the image's `(height, width)` dimensions after resizing. Can be overridden by the `size` parameter + in the `preprocess` method. Available options are: + - `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`. + Do NOT keep the aspect ratio. + - `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting + the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge + less or equal to `longest_edge`. + - `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the + aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to + `max_width`. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`): + Resampling filter to use if resizing the image. + do_rescale (`bool`, *optional*, defaults to `True`): + Controls whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the + `do_rescale` parameter in the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the + `preprocess` method. + do_normalize (`bool`, *optional*, defaults to True): + Controls whether to normalize the image. Can be overridden by the `do_normalize` parameter in the + `preprocess` method. + image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_DEFAULT_MEAN`): + Mean values to use when normalizing the image. Can be a single value or a list of values, one for each + channel. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_DEFAULT_STD`): + Standard deviation values to use when normalizing the image. Can be a single value or a list of values, one + for each channel. Can be overridden by the `image_std` parameter in the `preprocess` method. + do_convert_annotations (`bool`, *optional*, defaults to `True`): + Controls whether to convert the annotations to the format expected by the DETR model. Converts the + bounding boxes to the format `(center_x, center_y, width, height)` and in the range `[0, 1]`. + Can be overridden by the `do_convert_annotations` parameter in the `preprocess` method. + do_pad (`bool`, *optional*, defaults to `True`): + Controls whether to pad the image. Can be overridden by the `do_pad` parameter in the `preprocess` + method. If `True`, padding will be applied to the bottom and right of the image with zeros. + If `pad_size` is provided, the image will be padded to the specified dimensions. + Otherwise, the image will be padded to the maximum height and width of the batch. + pad_size (`Dict[str, int]`, *optional*): + The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size + provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest + height and width in the batch. + """ + + model_input_names = ["pixel_values", "pixel_mask"] + + def __init__( + self, + format: Union[str, AnnotationFormat] = AnnotationFormat.COCO_DETECTION, + do_resize: bool = True, + size: Dict[str, int] = None, + resample: PILImageResampling = PILImageResampling.BILINEAR, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Union[float, List[float]] = None, + image_std: Union[float, List[float]] = None, + do_convert_annotations: Optional[bool] = None, + do_pad: bool = True, + pad_size: Optional[Dict[str, int]] = None, + **kwargs, + ) -> None: + if "pad_and_return_pixel_mask" in kwargs: + do_pad = kwargs.pop("pad_and_return_pixel_mask") + + if "max_size" in kwargs: + logger.warning_once( + "The `max_size` parameter is deprecated and will be removed in v4.26. " + "Please specify in `size['longest_edge'] instead`.", + ) + max_size = kwargs.pop("max_size") + else: + max_size = None if size is None else 1333 + + size = size if size is not None else {"shortest_edge": 800, "longest_edge": 1333} + size = get_size_dict(size, max_size=max_size, default_to_square=False) + + # Backwards compatibility + if do_convert_annotations is None: + do_convert_annotations = do_normalize + + super().__init__(**kwargs) + self.format = format + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.do_convert_annotations = do_convert_annotations + self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN + self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD + self.do_pad = do_pad + self.pad_size = pad_size + self._valid_processor_keys = [ + "images", + "annotations", + "return_segmentation_masks", + "masks_path", + "do_resize", + "size", + "resample", + "do_rescale", + "rescale_factor", + "do_normalize", + "do_convert_annotations", + "image_mean", + "image_std", + "do_pad", + "pad_size", + "format", + "return_tensors", + "data_format", + "input_data_format", + ] + + @classmethod + def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs): + """ + Overrides the `from_dict` method from the base class to make sure parameters are updated if image processor is + created using from_dict and kwargs e.g. `DetrImageProcessor.from_pretrained(checkpoint, size=600, + max_size=800)` + """ + image_processor_dict = image_processor_dict.copy() + if "max_size" in kwargs: + image_processor_dict["max_size"] = kwargs.pop("max_size") + if "pad_and_return_pixel_mask" in kwargs: + image_processor_dict["pad_and_return_pixel_mask"] = kwargs.pop("pad_and_return_pixel_mask") + return super().from_dict(image_processor_dict, **kwargs) + + def prepare_annotation( + self, + image: np.ndarray, + target: Dict, + format: Optional[AnnotationFormat] = None, + return_segmentation_masks: Optional[bool] = None, + masks_path: Optional[Union[str, pathlib.Path]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> Dict: + """ + Prepare an annotation for feeding into DETR model. + """ + format = format if format is not None else self.format + + if format == AnnotationFormat.COCO_DETECTION: + return_segmentation_masks = False if return_segmentation_masks is None else return_segmentation_masks + target = prepare_coco_detection_annotation( + image, target, return_segmentation_masks, input_data_format=input_data_format + ) + elif format == AnnotationFormat.COCO_PANOPTIC: + return_segmentation_masks = True if return_segmentation_masks is None else return_segmentation_masks + target = prepare_coco_panoptic_annotation( + image, + target, + masks_path=masks_path, + return_masks=return_segmentation_masks, + input_data_format=input_data_format, + ) + else: + raise ValueError(f"Format {format} is not supported.") + return target + + def resize( + self, + image: np.ndarray, + size: Dict[str, int], + resample: PILImageResampling = PILImageResampling.BILINEAR, + data_format: Optional[ChannelDimension] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize the image to the given size. Size can be `min_size` (scalar) or `(height, width)` tuple. If size is an + int, smaller edge of the image will be matched to this number. + + Args: + image (`np.ndarray`): + Image to resize. + size (`Dict[str, int]`): + Size of the image's `(height, width)` dimensions after resizing. Available options are: + - `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`. + Do NOT keep the aspect ratio. + - `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting + the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge + less or equal to `longest_edge`. + - `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the + aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to + `max_width`. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`): + Resampling filter to use if resizing the image. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the output image. If unset, the channel dimension format of the input + image is used. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + if "max_size" in kwargs: + logger.warning_once( + "The `max_size` parameter is deprecated and will be removed in v4.26. " + "Please specify in `size['longest_edge'] instead`.", + ) + max_size = kwargs.pop("max_size") + else: + max_size = None + size = get_size_dict(size, max_size=max_size, default_to_square=False) + if "shortest_edge" in size and "longest_edge" in size: + new_size = get_resize_output_image_size( + image, size["shortest_edge"], size["longest_edge"], input_data_format=input_data_format + ) + elif "max_height" in size and "max_width" in size: + new_size = get_image_size_for_max_height_width( + image, size["max_height"], size["max_width"], input_data_format=input_data_format + ) + elif "height" in size and "width" in size: + new_size = (size["height"], size["width"]) + else: + raise ValueError( + "Size must contain 'height' and 'width' keys or 'shortest_edge' and 'longest_edge' keys. Got" + f" {size.keys()}." + ) + image = resize( + image, + size=new_size, + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + return image + + def resize_annotation( + self, + annotation, + orig_size, + size, + resample: PILImageResampling = PILImageResampling.NEAREST, + ) -> Dict: + """ + Resize the annotation to match the resized image. If size is an int, smaller edge of the mask will be matched + to this number. + """ + return resize_annotation(annotation, orig_size=orig_size, target_size=size, resample=resample) + + # TODO (Amy) - update to use `rescale_factor` instead of `scale` + def rescale( + self, + image: np.ndarray, + rescale_factor: float, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> np.ndarray: + """ + Rescale the image by the given factor. image = image * rescale_factor. + + Args: + image (`np.ndarray`): + Image to rescale. + rescale_factor (`float`): + The value to use for rescaling. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the output image. If unset, the channel dimension format of the input + image is used. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + input_data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the input image. If unset, is inferred from the input image. Can be + one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + """ + return rescale(image, rescale_factor, data_format=data_format, input_data_format=input_data_format) + + def normalize_annotation(self, annotation: Dict, image_size: Tuple[int, int]) -> Dict: + """ + Normalize the boxes in the annotation from `[top_left_x, top_left_y, bottom_right_x, bottom_right_y]` to + `[center_x, center_y, width, height]` format and from absolute to relative pixel values. + """ + return normalize_annotation(annotation, image_size=image_size) + + def _update_annotation_for_padded_image( + self, + annotation: Dict, + input_image_size: Tuple[int, int], + output_image_size: Tuple[int, int], + padding, + update_bboxes, + ) -> Dict: + """ + Update the annotation for a padded image. + """ + new_annotation = {} + new_annotation["size"] = output_image_size + + for key, value in annotation.items(): + if key == "masks": + masks = value + masks = pad( + masks, + padding, + mode=PaddingMode.CONSTANT, + constant_values=0, + input_data_format=ChannelDimension.FIRST, + ) + masks = safe_squeeze(masks, 1) + new_annotation["masks"] = masks + elif key == "boxes" and update_bboxes: + boxes = value + boxes *= np.asarray( + [ + input_image_size[1] / output_image_size[1], + input_image_size[0] / output_image_size[0], + input_image_size[1] / output_image_size[1], + input_image_size[0] / output_image_size[0], + ] + ) + new_annotation["boxes"] = boxes + elif key == "size": + new_annotation["size"] = output_image_size + else: + new_annotation[key] = value + return new_annotation + + def _pad_image( + self, + image: np.ndarray, + output_size: Tuple[int, int], + annotation: Optional[Dict[str, Any]] = None, + constant_values: Union[float, Iterable[float]] = 0, + data_format: Optional[ChannelDimension] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + update_bboxes: bool = True, + ) -> np.ndarray: + """ + Pad an image with zeros to the given size. + """ + input_height, input_width = get_image_size(image, channel_dim=input_data_format) + output_height, output_width = output_size + + pad_bottom = output_height - input_height + pad_right = output_width - input_width + padding = ((0, pad_bottom), (0, pad_right)) + padded_image = pad( + image, + padding, + mode=PaddingMode.CONSTANT, + constant_values=constant_values, + data_format=data_format, + input_data_format=input_data_format, + ) + if annotation is not None: + annotation = self._update_annotation_for_padded_image( + annotation, (input_height, input_width), (output_height, output_width), padding, update_bboxes + ) + return padded_image, annotation + + def pad( + self, + images: List[np.ndarray], + annotations: Optional[Union[AnnotationType, List[AnnotationType]]] = None, + constant_values: Union[float, Iterable[float]] = 0, + return_pixel_mask: bool = True, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Optional[ChannelDimension] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + update_bboxes: bool = True, + pad_size: Optional[Dict[str, int]] = None, + ) -> BatchFeature: + """ + Pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width + in the batch and optionally returns their corresponding pixel mask. + + Args: + images (List[`np.ndarray`]): + Images to pad. + annotations (`AnnotationType` or `List[AnnotationType]`, *optional*): + Annotations to transform according to the padding that is applied to the images. + constant_values (`float` or `Iterable[float]`, *optional*): + The value to use for the padding if `mode` is `"constant"`. + return_pixel_mask (`bool`, *optional*, defaults to `True`): + Whether to return a pixel mask. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + update_bboxes (`bool`, *optional*, defaults to `True`): + Whether to update the bounding boxes in the annotations to match the padded images. If the + bounding boxes have not been converted to relative coordinates and `(centre_x, centre_y, width, height)` + format, the bounding boxes will not be updated. + pad_size (`Dict[str, int]`, *optional*): + The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size + provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest + height and width in the batch. + """ + pad_size = pad_size if pad_size is not None else self.pad_size + if pad_size is not None: + padded_size = (pad_size["height"], pad_size["width"]) + else: + padded_size = get_max_height_width(images, input_data_format=input_data_format) + + annotation_list = annotations if annotations is not None else [None] * len(images) + padded_images = [] + padded_annotations = [] + for image, annotation in zip(images, annotation_list): + padded_image, padded_annotation = self._pad_image( + image, + padded_size, + annotation, + constant_values=constant_values, + data_format=data_format, + input_data_format=input_data_format, + update_bboxes=update_bboxes, + ) + padded_images.append(padded_image) + padded_annotations.append(padded_annotation) + + data = {"pixel_values": padded_images} + + if return_pixel_mask: + masks = [ + make_pixel_mask(image=image, output_size=padded_size, input_data_format=input_data_format) + for image in images + ] + data["pixel_mask"] = masks + + encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors) + + if annotations is not None: + encoded_inputs["labels"] = [ + BatchFeature(annotation, tensor_type=return_tensors) for annotation in padded_annotations + ] + + return encoded_inputs + + def preprocess( + self, + images: ImageInput, + annotations: Optional[Union[AnnotationType, List[AnnotationType]]] = None, + return_segmentation_masks: Optional[bool] = None, + masks_path: Optional[Union[str, pathlib.Path]] = None, + do_resize: Optional[bool] = None, + size: Optional[Dict[str, int]] = None, + resample=None, # PILImageResampling + do_rescale: Optional[bool] = None, + rescale_factor: Optional[Union[int, float]] = None, + do_normalize: Optional[bool] = None, + do_convert_annotations: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_pad: Optional[bool] = None, + format: Optional[Union[str, AnnotationFormat]] = None, + return_tensors: Optional[Union[TensorType, str]] = None, + data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + pad_size: Optional[Dict[str, int]] = None, + **kwargs, + ) -> BatchFeature: + """ + Preprocess an image or a batch of images so that it can be used by the model. + + Args: + images (`ImageInput`): + Image or batch of images to preprocess. Expects a single or batch of images with pixel values ranging + from 0 to 255. If passing in images with pixel values between 0 and 1, set `do_rescale=False`. + annotations (`AnnotationType` or `List[AnnotationType]`, *optional*): + List of annotations associated with the image or batch of images. If annotation is for object + detection, the annotations should be a dictionary with the following keys: + - "image_id" (`int`): The image id. + - "annotations" (`List[Dict]`): List of annotations for an image. Each annotation should be a + dictionary. An image can have no annotations, in which case the list should be empty. + If annotation is for segmentation, the annotations should be a dictionary with the following keys: + - "image_id" (`int`): The image id. + - "segments_info" (`List[Dict]`): List of segments for an image. Each segment should be a dictionary. + An image can have no segments, in which case the list should be empty. + - "file_name" (`str`): The file name of the image. + return_segmentation_masks (`bool`, *optional*, defaults to self.return_segmentation_masks): + Whether to return segmentation masks. + masks_path (`str` or `pathlib.Path`, *optional*): + Path to the directory containing the segmentation masks. + do_resize (`bool`, *optional*, defaults to self.do_resize): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to self.size): + Size of the image's `(height, width)` dimensions after resizing. Available options are: + - `{"height": int, "width": int}`: The image will be resized to the exact size `(height, width)`. + Do NOT keep the aspect ratio. + - `{"shortest_edge": int, "longest_edge": int}`: The image will be resized to a maximum size respecting + the aspect ratio and keeping the shortest edge less or equal to `shortest_edge` and the longest edge + less or equal to `longest_edge`. + - `{"max_height": int, "max_width": int}`: The image will be resized to the maximum size respecting the + aspect ratio and keeping the height less or equal to `max_height` and the width less or equal to + `max_width`. + resample (`PILImageResampling`, *optional*, defaults to self.resample): + Resampling filter to use when resizing the image. + do_rescale (`bool`, *optional*, defaults to self.do_rescale): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to self.rescale_factor): + Rescale factor to use when rescaling the image. + do_normalize (`bool`, *optional*, defaults to self.do_normalize): + Whether to normalize the image. + do_convert_annotations (`bool`, *optional*, defaults to self.do_convert_annotations): + Whether to convert the annotations to the format expected by the model. Converts the bounding + boxes from the format `(top_left_x, top_left_y, width, height)` to `(center_x, center_y, width, height)` + and in relative coordinates. + image_mean (`float` or `List[float]`, *optional*, defaults to self.image_mean): + Mean to use when normalizing the image. + image_std (`float` or `List[float]`, *optional*, defaults to self.image_std): + Standard deviation to use when normalizing the image. + do_pad (`bool`, *optional*, defaults to self.do_pad): + Whether to pad the image. If `True`, padding will be applied to the bottom and right of + the image with zeros. If `pad_size` is provided, the image will be padded to the specified + dimensions. Otherwise, the image will be padded to the maximum height and width of the batch. + format (`str` or `AnnotationFormat`, *optional*, defaults to self.format): + Format of the annotations. + return_tensors (`str` or `TensorType`, *optional*, defaults to self.return_tensors): + Type of tensors to return. If `None`, will return the list of images. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + pad_size (`Dict[str, int]`, *optional*): + The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size + provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest + height and width in the batch. + """ + if "pad_and_return_pixel_mask" in kwargs: + logger.warning_once( + "The `pad_and_return_pixel_mask` argument is deprecated and will be removed in a future version, " + "use `do_pad` instead." + ) + do_pad = kwargs.pop("pad_and_return_pixel_mask") + + if "max_size" in kwargs: + logger.warning_once( + "The `max_size` argument is deprecated and will be removed in a future version, use" + " `size['longest_edge']` instead." + ) + size = kwargs.pop("max_size") + + do_resize = self.do_resize if do_resize is None else do_resize + size = self.size if size is None else size + size = get_size_dict(size=size, default_to_square=False) + resample = self.resample if resample is None else resample + do_rescale = self.do_rescale if do_rescale is None else do_rescale + rescale_factor = self.rescale_factor if rescale_factor is None else rescale_factor + do_normalize = self.do_normalize if do_normalize is None else do_normalize + image_mean = self.image_mean if image_mean is None else image_mean + image_std = self.image_std if image_std is None else image_std + do_convert_annotations = ( + self.do_convert_annotations if do_convert_annotations is None else do_convert_annotations + ) + do_pad = self.do_pad if do_pad is None else do_pad + pad_size = self.pad_size if pad_size is None else pad_size + format = self.format if format is None else format + + images = make_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys) + + # Here, the pad() method pads to the maximum of (width, height). It does not need to be validated. + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_resize=do_resize, + size=size, + resample=resample, + ) + + if annotations is not None and isinstance(annotations, dict): + annotations = [annotations] + + if annotations is not None and len(images) != len(annotations): + raise ValueError( + f"The number of images ({len(images)}) and annotations ({len(annotations)}) do not match." + ) + + format = AnnotationFormat(format) + if annotations is not None: + validate_annotations(format, SUPPORTED_ANNOTATION_FORMATS, annotations) + + if ( + masks_path is not None + and format == AnnotationFormat.COCO_PANOPTIC + and not isinstance(masks_path, (pathlib.Path, str)) + ): + raise ValueError( + "The path to the directory containing the mask PNG files should be provided as a" + f" `pathlib.Path` or string object, but is {type(masks_path)} instead." + ) + + # All transformations expect numpy arrays + images = [to_numpy_array(image) for image in images] + + if do_rescale and is_scaled_image(images[0]): + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + # prepare (COCO annotations as a list of Dict -> DETR target as a single Dict per image) + if annotations is not None: + prepared_images = [] + prepared_annotations = [] + for image, target in zip(images, annotations): + target = self.prepare_annotation( + image, + target, + format, + return_segmentation_masks=return_segmentation_masks, + masks_path=masks_path, + input_data_format=input_data_format, + ) + prepared_images.append(image) + prepared_annotations.append(target) + images = prepared_images + annotations = prepared_annotations + del prepared_images, prepared_annotations + + # transformations + if do_resize: + if annotations is not None: + resized_images, resized_annotations = [], [] + for image, target in zip(images, annotations): + orig_size = get_image_size(image, input_data_format) + resized_image = self.resize( + image, size=size, resample=resample, input_data_format=input_data_format + ) + resized_annotation = self.resize_annotation( + target, orig_size, get_image_size(resized_image, input_data_format) + ) + resized_images.append(resized_image) + resized_annotations.append(resized_annotation) + images = resized_images + annotations = resized_annotations + del resized_images, resized_annotations + else: + images = [ + self.resize(image, size=size, resample=resample, input_data_format=input_data_format) + for image in images + ] + + if do_rescale: + images = [self.rescale(image, rescale_factor, input_data_format=input_data_format) for image in images] + + if do_normalize: + images = [ + self.normalize(image, image_mean, image_std, input_data_format=input_data_format) for image in images + ] + + if do_convert_annotations and annotations is not None: + annotations = [ + self.normalize_annotation(annotation, get_image_size(image, input_data_format)) + for annotation, image in zip(annotations, images) + ] + + if do_pad: + # Pads images and returns their mask: {'pixel_values': ..., 'pixel_mask': ...} + encoded_inputs = self.pad( + images, + annotations=annotations, + return_pixel_mask=True, + data_format=data_format, + input_data_format=input_data_format, + update_bboxes=do_convert_annotations, + return_tensors=return_tensors, + pad_size=pad_size, + ) + else: + images = [ + to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + for image in images + ] + encoded_inputs = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors) + if annotations is not None: + encoded_inputs["labels"] = [ + BatchFeature(annotation, tensor_type=return_tensors) for annotation in annotations + ] + + return encoded_inputs + + # POSTPROCESSING METHODS - TODO: add support for other frameworks + # inspired by https://github.com/facebookresearch/detr/blob/master/models/detr.py#L258 + def post_process(self, outputs, target_sizes): + """ + Converts the raw output of [`DetrForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y, + bottom_right_x, bottom_right_y) format. Only supports PyTorch. + + Args: + outputs ([`DetrObjectDetectionOutput`]): + Raw outputs of the model. + target_sizes (`torch.Tensor` of shape `(batch_size, 2)`): + Tensor containing the size (height, width) of each image of the batch. For evaluation, this must be the + original image size (before any data augmentation). For visualization, this should be the image size + after data augment, but before padding. + Returns: + `List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image + in the batch as predicted by the model. + """ + logger.warning_once( + "`post_process` is deprecated and will be removed in v5 of Transformers, please use" + " `post_process_object_detection` instead, with `threshold=0.` for equivalent results.", + ) + + out_logits, out_bbox = outputs.logits, outputs.pred_boxes + + if len(out_logits) != len(target_sizes): + raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits") + if target_sizes.shape[1] != 2: + raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch") + + prob = nn.functional.softmax(out_logits, -1) + scores, labels = prob[..., :-1].max(-1) + + # convert to [x0, y0, x1, y1] format + boxes = center_to_corners_format(out_bbox) + # and from relative [0, 1] to absolute [0, height] coordinates + img_h, img_w = target_sizes.unbind(1) + scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device) + boxes = boxes * scale_fct[:, None, :] + + results = [{"scores": s, "labels": l, "boxes": b} for s, l, b in zip(scores, labels, boxes)] + return results + + def post_process_segmentation(self, outputs, target_sizes, threshold=0.9, mask_threshold=0.5): + """ + Converts the output of [`DetrForSegmentation`] into image segmentation predictions. Only supports PyTorch. + + Args: + outputs ([`DetrSegmentationOutput`]): + Raw outputs of the model. + target_sizes (`torch.Tensor` of shape `(batch_size, 2)` or `List[Tuple]` of length `batch_size`): + Torch Tensor (or list) corresponding to the requested final size (h, w) of each prediction. + threshold (`float`, *optional*, defaults to 0.9): + Threshold to use to filter out queries. + mask_threshold (`float`, *optional*, defaults to 0.5): + Threshold to use when turning the predicted masks into binary values. + Returns: + `List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels, and masks for an image + in the batch as predicted by the model. + """ + logger.warning_once( + "`post_process_segmentation` is deprecated and will be removed in v5 of Transformers, please use" + " `post_process_semantic_segmentation`.", + ) + out_logits, raw_masks = outputs.logits, outputs.pred_masks + empty_label = out_logits.shape[-1] - 1 + preds = [] + + def to_tuple(tup): + if isinstance(tup, tuple): + return tup + return tuple(tup.tolist()) + + for cur_logits, cur_masks, size in zip(out_logits, raw_masks, target_sizes): + # we filter empty queries and detection below threshold + cur_scores, cur_labels = cur_logits.softmax(-1).max(-1) + keep = cur_labels.ne(empty_label) & (cur_scores > threshold) + cur_scores = cur_scores[keep] + cur_labels = cur_labels[keep] + cur_masks = cur_masks[keep] + cur_masks = nn.functional.interpolate(cur_masks[:, None], to_tuple(size), mode="bilinear").squeeze(1) + cur_masks = (cur_masks.sigmoid() > mask_threshold) * 1 + + predictions = {"scores": cur_scores, "labels": cur_labels, "masks": cur_masks} + preds.append(predictions) + return preds + + # inspired by https://github.com/facebookresearch/detr/blob/master/models/segmentation.py#L218 + def post_process_instance(self, results, outputs, orig_target_sizes, max_target_sizes, threshold=0.5): + """ + Converts the output of [`DetrForSegmentation`] into actual instance segmentation predictions. Only supports + PyTorch. + + Args: + results (`List[Dict]`): + Results list obtained by [`~DetrImageProcessor.post_process`], to which "masks" results will be added. + outputs ([`DetrSegmentationOutput`]): + Raw outputs of the model. + orig_target_sizes (`torch.Tensor` of shape `(batch_size, 2)`): + Tensor containing the size (h, w) of each image of the batch. For evaluation, this must be the original + image size (before any data augmentation). + max_target_sizes (`torch.Tensor` of shape `(batch_size, 2)`): + Tensor containing the maximum size (h, w) of each image of the batch. For evaluation, this must be the + original image size (before any data augmentation). + threshold (`float`, *optional*, defaults to 0.5): + Threshold to use when turning the predicted masks into binary values. + Returns: + `List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels, boxes and masks for an + image in the batch as predicted by the model. + """ + logger.warning_once( + "`post_process_instance` is deprecated and will be removed in v5 of Transformers, please use" + " `post_process_instance_segmentation`.", + ) + + if len(orig_target_sizes) != len(max_target_sizes): + raise ValueError("Make sure to pass in as many orig_target_sizes as max_target_sizes") + max_h, max_w = max_target_sizes.max(0)[0].tolist() + outputs_masks = outputs.pred_masks.squeeze(2) + outputs_masks = nn.functional.interpolate( + outputs_masks, size=(max_h, max_w), mode="bilinear", align_corners=False + ) + outputs_masks = (outputs_masks.sigmoid() > threshold).cpu() + + for i, (cur_mask, t, tt) in enumerate(zip(outputs_masks, max_target_sizes, orig_target_sizes)): + img_h, img_w = t[0], t[1] + results[i]["masks"] = cur_mask[:, :img_h, :img_w].unsqueeze(1) + results[i]["masks"] = nn.functional.interpolate( + results[i]["masks"].float(), size=tuple(tt.tolist()), mode="nearest" + ).byte() + + return results + + # inspired by https://github.com/facebookresearch/detr/blob/master/models/segmentation.py#L241 + def post_process_panoptic(self, outputs, processed_sizes, target_sizes=None, is_thing_map=None, threshold=0.85): + """ + Converts the output of [`DetrForSegmentation`] into actual panoptic predictions. Only supports PyTorch. + + Args: + outputs ([`DetrSegmentationOutput`]): + Raw outputs of the model. + processed_sizes (`torch.Tensor` of shape `(batch_size, 2)` or `List[Tuple]` of length `batch_size`): + Torch Tensor (or list) containing the size (h, w) of each image of the batch, i.e. the size after data + augmentation but before batching. + target_sizes (`torch.Tensor` of shape `(batch_size, 2)` or `List[Tuple]` of length `batch_size`, *optional*): + Torch Tensor (or list) corresponding to the requested final size `(height, width)` of each prediction. + If left to None, it will default to the `processed_sizes`. + is_thing_map (`torch.Tensor` of shape `(batch_size, 2)`, *optional*): + Dictionary mapping class indices to either True or False, depending on whether or not they are a thing. + If not set, defaults to the `is_thing_map` of COCO panoptic. + threshold (`float`, *optional*, defaults to 0.85): + Threshold to use to filter out queries. + Returns: + `List[Dict]`: A list of dictionaries, each dictionary containing a PNG string and segments_info values for + an image in the batch as predicted by the model. + """ + logger.warning_once( + "`post_process_panoptic is deprecated and will be removed in v5 of Transformers, please use" + " `post_process_panoptic_segmentation`.", + ) + if target_sizes is None: + target_sizes = processed_sizes + if len(processed_sizes) != len(target_sizes): + raise ValueError("Make sure to pass in as many processed_sizes as target_sizes") + + if is_thing_map is None: + # default to is_thing_map of COCO panoptic + is_thing_map = {i: i <= 90 for i in range(201)} + + out_logits, raw_masks, raw_boxes = outputs.logits, outputs.pred_masks, outputs.pred_boxes + if not len(out_logits) == len(raw_masks) == len(target_sizes): + raise ValueError( + "Make sure that you pass in as many target sizes as the batch dimension of the logits and masks" + ) + empty_label = out_logits.shape[-1] - 1 + preds = [] + + def to_tuple(tup): + if isinstance(tup, tuple): + return tup + return tuple(tup.tolist()) + + for cur_logits, cur_masks, cur_boxes, size, target_size in zip( + out_logits, raw_masks, raw_boxes, processed_sizes, target_sizes + ): + # we filter empty queries and detection below threshold + cur_scores, cur_labels = cur_logits.softmax(-1).max(-1) + keep = cur_labels.ne(empty_label) & (cur_scores > threshold) + cur_scores = cur_scores[keep] + cur_labels = cur_labels[keep] + cur_masks = cur_masks[keep] + cur_masks = nn.functional.interpolate(cur_masks[:, None], to_tuple(size), mode="bilinear").squeeze(1) + cur_boxes = center_to_corners_format(cur_boxes[keep]) + + h, w = cur_masks.shape[-2:] + if len(cur_boxes) != len(cur_labels): + raise ValueError("Not as many boxes as there are classes") + + # It may be that we have several predicted masks for the same stuff class. + # In the following, we track the list of masks ids for each stuff class (they are merged later on) + cur_masks = cur_masks.flatten(1) + stuff_equiv_classes = defaultdict(lambda: []) + for k, label in enumerate(cur_labels): + if not is_thing_map[label.item()]: + stuff_equiv_classes[label.item()].append(k) + + def get_ids_area(masks, scores, dedup=False): + # This helper function creates the final panoptic segmentation image + # It also returns the area of the masks that appears on the image + + m_id = masks.transpose(0, 1).softmax(-1) + + if m_id.shape[-1] == 0: + # We didn't detect any mask :( + m_id = torch.zeros((h, w), dtype=torch.long, device=m_id.device) + else: + m_id = m_id.argmax(-1).view(h, w) + + if dedup: + # Merge the masks corresponding to the same stuff class + for equiv in stuff_equiv_classes.values(): + if len(equiv) > 1: + for eq_id in equiv: + m_id.masked_fill_(m_id.eq(eq_id), equiv[0]) + + final_h, final_w = to_tuple(target_size) + + seg_img = PIL.Image.fromarray(id_to_rgb(m_id.view(h, w).cpu().numpy())) + seg_img = seg_img.resize(size=(final_w, final_h), resample=PILImageResampling.NEAREST) + + np_seg_img = torch.ByteTensor(torch.ByteStorage.from_buffer(seg_img.tobytes())) + np_seg_img = np_seg_img.view(final_h, final_w, 3) + np_seg_img = np_seg_img.numpy() + + m_id = torch.from_numpy(rgb_to_id(np_seg_img)) + + area = [] + for i in range(len(scores)): + area.append(m_id.eq(i).sum().item()) + return area, seg_img + + area, seg_img = get_ids_area(cur_masks, cur_scores, dedup=True) + if cur_labels.numel() > 0: + # We know filter empty masks as long as we find some + while True: + filtered_small = torch.as_tensor( + [area[i] <= 4 for i, c in enumerate(cur_labels)], dtype=torch.bool, device=keep.device + ) + if filtered_small.any().item(): + cur_scores = cur_scores[~filtered_small] + cur_labels = cur_labels[~filtered_small] + cur_masks = cur_masks[~filtered_small] + area, seg_img = get_ids_area(cur_masks, cur_scores) + else: + break + + else: + cur_labels = torch.ones(1, dtype=torch.long, device=cur_labels.device) + + segments_info = [] + for i, a in enumerate(area): + cat = cur_labels[i].item() + segments_info.append({"id": i, "isthing": is_thing_map[cat], "category_id": cat, "area": a}) + del cur_labels + + with io.BytesIO() as out: + seg_img.save(out, format="PNG") + predictions = {"png_string": out.getvalue(), "segments_info": segments_info} + preds.append(predictions) + return preds + + # inspired by https://github.com/facebookresearch/detr/blob/master/models/detr.py#L258 + def post_process_object_detection( + self, outputs, threshold: float = 0.5, target_sizes: Union[TensorType, List[Tuple]] = None + ): + """ + Converts the raw output of [`DetrForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y, + bottom_right_x, bottom_right_y) format. Only supports PyTorch. + + Args: + outputs ([`DetrObjectDetectionOutput`]): + Raw outputs of the model. + threshold (`float`, *optional*): + Score threshold to keep object detection predictions. + target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*): + Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size + `(height, width)` of each image in the batch. If unset, predictions will not be resized. + Returns: + `List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image + in the batch as predicted by the model. + """ + out_logits, out_bbox = outputs.logits, outputs.pred_boxes + + if target_sizes is not None: + if len(out_logits) != len(target_sizes): + raise ValueError( + "Make sure that you pass in as many target sizes as the batch dimension of the logits" + ) + + prob = nn.functional.softmax(out_logits, -1) + scores, labels = prob[..., :-1].max(-1) + + # Convert to [x0, y0, x1, y1] format + boxes = center_to_corners_format(out_bbox) + + # Convert from relative [0, 1] to absolute [0, height] coordinates + if target_sizes is not None: + if isinstance(target_sizes, List): + img_h = torch.Tensor([i[0] for i in target_sizes]) + img_w = torch.Tensor([i[1] for i in target_sizes]) + else: + img_h, img_w = target_sizes.unbind(1) + + scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device) + boxes = boxes * scale_fct[:, None, :] + + results = [] + for s, l, b in zip(scores, labels, boxes): + score = s[s > threshold] + label = l[s > threshold] + box = b[s > threshold] + results.append({"scores": score, "labels": label, "boxes": box}) + + return results + + def post_process_semantic_segmentation(self, outputs, target_sizes: List[Tuple[int, int]] = None): + """ + Converts the output of [`DetrForSegmentation`] into semantic segmentation maps. Only supports PyTorch. + + Args: + outputs ([`DetrForSegmentation`]): + Raw outputs of the model. + target_sizes (`List[Tuple[int, int]]`, *optional*): + A list of tuples (`Tuple[int, int]`) containing the target size (height, width) of each image in the + batch. If unset, predictions will not be resized. + Returns: + `List[torch.Tensor]`: + A list of length `batch_size`, where each item is a semantic segmentation map of shape (height, width) + corresponding to the target_sizes entry (if `target_sizes` is specified). Each entry of each + `torch.Tensor` correspond to a semantic class id. + """ + class_queries_logits = outputs.logits # [batch_size, num_queries, num_classes+1] + masks_queries_logits = outputs.pred_masks # [batch_size, num_queries, height, width] + + # Remove the null class `[..., :-1]` + masks_classes = class_queries_logits.softmax(dim=-1)[..., :-1] + masks_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width] + + # Semantic segmentation logits of shape (batch_size, num_classes, height, width) + segmentation = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs) + batch_size = class_queries_logits.shape[0] + + # Resize logits and compute semantic segmentation maps + if target_sizes is not None: + if batch_size != len(target_sizes): + raise ValueError( + "Make sure that you pass in as many target sizes as the batch dimension of the logits" + ) + + semantic_segmentation = [] + for idx in range(batch_size): + resized_logits = nn.functional.interpolate( + segmentation[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False + ) + semantic_map = resized_logits[0].argmax(dim=0) + semantic_segmentation.append(semantic_map) + else: + semantic_segmentation = segmentation.argmax(dim=1) + semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])] + + return semantic_segmentation + + # inspired by https://github.com/facebookresearch/detr/blob/master/models/segmentation.py#L218 + def post_process_instance_segmentation( + self, + outputs, + threshold: float = 0.5, + mask_threshold: float = 0.5, + overlap_mask_area_threshold: float = 0.8, + target_sizes: Optional[List[Tuple[int, int]]] = None, + return_coco_annotation: Optional[bool] = False, + ) -> List[Dict]: + """ + Converts the output of [`DetrForSegmentation`] into instance segmentation predictions. Only supports PyTorch. + + Args: + outputs ([`DetrForSegmentation`]): + Raw outputs of the model. + threshold (`float`, *optional*, defaults to 0.5): + The probability score threshold to keep predicted instance masks. + mask_threshold (`float`, *optional*, defaults to 0.5): + Threshold to use when turning the predicted masks into binary values. + overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8): + The overlap mask area threshold to merge or discard small disconnected parts within each binary + instance mask. + target_sizes (`List[Tuple]`, *optional*): + List of length (batch_size), where each list item (`Tuple[int, int]]`) corresponds to the requested + final size (height, width) of each prediction. If unset, predictions will not be resized. + return_coco_annotation (`bool`, *optional*): + Defaults to `False`. If set to `True`, segmentation maps are returned in COCO run-length encoding (RLE) + format. + Returns: + `List[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys: + - **segmentation** -- A tensor of shape `(height, width)` where each pixel represents a `segment_id` or + `List[List]` run-length encoding (RLE) of the segmentation map if return_coco_annotation is set to + `True`. Set to `None` if no mask if found above `threshold`. + - **segments_info** -- A dictionary that contains additional information on each segment. + - **id** -- An integer representing the `segment_id`. + - **label_id** -- An integer representing the label / semantic class id corresponding to `segment_id`. + - **score** -- Prediction score of segment with `segment_id`. + """ + class_queries_logits = outputs.logits # [batch_size, num_queries, num_classes+1] + masks_queries_logits = outputs.pred_masks # [batch_size, num_queries, height, width] + + batch_size = class_queries_logits.shape[0] + num_labels = class_queries_logits.shape[-1] - 1 + + mask_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width] + + # Predicted label and score of each query (batch_size, num_queries) + pred_scores, pred_labels = nn.functional.softmax(class_queries_logits, dim=-1).max(-1) + + # Loop over items in batch size + results: List[Dict[str, TensorType]] = [] + + for i in range(batch_size): + mask_probs_item, pred_scores_item, pred_labels_item = remove_low_and_no_objects( + mask_probs[i], pred_scores[i], pred_labels[i], threshold, num_labels + ) + + # No mask found + if mask_probs_item.shape[0] <= 0: + height, width = target_sizes[i] if target_sizes is not None else mask_probs_item.shape[1:] + segmentation = torch.zeros((height, width)) - 1 + results.append({"segmentation": segmentation, "segments_info": []}) + continue + + # Get segmentation map and segment information of batch item + target_size = target_sizes[i] if target_sizes is not None else None + segmentation, segments = compute_segments( + mask_probs=mask_probs_item, + pred_scores=pred_scores_item, + pred_labels=pred_labels_item, + mask_threshold=mask_threshold, + overlap_mask_area_threshold=overlap_mask_area_threshold, + label_ids_to_fuse=[], + target_size=target_size, + ) + + # Return segmentation map in run-length encoding (RLE) format + if return_coco_annotation: + segmentation = convert_segmentation_to_rle(segmentation) + + results.append({"segmentation": segmentation, "segments_info": segments}) + return results + + # inspired by https://github.com/facebookresearch/detr/blob/master/models/segmentation.py#L241 + def post_process_panoptic_segmentation( + self, + outputs, + threshold: float = 0.5, + mask_threshold: float = 0.5, + overlap_mask_area_threshold: float = 0.8, + label_ids_to_fuse: Optional[Set[int]] = None, + target_sizes: Optional[List[Tuple[int, int]]] = None, + ) -> List[Dict]: + """ + Converts the output of [`DetrForSegmentation`] into image panoptic segmentation predictions. Only supports + PyTorch. + + Args: + outputs ([`DetrForSegmentation`]): + The outputs from [`DetrForSegmentation`]. + threshold (`float`, *optional*, defaults to 0.5): + The probability score threshold to keep predicted instance masks. + mask_threshold (`float`, *optional*, defaults to 0.5): + Threshold to use when turning the predicted masks into binary values. + overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8): + The overlap mask area threshold to merge or discard small disconnected parts within each binary + instance mask. + label_ids_to_fuse (`Set[int]`, *optional*): + The labels in this state will have all their instances be fused together. For instance we could say + there can only be one sky in an image, but several persons, so the label ID for sky would be in that + set, but not the one for person. + target_sizes (`List[Tuple]`, *optional*): + List of length (batch_size), where each list item (`Tuple[int, int]]`) corresponds to the requested + final size (height, width) of each prediction in batch. If unset, predictions will not be resized. + Returns: + `List[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys: + - **segmentation** -- a tensor of shape `(height, width)` where each pixel represents a `segment_id` or + `None` if no mask if found above `threshold`. If `target_sizes` is specified, segmentation is resized to + the corresponding `target_sizes` entry. + - **segments_info** -- A dictionary that contains additional information on each segment. + - **id** -- an integer representing the `segment_id`. + - **label_id** -- An integer representing the label / semantic class id corresponding to `segment_id`. + - **was_fused** -- a boolean, `True` if `label_id` was in `label_ids_to_fuse`, `False` otherwise. + Multiple instances of the same class / label were fused and assigned a single `segment_id`. + - **score** -- Prediction score of segment with `segment_id`. + """ + + if label_ids_to_fuse is None: + logger.warning_once("`label_ids_to_fuse` unset. No instance will be fused.") + label_ids_to_fuse = set() + + class_queries_logits = outputs.logits # [batch_size, num_queries, num_classes+1] + masks_queries_logits = outputs.pred_masks # [batch_size, num_queries, height, width] + + batch_size = class_queries_logits.shape[0] + num_labels = class_queries_logits.shape[-1] - 1 + + mask_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width] + + # Predicted label and score of each query (batch_size, num_queries) + pred_scores, pred_labels = nn.functional.softmax(class_queries_logits, dim=-1).max(-1) + + # Loop over items in batch size + results: List[Dict[str, TensorType]] = [] + + for i in range(batch_size): + mask_probs_item, pred_scores_item, pred_labels_item = remove_low_and_no_objects( + mask_probs[i], pred_scores[i], pred_labels[i], threshold, num_labels + ) + + # No mask found + if mask_probs_item.shape[0] <= 0: + height, width = target_sizes[i] if target_sizes is not None else mask_probs_item.shape[1:] + segmentation = torch.zeros((height, width)) - 1 + results.append({"segmentation": segmentation, "segments_info": []}) + continue + + # Get segmentation map and segment information of batch item + target_size = target_sizes[i] if target_sizes is not None else None + segmentation, segments = compute_segments( + mask_probs=mask_probs_item, + pred_scores=pred_scores_item, + pred_labels=pred_labels_item, + mask_threshold=mask_threshold, + overlap_mask_area_threshold=overlap_mask_area_threshold, + label_ids_to_fuse=label_ids_to_fuse, + target_size=target_size, + ) + + results.append({"segmentation": segmentation, "segments_info": segments}) + return results + + +__all__ = ["DetrImageProcessor"] diff --git a/docs/transformers/build/lib/transformers/models/diffllama/modeling_diffllama.py b/docs/transformers/build/lib/transformers/models/diffllama/modeling_diffllama.py new file mode 100644 index 0000000000000000000000000000000000000000..b2d736cb4f99d912ea8493e35f886be6f4702a21 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/diffllama/modeling_diffllama.py @@ -0,0 +1,1389 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/diffllama/modular_diffllama.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_diffllama.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2024 weak-kajuma and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on Llama implementations in this library and Microsoft's +# Differential Transformer implementations. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import Optional, Tuple, Union + +import torch +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import ( + FlashAttentionKwargs, + _flash_attention_forward, + flash_attn_supports_top_left_mask, +) +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ( + LossKwargs, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + can_return_tuple, + is_torch_flex_attn_available, + logging, + replace_return_docstrings, +) +from .configuration_diffllama import DiffLlamaConfig + + +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "kajuma/DiffLlama-0.3B-handcut" +_CONFIG_FOR_DOC = "DiffLlamaConfig" + + +class DiffLlamaMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def lambda_init_fn(layer_idx): + return 0.8 - 0.6 * math.exp(-0.3 * layer_idx) + + +class DiffLlamaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: DiffLlamaConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + # under this are not used + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) + + self.lambda_init = lambda_init_fn(layer_idx) + self.lambda_q1 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,))) + self.lambda_k1 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,))) + self.lambda_q2 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,))) + self.lambda_k2 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,))) + self.groupnorm = nn.RMSNorm(2 * self.head_dim, eps=config.rms_norm_eps, elementwise_affine=False) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, target_len, _ = hidden_states.size() + q_len = target_len + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + value_states = torch.cat(torch.chunk(value_states, 2, dim=1), dim=-1) + value_states = value_states.repeat(1, 2, 1, 1) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1, dtype=torch.float32)).to( + query_states.dtype + ) + lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1, dtype=torch.float32)).to( + query_states.dtype + ) + lambda_full = lambda_1 - lambda_2 + self.lambda_init + + attn_output = torch.matmul(attn_weights, value_states) + attn_output1, attn_output2 = torch.chunk(attn_output, 2, dim=1) + + attn_output = attn_output1 - lambda_full * attn_output2 + attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights + + +class DiffLlamaFlashAttention2(DiffLlamaAttention): + """ + DiffLlama flash attention module. This module inherits from `DiffLlamaAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (DiffLlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + value_states1, value_states2 = torch.chunk(value_states, 2, dim=2) + value_states1 = value_states1.repeat(1, 1, 2, 1) + value_states2 = value_states2.repeat(1, 1, 2, 1) + + attn_output1 = _flash_attention_forward( + query_states, + key_states, + value_states1, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + sliding_window=getattr(self, "sliding_window", None), + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + ) + + attn_output2 = _flash_attention_forward( + query_states, + key_states, + value_states2, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + sliding_window=getattr(self, "sliding_window", None), + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + ) + + attn_output = torch.cat([attn_output1, attn_output2], dim=-1) + attn_output1, attn_output2 = torch.chunk(attn_output, 2, dim=2) + + lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1, dtype=torch.float32)).to( + query_states.dtype + ) + lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1, dtype=torch.float32)).to( + query_states.dtype + ) + lambda_full = lambda_1 - lambda_2 + self.lambda_init + + attn_output = attn_output1 - lambda_full * attn_output2 + attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output) + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights + + +class DiffLlamaSdpaAttention(DiffLlamaAttention): + """ + DiffLlama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `DiffLlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from DiffLlamaAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "DiffLlamaModel is using DiffLlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + value_states = torch.cat(torch.chunk(value_states, 2, dim=1), dim=-1) + value_states = value_states.repeat(1, 2, 1, 1) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output1, attn_output2 = torch.chunk(attn_output, 2, dim=1) + + lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1, dtype=torch.float32)).to( + query_states.dtype + ) + lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1, dtype=torch.float32)).to( + query_states.dtype + ) + lambda_full = lambda_1 - lambda_2 + self.lambda_init + + attn_output = attn_output1 - lambda_full * attn_output2 + attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) + attn_output = self.o_proj(attn_output) + + return attn_output, None + + +@use_kernel_forward_from_hub("RMSNorm") +class DiffLlamaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + DiffLlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +DIFFLLAMA_ATTENTION_CLASSES = { + "eager": DiffLlamaAttention, + "flash_attention_2": DiffLlamaFlashAttention2, + "sdpa": DiffLlamaSdpaAttention, +} + + +class DiffLlamaDecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: DiffLlamaConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = DIFFLLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + + self.mlp = DiffLlamaMLP(config) + self.input_layernorm = DiffLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = DiffLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +DIFFLLAMA_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`DiffLlamaConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare DiffLlama Model outputting raw hidden-states without any specific head on top.", + DIFFLLAMA_START_DOCSTRING, +) +class DiffLlamaPreTrainedModel(PreTrainedModel): + config_class = DiffLlamaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["DiffLlamaDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = False + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + _supports_attention_backend = False + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, DiffLlamaRMSNorm): # noqa: F821 + module.weight.data.fill_(1.0) + elif isinstance(module, DiffLlamaAttention): + module.lambda_q1.data.normal_(0, self.config.lambda_std_dev) + module.lambda_k1.data.normal_(0, self.config.lambda_std_dev) + module.lambda_q2.data.normal_(0, self.config.lambda_std_dev) + module.lambda_k2.data.normal_(0, self.config.lambda_std_dev) + + +class DiffLlamaRotaryEmbedding(nn.Module): + def __init__(self, config: DiffLlamaConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +DIFFLLAMA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length) or `BlockMask`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + If the model is configured to use flex_attention, it will attempt to convert the mask Tensor into a BlockMask, + but you can also pass a `BlockMask` object directly here. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare DiffLlama Model outputting raw hidden-states without any specific head on top.", + DIFFLLAMA_START_DOCSTRING, +) +class DiffLlamaModel(DiffLlamaPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DiffLlamaDecoderLayer`] + + Args: + config: DiffLlamaConfig + """ + + def __init__(self, config: DiffLlamaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [DiffLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = DiffLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = DiffLlamaRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @can_return_tuple + @add_start_docstrings_to_model_forward(DIFFLLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> BaseModelOutputWithPast: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache + if not isinstance(past_key_values, (type(None), Cache)): + raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, "BlockMask"], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to place the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + +class DiffLlamaForCausalLM(DiffLlamaPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = DiffLlamaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @can_return_tuple + @add_start_docstrings_to_model_forward(DIFFLLAMA_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[KwargsForCausalLM], + ) -> CausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, DiffLlamaForCausalLM + + >>> model = DiffLlamaForCausalLM.from_pretrained("google/diffllama-7b") + >>> tokenizer = AutoTokenizer.from_pretrained("google/diffllama-7b") + + >>> prompt = "What is your favorite condiment?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "What is your favorite condiment?" + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + The DiffLlama Model transformer with a sequence classification head on top (linear layer). + + [`DiffLlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + DIFFLLAMA_START_DOCSTRING, +) +class DiffLlamaForSequenceClassification(DiffLlamaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = DiffLlamaModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @can_return_tuple + @add_start_docstrings_to_model_forward(DIFFLLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> SequenceClassifierOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + transformer_outputs: BaseModelOutputWithPast = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + hidden_states = transformer_outputs.last_hidden_state + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) + last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) + else: + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ +The DiffLlama Model transformer with a span classification head on top for extractive question-answering tasks like +SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + DIFFLLAMA_START_DOCSTRING, +) +class DiffLlamaForQuestionAnswering(DiffLlamaPreTrainedModel): + base_model_prefix = "transformer" + + def __init__(self, config): + super().__init__(config) + self.transformer = DiffLlamaModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.transformer.embed_tokens + + def set_input_embeddings(self, value): + self.transformer.embed_tokens = value + + @can_return_tuple + @add_start_docstrings_to_model_forward(DIFFLLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + **kwargs, + ) -> QuestionAnsweringModelOutput: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + + outputs: BaseModelOutputWithPast = self.transformer( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + sequence_output = outputs.last_hidden_state + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + loss = None + if start_positions is not None and end_positions is not None: + loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs) + + return QuestionAnsweringModelOutput( + loss=loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + The DiffLlama Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + DIFFLLAMA_START_DOCSTRING, +) +class DiffLlamaForTokenClassification(DiffLlamaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = DiffLlamaModel(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @can_return_tuple + @add_start_docstrings_to_model_forward(DIFFLLAMA_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> TokenClassifierOutput: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + outputs: BaseModelOutputWithPast = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + sequence_output = outputs.last_hidden_state + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.config) + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "DiffLlamaPreTrainedModel", + "DiffLlamaModel", + "DiffLlamaForCausalLM", + "DiffLlamaForSequenceClassification", + "DiffLlamaForQuestionAnswering", + "DiffLlamaForTokenClassification", +] diff --git a/docs/transformers/build/lib/transformers/models/diffllama/modular_diffllama.py b/docs/transformers/build/lib/transformers/models/diffllama/modular_diffllama.py new file mode 100644 index 0000000000000000000000000000000000000000..f7bc2d2c5ac1fbf98e7ffe9607d2dd8c4e5ae265 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/diffllama/modular_diffllama.py @@ -0,0 +1,480 @@ +# coding=utf-8 +# Copyright 2024 weak-kajuma and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on Llama implementations in this library and Microsoft's +# Differential Transformer implementations. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import Optional, Tuple + +import torch +import torch.utils.checkpoint +from torch import nn + +from ...cache_utils import Cache, StaticCache +from ...modeling_flash_attention_utils import _flash_attention_forward, flash_attn_supports_top_left_mask +from ...utils import logging +from ..gemma.modeling_gemma import GemmaForCausalLM +from ..llama.modeling_llama import ( + LlamaDecoderLayer, + LlamaForQuestionAnswering, + LlamaForSequenceClassification, + LlamaForTokenClassification, + LlamaModel, + LlamaPreTrainedModel, + apply_rotary_pos_emb, + repeat_kv, +) +from ..mistral.modeling_mistral import MistralMLP +from .configuration_diffllama import DiffLlamaConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "kajuma/DiffLlama-0.3B-handcut" +_CONFIG_FOR_DOC = "DiffLlamaConfig" + + +class DiffLlamaMLP(MistralMLP): + pass + + +def lambda_init_fn(layer_idx): + return 0.8 - 0.6 * math.exp(-0.3 * layer_idx) + + +class DiffLlamaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: DiffLlamaConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + # under this are not used + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) + + self.lambda_init = lambda_init_fn(layer_idx) + self.lambda_q1 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,))) + self.lambda_k1 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,))) + self.lambda_q2 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,))) + self.lambda_k2 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,))) + self.groupnorm = nn.RMSNorm(2 * self.head_dim, eps=config.rms_norm_eps, elementwise_affine=False) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, target_len, _ = hidden_states.size() + q_len = target_len + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + value_states = torch.cat(torch.chunk(value_states, 2, dim=1), dim=-1) + value_states = value_states.repeat(1, 2, 1, 1) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1, dtype=torch.float32)).to( + query_states.dtype + ) + lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1, dtype=torch.float32)).to( + query_states.dtype + ) + lambda_full = lambda_1 - lambda_2 + self.lambda_init + + attn_output = torch.matmul(attn_weights, value_states) + attn_output1, attn_output2 = torch.chunk(attn_output, 2, dim=1) + + attn_output = attn_output1 - lambda_full * attn_output2 + attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights + + +class DiffLlamaFlashAttention2(DiffLlamaAttention): + """ + DiffLlama flash attention module. This module inherits from `DiffLlamaAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (DiffLlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + value_states1, value_states2 = torch.chunk(value_states, 2, dim=2) + value_states1 = value_states1.repeat(1, 1, 2, 1) + value_states2 = value_states2.repeat(1, 1, 2, 1) + + attn_output1 = _flash_attention_forward( + query_states, + key_states, + value_states1, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + sliding_window=getattr(self, "sliding_window", None), + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + ) + + attn_output2 = _flash_attention_forward( + query_states, + key_states, + value_states2, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + sliding_window=getattr(self, "sliding_window", None), + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + ) + + attn_output = torch.cat([attn_output1, attn_output2], dim=-1) + attn_output1, attn_output2 = torch.chunk(attn_output, 2, dim=2) + + lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1, dtype=torch.float32)).to( + query_states.dtype + ) + lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1, dtype=torch.float32)).to( + query_states.dtype + ) + lambda_full = lambda_1 - lambda_2 + self.lambda_init + + attn_output = attn_output1 - lambda_full * attn_output2 + attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output) + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights + + +class DiffLlamaSdpaAttention(DiffLlamaAttention): + """ + DiffLlama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `DiffLlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from DiffLlamaAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "DiffLlamaModel is using DiffLlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + value_states = torch.cat(torch.chunk(value_states, 2, dim=1), dim=-1) + value_states = value_states.repeat(1, 2, 1, 1) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output1, attn_output2 = torch.chunk(attn_output, 2, dim=1) + + lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1, dtype=torch.float32)).to( + query_states.dtype + ) + lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1, dtype=torch.float32)).to( + query_states.dtype + ) + lambda_full = lambda_1 - lambda_2 + self.lambda_init + + attn_output = attn_output1 - lambda_full * attn_output2 + attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) + attn_output = self.o_proj(attn_output) + + return attn_output, None + + +DIFFLLAMA_ATTENTION_CLASSES = { + "eager": DiffLlamaAttention, + "flash_attention_2": DiffLlamaFlashAttention2, + "sdpa": DiffLlamaSdpaAttention, +} + + +class DiffLlamaDecoderLayer(LlamaDecoderLayer): + def __init__(self, config: DiffLlamaConfig, layer_idx: int): + super().__init__(config, layer_idx) + + self.self_attn = DIFFLLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + + +class DiffLlamaPreTrainedModel(LlamaPreTrainedModel): + _supports_flex_attn = False + _supports_attention_backend = False + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, DiffLlamaRMSNorm): # noqa: F821 + module.weight.data.fill_(1.0) + elif isinstance(module, DiffLlamaAttention): + module.lambda_q1.data.normal_(0, self.config.lambda_std_dev) + module.lambda_k1.data.normal_(0, self.config.lambda_std_dev) + module.lambda_q2.data.normal_(0, self.config.lambda_std_dev) + module.lambda_k2.data.normal_(0, self.config.lambda_std_dev) + + +class DiffLlamaModel(LlamaModel): + pass + + +class DiffLlamaForCausalLM(GemmaForCausalLM): + pass + + +class DiffLlamaForSequenceClassification(LlamaForSequenceClassification): + pass + + +class DiffLlamaForQuestionAnswering(LlamaForQuestionAnswering): + pass + + +class DiffLlamaForTokenClassification(LlamaForTokenClassification): + pass + + +__all__ = [ + "DiffLlamaPreTrainedModel", + "DiffLlamaModel", # noqa: F822 + "DiffLlamaForCausalLM", + "DiffLlamaForSequenceClassification", + "DiffLlamaForQuestionAnswering", + "DiffLlamaForTokenClassification", +] diff --git a/docs/transformers/build/lib/transformers/models/dinat/__init__.py b/docs/transformers/build/lib/transformers/models/dinat/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b64cdbb3c7eb0467f6112225b8c0d9e1f65f9e99 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/dinat/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_dinat import * + from .modeling_dinat import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/build/lib/transformers/models/dinat/configuration_dinat.py b/docs/transformers/build/lib/transformers/models/dinat/configuration_dinat.py new file mode 100644 index 0000000000000000000000000000000000000000..7b432e37c851395e98ff9c0dff859294b3e016f4 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/dinat/configuration_dinat.py @@ -0,0 +1,152 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Dilated Neighborhood Attention Transformer model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices + + +logger = logging.get_logger(__name__) + + +class DinatConfig(BackboneConfigMixin, PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`DinatModel`]. It is used to instantiate a Dinat + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the Dinat + [shi-labs/dinat-mini-in1k-224](https://huggingface.co/shi-labs/dinat-mini-in1k-224) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + patch_size (`int`, *optional*, defaults to 4): + The size (resolution) of each patch. NOTE: Only patch size of 4 is supported at the moment. + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + embed_dim (`int`, *optional*, defaults to 64): + Dimensionality of patch embedding. + depths (`List[int]`, *optional*, defaults to `[3, 4, 6, 5]`): + Number of layers in each level of the encoder. + num_heads (`List[int]`, *optional*, defaults to `[2, 4, 8, 16]`): + Number of attention heads in each layer of the Transformer encoder. + kernel_size (`int`, *optional*, defaults to 7): + Neighborhood Attention kernel size. + dilations (`List[List[int]]`, *optional*, defaults to `[[1, 8, 1], [1, 4, 1, 4], [1, 2, 1, 2, 1, 2], [1, 1, 1, 1, 1]]`): + Dilation value of each NA layer in the Transformer encoder. + mlp_ratio (`float`, *optional*, defaults to 3.0): + Ratio of MLP hidden dimensionality to embedding dimensionality. + qkv_bias (`bool`, *optional*, defaults to `True`): + Whether or not a learnable bias should be added to the queries, keys and values. + hidden_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings and encoder. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + drop_path_rate (`float`, *optional*, defaults to 0.1): + Stochastic depth rate. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder. If string, `"gelu"`, `"relu"`, + `"selu"` and `"gelu_new"` are supported. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the layer normalization layers. + layer_scale_init_value (`float`, *optional*, defaults to 0.0): + The initial value for the layer scale. Disabled if <=0. + out_features (`List[str]`, *optional*): + If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc. + (depending on how many stages the model has). If unset and `out_indices` is set, will default to the + corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the + same order as defined in the `stage_names` attribute. + out_indices (`List[int]`, *optional*): + If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how + many stages the model has). If unset and `out_features` is set, will default to the corresponding stages. + If unset and `out_features` is unset, will default to the last stage. Must be in the + same order as defined in the `stage_names` attribute. + + Example: + + ```python + >>> from transformers import DinatConfig, DinatModel + + >>> # Initializing a Dinat shi-labs/dinat-mini-in1k-224 style configuration + >>> configuration = DinatConfig() + + >>> # Initializing a model (with random weights) from the shi-labs/dinat-mini-in1k-224 style configuration + >>> model = DinatModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "dinat" + + attribute_map = { + "num_attention_heads": "num_heads", + "num_hidden_layers": "num_layers", + } + + def __init__( + self, + patch_size=4, + num_channels=3, + embed_dim=64, + depths=[3, 4, 6, 5], + num_heads=[2, 4, 8, 16], + kernel_size=7, + dilations=[[1, 8, 1], [1, 4, 1, 4], [1, 2, 1, 2, 1, 2], [1, 1, 1, 1, 1]], + mlp_ratio=3.0, + qkv_bias=True, + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + drop_path_rate=0.1, + hidden_act="gelu", + initializer_range=0.02, + layer_norm_eps=1e-5, + layer_scale_init_value=0.0, + out_features=None, + out_indices=None, + **kwargs, + ): + super().__init__(**kwargs) + + self.patch_size = patch_size + self.num_channels = num_channels + self.embed_dim = embed_dim + self.depths = depths + self.num_layers = len(depths) + self.num_heads = num_heads + self.kernel_size = kernel_size + self.dilations = dilations + self.mlp_ratio = mlp_ratio + self.qkv_bias = qkv_bias + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.drop_path_rate = drop_path_rate + self.hidden_act = hidden_act + self.layer_norm_eps = layer_norm_eps + self.initializer_range = initializer_range + # we set the hidden_size attribute in order to make Dinat work with VisionEncoderDecoderModel + # this indicates the channel dimension after the last stage of the model + self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1)) + self.layer_scale_init_value = layer_scale_init_value + self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)] + self._out_features, self._out_indices = get_aligned_output_features_output_indices( + out_features=out_features, out_indices=out_indices, stage_names=self.stage_names + ) + + +__all__ = ["DinatConfig"] diff --git a/docs/transformers/build/lib/transformers/models/dinat/modeling_dinat.py b/docs/transformers/build/lib/transformers/models/dinat/modeling_dinat.py new file mode 100644 index 0000000000000000000000000000000000000000..8837372a84c463af5a92ad23627c9336ec8be24c --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/dinat/modeling_dinat.py @@ -0,0 +1,960 @@ +# coding=utf-8 +# Copyright 2022 SHI Labs and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Dilated Neighborhood Attention Transformer model.""" + +import math +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...modeling_outputs import BackboneOutput +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import ( + ModelOutput, + OptionalDependencyNotAvailable, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_natten_available, + logging, + replace_return_docstrings, + requires_backends, +) +from ...utils.backbone_utils import BackboneMixin +from .configuration_dinat import DinatConfig + + +if is_natten_available(): + from natten.functional import natten2dav, natten2dqkrpb +else: + + def natten2dqkrpb(*args, **kwargs): + raise OptionalDependencyNotAvailable() + + def natten2dav(*args, **kwargs): + raise OptionalDependencyNotAvailable() + + +logger = logging.get_logger(__name__) + +# General docstring +_CONFIG_FOR_DOC = "DinatConfig" + +# Base docstring +_CHECKPOINT_FOR_DOC = "shi-labs/dinat-mini-in1k-224" +_EXPECTED_OUTPUT_SHAPE = [1, 7, 7, 512] + +# Image classification docstring +_IMAGE_CLASS_CHECKPOINT = "shi-labs/dinat-mini-in1k-224" +_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat" + + +# drop_path and DinatDropPath are from the timm library. + + +@dataclass +class DinatEncoderOutput(ModelOutput): + """ + Dinat encoder's outputs, with potential hidden states and attentions. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, hidden_size, height, width)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to + include the spatial dimensions. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class DinatModelOutput(ModelOutput): + """ + Dinat model's outputs that also contains a pooling of the last hidden states. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*, returned when `add_pooling_layer=True` is passed): + Average pooling of the last layer hidden-state. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, hidden_size, height, width)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to + include the spatial dimensions. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + pooler_output: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class DinatImageClassifierOutput(ModelOutput): + """ + Dinat outputs for image classification. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, hidden_size, height, width)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to + include the spatial dimensions. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + + +class DinatEmbeddings(nn.Module): + """ + Construct the patch and position embeddings. + """ + + def __init__(self, config): + super().__init__() + + self.patch_embeddings = DinatPatchEmbeddings(config) + + self.norm = nn.LayerNorm(config.embed_dim) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor]: + embeddings = self.patch_embeddings(pixel_values) + embeddings = self.norm(embeddings) + + embeddings = self.dropout(embeddings) + + return embeddings + + +class DinatPatchEmbeddings(nn.Module): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, height, width, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__(self, config): + super().__init__() + patch_size = config.patch_size + num_channels, hidden_size = config.num_channels, config.embed_dim + self.num_channels = num_channels + + if patch_size == 4: + pass + else: + # TODO: Support arbitrary patch sizes. + raise ValueError("Dinat only supports patch size of 4 at the moment.") + + self.projection = nn.Sequential( + nn.Conv2d(self.num_channels, hidden_size // 2, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)), + nn.Conv2d(hidden_size // 2, hidden_size, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)), + ) + + def forward(self, pixel_values: Optional[torch.FloatTensor]) -> torch.Tensor: + _, num_channels, height, width = pixel_values.shape + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + embeddings = self.projection(pixel_values) + embeddings = embeddings.permute(0, 2, 3, 1) + + return embeddings + + +class DinatDownsampler(nn.Module): + """ + Convolutional Downsampling Layer. + + Args: + dim (`int`): + Number of input channels. + norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`): + Normalization layer class. + """ + + def __init__(self, dim: int, norm_layer: nn.Module = nn.LayerNorm) -> None: + super().__init__() + self.dim = dim + self.reduction = nn.Conv2d(dim, 2 * dim, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) + self.norm = norm_layer(2 * dim) + + def forward(self, input_feature: torch.Tensor) -> torch.Tensor: + input_feature = self.reduction(input_feature.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) + input_feature = self.norm(input_feature) + return input_feature + + +# Copied from transformers.models.beit.modeling_beit.drop_path +def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, + however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the + layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the + argument. + """ + if drop_prob == 0.0 or not training: + return input + keep_prob = 1 - drop_prob + shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) + random_tensor.floor_() # binarize + output = input.div(keep_prob) * random_tensor + return output + + +# Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->Dinat +class DinatDropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: Optional[float] = None) -> None: + super().__init__() + self.drop_prob = drop_prob + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return drop_path(hidden_states, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return "p={}".format(self.drop_prob) + + +class NeighborhoodAttention(nn.Module): + def __init__(self, config, dim, num_heads, kernel_size, dilation): + super().__init__() + if dim % num_heads != 0: + raise ValueError( + f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})" + ) + + self.num_attention_heads = num_heads + self.attention_head_size = int(dim / num_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.kernel_size = kernel_size + self.dilation = dilation + + # rpb is learnable relative positional biases; same concept is used Swin. + self.rpb = nn.Parameter(torch.zeros(num_heads, (2 * self.kernel_size - 1), (2 * self.kernel_size - 1))) + + self.query = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias) + self.key = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias) + self.value = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 3, 1, 2, 4) + + def forward( + self, + hidden_states: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + query_layer = self.transpose_for_scores(self.query(hidden_states)) + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + # Apply the scale factor before computing attention weights. It's usually more efficient because + # attention weights are typically a bigger tensor compared to query. + # It gives identical results because scalars are commutable in matrix multiplication. + query_layer = query_layer / math.sqrt(self.attention_head_size) + + # Compute NA between "query" and "key" to get the raw attention scores, and add relative positional biases. + attention_scores = natten2dqkrpb(query_layer, key_layer, self.rpb, self.kernel_size, self.dilation) + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + context_layer = natten2dav(attention_probs, value_layer, self.kernel_size, self.dilation) + context_layer = context_layer.permute(0, 2, 3, 1, 4).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +class NeighborhoodAttentionOutput(nn.Module): + def __init__(self, config, dim): + super().__init__() + self.dense = nn.Linear(dim, dim) + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + return hidden_states + + +class NeighborhoodAttentionModule(nn.Module): + def __init__(self, config, dim, num_heads, kernel_size, dilation): + super().__init__() + self.self = NeighborhoodAttention(config, dim, num_heads, kernel_size, dilation) + self.output = NeighborhoodAttentionOutput(config, dim) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + self_outputs = self.self(hidden_states, output_attentions) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class DinatIntermediate(nn.Module): + def __init__(self, config, dim): + super().__init__() + self.dense = nn.Linear(dim, int(config.mlp_ratio * dim)) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class DinatOutput(nn.Module): + def __init__(self, config, dim): + super().__init__() + self.dense = nn.Linear(int(config.mlp_ratio * dim), dim) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +class DinatLayer(nn.Module): + def __init__(self, config, dim, num_heads, dilation, drop_path_rate=0.0): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.kernel_size = config.kernel_size + self.dilation = dilation + self.window_size = self.kernel_size * self.dilation + self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps) + self.attention = NeighborhoodAttentionModule( + config, dim, num_heads, kernel_size=self.kernel_size, dilation=self.dilation + ) + self.drop_path = DinatDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() + self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps) + self.intermediate = DinatIntermediate(config, dim) + self.output = DinatOutput(config, dim) + self.layer_scale_parameters = ( + nn.Parameter(config.layer_scale_init_value * torch.ones((2, dim)), requires_grad=True) + if config.layer_scale_init_value > 0 + else None + ) + + def maybe_pad(self, hidden_states, height, width): + window_size = self.window_size + pad_values = (0, 0, 0, 0, 0, 0) + if height < window_size or width < window_size: + pad_l = pad_t = 0 + pad_r = max(0, window_size - width) + pad_b = max(0, window_size - height) + pad_values = (0, 0, pad_l, pad_r, pad_t, pad_b) + hidden_states = nn.functional.pad(hidden_states, pad_values) + return hidden_states, pad_values + + def forward( + self, + hidden_states: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size, height, width, channels = hidden_states.size() + shortcut = hidden_states + + hidden_states = self.layernorm_before(hidden_states) + # pad hidden_states if they are smaller than kernel size x dilation + hidden_states, pad_values = self.maybe_pad(hidden_states, height, width) + + _, height_pad, width_pad, _ = hidden_states.shape + + attention_outputs = self.attention(hidden_states, output_attentions=output_attentions) + + attention_output = attention_outputs[0] + + was_padded = pad_values[3] > 0 or pad_values[5] > 0 + if was_padded: + attention_output = attention_output[:, :height, :width, :].contiguous() + + if self.layer_scale_parameters is not None: + attention_output = self.layer_scale_parameters[0] * attention_output + + hidden_states = shortcut + self.drop_path(attention_output) + + layer_output = self.layernorm_after(hidden_states) + layer_output = self.output(self.intermediate(layer_output)) + + if self.layer_scale_parameters is not None: + layer_output = self.layer_scale_parameters[1] * layer_output + + layer_output = hidden_states + self.drop_path(layer_output) + + layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,) + return layer_outputs + + +class DinatStage(nn.Module): + def __init__(self, config, dim, depth, num_heads, dilations, drop_path_rate, downsample): + super().__init__() + self.config = config + self.dim = dim + self.layers = nn.ModuleList( + [ + DinatLayer( + config=config, + dim=dim, + num_heads=num_heads, + dilation=dilations[i], + drop_path_rate=drop_path_rate[i], + ) + for i in range(depth) + ] + ) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(dim=dim, norm_layer=nn.LayerNorm) + else: + self.downsample = None + + self.pointing = False + + def forward( + self, + hidden_states: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + _, height, width, _ = hidden_states.size() + for i, layer_module in enumerate(self.layers): + layer_outputs = layer_module(hidden_states, output_attentions) + hidden_states = layer_outputs[0] + + hidden_states_before_downsampling = hidden_states + if self.downsample is not None: + hidden_states = self.downsample(hidden_states_before_downsampling) + + stage_outputs = (hidden_states, hidden_states_before_downsampling) + + if output_attentions: + stage_outputs += layer_outputs[1:] + return stage_outputs + + +class DinatEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.num_levels = len(config.depths) + self.config = config + dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")] + self.levels = nn.ModuleList( + [ + DinatStage( + config=config, + dim=int(config.embed_dim * 2**i_layer), + depth=config.depths[i_layer], + num_heads=config.num_heads[i_layer], + dilations=config.dilations[i_layer], + drop_path_rate=dpr[sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])], + downsample=DinatDownsampler if (i_layer < self.num_levels - 1) else None, + ) + for i_layer in range(self.num_levels) + ] + ) + + def forward( + self, + hidden_states: torch.Tensor, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + output_hidden_states_before_downsampling: Optional[bool] = False, + return_dict: Optional[bool] = True, + ) -> Union[Tuple, DinatEncoderOutput]: + all_hidden_states = () if output_hidden_states else None + all_reshaped_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + if output_hidden_states: + # rearrange b h w c -> b c h w + reshaped_hidden_state = hidden_states.permute(0, 3, 1, 2) + all_hidden_states += (hidden_states,) + all_reshaped_hidden_states += (reshaped_hidden_state,) + + for i, layer_module in enumerate(self.levels): + layer_outputs = layer_module(hidden_states, output_attentions) + + hidden_states = layer_outputs[0] + hidden_states_before_downsampling = layer_outputs[1] + + if output_hidden_states and output_hidden_states_before_downsampling: + # rearrange b h w c -> b c h w + reshaped_hidden_state = hidden_states_before_downsampling.permute(0, 3, 1, 2) + all_hidden_states += (hidden_states_before_downsampling,) + all_reshaped_hidden_states += (reshaped_hidden_state,) + elif output_hidden_states and not output_hidden_states_before_downsampling: + # rearrange b h w c -> b c h w + reshaped_hidden_state = hidden_states.permute(0, 3, 1, 2) + all_hidden_states += (hidden_states,) + all_reshaped_hidden_states += (reshaped_hidden_state,) + + if output_attentions: + all_self_attentions += layer_outputs[2:] + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + + return DinatEncoderOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + reshaped_hidden_states=all_reshaped_hidden_states, + ) + + +class DinatPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = DinatConfig + base_model_prefix = "dinat" + main_input_name = "pixel_values" + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +DINAT_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use + it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`DinatConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +DINAT_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`] + for details. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare Dinat Model transformer outputting raw hidden-states without any specific head on top.", + DINAT_START_DOCSTRING, +) +class DinatModel(DinatPreTrainedModel): + def __init__(self, config, add_pooling_layer=True): + super().__init__(config) + + requires_backends(self, ["natten"]) + + self.config = config + self.num_levels = len(config.depths) + self.num_features = int(config.embed_dim * 2 ** (self.num_levels - 1)) + + self.embeddings = DinatEmbeddings(config) + self.encoder = DinatEncoder(config) + + self.layernorm = nn.LayerNorm(self.num_features, eps=config.layer_norm_eps) + self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.patch_embeddings + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(DINAT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=DinatModelOutput, + config_class=_CONFIG_FOR_DOC, + modality="vision", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, DinatModelOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + embedding_output = self.embeddings(pixel_values) + + encoder_outputs = self.encoder( + embedding_output, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + + pooled_output = None + if self.pooler is not None: + pooled_output = self.pooler(sequence_output.flatten(1, 2).transpose(1, 2)) + pooled_output = torch.flatten(pooled_output, 1) + + if not return_dict: + output = (sequence_output, pooled_output) + encoder_outputs[1:] + + return output + + return DinatModelOutput( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + reshaped_hidden_states=encoder_outputs.reshaped_hidden_states, + ) + + +@add_start_docstrings( + """ + Dinat Model transformer with an image classification head on top (a linear layer on top of the final hidden state + of the [CLS] token) e.g. for ImageNet. + """, + DINAT_START_DOCSTRING, +) +class DinatForImageClassification(DinatPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + requires_backends(self, ["natten"]) + + self.num_labels = config.num_labels + self.dinat = DinatModel(config) + + # Classifier head + self.classifier = ( + nn.Linear(self.dinat.num_features, config.num_labels) if config.num_labels > 0 else nn.Identity() + ) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(DINAT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_IMAGE_CLASS_CHECKPOINT, + output_type=DinatImageClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, + ) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, DinatImageClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.dinat( + pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return DinatImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + reshaped_hidden_states=outputs.reshaped_hidden_states, + ) + + +@add_start_docstrings( + "NAT backbone, to be used with frameworks like DETR and MaskFormer.", + DINAT_START_DOCSTRING, +) +class DinatBackbone(DinatPreTrainedModel, BackboneMixin): + def __init__(self, config): + super().__init__(config) + super()._init_backbone(config) + + requires_backends(self, ["natten"]) + + self.embeddings = DinatEmbeddings(config) + self.encoder = DinatEncoder(config) + self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))] + + # Add layer norms to hidden states of out_features + hidden_states_norms = {} + for stage, num_channels in zip(self._out_features, self.channels): + hidden_states_norms[stage] = nn.LayerNorm(num_channels) + self.hidden_states_norms = nn.ModuleDict(hidden_states_norms) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.patch_embeddings + + @add_start_docstrings_to_model_forward(DINAT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: torch.Tensor, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> BackboneOutput: + """ + Returns: + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, AutoBackbone + >>> import torch + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> processor = AutoImageProcessor.from_pretrained("shi-labs/nat-mini-in1k-224") + >>> model = AutoBackbone.from_pretrained( + ... "shi-labs/nat-mini-in1k-224", out_features=["stage1", "stage2", "stage3", "stage4"] + ... ) + + >>> inputs = processor(image, return_tensors="pt") + + >>> outputs = model(**inputs) + + >>> feature_maps = outputs.feature_maps + >>> list(feature_maps[-1].shape) + [1, 512, 7, 7] + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + embedding_output = self.embeddings(pixel_values) + + outputs = self.encoder( + embedding_output, + output_attentions=output_attentions, + output_hidden_states=True, + output_hidden_states_before_downsampling=True, + return_dict=True, + ) + + hidden_states = outputs.reshaped_hidden_states + + feature_maps = () + for stage, hidden_state in zip(self.stage_names, hidden_states): + if stage in self.out_features: + batch_size, num_channels, height, width = hidden_state.shape + hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous() + hidden_state = hidden_state.view(batch_size, height * width, num_channels) + hidden_state = self.hidden_states_norms[stage](hidden_state) + hidden_state = hidden_state.view(batch_size, height, width, num_channels) + hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous() + feature_maps += (hidden_state,) + + if not return_dict: + output = (feature_maps,) + if output_hidden_states: + output += (outputs.hidden_states,) + return output + + return BackboneOutput( + feature_maps=feature_maps, + hidden_states=outputs.hidden_states if output_hidden_states else None, + attentions=outputs.attentions, + ) + + +__all__ = ["DinatForImageClassification", "DinatModel", "DinatPreTrainedModel", "DinatBackbone"] diff --git a/docs/transformers/build/lib/transformers/models/dinov2/__init__.py b/docs/transformers/build/lib/transformers/models/dinov2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3cc316957eac509573bf44785209d0729ea13bb6 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/dinov2/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_dinov2 import * + from .modeling_dinov2 import * + from .modeling_flax_dinov2 import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/build/lib/transformers/models/dinov2/configuration_dinov2.py b/docs/transformers/build/lib/transformers/models/dinov2/configuration_dinov2.py new file mode 100644 index 0000000000000000000000000000000000000000..f4b29273a509d3e25d7a11eb77de8808e0893f96 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/dinov2/configuration_dinov2.py @@ -0,0 +1,179 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""DINOv2 model configuration""" + +from collections import OrderedDict +from typing import Mapping + +from packaging import version + +from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig +from ...utils import logging +from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices + + +logger = logging.get_logger(__name__) + + +class Dinov2Config(BackboneConfigMixin, PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Dinov2Model`]. It is used to instantiate an + Dinov2 model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the Dinov2 + [google/dinov2-base-patch16-224](https://huggingface.co/google/dinov2-base-patch16-224) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + mlp_ratio (`int`, *optional*, defaults to 4): + Ratio of the hidden size of the MLPs relative to the `hidden_size`. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 14): + The size (resolution) of each patch. + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + qkv_bias (`bool`, *optional*, defaults to `True`): + Whether to add a bias to the queries, keys and values. + layerscale_value (`float`, *optional*, defaults to 1.0): + Initial value to use for layer scale. + drop_path_rate (`float`, *optional*, defaults to 0.0): + Stochastic depth rate per sample (when applied in the main path of residual layers). + use_swiglu_ffn (`bool`, *optional*, defaults to `False`): + Whether to use the SwiGLU feedforward neural network. + out_features (`List[str]`, *optional*): + If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc. + (depending on how many stages the model has). If unset and `out_indices` is set, will default to the + corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the + same order as defined in the `stage_names` attribute. + out_indices (`List[int]`, *optional*): + If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how + many stages the model has). If unset and `out_features` is set, will default to the corresponding stages. + If unset and `out_features` is unset, will default to the last stage. Must be in the + same order as defined in the `stage_names` attribute. + apply_layernorm (`bool`, *optional*, defaults to `True`): + Whether to apply layer normalization to the feature maps in case the model is used as backbone. + reshape_hidden_states (`bool`, *optional*, defaults to `True`): + Whether to reshape the feature maps to 4D tensors of shape `(batch_size, hidden_size, height, width)` in + case the model is used as backbone. If `False`, the feature maps will be 3D tensors of shape `(batch_size, + seq_len, hidden_size)`. + use_mask_token (`bool`, *optional*, defaults to `True`): + Whether to use mask_token in embeddings. + + Example: + + ```python + >>> from transformers import Dinov2Config, Dinov2Model + + >>> # Initializing a Dinov2 dinov2-base-patch16-224 style configuration + >>> configuration = Dinov2Config() + + >>> # Initializing a model (with random weights) from the dinov2-base-patch16-224 style configuration + >>> model = Dinov2Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "dinov2" + + def __init__( + self, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + mlp_ratio=4, + hidden_act="gelu", + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + initializer_range=0.02, + layer_norm_eps=1e-6, + image_size=224, + patch_size=14, + num_channels=3, + qkv_bias=True, + layerscale_value=1.0, + drop_path_rate=0.0, + use_swiglu_ffn=False, + out_features=None, + out_indices=None, + apply_layernorm=True, + reshape_hidden_states=True, + use_mask_token=True, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.mlp_ratio = mlp_ratio + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.qkv_bias = qkv_bias + self.layerscale_value = layerscale_value + self.drop_path_rate = drop_path_rate + self.use_swiglu_ffn = use_swiglu_ffn + self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, num_hidden_layers + 1)] + self._out_features, self._out_indices = get_aligned_output_features_output_indices( + out_features=out_features, out_indices=out_indices, stage_names=self.stage_names + ) + self.apply_layernorm = apply_layernorm + self.reshape_hidden_states = reshape_hidden_states + self.use_mask_token = use_mask_token + + +class Dinov2OnnxConfig(OnnxConfig): + torch_onnx_minimum_version = version.parse("1.11") + + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict( + [ + ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}), + ] + ) + + @property + def atol_for_validation(self) -> float: + return 1e-4 + + +__all__ = ["Dinov2Config", "Dinov2OnnxConfig"] diff --git a/docs/transformers/build/lib/transformers/models/dinov2/convert_dinov2_to_hf.py b/docs/transformers/build/lib/transformers/models/dinov2/convert_dinov2_to_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..d716191b2fcbd4775bd2349ef98a7ad0d781a90c --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/dinov2/convert_dinov2_to_hf.py @@ -0,0 +1,285 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert DINOv2 checkpoints from the original repository. + +URL: https://github.com/facebookresearch/dinov2/tree/main +""" + +import argparse +import json +from pathlib import Path + +import requests +import torch +import torch.nn as nn +from huggingface_hub import hf_hub_download +from PIL import Image +from torchvision import transforms + +from transformers import BitImageProcessor, Dinov2Config, Dinov2ForImageClassification, Dinov2Model +from transformers.image_utils import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, PILImageResampling +from transformers.utils import logging + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + +def get_dinov2_config(model_name, image_classifier=False): + config = Dinov2Config(image_size=518, patch_size=14) + + # size of the architecture + if "vits" in model_name: + config.hidden_size = 384 + config.num_attention_heads = 6 + elif "vitb" in model_name: + pass + elif "vitl" in model_name: + config.hidden_size = 1024 + config.num_hidden_layers = 24 + config.num_attention_heads = 16 + elif "vitg" in model_name: + config.use_swiglu_ffn = True + config.hidden_size = 1536 + config.num_hidden_layers = 40 + config.num_attention_heads = 24 + else: + raise ValueError("Model not supported") + + if image_classifier: + repo_id = "huggingface/label-files" + filename = "imagenet-1k-id2label.json" + config.num_labels = 1000 + config.id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r")) + config.id2label = {int(k): v for k, v in config.id2label.items()} + + return config + + +def create_rename_keys(config): + rename_keys = [] + # fmt: off + + # patch embedding layer + rename_keys.append(("cls_token", "embeddings.cls_token")) + rename_keys.append(("mask_token", "embeddings.mask_token")) + rename_keys.append(("pos_embed", "embeddings.position_embeddings")) + rename_keys.append(("patch_embed.proj.weight", "embeddings.patch_embeddings.projection.weight")) + rename_keys.append(("patch_embed.proj.bias", "embeddings.patch_embeddings.projection.bias")) + + for i in range(config.num_hidden_layers): + # layernorms + rename_keys.append((f"blocks.{i}.norm1.weight", f"encoder.layer.{i}.norm1.weight")) + rename_keys.append((f"blocks.{i}.norm1.bias", f"encoder.layer.{i}.norm1.bias")) + rename_keys.append((f"blocks.{i}.norm2.weight", f"encoder.layer.{i}.norm2.weight")) + rename_keys.append((f"blocks.{i}.norm2.bias", f"encoder.layer.{i}.norm2.bias")) + # MLP + if config.use_swiglu_ffn: + rename_keys.append((f"blocks.{i}.mlp.w12.weight", f"encoder.layer.{i}.mlp.w12.weight")) + rename_keys.append((f"blocks.{i}.mlp.w12.bias", f"encoder.layer.{i}.mlp.w12.bias")) + rename_keys.append((f"blocks.{i}.mlp.w3.weight", f"encoder.layer.{i}.mlp.w3.weight")) + rename_keys.append((f"blocks.{i}.mlp.w3.bias", f"encoder.layer.{i}.mlp.w3.bias")) + else: + rename_keys.append((f"blocks.{i}.mlp.fc1.weight", f"encoder.layer.{i}.mlp.fc1.weight")) + rename_keys.append((f"blocks.{i}.mlp.fc1.bias", f"encoder.layer.{i}.mlp.fc1.bias")) + rename_keys.append((f"blocks.{i}.mlp.fc2.weight", f"encoder.layer.{i}.mlp.fc2.weight")) + rename_keys.append((f"blocks.{i}.mlp.fc2.bias", f"encoder.layer.{i}.mlp.fc2.bias")) + # layerscale + rename_keys.append((f"blocks.{i}.ls1.gamma", f"encoder.layer.{i}.layer_scale1.lambda1")) + rename_keys.append((f"blocks.{i}.ls2.gamma", f"encoder.layer.{i}.layer_scale2.lambda1")) + # attention projection layer + rename_keys.append((f"blocks.{i}.attn.proj.weight", f"encoder.layer.{i}.attention.output.dense.weight")) + rename_keys.append((f"blocks.{i}.attn.proj.bias", f"encoder.layer.{i}.attention.output.dense.bias")) + + # final layernorm + rename_keys.append(("norm.weight", "layernorm.weight")) + rename_keys.append(("norm.bias", "layernorm.bias")) + + # fmt: on + return rename_keys + + +def rename_key(dct, old, new): + val = dct.pop(old) + dct[new] = val + + +# we split up the matrix of each encoder layer into queries, keys and values +def read_in_q_k_v(state_dict, config): + for i in range(config.num_hidden_layers): + # read in weights + bias of input projection layer (in timm, this is a single matrix + bias) + in_proj_weight = state_dict.pop(f"blocks.{i}.attn.qkv.weight") + in_proj_bias = state_dict.pop(f"blocks.{i}.attn.qkv.bias") + # next, add query, keys and values (in that order) to the state dict + state_dict[f"encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[: config.hidden_size, :] + state_dict[f"encoder.layer.{i}.attention.attention.query.bias"] = in_proj_bias[: config.hidden_size] + state_dict[f"encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[ + config.hidden_size : config.hidden_size * 2, : + ] + state_dict[f"encoder.layer.{i}.attention.attention.key.bias"] = in_proj_bias[ + config.hidden_size : config.hidden_size * 2 + ] + state_dict[f"encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[-config.hidden_size :, :] + state_dict[f"encoder.layer.{i}.attention.attention.value.bias"] = in_proj_bias[-config.hidden_size :] + + +# We will verify our results on an image of cute cats +def prepare_img(): + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + image = Image.open(requests.get(url, stream=True).raw).convert("RGB") + return image + + +@torch.no_grad() +def convert_dinov2_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=False): + """ + Copy/paste/tweak model's weights to our DINOv2 structure. + """ + + # define default Dinov2 configuration + image_classifier = "1layer" in model_name + config = get_dinov2_config(model_name, image_classifier=image_classifier) + + # load original model from torch hub + original_model = torch.hub.load("facebookresearch/dinov2", model_name.replace("_1layer", "")) + original_model.eval() + + # load state_dict of original model, remove and rename some keys + state_dict = original_model.state_dict() + rename_keys = create_rename_keys(config) + for src, dest in rename_keys: + rename_key(state_dict, src, dest) + read_in_q_k_v(state_dict, config) + + for key, val in state_dict.copy().items(): + val = state_dict.pop(key) + if "w12" in key: + key = key.replace("w12", "weights_in") + if "w3" in key: + key = key.replace("w3", "weights_out") + state_dict[key] = val + + # load HuggingFace model + if image_classifier: + model = Dinov2ForImageClassification(config).eval() + model.dinov2.load_state_dict(state_dict) + model_name_to_classifier_dict_url = { + "dinov2_vits14_1layer": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_linear_head.pth", + "dinov2_vitb14_1layer": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_linear_head.pth", + "dinov2_vitl14_1layer": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_linear_head.pth", + "dinov2_vitg14_1layer": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_linear_head.pth", + } + url = model_name_to_classifier_dict_url[model_name] + classifier_state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu") + model.classifier.weight = nn.Parameter(classifier_state_dict["weight"]) + model.classifier.bias = nn.Parameter(classifier_state_dict["bias"]) + else: + model = Dinov2Model(config).eval() + model.load_state_dict(state_dict) + + # load image + image = prepare_img() + + # preprocess image + transformations = transforms.Compose( + [ + transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize( + mean=IMAGENET_DEFAULT_MEAN, # these are RGB mean+std values + std=IMAGENET_DEFAULT_STD, # across a large photo dataset. + ), + ] + ) + + original_pixel_values = transformations(image).unsqueeze(0) # insert batch dimension + + processor = BitImageProcessor( + size={"shortest_edge": 256}, + resample=PILImageResampling.BICUBIC, + image_mean=IMAGENET_DEFAULT_MEAN, + image_std=IMAGENET_DEFAULT_STD, + ) + pixel_values = processor(image, return_tensors="pt").pixel_values + + assert torch.allclose(original_pixel_values, pixel_values) + + with torch.no_grad(): + outputs = model(pixel_values, output_hidden_states=True) + original_outputs = original_model(pixel_values) + + # assert values + if image_classifier: + print("Predicted class:") + class_idx = outputs.logits.argmax(-1).item() + print(model.config.id2label[class_idx]) + else: + assert outputs.last_hidden_state[:, 0].shape == original_outputs.shape + assert torch.allclose(outputs.last_hidden_state[:, 0], original_outputs, atol=1e-3) + print("Looks ok!") + + if pytorch_dump_folder_path is not None: + Path(pytorch_dump_folder_path).mkdir(exist_ok=True) + print(f"Saving model {model_name} to {pytorch_dump_folder_path}") + model.save_pretrained(pytorch_dump_folder_path) + print(f"Saving image processor to {pytorch_dump_folder_path}") + processor.save_pretrained(pytorch_dump_folder_path) + + if push_to_hub: + model_name_to_hf_name = { + "dinov2_vits14": "dinov2-small", + "dinov2_vitb14": "dinov2-base", + "dinov2_vitl14": "dinov2-large", + "dinov2_vitg14": "dinov2-giant", + "dinov2_vits14_1layer": "dinov2-small-imagenet1k-1-layer", + "dinov2_vitb14_1layer": "dinov2-base-imagenet1k-1-layer", + "dinov2_vitl14_1layer": "dinov2-large-imagenet1k-1-layer", + "dinov2_vitg14_1layer": "dinov2-giant-imagenet1k-1-layer", + } + + name = model_name_to_hf_name[model_name] + model.push_to_hub(f"facebook/{name}") + processor.push_to_hub(f"facebook/{name}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--model_name", + default="dinov2_vitb14", + type=str, + choices=[ + "dinov2_vits14", + "dinov2_vitb14", + "dinov2_vitl14", + "dinov2_vitg14", + "dinov2_vits14_1layer", + "dinov2_vitb14_1layer", + "dinov2_vitl14_1layer", + "dinov2_vitg14_1layer", + ], + help="Name of the model you'd like to convert.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory." + ) + parser.add_argument( + "--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub." + ) + + args = parser.parse_args() + convert_dinov2_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub) diff --git a/docs/transformers/build/lib/transformers/models/dinov2/modeling_dinov2.py b/docs/transformers/build/lib/transformers/models/dinov2/modeling_dinov2.py new file mode 100644 index 0000000000000000000000000000000000000000..7ed5a4ec6cb7b180e23c84a8beb1b3a2dc0ed392 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/dinov2/modeling_dinov2.py @@ -0,0 +1,904 @@ +# coding=utf-8 +# Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch DINOv2 model.""" + +import collections.abc +from typing import Callable, Dict, List, Optional, Set, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...modeling_outputs import ( + BackboneOutput, + BaseModelOutput, + BaseModelOutputWithPooling, + ImageClassifierOutput, +) +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, + torch_int, +) +from ...utils.backbone_utils import BackboneMixin +from .configuration_dinov2 import Dinov2Config + + +logger = logging.get_logger(__name__) + +# General docstring +_CONFIG_FOR_DOC = "Dinov2Config" + +# Base docstring +_CHECKPOINT_FOR_DOC = "facebook/dinov2-base" +_EXPECTED_OUTPUT_SHAPE = [1, 257, 768] + +# Image classification docstring +_IMAGE_CLASS_CHECKPOINT = "facebook/dinov2-small-imagenet1k-1-layer" +_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat" + + +class Dinov2Embeddings(nn.Module): + """ + Construct the CLS token, mask token, position and patch embeddings. + """ + + def __init__(self, config: Dinov2Config) -> None: + super().__init__() + + self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size)) + if config.use_mask_token: + self.mask_token = nn.Parameter(torch.zeros(1, config.hidden_size)) + self.patch_embeddings = Dinov2PatchEmbeddings(config) + num_patches = self.patch_embeddings.num_patches + self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size)) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.patch_size = config.patch_size + self.use_mask_token = config.use_mask_token + self.config = config + + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing and interpolation at torch.float32 precision. + + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 + """ + + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embeddings.shape[1] - 1 + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: + return self.position_embeddings + + class_pos_embed = self.position_embeddings[:, :1] + patch_pos_embed = self.position_embeddings[:, 1:] + + dim = embeddings.shape[-1] + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + target_dtype = patch_pos_embed.dtype + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.to(torch.float32), + size=(new_height, new_width), + mode="bicubic", + align_corners=False, + ).to(dtype=target_dtype) + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) + + def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Tensor] = None) -> torch.Tensor: + batch_size, _, height, width = pixel_values.shape + target_dtype = self.patch_embeddings.projection.weight.dtype + embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype)) + + if bool_masked_pos is not None and self.use_mask_token: + embeddings = torch.where( + bool_masked_pos.unsqueeze(-1), self.mask_token.to(embeddings.dtype).unsqueeze(0), embeddings + ) + + # add the [CLS] token to the embedded patch tokens + cls_tokens = self.cls_token.expand(batch_size, -1, -1) + embeddings = torch.cat((cls_tokens, embeddings), dim=1) + + # add positional encoding to each token + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + + embeddings = self.dropout(embeddings) + + return embeddings + + +class Dinov2PatchEmbeddings(nn.Module): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__(self, config): + super().__init__() + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.hidden_size + + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + + self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + num_channels = pixel_values.shape[1] + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + f" Expected {self.num_channels} but got {num_channels}." + ) + embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2) + return embeddings + + +# Copied from transformers.models.vit.modeling_vit.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + # Take the dot product between "query" and "key" to get the raw attention scores. + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + + # Normalize the attention scores to probabilities. + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + # Mask heads if we want to + if attention_mask is not None: + attn_weights = attn_weights * attention_mask + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->Dinov2 +class Dinov2SelfAttention(nn.Module): + def __init__(self, config: Dinov2Config) -> None: + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size {config.hidden_size} is not a multiple of the number of attention " + f"heads {config.num_attention_heads}." + ) + + self.config = config + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.dropout_prob = config.attention_probs_dropout_prob + self.scaling = self.attention_head_size**-0.5 + self.is_causal = False + + self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(self.query(hidden_states)) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + context_layer, attention_probs = attention_interface( + self, + query_layer, + key_layer, + value_layer, + head_mask, + is_causal=self.is_causal, + scaling=self.scaling, + dropout=0.0 if not self.training else self.dropout_prob, + ) + + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.reshape(new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Dinov2 +class Dinov2SelfOutput(nn.Module): + """ + The residual connection is defined in Dinov2Layer instead of here (as is the case with other models), due to the + layernorm applied before each block. + """ + + def __init__(self, config: Dinov2Config) -> None: + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + return hidden_states + + +# Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->Dinov2 +class Dinov2Attention(nn.Module): + def __init__(self, config: Dinov2Config) -> None: + super().__init__() + self.attention = Dinov2SelfAttention(config) + self.output = Dinov2SelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads: Set[int]) -> None: + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.attention.query = prune_linear_layer(self.attention.query, index) + self.attention.key = prune_linear_layer(self.attention.key, index) + self.attention.value = prune_linear_layer(self.attention.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads) + self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + self_outputs = self.attention(hidden_states, head_mask, output_attentions) + + attention_output = self.output(self_outputs[0], hidden_states) + + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class Dinov2LayerScale(nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.lambda1 = nn.Parameter(config.layerscale_value * torch.ones(config.hidden_size)) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + return hidden_state * self.lambda1 + + +# Copied from transformers.models.beit.modeling_beit.drop_path +def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, + however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the + layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the + argument. + """ + if drop_prob == 0.0 or not training: + return input + keep_prob = 1 - drop_prob + shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) + random_tensor.floor_() # binarize + output = input.div(keep_prob) * random_tensor + return output + + +# Copied from transformers.models.beit.modeling_beit.BeitDropPath +class Dinov2DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: Optional[float] = None) -> None: + super().__init__() + self.drop_prob = drop_prob + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return drop_path(hidden_states, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return "p={}".format(self.drop_prob) + + +class Dinov2MLP(nn.Module): + def __init__(self, config) -> None: + super().__init__() + in_features = out_features = config.hidden_size + hidden_features = int(config.hidden_size * config.mlp_ratio) + self.fc1 = nn.Linear(in_features, hidden_features, bias=True) + if isinstance(config.hidden_act, str): + self.activation = ACT2FN[config.hidden_act] + else: + self.activation = config.hidden_act + self.fc2 = nn.Linear(hidden_features, out_features, bias=True) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + hidden_state = self.fc1(hidden_state) + hidden_state = self.activation(hidden_state) + hidden_state = self.fc2(hidden_state) + return hidden_state + + +class Dinov2SwiGLUFFN(nn.Module): + def __init__(self, config) -> None: + super().__init__() + in_features = out_features = config.hidden_size + hidden_features = int(config.hidden_size * config.mlp_ratio) + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + + self.weights_in = nn.Linear(in_features, 2 * hidden_features, bias=True) + self.weights_out = nn.Linear(hidden_features, out_features, bias=True) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + hidden_state = self.weights_in(hidden_state) + x1, x2 = hidden_state.chunk(2, dim=-1) + hidden = nn.functional.silu(x1) * x2 + return self.weights_out(hidden) + + +class Dinov2Layer(nn.Module): + """This corresponds to the Block class in the original implementation.""" + + def __init__(self, config: Dinov2Config) -> None: + super().__init__() + + self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.attention = Dinov2Attention(config) + self.layer_scale1 = Dinov2LayerScale(config) + self.drop_path = Dinov2DropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity() + + self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + if config.use_swiglu_ffn: + self.mlp = Dinov2SwiGLUFFN(config) + else: + self.mlp = Dinov2MLP(config) + self.layer_scale2 = Dinov2LayerScale(config) + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + self_attention_outputs = self.attention( + self.norm1(hidden_states), # in Dinov2, layernorm is applied before self-attention + head_mask, + output_attentions=output_attentions, + ) + attention_output = self_attention_outputs[0] + + attention_output = self.layer_scale1(attention_output) + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + # first residual connection + hidden_states = self.drop_path(attention_output) + hidden_states + + # in Dinov2, layernorm is also applied after self-attention + layer_output = self.norm2(hidden_states) + layer_output = self.mlp(layer_output) + layer_output = self.layer_scale2(layer_output) + + # second residual connection + layer_output = self.drop_path(layer_output) + hidden_states + + outputs = (layer_output,) + outputs + + return outputs + + +# Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->Dinov2 +class Dinov2Encoder(nn.Module): + def __init__(self, config: Dinov2Config) -> None: + super().__init__() + self.config = config + self.layer = nn.ModuleList([Dinov2Layer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ) -> Union[tuple, BaseModelOutput]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + layer_head_mask, + output_attentions, + ) + else: + layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class Dinov2PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = Dinov2Config + base_model_prefix = "dinov2" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + _no_split_modules = ["Dinov2SwiGLUFFN"] + _supports_sdpa = True + _supports_flash_attn_2 = True + + def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid + # `trunc_normal_cpu` not implemented in `half` issues + module.weight.data = nn.init.trunc_normal_( + module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range + ).to(module.weight.dtype) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, Dinov2Embeddings): + module.position_embeddings.data = nn.init.trunc_normal_( + module.position_embeddings.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.position_embeddings.dtype) + + module.cls_token.data = nn.init.trunc_normal_( + module.cls_token.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.cls_token.dtype) + + if self.config.use_mask_token: + module.mask_token.data.zero_() + elif isinstance(module, Dinov2LayerScale): + module.lambda1.data.fill_(self.config.layerscale_value) + + +DINOV2_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it + as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`Dinov2Config`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +DINOV2_BASE_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See + [`BitImageProcessor.preprocess`] for details. + + bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, sequence_length)`): + Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). Only relevant for + pre-training. + + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +DINOV2_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See + [`BitImageProcessor.preprocess`] for details. + + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare DINOv2 Model transformer outputting raw hidden-states without any specific head on top.", + DINOV2_START_DOCSTRING, +) +class Dinov2Model(Dinov2PreTrainedModel): + def __init__(self, config: Dinov2Config): + super().__init__(config) + self.config = config + + self.embeddings = Dinov2Embeddings(config) + self.encoder = Dinov2Encoder(config) + + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> Dinov2PatchEmbeddings: + return self.embeddings.patch_embeddings + + def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None: + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(DINOV2_BASE_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPooling, + config_class=_CONFIG_FOR_DOC, + modality="vision", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos) + + encoder_outputs = self.encoder( + embedding_output, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + pooled_output = sequence_output[:, 0, :] + + if not return_dict: + head_outputs = (sequence_output, pooled_output) + return head_outputs + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + """ + Dinov2 Model transformer with an image classification head on top (a linear layer on top of the final hidden state + of the [CLS] token) e.g. for ImageNet. + """, + DINOV2_START_DOCSTRING, +) +class Dinov2ForImageClassification(Dinov2PreTrainedModel): + def __init__(self, config: Dinov2Config) -> None: + super().__init__(config) + + self.num_labels = config.num_labels + self.dinov2 = Dinov2Model(config) + + # Classifier head + self.classifier = ( + nn.Linear(config.hidden_size * 2, config.num_labels) if config.num_labels > 0 else nn.Identity() + ) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(DINOV2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_IMAGE_CLASS_CHECKPOINT, + output_type=ImageClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, + ) + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, ImageClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.dinov2( + pixel_values, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] # batch_size, sequence_length, hidden_size + + cls_token = sequence_output[:, 0] + patch_tokens = sequence_output[:, 1:] + + linear_input = torch.cat([cls_token, patch_tokens.mean(dim=1)], dim=1) + + logits = self.classifier(linear_input) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Dinov2 backbone, to be used with frameworks like DETR and MaskFormer. + """, + DINOV2_START_DOCSTRING, +) +class Dinov2Backbone(Dinov2PreTrainedModel, BackboneMixin): + def __init__(self, config): + super().__init__(config) + super()._init_backbone(config) + + self.num_features = [config.hidden_size for _ in range(config.num_hidden_layers + 1)] + self.embeddings = Dinov2Embeddings(config) + self.encoder = Dinov2Encoder(config) + + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> Dinov2PatchEmbeddings: + return self.embeddings.patch_embeddings + + @add_start_docstrings_to_model_forward(DINOV2_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: torch.Tensor, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> BackboneOutput: + """ + Returns: + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, AutoBackbone + >>> import torch + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base") + >>> model = AutoBackbone.from_pretrained( + ... "facebook/dinov2-base", out_features=["stage2", "stage5", "stage8", "stage11"] + ... ) + + >>> inputs = processor(image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> feature_maps = outputs.feature_maps + >>> list(feature_maps[-1].shape) + [1, 768, 16, 16] + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + embedding_output = self.embeddings(pixel_values) + + outputs = self.encoder( + embedding_output, output_hidden_states=True, output_attentions=output_attentions, return_dict=return_dict + ) + + hidden_states = outputs.hidden_states if return_dict else outputs[1] + + feature_maps = () + for stage, hidden_state in zip(self.stage_names, hidden_states): + if stage in self.out_features: + if self.config.apply_layernorm: + hidden_state = self.layernorm(hidden_state) + if self.config.reshape_hidden_states: + hidden_state = hidden_state[:, 1:] + # this was actually a bug in the original implementation that we copied here, + # cause normally the order is height, width + batch_size, _, height, width = pixel_values.shape + patch_size = self.config.patch_size + hidden_state = hidden_state.reshape(batch_size, height // patch_size, width // patch_size, -1) + hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous() + feature_maps += (hidden_state,) + + if not return_dict: + if output_hidden_states: + output = (feature_maps,) + outputs[1:] + else: + output = (feature_maps,) + outputs[2:] + return output + + return BackboneOutput( + feature_maps=feature_maps, + hidden_states=outputs.hidden_states if output_hidden_states else None, + attentions=outputs.attentions if output_attentions else None, + ) + + +__all__ = ["Dinov2ForImageClassification", "Dinov2Model", "Dinov2PreTrainedModel", "Dinov2Backbone"] diff --git a/docs/transformers/build/lib/transformers/models/dinov2/modeling_flax_dinov2.py b/docs/transformers/build/lib/transformers/models/dinov2/modeling_flax_dinov2.py new file mode 100644 index 0000000000000000000000000000000000000000..2766850e921ea805d525c330858babab23614204 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/dinov2/modeling_flax_dinov2.py @@ -0,0 +1,801 @@ +# coding=utf-8 +# Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Flax DINOv2 model.""" + +import collections.abc +import math +from typing import Optional, Tuple + +import flax.linen as nn +import jax +import jax.numpy as jnp +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze +from flax.linen.attention import dot_product_attention_weights +from flax.traverse_util import flatten_dict, unflatten_dict + +from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxBaseModelOutputWithPooling, FlaxSequenceClassifierOutput +from ...modeling_flax_utils import ( + ACT2FN, + FlaxPreTrainedModel, + append_replace_return_docstrings, + overwrite_call_docstring, +) +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward +from .configuration_dinov2 import Dinov2Config + + +DINOV2_START_DOCSTRING = r""" + + This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading, saving and converting weights from PyTorch models) + + This model is also a + [flax.linen.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) subclass. Use it as + a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and + behavior. + + Finally, this model supports inherent JAX features such as: + + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + + Parameters: + config ([`Dinov2Config`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. + dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): + The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and + `jax.numpy.bfloat16` (on TPUs). + + This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If + specified all the computation will be performed with the given `dtype`. + + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** + + If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and + [`~FlaxPreTrainedModel.to_bf16`]. +""" + +DINOV2_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`numpy.ndarray` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`Dinov2ImageProcessor.__call__`] + for details. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class FlaxDinov2PatchEmbeddings(nn.Module): + config: Dinov2Config + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + image_size = self.config.image_size + patch_size = self.config.patch_size + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + + self.num_patches = num_patches + self.num_channels = self.config.num_channels + self.projection = nn.Conv( + self.config.hidden_size, + kernel_size=patch_size, + strides=patch_size, + padding="VALID", + dtype=self.dtype, + kernel_init=jax.nn.initializers.variance_scaling( + self.config.initializer_range**2, "fan_in", "truncated_normal" + ), + ) + + # Copied from transformers.models.vit.modeling_flax_vit.FlaxViTPatchEmbeddings.__call__ + def __call__(self, pixel_values): + num_channels = pixel_values.shape[-1] + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + embeddings = self.projection(pixel_values) + batch_size, _, _, channels = embeddings.shape + return jnp.reshape(embeddings, (batch_size, -1, channels)) + + +class FlaxDinov2Embeddings(nn.Module): + """Construct the CLS token, position and patch embeddings.""" + + config: Dinov2Config + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.cls_token = self.param( + "cls_token", + jax.nn.initializers.variance_scaling(self.config.initializer_range**2, "fan_in", "truncated_normal"), + (1, 1, self.config.hidden_size), + ) + if self.config.use_mask_token: + self.mask_token = self.param( + "mask_token", + jax.nn.initializers.variance_scaling(self.config.initializer_range**2, "fan_in", "truncated_normal"), + (1, self.config.hidden_size), + ) + self.patch_embeddings = FlaxDinov2PatchEmbeddings(self.config, dtype=self.dtype) + num_patches = self.patch_embeddings.num_patches + self.position_embeddings = self.param( + "position_embeddings", + jax.nn.initializers.variance_scaling(self.config.initializer_range**2, "fan_in", "truncated_normal"), + (1, num_patches + 1, self.config.hidden_size), + ) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + + def interpolate_pos_encoding(self, config, hidden_states, height, width, position_embeddings): + num_patches = hidden_states.shape[1] - 1 + num_positions = position_embeddings.shape[1] - 1 + if num_patches == num_positions and height == width: + return position_embeddings + class_pos_embed = position_embeddings[:, 0] + patch_pos_embed = position_embeddings[:, 1:] + dim = hidden_states.shape[-1] + + h = height // config.patch_size + w = width // config.patch_size + height, width = h + 0.1, w + 0.1 + + patch_pos_embed = patch_pos_embed.reshape( + (1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + ) + patch_pos_embed = jnp.transpose(patch_pos_embed, (0, 3, 1, 2)) + target_dtype = patch_pos_embed.dtype + new_height_ratio = jnp.float32(height / math.sqrt(num_positions)) + new_width_ratio = jnp.float32(width / math.sqrt(num_positions)) + + scale = jnp.array([new_height_ratio, new_width_ratio], dtype=jnp.float32) + translation = jnp.array([0.0, 0.0], dtype=jnp.float32) + + patch_pos_embed = jax.image.scale_and_translate( + patch_pos_embed.astype(jnp.float32), + shape=(patch_pos_embed.shape[0], patch_pos_embed.shape[1], h, w), + spatial_dims=(2, 3), + scale=scale, + translation=translation, + method="bicubic", + antialias=False, + ) + patch_pos_embed = patch_pos_embed.astype(target_dtype) + patch_pos_embed = jnp.transpose(patch_pos_embed, (0, 2, 3, 1)).reshape((position_embeddings.shape[0], -1, dim)) + patch_pos_embed_expanded = jnp.tile(patch_pos_embed, (hidden_states.shape[0], 1, 1)) + class_pos_embed_expanded = jnp.tile(class_pos_embed, (hidden_states.shape[0], 1, 1)) + + return jnp.concatenate((class_pos_embed_expanded, patch_pos_embed_expanded), axis=1) + + def __call__(self, pixel_values, deterministic=True): + batch_size = pixel_values.shape[0] + target_dtype = self.patch_embeddings.projection.dtype + height, width = pixel_values.shape[1], pixel_values.shape[2] + + embeddings = self.patch_embeddings(pixel_values.astype(target_dtype)) + + cls_tokens = jnp.broadcast_to(self.cls_token, (batch_size, 1, self.config.hidden_size)) + embeddings = jnp.concatenate((cls_tokens, embeddings), axis=1) + + embeddings = embeddings + self.interpolate_pos_encoding( + self.config, embeddings, height, width, self.position_embeddings + ) + + embeddings = self.dropout(embeddings, deterministic=deterministic) + return embeddings + + +# Copied from transformers.models.vit.modeling_flax_vit.FlaxViTSelfAttention with ViT->Dinov2 +class FlaxDinov2SelfAttention(nn.Module): + config: Dinov2Config + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + if self.config.hidden_size % self.config.num_attention_heads != 0: + raise ValueError( + "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`:" + " {self.config.num_attention_heads}" + ) + + self.query = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.variance_scaling( + self.config.initializer_range**2, mode="fan_in", distribution="truncated_normal" + ), + use_bias=self.config.qkv_bias, + ) + self.key = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.variance_scaling( + self.config.initializer_range**2, mode="fan_in", distribution="truncated_normal" + ), + use_bias=self.config.qkv_bias, + ) + self.value = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.variance_scaling( + self.config.initializer_range**2, mode="fan_in", distribution="truncated_normal" + ), + use_bias=self.config.qkv_bias, + ) + + def __call__(self, hidden_states, deterministic: bool = True, output_attentions: bool = False): + head_dim = self.config.hidden_size // self.config.num_attention_heads + + query_states = self.query(hidden_states).reshape( + hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) + ) + value_states = self.value(hidden_states).reshape( + hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) + ) + key_states = self.key(hidden_states).reshape( + hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) + ) + + dropout_rng = None + if not deterministic and self.config.attention_probs_dropout_prob > 0.0: + dropout_rng = self.make_rng("dropout") + + attn_weights = dot_product_attention_weights( + query_states, + key_states, + dropout_rng=dropout_rng, + dropout_rate=self.config.attention_probs_dropout_prob, + broadcast_dropout=True, + deterministic=deterministic, + dtype=self.dtype, + precision=None, + ) + + attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) + attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,)) + + outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) + return outputs + + +# Copied from transformers.models.vit.modeling_flax_vit.FlaxViTSelfOutput with ViT->Dinov2 +class FlaxDinov2SelfOutput(nn.Module): + config: Dinov2Config + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dense = nn.Dense( + self.config.hidden_size, + kernel_init=jax.nn.initializers.variance_scaling( + self.config.initializer_range**2, "fan_in", "truncated_normal" + ), + dtype=self.dtype, + ) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + + def __call__(self, hidden_states, input_tensor, deterministic: bool = True): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + return hidden_states + + +# Copied from transformers.models.vit.modeling_flax_vit.FlaxViTAttention with ViT->Dinov2 +class FlaxDinov2Attention(nn.Module): + config: Dinov2Config + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.attention = FlaxDinov2SelfAttention(self.config, dtype=self.dtype) + self.output = FlaxDinov2SelfOutput(self.config, dtype=self.dtype) + + def __call__(self, hidden_states, deterministic=True, output_attentions: bool = False): + attn_outputs = self.attention(hidden_states, deterministic=deterministic, output_attentions=output_attentions) + attn_output = attn_outputs[0] + hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_outputs[1],) + + return outputs + + +def ones_with_scale(key, shape, scale, dtype=jnp.float32): + return jnp.ones(shape, dtype) * scale + + +class FlaxDinov2LayerScale(nn.Module): + config: Dinov2Config + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.lambda1 = self.config.layerscale_value * self.param( + "lambda1", + jax.nn.initializers.ones, + (self.config.hidden_size,), + ) + self.lambda1 = self.lambda1 * self.config.layerscale_value + + def __call__(self, hidden_states): + return self.lambda1 * hidden_states + + +# Copied from transformers.models.beit.modeling_flax_beit.FlaxBeitDropPath with Beit -> Dinov2 +class FlaxDinov2DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + rate: float + + @nn.module.compact + def __call__(self, inputs, deterministic: Optional[bool] = True): + if self.rate == 0.0: + return inputs + keep_prob = 1.0 - self.rate + if deterministic: + return inputs + else: + shape = (inputs.shape[0],) + (1,) * (inputs.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + rng = self.make_rng("droppath") + random_tensor = keep_prob + jax.random.uniform(rng, shape=shape, dtype=inputs.dtype) + binary_tensor = jnp.floor(random_tensor) + output = inputs / keep_prob * binary_tensor + return output + + +class FlaxDinov2MLP(nn.Module): + config: Dinov2Config + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.fc1 = nn.Dense( + self.config.hidden_size * self.config.mlp_ratio, + kernel_init=jax.nn.initializers.variance_scaling( + self.config.initializer_range**2, "fan_in", "truncated_normal" + ), + dtype=self.dtype, + ) + self.fc2 = nn.Dense( + self.config.hidden_size, + kernel_init=jax.nn.initializers.variance_scaling( + self.config.initializer_range**2, "fan_in", "truncated_normal" + ), + dtype=self.dtype, + ) + if isinstance(self.config.hidden_act, str): + self.act = ACT2FN[self.config.hidden_act] + else: + self.act = self.config.hidden_act + + def __call__(self, hidden_states): + hidden_states = self.fc1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class FlaxDinov2SwiGLUFFN(nn.Module): + config: Dinov2Config + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + hidden_features = int(self.config.hidden_size * self.config.mlp_ratio) + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + + self.weights_in = nn.Dense( + 2 * hidden_features, + kernel_init=jax.nn.initializers.variance_scaling( + self.config.initializer_range**2, "fan_in", "truncated_normal" + ), + dtype=self.dtype, + ) + self.weights_out = nn.Dense( + self.config.hidden_size, + kernel_init=jax.nn.initializers.variance_scaling( + self.config.initializer_range**2, "fan_in", "truncated_normal" + ), + dtype=self.dtype, + ) + + def __call__(self, hidden_states): + hidden_states = self.weights_in(hidden_states) + x1, x2 = jnp.split(hidden_states, 2, axis=-1) + hidden = nn.silu(x1) * x2 + return self.weights_out(hidden) + + +class FlaxDinov2Layer(nn.Module): + config: Dinov2Config + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.norm1 = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + self.attention = FlaxDinov2Attention(self.config, dtype=self.dtype) + self.layer_scale1 = FlaxDinov2LayerScale(self.config, dtype=self.dtype) + self.drop_path = FlaxDinov2DropPath(self.config.drop_path_rate) + self.norm2 = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + + if self.config.use_swiglu_ffn: + self.mlp = FlaxDinov2SwiGLUFFN(self.config, dtype=self.dtype) + else: + self.mlp = FlaxDinov2MLP(self.config, dtype=self.dtype) + + self.layer_scale2 = FlaxDinov2LayerScale(self.config, dtype=self.dtype) + + def __call__(self, hidden_states, deterministic: bool = True, output_attentions: bool = False): + self_attention_outputs = self.attention( + self.norm1(hidden_states), # in Dinov2, layernorm is applied before self-attention + deterministic=deterministic, + output_attentions=output_attentions, + ) + + attention_output = self_attention_outputs[0] + + attention_output = self.layer_scale1(attention_output) + + outputs = self_attention_outputs[1:] + + # first residual connection + hidden_states = self.drop_path(attention_output) + hidden_states + + # in Dinov2, layernorm is also applied after self-attention + layer_output = self.norm2(hidden_states) + layer_output = self.mlp(layer_output) + layer_output = self.layer_scale2(layer_output) + + # second residual connection + layer_output = self.drop_path(layer_output) + hidden_states + + outputs = (layer_output,) + outputs + + return outputs + + +# Copied from transformers.models.vit.modeling_flax_vit.FlaxViTLayerCollection with ViT->Dinov2 +class FlaxDinov2LayerCollection(nn.Module): + config: Dinov2Config + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.layers = [ + FlaxDinov2Layer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers) + ] + + def __call__( + self, + hidden_states, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + all_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + for i, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = layer(hidden_states, deterministic=deterministic, output_attentions=output_attentions) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions += (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + outputs = (hidden_states,) + if not return_dict: + return tuple(v for v in outputs if v is not None) + + return FlaxBaseModelOutput( + last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions + ) + + +# Copied from transformers.models.vit.modeling_flax_vit.FlaxViTEncoder with ViT->Dinov2 +class FlaxDinov2Encoder(nn.Module): + config: Dinov2Config + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.layer = FlaxDinov2LayerCollection(self.config, dtype=self.dtype) + + def __call__( + self, + hidden_states, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + return self.layer( + hidden_states, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +class FlaxDinov2PreTrainedModel(FlaxPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = Dinov2Config + base_model_prefix = "dinov2" + main_input_name = "pixel_values" + module_class: nn.Module = None + + def __init__( + self, + config: Dinov2Config, + input_shape=None, + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, + **kwargs, + ): + module = self.module_class(config=config, dtype=dtype, **kwargs) + if input_shape is None: + input_shape = (1, config.image_size, config.image_size, config.num_channels) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: + # init input tensors + pixel_values = jnp.zeros(input_shape, dtype=self.dtype) + + params_rng, dropout_rng = jax.random.split(rng) + dropout_rng, droppath_rng = jax.random.split(dropout_rng) + rngs = {"params": params_rng, "dropout": dropout_rng, "droppath": droppath_rng} + + random_params = self.module.init(rngs, pixel_values, return_dict=False)["params"] + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + + @add_start_docstrings_to_model_forward(DINOV2_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + def __call__( + self, + pixel_values, + params: dict = None, + dropout_rng: jax.random.PRNGKey = None, + train: bool = False, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1)) + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + dropout_rng, droppath_rng = jax.random.split(dropout_rng) + rngs["dropout"] = dropout_rng + rngs["droppath"] = droppath_rng + + return self.module.apply( + {"params": params or self.params}, + jnp.array(pixel_values, dtype=jnp.float32), + not train, + output_attentions, + output_hidden_states, + return_dict, + rngs=rngs, + ) + + +class FlaxDinov2Module(nn.Module): + config: Dinov2Config + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.embeddings = FlaxDinov2Embeddings(self.config, dtype=self.dtype) + self.encoder = FlaxDinov2Encoder(self.config, dtype=self.dtype) + self.layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + + def __call__( + self, + pixel_values, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + hidden_states = self.embeddings(pixel_values, deterministic=deterministic) + + encoder_outputs = self.encoder( + hidden_states, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + pooled_output = sequence_output[:, 0, :] + + if not return_dict: + head_outputs = (sequence_output, pooled_output) + return head_outputs + encoder_outputs[1:] + + return FlaxBaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + "The bare Dinov2 Model transformer outputting raw hidden-states without any specific head on top.", + DINOV2_START_DOCSTRING, +) +class FlaxDinov2Model(FlaxDinov2PreTrainedModel): + module_class = FlaxDinov2Module + + +FLAX_VISION_MODEL_DOCSTRING = """ + Returns: + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, FlaxDinov2Model + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base") + >>> model = FlaxDinov2Model.from_pretrained("facebook/dinov2-base") + + >>> inputs = image_processor(images=image, return_tensors="np") + >>> outputs = model(**inputs) + >>> last_hidden_states = outputs.last_hidden_state + ``` +""" + +overwrite_call_docstring(FlaxDinov2Model, FLAX_VISION_MODEL_DOCSTRING) +append_replace_return_docstrings( + FlaxDinov2Model, output_type=FlaxBaseModelOutputWithPooling, config_class=Dinov2Config +) + + +class FlaxDinov2ForImageClassificationModule(nn.Module): + config: Dinov2Config + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.dinov2 = FlaxDinov2Module(config=self.config, dtype=self.dtype) + self.classifier = nn.Dense( + self.config.num_labels, + dtype=self.dtype, + kernel_init=jax.nn.initializers.variance_scaling( + self.config.initializer_range**2, "fan_in", "truncated_normal" + ), + ) + + def __call__( + self, + pixel_values=None, + deterministic: bool = True, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.dinov2( + pixel_values, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + + cls_token = hidden_states[:, 0] + patch_tokens = hidden_states[:, 1:] + linear_input = jnp.concatenate([cls_token, patch_tokens.mean(axis=1)], axis=-1) + + logits = self.classifier(linear_input) + + if not return_dict: + output = (logits,) + outputs[2:] + return output + + return FlaxSequenceClassifierOutput( + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Dinov2 Model transformer with an image classification head on top (a linear layer on top of the final hidden state of + the [CLS] token) e.g. for ImageNet. + """, + DINOV2_START_DOCSTRING, +) +class FlaxDinov2ForImageClassification(FlaxDinov2PreTrainedModel): + module_class = FlaxDinov2ForImageClassificationModule + + +FLAX_VISION_CLASSIFICATION_DOCSTRING = """ + Returns: + + Example: + + ```python + >>> from transformers import AutoImageProcessor, FlaxDinov2ForImageClassification + >>> from PIL import Image + >>> import jax + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base-imagenet1k-1-layer") + >>> model = FlaxDinov2ForImageClassification.from_pretrained("facebook/dinov2-base-imagenet1k-1-layer", from_pt=True) + + >>> inputs = image_processor(images=image, return_tensors="np") + >>> outputs = model(**inputs) + >>> logits = outputs.logits + + >>> # model predicts one of the 1000 ImageNet classes + >>> predicted_class_idx = jax.numpy.argmax(logits, axis=-1) + >>> print("Predicted class:", model.config.id2label[predicted_class_idx.item()]) + ``` +""" + +overwrite_call_docstring(FlaxDinov2ForImageClassification, FLAX_VISION_CLASSIFICATION_DOCSTRING) +append_replace_return_docstrings( + FlaxDinov2ForImageClassification, output_type=FlaxSequenceClassifierOutput, config_class=Dinov2Config +) + + +__all__ = ["FlaxDinov2ForImageClassification", "FlaxDinov2Model", "FlaxDinov2PreTrainedModel"] diff --git a/docs/transformers/build/lib/transformers/models/dinov2_with_registers/__init__.py b/docs/transformers/build/lib/transformers/models/dinov2_with_registers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2d10027b6a3b6375235a6785df044e8f0ce5fb33 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/dinov2_with_registers/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_dinov2_with_registers import * + from .modeling_dinov2_with_registers import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/build/lib/transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py b/docs/transformers/build/lib/transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py new file mode 100644 index 0000000000000000000000000000000000000000..a0295899fcd72f6481aa4e12cbef8630652f8149 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py @@ -0,0 +1,159 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_dinov2_with_registers.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2024 Meta Inc. and the HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...configuration_utils import PretrainedConfig +from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices + + +class Dinov2WithRegistersConfig(BackboneConfigMixin, PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Dinov2WithRegistersModel`]. It is used to instantiate an + Dinov2WithRegisters model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the DINOv2 with Registers + [facebook/dinov2-with-registers-base](https://huggingface.co/facebook/dinov2-with-registers-base) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + mlp_ratio (`int`, *optional*, defaults to 4): + Ratio of the hidden size of the MLPs relative to the `hidden_size`. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch. + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + qkv_bias (`bool`, *optional*, defaults to `True`): + Whether to add a bias to the queries, keys and values. + layerscale_value (`float`, *optional*, defaults to 1.0): + Initial value to use for layer scale. + drop_path_rate (`float`, *optional*, defaults to 0.0): + Stochastic depth rate per sample (when applied in the main path of residual layers). + use_swiglu_ffn (`bool`, *optional*, defaults to `False`): + Whether to use the SwiGLU feedforward neural network. + num_register_tokens (`int`, *optional*, defaults to 4): + Number of register tokens to use. + out_features (`List[str]`, *optional*): + If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc. + (depending on how many stages the model has). If unset and `out_indices` is set, will default to the + corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the + same order as defined in the `stage_names` attribute. + out_indices (`List[int]`, *optional*): + If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how + many stages the model has). If unset and `out_features` is set, will default to the corresponding stages. + If unset and `out_features` is unset, will default to the last stage. Must be in the + same order as defined in the `stage_names` attribute. + apply_layernorm (`bool`, *optional*, defaults to `True`): + Whether to apply layer normalization to the feature maps in case the model is used as backbone. + reshape_hidden_states (`bool`, *optional*, defaults to `True`): + Whether to reshape the feature maps to 4D tensors of shape `(batch_size, hidden_size, height, width)` in + case the model is used as backbone. If `False`, the feature maps will be 3D tensors of shape `(batch_size, + seq_len, hidden_size)`. + + Example: + + ```python + >>> from transformers import Dinov2WithRegistersConfig, Dinov2WithRegistersModel + + >>> # Initializing a Dinov2WithRegisters base style configuration + >>> configuration = Dinov2WithRegistersConfig() + + >>> # Initializing a model (with random weights) from the base style configuration + >>> model = Dinov2WithRegistersModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "dinov2_with_registers" + + def __init__( + self, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + mlp_ratio=4, + hidden_act="gelu", + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + initializer_range=0.02, + layer_norm_eps=1e-6, + image_size=224, + patch_size=16, + num_channels=3, + qkv_bias=True, + layerscale_value=1.0, + drop_path_rate=0.0, + use_swiglu_ffn=False, + num_register_tokens=4, + out_features=None, + out_indices=None, + apply_layernorm=True, + reshape_hidden_states=True, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.mlp_ratio = mlp_ratio + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.qkv_bias = qkv_bias + self.layerscale_value = layerscale_value + self.drop_path_rate = drop_path_rate + self.use_swiglu_ffn = use_swiglu_ffn + self.num_register_tokens = num_register_tokens + self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, num_hidden_layers + 1)] + self._out_features, self._out_indices = get_aligned_output_features_output_indices( + out_features=out_features, out_indices=out_indices, stage_names=self.stage_names + ) + self.apply_layernorm = apply_layernorm + self.reshape_hidden_states = reshape_hidden_states + + +__all__ = ["Dinov2WithRegistersConfig"] diff --git a/docs/transformers/build/lib/transformers/models/dinov2_with_registers/convert_dinov2_with_registers_to_hf.py b/docs/transformers/build/lib/transformers/models/dinov2_with_registers/convert_dinov2_with_registers_to_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..0ff2697f74667e1b941ec65adb1a39cfd0a87460 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/dinov2_with_registers/convert_dinov2_with_registers_to_hf.py @@ -0,0 +1,291 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert DINOv2 with Registers checkpoints from the original repository. + +URL: https://github.com/facebookresearch/dinov2/tree/main +""" + +import argparse +import json +from pathlib import Path + +import requests +import torch +import torch.nn as nn +from huggingface_hub import hf_hub_download +from PIL import Image +from torchvision import transforms + +from transformers import ( + BitImageProcessor, + Dinov2WithRegistersConfig, + Dinov2WithRegistersForImageClassification, + Dinov2WithRegistersModel, +) +from transformers.image_utils import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, PILImageResampling +from transformers.utils import logging + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + +def get_dinov2_with_registers_config(model_name, image_classifier=False): + config = Dinov2WithRegistersConfig(image_size=518, patch_size=14) + + # size of the architecture + if "vits" in model_name: + config.hidden_size = 384 + config.num_attention_heads = 6 + elif "vitb" in model_name: + pass + elif "vitl" in model_name: + config.hidden_size = 1024 + config.num_hidden_layers = 24 + config.num_attention_heads = 16 + elif "vitg" in model_name: + config.use_swiglu_ffn = True + config.hidden_size = 1536 + config.num_hidden_layers = 40 + config.num_attention_heads = 24 + else: + raise ValueError("Model not supported") + + if image_classifier: + repo_id = "huggingface/label-files" + filename = "imagenet-1k-id2label.json" + config.num_labels = 1000 + config.id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r")) + config.id2label = {int(k): v for k, v in config.id2label.items()} + + return config + + +def create_rename_keys(config): + rename_keys = [] + # fmt: off + + # patch embedding layer + rename_keys.append(("cls_token", "embeddings.cls_token")) + rename_keys.append(("mask_token", "embeddings.mask_token")) + rename_keys.append(("pos_embed", "embeddings.position_embeddings")) + rename_keys.append(("register_tokens", "embeddings.register_tokens")) + rename_keys.append(("patch_embed.proj.weight", "embeddings.patch_embeddings.projection.weight")) + rename_keys.append(("patch_embed.proj.bias", "embeddings.patch_embeddings.projection.bias")) + + for i in range(config.num_hidden_layers): + # layernorms + rename_keys.append((f"blocks.{i}.norm1.weight", f"encoder.layer.{i}.norm1.weight")) + rename_keys.append((f"blocks.{i}.norm1.bias", f"encoder.layer.{i}.norm1.bias")) + rename_keys.append((f"blocks.{i}.norm2.weight", f"encoder.layer.{i}.norm2.weight")) + rename_keys.append((f"blocks.{i}.norm2.bias", f"encoder.layer.{i}.norm2.bias")) + # MLP + if config.use_swiglu_ffn: + rename_keys.append((f"blocks.{i}.mlp.w12.weight", f"encoder.layer.{i}.mlp.w12.weight")) + rename_keys.append((f"blocks.{i}.mlp.w12.bias", f"encoder.layer.{i}.mlp.w12.bias")) + rename_keys.append((f"blocks.{i}.mlp.w3.weight", f"encoder.layer.{i}.mlp.w3.weight")) + rename_keys.append((f"blocks.{i}.mlp.w3.bias", f"encoder.layer.{i}.mlp.w3.bias")) + else: + rename_keys.append((f"blocks.{i}.mlp.fc1.weight", f"encoder.layer.{i}.mlp.fc1.weight")) + rename_keys.append((f"blocks.{i}.mlp.fc1.bias", f"encoder.layer.{i}.mlp.fc1.bias")) + rename_keys.append((f"blocks.{i}.mlp.fc2.weight", f"encoder.layer.{i}.mlp.fc2.weight")) + rename_keys.append((f"blocks.{i}.mlp.fc2.bias", f"encoder.layer.{i}.mlp.fc2.bias")) + # layerscale + rename_keys.append((f"blocks.{i}.ls1.gamma", f"encoder.layer.{i}.layer_scale1.lambda1")) + rename_keys.append((f"blocks.{i}.ls2.gamma", f"encoder.layer.{i}.layer_scale2.lambda1")) + # attention projection layer + rename_keys.append((f"blocks.{i}.attn.proj.weight", f"encoder.layer.{i}.attention.output.dense.weight")) + rename_keys.append((f"blocks.{i}.attn.proj.bias", f"encoder.layer.{i}.attention.output.dense.bias")) + + # final layernorm + rename_keys.append(("norm.weight", "layernorm.weight")) + rename_keys.append(("norm.bias", "layernorm.bias")) + + # fmt: on + return rename_keys + + +def rename_key(dct, old, new): + val = dct.pop(old) + dct[new] = val + + +# we split up the matrix of each encoder layer into queries, keys and values +def read_in_q_k_v(state_dict, config): + for i in range(config.num_hidden_layers): + # read in weights + bias of input projection layer (in timm, this is a single matrix + bias) + in_proj_weight = state_dict.pop(f"blocks.{i}.attn.qkv.weight") + in_proj_bias = state_dict.pop(f"blocks.{i}.attn.qkv.bias") + # next, add query, keys and values (in that order) to the state dict + state_dict[f"encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[: config.hidden_size, :] + state_dict[f"encoder.layer.{i}.attention.attention.query.bias"] = in_proj_bias[: config.hidden_size] + state_dict[f"encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[ + config.hidden_size : config.hidden_size * 2, : + ] + state_dict[f"encoder.layer.{i}.attention.attention.key.bias"] = in_proj_bias[ + config.hidden_size : config.hidden_size * 2 + ] + state_dict[f"encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[-config.hidden_size :, :] + state_dict[f"encoder.layer.{i}.attention.attention.value.bias"] = in_proj_bias[-config.hidden_size :] + + +# We will verify our results on an image of cute cats +def prepare_img(): + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + image = Image.open(requests.get(url, stream=True).raw).convert("RGB") + return image + + +@torch.no_grad() +def convert_dinov2_with_registers_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=False): + """ + Copy/paste/tweak model's weights to our Dinov2WithRegisters structure. + """ + + # define default Dinov2WithRegisters configuration + image_classifier = "1layer" in model_name + config = get_dinov2_with_registers_config(model_name, image_classifier=image_classifier) + + # load original model from torch hub + original_model = torch.hub.load("facebookresearch/dinov2", model_name.replace("_1layer", "")) + original_model.eval() + + # load state_dict of original model, remove and rename some keys + state_dict = original_model.state_dict() + rename_keys = create_rename_keys(config) + for src, dest in rename_keys: + rename_key(state_dict, src, dest) + read_in_q_k_v(state_dict, config) + + for key, val in state_dict.copy().items(): + val = state_dict.pop(key) + if "w12" in key: + key = key.replace("w12", "weights_in") + if "w3" in key: + key = key.replace("w3", "weights_out") + state_dict[key] = val + + # load HuggingFace model + if image_classifier: + model = Dinov2WithRegistersForImageClassification(config).eval() + model.dinov2_with_registers.load_state_dict(state_dict) + model_name_to_classifier_dict_url = { + "dinov2_vits14_reg_1layer": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_reg4_linear_head.pth", + "dinov2_vitb14_reg_1layer": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_reg4_linear_head.pth", + "dinov2_vitl14_reg_1layer": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_reg4_linear_head.pth", + "dinov2_vitg14_reg_1layer": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_reg4_linear_head.pth", + } + url = model_name_to_classifier_dict_url[model_name] + classifier_state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu") + model.classifier.weight = nn.Parameter(classifier_state_dict["weight"]) + model.classifier.bias = nn.Parameter(classifier_state_dict["bias"]) + else: + model = Dinov2WithRegistersModel(config).eval() + model.load_state_dict(state_dict) + + # load image + image = prepare_img() + + # preprocess image + transformations = transforms.Compose( + [ + transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize( + mean=IMAGENET_DEFAULT_MEAN, # these are RGB mean+std values + std=IMAGENET_DEFAULT_STD, # across a large photo dataset. + ), + ] + ) + + original_pixel_values = transformations(image).unsqueeze(0) # insert batch dimension + + processor = BitImageProcessor( + size={"shortest_edge": 256}, + resample=PILImageResampling.BICUBIC, + image_mean=IMAGENET_DEFAULT_MEAN, + image_std=IMAGENET_DEFAULT_STD, + ) + pixel_values = processor(image, return_tensors="pt").pixel_values + + assert torch.allclose(original_pixel_values, pixel_values) + + with torch.no_grad(): + outputs = model(pixel_values, output_hidden_states=True) + original_outputs = original_model(pixel_values) + + # assert values + if image_classifier: + print("Predicted class:") + class_idx = outputs.logits.argmax(-1).item() + print(model.config.id2label[class_idx]) + else: + assert outputs.last_hidden_state[:, 0].shape == original_outputs.shape + assert torch.allclose(outputs.last_hidden_state[:, 0], original_outputs, atol=1e-3) + print("Looks ok!") + + if pytorch_dump_folder_path is not None: + Path(pytorch_dump_folder_path).mkdir(exist_ok=True) + print(f"Saving model {model_name} to {pytorch_dump_folder_path}") + model.save_pretrained(pytorch_dump_folder_path) + print(f"Saving image processor to {pytorch_dump_folder_path}") + processor.save_pretrained(pytorch_dump_folder_path) + + if push_to_hub: + model_name_to_hf_name = { + "dinov2_vits14_reg": "dinov2-with-registers-small", + "dinov2_vitb14_reg": "dinov2-with-registers-base", + "dinov2_vitl14_reg": "dinov2-with-registers-large", + "dinov2_vitg14_reg": "dinov2-with-registers-giant", + "dinov2_vits14_reg_1layer": "dinov2-with-registers-small-imagenet1k-1-layer", + "dinov2_vitb14_reg_1layer": "dinov2-with-registers-base-imagenet1k-1-layer", + "dinov2_vitl14_reg_1layer": "dinov2-with-registers-large-imagenet1k-1-layer", + "dinov2_vitg14_reg_1layer": "dinov2-with-registers-giant-imagenet1k-1-layer", + } + + name = model_name_to_hf_name[model_name] + model.push_to_hub(f"nielsr/{name}") + processor.push_to_hub(f"nielsr/{name}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--model_name", + default="dinov2_vits14_reg", + type=str, + choices=[ + "dinov2_vits14_reg", + "dinov2_vitb14_reg", + "dinov2_vitl14_reg", + "dinov2_vitg14_reg", + "dinov2_vits14_reg_1layer", + "dinov2_vitb14_reg_1layer", + "dinov2_vitl14_reg_1layer", + "dinov2_vitg14_reg_1layer", + ], + help="Name of the model you'd like to convert.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory." + ) + parser.add_argument( + "--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub." + ) + + args = parser.parse_args() + convert_dinov2_with_registers_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub) diff --git a/docs/transformers/build/lib/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py b/docs/transformers/build/lib/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py new file mode 100644 index 0000000000000000000000000000000000000000..449bfb9b91cdcaf665c6a3e0ac4cad828d5779eb --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py @@ -0,0 +1,930 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_dinov2_with_registers.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2024 Meta Inc. and the HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections.abc +from typing import Callable, Dict, List, Optional, Set, Tuple, Union + +import torch +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...modeling_outputs import BackboneOutput, BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, + torch_int, +) +from ...utils.backbone_utils import BackboneMixin +from .configuration_dinov2_with_registers import Dinov2WithRegistersConfig + + +logger = logging.get_logger(__name__) + +# Base docstring +_CHECKPOINT_FOR_DOC = "facebook/dinov2_with_registers-base" + +# General docstring +_CONFIG_FOR_DOC = "Dinov2WithRegistersConfig" + + +class Dinov2WithRegistersPatchEmbeddings(nn.Module): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__(self, config): + super().__init__() + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.hidden_size + + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + + self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + num_channels = pixel_values.shape[1] + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + f" Expected {self.num_channels} but got {num_channels}." + ) + embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2) + return embeddings + + +class Dinov2WithRegistersEmbeddings(nn.Module): + """ + Construct the CLS token, mask token, register tokens, position and patch embeddings. + """ + + def __init__(self, config: Dinov2WithRegistersConfig) -> None: + super().__init__() + + self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size)) + self.mask_token = nn.Parameter(torch.zeros(1, config.hidden_size)) + self.register_tokens = nn.Parameter(torch.zeros(1, config.num_register_tokens, config.hidden_size)) + self.patch_embeddings = Dinov2WithRegistersPatchEmbeddings(config) + num_patches = self.patch_embeddings.num_patches + self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size)) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.patch_size = config.patch_size + self.config = config + + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher + resolution images. This implementation supports torch.jit tracing while maintaining backwards compatibility + with the original implementation. + + Adapted from: + - https://github.com/facebookresearch/dino/blob/main/vision_transformer.py + - https://github.com/facebookresearch/dinov2/blob/main/dinov2/models/vision_transformer.py + """ + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embeddings.shape[1] - 1 + + # Skip interpolation for matching dimensions (unless tracing) + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: + return self.position_embeddings + + # Handle class token and patch embeddings separately + class_pos_embed = self.position_embeddings[:, 0] + patch_pos_embed = self.position_embeddings[:, 1:] + dim = embeddings.shape[-1] + + # Calculate new dimensions + height = height // self.config.patch_size + width = width // self.config.patch_size + + # Reshape for interpolation + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + + # Store original dtype for restoration after interpolation + target_dtype = patch_pos_embed.dtype + + # Interpolate at float32 precision + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.to(dtype=torch.float32), + size=(torch_int(height), torch_int(width)), # Explicit size instead of scale_factor + mode="bicubic", + align_corners=False, + antialias=True, + ).to(dtype=target_dtype) + + # Validate output dimensions if not tracing + if not torch.jit.is_tracing(): + if int(height) != patch_pos_embed.shape[-2] or int(width) != patch_pos_embed.shape[-1]: + raise ValueError("Width or height does not match with the interpolated position embeddings") + + # Reshape back to original format + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + + # Combine class and patch embeddings + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Tensor] = None) -> torch.Tensor: + batch_size, _, height, width = pixel_values.shape + target_dtype = self.patch_embeddings.projection.weight.dtype + embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype)) + + if bool_masked_pos is not None: + embeddings = torch.where( + bool_masked_pos.unsqueeze(-1), self.mask_token.to(embeddings.dtype).unsqueeze(0), embeddings + ) + + # add the [CLS] token to the embedded patch tokens + cls_tokens = self.cls_token.expand(batch_size, -1, -1) + embeddings = torch.cat((cls_tokens, embeddings), dim=1) + + # add positional encoding to each token + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + + # add register tokens + embeddings = torch.cat( + (embeddings[:, :1], self.register_tokens.expand(embeddings.shape[0], -1, -1), embeddings[:, 1:]), dim=1 + ) + + embeddings = self.dropout(embeddings) + + return embeddings + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + # Take the dot product between "query" and "key" to get the raw attention scores. + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + + # Normalize the attention scores to probabilities. + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + # Mask heads if we want to + if attention_mask is not None: + attn_weights = attn_weights * attention_mask + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class Dinov2WithRegistersSelfAttention(nn.Module): + def __init__(self, config: Dinov2WithRegistersConfig) -> None: + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size {config.hidden_size} is not a multiple of the number of attention " + f"heads {config.num_attention_heads}." + ) + + self.config = config + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.dropout_prob = config.attention_probs_dropout_prob + self.scaling = self.attention_head_size**-0.5 + self.is_causal = False + + self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(self.query(hidden_states)) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + context_layer, attention_probs = attention_interface( + self, + query_layer, + key_layer, + value_layer, + head_mask, + is_causal=self.is_causal, + scaling=self.scaling, + dropout=0.0 if not self.training else self.dropout_prob, + ) + + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.reshape(new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +class Dinov2WithRegistersSelfOutput(nn.Module): + """ + The residual connection is defined in Dinov2WithRegistersLayer instead of here (as is the case with other models), due to the + layernorm applied before each block. + """ + + def __init__(self, config: Dinov2WithRegistersConfig) -> None: + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + return hidden_states + + +class Dinov2WithRegistersAttention(nn.Module): + def __init__(self, config: Dinov2WithRegistersConfig) -> None: + super().__init__() + self.attention = Dinov2WithRegistersSelfAttention(config) + self.output = Dinov2WithRegistersSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads: Set[int]) -> None: + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.attention.query = prune_linear_layer(self.attention.query, index) + self.attention.key = prune_linear_layer(self.attention.key, index) + self.attention.value = prune_linear_layer(self.attention.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads) + self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + self_outputs = self.attention(hidden_states, head_mask, output_attentions) + + attention_output = self.output(self_outputs[0], hidden_states) + + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class Dinov2WithRegistersLayerScale(nn.Module): + def __init__(self, config) -> None: + super().__init__() + self.lambda1 = nn.Parameter(config.layerscale_value * torch.ones(config.hidden_size)) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + return hidden_state * self.lambda1 + + +def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, + however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the + layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the + argument. + """ + if drop_prob == 0.0 or not training: + return input + keep_prob = 1 - drop_prob + shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) + random_tensor.floor_() # binarize + output = input.div(keep_prob) * random_tensor + return output + + +class Dinov2WithRegistersDropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: Optional[float] = None) -> None: + super().__init__() + self.drop_prob = drop_prob + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return drop_path(hidden_states, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return "p={}".format(self.drop_prob) + + +class Dinov2WithRegistersMLP(nn.Module): + def __init__(self, config) -> None: + super().__init__() + in_features = out_features = config.hidden_size + hidden_features = int(config.hidden_size * config.mlp_ratio) + self.fc1 = nn.Linear(in_features, hidden_features, bias=True) + if isinstance(config.hidden_act, str): + self.activation = ACT2FN[config.hidden_act] + else: + self.activation = config.hidden_act + self.fc2 = nn.Linear(hidden_features, out_features, bias=True) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + hidden_state = self.fc1(hidden_state) + hidden_state = self.activation(hidden_state) + hidden_state = self.fc2(hidden_state) + return hidden_state + + +class Dinov2WithRegistersSwiGLUFFN(nn.Module): + def __init__(self, config) -> None: + super().__init__() + in_features = out_features = config.hidden_size + hidden_features = int(config.hidden_size * config.mlp_ratio) + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + + self.weights_in = nn.Linear(in_features, 2 * hidden_features, bias=True) + self.weights_out = nn.Linear(hidden_features, out_features, bias=True) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + hidden_state = self.weights_in(hidden_state) + x1, x2 = hidden_state.chunk(2, dim=-1) + hidden = nn.functional.silu(x1) * x2 + return self.weights_out(hidden) + + +class Dinov2WithRegistersLayer(nn.Module): + """This corresponds to the Block class in the original implementation.""" + + def __init__(self, config: Dinov2WithRegistersConfig) -> None: + super().__init__() + + self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.attention = Dinov2WithRegistersAttention(config) + self.layer_scale1 = Dinov2WithRegistersLayerScale(config) + self.drop_path = ( + Dinov2WithRegistersDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity() + ) + + self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + if config.use_swiglu_ffn: + self.mlp = Dinov2WithRegistersSwiGLUFFN(config) + else: + self.mlp = Dinov2WithRegistersMLP(config) + self.layer_scale2 = Dinov2WithRegistersLayerScale(config) + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + self_attention_outputs = self.attention( + self.norm1(hidden_states), # in Dinov2WithRegisters, layernorm is applied before self-attention + head_mask, + output_attentions=output_attentions, + ) + attention_output = self_attention_outputs[0] + + attention_output = self.layer_scale1(attention_output) + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + # first residual connection + hidden_states = self.drop_path(attention_output) + hidden_states + + # in Dinov2WithRegisters, layernorm is also applied after self-attention + layer_output = self.norm2(hidden_states) + layer_output = self.mlp(layer_output) + layer_output = self.layer_scale2(layer_output) + + # second residual connection + layer_output = self.drop_path(layer_output) + hidden_states + + outputs = (layer_output,) + outputs + + return outputs + + +class Dinov2WithRegistersEncoder(nn.Module): + def __init__(self, config: Dinov2WithRegistersConfig) -> None: + super().__init__() + self.config = config + self.layer = nn.ModuleList([Dinov2WithRegistersLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ) -> Union[tuple, BaseModelOutput]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + layer_head_mask, + output_attentions, + ) + else: + layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class Dinov2WithRegistersPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = Dinov2WithRegistersConfig + base_model_prefix = "dinov2_with_registers" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + _no_split_modules = ["Dinov2WithRegistersSwiGLUFFN"] + _supports_sdpa = True + _supports_flash_attn_2 = True + + def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid + # `trunc_normal_cpu` not implemented in `half` issues + module.weight.data = nn.init.trunc_normal_( + module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range + ).to(module.weight.dtype) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, Dinov2WithRegistersEmbeddings): + module.position_embeddings.data = nn.init.trunc_normal_( + module.position_embeddings.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.position_embeddings.dtype) + + module.cls_token.data = nn.init.trunc_normal_( + module.cls_token.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.cls_token.dtype) + + module.mask_token.data.zero_() + module.register_tokens.data.zero_() + elif isinstance(module, Dinov2WithRegistersLayerScale): # noqa: F821 + module.lambda1.data.fill_(self.config.layerscale_value) + + +_EXPECTED_OUTPUT_SHAPE = [1, 257, 768] + + +DINOV2_WITH_REGISTERS_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it + as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`Dinov2WithRegistersConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +DINOV2_WITH_REGISTERS_BASE_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See + [`BitImageProcessor.preprocess`] for details. + + bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, sequence_length)`): + Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). Only relevant for + pre-training. + + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare Dinov2WithRegisters Model transformer outputting raw hidden-states without any specific head on top.", + DINOV2_WITH_REGISTERS_START_DOCSTRING, +) +class Dinov2WithRegistersModel(Dinov2WithRegistersPreTrainedModel): + def __init__(self, config: Dinov2WithRegistersConfig): + super().__init__(config) + self.config = config + + self.embeddings = Dinov2WithRegistersEmbeddings(config) + self.encoder = Dinov2WithRegistersEncoder(config) + + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> Dinov2WithRegistersPatchEmbeddings: + return self.embeddings.patch_embeddings + + def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None: + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(DINOV2_WITH_REGISTERS_BASE_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPooling, + config_class=_CONFIG_FOR_DOC, + modality="vision", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos) + + encoder_outputs = self.encoder( + embedding_output, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + pooled_output = sequence_output[:, 0, :] + + if not return_dict: + head_outputs = (sequence_output, pooled_output) + return head_outputs + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +# Image classification docstring +_IMAGE_CLASS_CHECKPOINT = "facebook/dinov2_with_registers-small-imagenet1k-1-layer" +_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat" + +DINOV2_WITH_REGISTERS_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See + [`BitImageProcessor.preprocess`] for details. + + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + """ + Dinov2WithRegisters Model transformer with an image classification head on top (a linear layer on top of the final hidden state + of the [CLS] token) e.g. for ImageNet. + """, + DINOV2_WITH_REGISTERS_START_DOCSTRING, +) +class Dinov2WithRegistersForImageClassification(Dinov2WithRegistersPreTrainedModel): + def __init__(self, config: Dinov2WithRegistersConfig) -> None: + super().__init__(config) + + self.num_labels = config.num_labels + self.dinov2_with_registers = Dinov2WithRegistersModel(config) + + # Classifier head + self.classifier = ( + nn.Linear(config.hidden_size * 2, config.num_labels) if config.num_labels > 0 else nn.Identity() + ) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(DINOV2_WITH_REGISTERS_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_IMAGE_CLASS_CHECKPOINT, + output_type=ImageClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, + ) + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, ImageClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.dinov2_with_registers( + pixel_values, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] # batch_size, sequence_length, hidden_size + + cls_token = sequence_output[:, 0] + patch_tokens = sequence_output[:, 1:] + + linear_input = torch.cat([cls_token, patch_tokens.mean(dim=1)], dim=1) + + logits = self.classifier(linear_input) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Dinov2WithRegisters backbone, to be used with frameworks like DETR and MaskFormer. + """, + DINOV2_WITH_REGISTERS_START_DOCSTRING, +) +class Dinov2WithRegistersBackbone(Dinov2WithRegistersPreTrainedModel, BackboneMixin): + def __init__(self, config): + super().__init__(config) + super()._init_backbone(config) + self.num_features = [config.hidden_size for _ in range(config.num_hidden_layers + 1)] + self.embeddings = Dinov2WithRegistersEmbeddings(config) + self.encoder = Dinov2WithRegistersEncoder(config) + + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.num_register_tokens = config.num_register_tokens + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> Dinov2WithRegistersPatchEmbeddings: + return self.embeddings.patch_embeddings + + @add_start_docstrings_to_model_forward(DINOV2_WITH_REGISTERS_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: torch.Tensor, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> BackboneOutput: + """ + Returns: + + Examples: + Returns: + + Examples: + + + ```python + >>> from transformers import AutoImageProcessor, AutoBackbone + >>> import torch + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> processor = AutoImageProcessor.from_pretrained("facebook/dinov2-with-registers-base") + >>> model = AutoBackbone.from_pretrained( + ... "facebook/dinov2-with-registers-base", out_features=["stage2", "stage5", "stage8", "stage11"] + ... ) + + >>> inputs = processor(image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> feature_maps = outputs.feature_maps + >>> list(feature_maps[-1].shape) + [1, 768, 16, 16] + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + embedding_output = self.embeddings(pixel_values) + + outputs = self.encoder( + embedding_output, output_hidden_states=True, output_attentions=output_attentions, return_dict=return_dict + ) + + hidden_states = outputs.hidden_states if return_dict else outputs[1] + + feature_maps = () + for stage, hidden_state in zip(self.stage_names, hidden_states): + if stage in self.out_features: + if self.config.apply_layernorm: + hidden_state = self.layernorm(hidden_state) + if self.config.reshape_hidden_states: + hidden_state = hidden_state[:, self.num_register_tokens + 1 :] + # this was actually a bug in the original implementation that we copied here, + # cause normally the order is height, width + batch_size, _, height, width = pixel_values.shape + patch_size = self.config.patch_size + hidden_state = hidden_state.reshape(batch_size, height // patch_size, width // patch_size, -1) + hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous() + feature_maps += (hidden_state,) + + if not return_dict: + if output_hidden_states: + output = (feature_maps,) + outputs[1:] + else: + output = (feature_maps,) + outputs[2:] + return output + + return BackboneOutput( + feature_maps=feature_maps, + hidden_states=outputs.hidden_states if output_hidden_states else None, + attentions=outputs.attentions if output_attentions else None, + ) + + +__all__ = [ + "Dinov2WithRegistersPreTrainedModel", + "Dinov2WithRegistersModel", + "Dinov2WithRegistersForImageClassification", + "Dinov2WithRegistersBackbone", +] diff --git a/docs/transformers/build/lib/transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py b/docs/transformers/build/lib/transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py new file mode 100644 index 0000000000000000000000000000000000000000..59777e2157894e7641c8512c7ff50febb75aed43 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py @@ -0,0 +1,420 @@ +# coding=utf-8 +# Copyright 2024 Meta Inc. and the HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Union + +import torch +import torch.utils.checkpoint +from torch import nn + +from ....transformers.models.dinov2.modeling_dinov2 import ( + Dinov2Backbone, + Dinov2Encoder, + Dinov2ForImageClassification, + Dinov2Model, + Dinov2PatchEmbeddings, + Dinov2PreTrainedModel, +) +from ...configuration_utils import PretrainedConfig +from ...modeling_outputs import BackboneOutput +from ...utils import logging, torch_int +from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices + + +logger = logging.get_logger(__name__) + + +class Dinov2WithRegistersConfig(BackboneConfigMixin, PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Dinov2WithRegistersModel`]. It is used to instantiate an + Dinov2WithRegisters model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the DINOv2 with Registers + [facebook/dinov2-with-registers-base](https://huggingface.co/facebook/dinov2-with-registers-base) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + mlp_ratio (`int`, *optional*, defaults to 4): + Ratio of the hidden size of the MLPs relative to the `hidden_size`. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch. + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + qkv_bias (`bool`, *optional*, defaults to `True`): + Whether to add a bias to the queries, keys and values. + layerscale_value (`float`, *optional*, defaults to 1.0): + Initial value to use for layer scale. + drop_path_rate (`float`, *optional*, defaults to 0.0): + Stochastic depth rate per sample (when applied in the main path of residual layers). + use_swiglu_ffn (`bool`, *optional*, defaults to `False`): + Whether to use the SwiGLU feedforward neural network. + num_register_tokens (`int`, *optional*, defaults to 4): + Number of register tokens to use. + out_features (`List[str]`, *optional*): + If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc. + (depending on how many stages the model has). If unset and `out_indices` is set, will default to the + corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the + same order as defined in the `stage_names` attribute. + out_indices (`List[int]`, *optional*): + If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how + many stages the model has). If unset and `out_features` is set, will default to the corresponding stages. + If unset and `out_features` is unset, will default to the last stage. Must be in the + same order as defined in the `stage_names` attribute. + apply_layernorm (`bool`, *optional*, defaults to `True`): + Whether to apply layer normalization to the feature maps in case the model is used as backbone. + reshape_hidden_states (`bool`, *optional*, defaults to `True`): + Whether to reshape the feature maps to 4D tensors of shape `(batch_size, hidden_size, height, width)` in + case the model is used as backbone. If `False`, the feature maps will be 3D tensors of shape `(batch_size, + seq_len, hidden_size)`. + + Example: + + ```python + >>> from transformers import Dinov2WithRegistersConfig, Dinov2WithRegistersModel + + >>> # Initializing a Dinov2WithRegisters base style configuration + >>> configuration = Dinov2WithRegistersConfig() + + >>> # Initializing a model (with random weights) from the base style configuration + >>> model = Dinov2WithRegistersModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "dinov2_with_registers" + + def __init__( + self, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + mlp_ratio=4, + hidden_act="gelu", + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + initializer_range=0.02, + layer_norm_eps=1e-6, + image_size=224, + patch_size=16, + num_channels=3, + qkv_bias=True, + layerscale_value=1.0, + drop_path_rate=0.0, + use_swiglu_ffn=False, + num_register_tokens=4, + out_features=None, + out_indices=None, + apply_layernorm=True, + reshape_hidden_states=True, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.mlp_ratio = mlp_ratio + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.qkv_bias = qkv_bias + self.layerscale_value = layerscale_value + self.drop_path_rate = drop_path_rate + self.use_swiglu_ffn = use_swiglu_ffn + self.num_register_tokens = num_register_tokens + self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, num_hidden_layers + 1)] + self._out_features, self._out_indices = get_aligned_output_features_output_indices( + out_features=out_features, out_indices=out_indices, stage_names=self.stage_names + ) + self.apply_layernorm = apply_layernorm + self.reshape_hidden_states = reshape_hidden_states + + +class Dinov2WithRegistersPatchEmbeddings(Dinov2PatchEmbeddings): + pass + + +class Dinov2WithRegistersEmbeddings(nn.Module): + """ + Construct the CLS token, mask token, register tokens, position and patch embeddings. + """ + + def __init__(self, config: Dinov2WithRegistersConfig) -> None: + super().__init__() + + self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size)) + self.mask_token = nn.Parameter(torch.zeros(1, config.hidden_size)) + self.register_tokens = nn.Parameter(torch.zeros(1, config.num_register_tokens, config.hidden_size)) + self.patch_embeddings = Dinov2WithRegistersPatchEmbeddings(config) + num_patches = self.patch_embeddings.num_patches + self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size)) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.patch_size = config.patch_size + self.config = config + + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher + resolution images. This implementation supports torch.jit tracing while maintaining backwards compatibility + with the original implementation. + + Adapted from: + - https://github.com/facebookresearch/dino/blob/main/vision_transformer.py + - https://github.com/facebookresearch/dinov2/blob/main/dinov2/models/vision_transformer.py + """ + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embeddings.shape[1] - 1 + + # Skip interpolation for matching dimensions (unless tracing) + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: + return self.position_embeddings + + # Handle class token and patch embeddings separately + class_pos_embed = self.position_embeddings[:, 0] + patch_pos_embed = self.position_embeddings[:, 1:] + dim = embeddings.shape[-1] + + # Calculate new dimensions + height = height // self.config.patch_size + width = width // self.config.patch_size + + # Reshape for interpolation + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + + # Store original dtype for restoration after interpolation + target_dtype = patch_pos_embed.dtype + + # Interpolate at float32 precision + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.to(dtype=torch.float32), + size=(torch_int(height), torch_int(width)), # Explicit size instead of scale_factor + mode="bicubic", + align_corners=False, + antialias=True, + ).to(dtype=target_dtype) + + # Validate output dimensions if not tracing + if not torch.jit.is_tracing(): + if int(height) != patch_pos_embed.shape[-2] or int(width) != patch_pos_embed.shape[-1]: + raise ValueError("Width or height does not match with the interpolated position embeddings") + + # Reshape back to original format + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + + # Combine class and patch embeddings + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Tensor] = None) -> torch.Tensor: + batch_size, _, height, width = pixel_values.shape + target_dtype = self.patch_embeddings.projection.weight.dtype + embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype)) + + if bool_masked_pos is not None: + embeddings = torch.where( + bool_masked_pos.unsqueeze(-1), self.mask_token.to(embeddings.dtype).unsqueeze(0), embeddings + ) + + # add the [CLS] token to the embedded patch tokens + cls_tokens = self.cls_token.expand(batch_size, -1, -1) + embeddings = torch.cat((cls_tokens, embeddings), dim=1) + + # add positional encoding to each token + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + + # add register tokens + embeddings = torch.cat( + (embeddings[:, :1], self.register_tokens.expand(embeddings.shape[0], -1, -1), embeddings[:, 1:]), dim=1 + ) + + embeddings = self.dropout(embeddings) + + return embeddings + + +class Dinov2WithRegistersEncoder(Dinov2Encoder): + pass + + +class Dinov2WithRegistersPreTrainedModel(Dinov2PreTrainedModel): + def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid + # `trunc_normal_cpu` not implemented in `half` issues + module.weight.data = nn.init.trunc_normal_( + module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range + ).to(module.weight.dtype) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, Dinov2WithRegistersEmbeddings): + module.position_embeddings.data = nn.init.trunc_normal_( + module.position_embeddings.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.position_embeddings.dtype) + + module.cls_token.data = nn.init.trunc_normal_( + module.cls_token.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.cls_token.dtype) + + module.mask_token.data.zero_() + module.register_tokens.data.zero_() + elif isinstance(module, Dinov2WithRegistersLayerScale): # noqa: F821 + module.lambda1.data.fill_(self.config.layerscale_value) + + +class Dinov2WithRegistersModel(Dinov2Model): + pass + + +class Dinov2WithRegistersForImageClassification(Dinov2ForImageClassification): + pass + + +class Dinov2WithRegistersBackbone(Dinov2Backbone): + def __init__(self, config): + super().__init__(config) + super()._init_backbone(config) + + self.num_register_tokens = config.num_register_tokens + self.num_features = [config.hidden_size for _ in range(config.num_hidden_layers + 1)] + self.embeddings = Dinov2WithRegistersEmbeddings(config) + self.encoder = Dinov2WithRegistersEncoder(config) + + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> Dinov2WithRegistersPatchEmbeddings: + return self.embeddings.patch_embeddings + + def forward( + self, + pixel_values: torch.Tensor, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> BackboneOutput: + """ + Returns: + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, AutoBackbone + >>> import torch + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> processor = AutoImageProcessor.from_pretrained("facebook/dinov2-with-registers-base") + >>> model = AutoBackbone.from_pretrained( + ... "facebook/dinov2-with-registers-base", out_features=["stage2", "stage5", "stage8", "stage11"] + ... ) + + >>> inputs = processor(image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> feature_maps = outputs.feature_maps + >>> list(feature_maps[-1].shape) + [1, 768, 16, 16] + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + embedding_output = self.embeddings(pixel_values) + + outputs = self.encoder( + embedding_output, output_hidden_states=True, output_attentions=output_attentions, return_dict=return_dict + ) + + hidden_states = outputs.hidden_states if return_dict else outputs[1] + + feature_maps = () + for stage, hidden_state in zip(self.stage_names, hidden_states): + if stage in self.out_features: + if self.config.apply_layernorm: + hidden_state = self.layernorm(hidden_state) + if self.config.reshape_hidden_states: + hidden_state = hidden_state[:, self.num_register_tokens + 1 :] + # this was actually a bug in the original implementation that we copied here, + # cause normally the order is height, width + batch_size, _, height, width = pixel_values.shape + patch_size = self.config.patch_size + hidden_state = hidden_state.reshape(batch_size, height // patch_size, width // patch_size, -1) + hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous() + feature_maps += (hidden_state,) + + if not return_dict: + if output_hidden_states: + output = (feature_maps,) + outputs[1:] + else: + output = (feature_maps,) + outputs[2:] + return output + + return BackboneOutput( + feature_maps=feature_maps, + hidden_states=outputs.hidden_states if output_hidden_states else None, + attentions=outputs.attentions if output_attentions else None, + ) + + +__all__ = [ + "Dinov2WithRegistersConfig", + "Dinov2WithRegistersPreTrainedModel", + "Dinov2WithRegistersModel", + "Dinov2WithRegistersForImageClassification", + "Dinov2WithRegistersBackbone", +] diff --git a/docs/transformers/build/lib/transformers/models/distilbert/__init__.py b/docs/transformers/build/lib/transformers/models/distilbert/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4d6fae2e0236e7619988f0cfa3502ed49d0f90b0 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/distilbert/__init__.py @@ -0,0 +1,31 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_distilbert import * + from .modeling_distilbert import * + from .modeling_flax_distilbert import * + from .modeling_tf_distilbert import * + from .tokenization_distilbert import * + from .tokenization_distilbert_fast import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/build/lib/transformers/models/distilbert/configuration_distilbert.py b/docs/transformers/build/lib/transformers/models/distilbert/configuration_distilbert.py new file mode 100644 index 0000000000000000000000000000000000000000..9a28c8e5d03d029d344c3f9f3294c9298a9fd808 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/distilbert/configuration_distilbert.py @@ -0,0 +1,141 @@ +# coding=utf-8 +# Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""DistilBERT model configuration""" + +from collections import OrderedDict +from typing import Mapping + +from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class DistilBertConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`DistilBertModel`] or a [`TFDistilBertModel`]. It + is used to instantiate a DistilBERT model according to the specified arguments, defining the model architecture. + Instantiating a configuration with the defaults will yield a similar configuration to that of the DistilBERT + [distilbert-base-uncased](https://huggingface.co/distilbert-base-uncased) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 30522): + Vocabulary size of the DistilBERT model. Defines the number of different tokens that can be represented by + the `inputs_ids` passed when calling [`DistilBertModel`] or [`TFDistilBertModel`]. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + sinusoidal_pos_embds (`boolean`, *optional*, defaults to `False`): + Whether to use sinusoidal positional embeddings. + n_layers (`int`, *optional*, defaults to 6): + Number of hidden layers in the Transformer encoder. + n_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + dim (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + hidden_dim (`int`, *optional*, defaults to 3072): + The size of the "intermediate" (often named feed-forward) layer in the Transformer encoder. + dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_dropout (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + activation (`str` or `Callable`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + qa_dropout (`float`, *optional*, defaults to 0.1): + The dropout probabilities used in the question answering model [`DistilBertForQuestionAnswering`]. + seq_classif_dropout (`float`, *optional*, defaults to 0.2): + The dropout probabilities used in the sequence classification and the multiple choice model + [`DistilBertForSequenceClassification`]. + + Examples: + + ```python + >>> from transformers import DistilBertConfig, DistilBertModel + + >>> # Initializing a DistilBERT configuration + >>> configuration = DistilBertConfig() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = DistilBertModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "distilbert" + attribute_map = { + "hidden_size": "dim", + "num_attention_heads": "n_heads", + "num_hidden_layers": "n_layers", + } + + def __init__( + self, + vocab_size=30522, + max_position_embeddings=512, + sinusoidal_pos_embds=False, + n_layers=6, + n_heads=12, + dim=768, + hidden_dim=4 * 768, + dropout=0.1, + attention_dropout=0.1, + activation="gelu", + initializer_range=0.02, + qa_dropout=0.1, + seq_classif_dropout=0.2, + pad_token_id=0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.sinusoidal_pos_embds = sinusoidal_pos_embds + self.n_layers = n_layers + self.n_heads = n_heads + self.dim = dim + self.hidden_dim = hidden_dim + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation = activation + self.initializer_range = initializer_range + self.qa_dropout = qa_dropout + self.seq_classif_dropout = seq_classif_dropout + super().__init__(**kwargs, pad_token_id=pad_token_id) + + +class DistilBertOnnxConfig(OnnxConfig): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + if self.task == "multiple-choice": + dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"} + else: + dynamic_axis = {0: "batch", 1: "sequence"} + return OrderedDict( + [ + ("input_ids", dynamic_axis), + ("attention_mask", dynamic_axis), + ] + ) + + +__all__ = ["DistilBertConfig", "DistilBertOnnxConfig"] diff --git a/docs/transformers/build/lib/transformers/models/distilbert/modeling_distilbert.py b/docs/transformers/build/lib/transformers/models/distilbert/modeling_distilbert.py new file mode 100644 index 0000000000000000000000000000000000000000..b78050a01aebb24f84ad3abb7b2a0925653f0541 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/distilbert/modeling_distilbert.py @@ -0,0 +1,1377 @@ +# coding=utf-8 +# Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +PyTorch DistilBERT model adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM) and in +part from HuggingFace PyTorch version of Google AI Bert model (https://github.com/google-research/bert) +""" + +import math +from typing import Dict, List, Optional, Set, Tuple, Union + +import numpy as np +import torch +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import get_activation +from ...configuration_utils import PretrainedConfig +from ...integrations.deepspeed import is_deepspeed_zero3_enabled +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa +from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_outputs import ( + BaseModelOutput, + MaskedLMOutput, + MultipleChoiceModelOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import ( + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + is_torch_greater_or_equal_than_2_2, + prune_linear_layer, +) +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from .configuration_distilbert import DistilBertConfig + + +if is_flash_attn_available(): + from ...modeling_flash_attention_utils import _flash_attention_forward + + +logger = logging.get_logger(__name__) +_CHECKPOINT_FOR_DOC = "distilbert-base-uncased" +_CONFIG_FOR_DOC = "DistilBertConfig" + + +# UTILS AND BUILDING BLOCKS OF THE ARCHITECTURE # + + +def create_sinusoidal_embeddings(n_pos: int, dim: int, out: torch.Tensor): + if is_deepspeed_zero3_enabled(): + import deepspeed + + with deepspeed.zero.GatheredParameters(out, modifier_rank=0): + if torch.distributed.get_rank() == 0: + _create_sinusoidal_embeddings(n_pos=n_pos, dim=dim, out=out) + else: + _create_sinusoidal_embeddings(n_pos=n_pos, dim=dim, out=out) + + +def _create_sinusoidal_embeddings(n_pos: int, dim: int, out: torch.Tensor): + position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]) + out.requires_grad = False + out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) + out[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2])) + out.detach_() + + +class Embeddings(nn.Module): + def __init__(self, config: PretrainedConfig): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.dim, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.dim) + + self.LayerNorm = nn.LayerNorm(config.dim, eps=1e-12) + self.dropout = nn.Dropout(config.dropout) + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + + def forward(self, input_ids: torch.Tensor, input_embeds: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Parameters: + input_ids (torch.Tensor): + torch.tensor(bs, max_seq_length) The token ids to embed. + input_embeds (*optional*, torch.Tensor): + The pre-computed word embeddings. Can only be passed if the input ids are `None`. + + + Returns: torch.tensor(bs, max_seq_length, dim) The embedded tokens (plus position embeddings, no token_type + embeddings) + """ + if input_ids is not None: + input_embeds = self.word_embeddings(input_ids) # (bs, max_seq_length, dim) + + seq_length = input_embeds.size(1) + + # Setting the position-ids to the registered buffer in constructor, it helps + # when tracing the model without passing position-ids, solves + # isues similar to issue #5664 + if hasattr(self, "position_ids"): + position_ids = self.position_ids[:, :seq_length] + else: + position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) # (max_seq_length) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) # (bs, max_seq_length) + + position_embeddings = self.position_embeddings(position_ids) # (bs, max_seq_length, dim) + + embeddings = input_embeds + position_embeddings # (bs, max_seq_length, dim) + embeddings = self.LayerNorm(embeddings) # (bs, max_seq_length, dim) + embeddings = self.dropout(embeddings) # (bs, max_seq_length, dim) + return embeddings + + +class MultiHeadSelfAttention(nn.Module): + def __init__(self, config: PretrainedConfig): + super().__init__() + self.config = config + + self.n_heads = config.n_heads + self.dim = config.dim + self.dropout = nn.Dropout(p=config.attention_dropout) + self.is_causal = False + + # Have an even number of multi heads that divide the dimensions + if self.dim % self.n_heads != 0: + # Raise value errors for even multi-head attention nodes + raise ValueError(f"self.n_heads: {self.n_heads} must divide self.dim: {self.dim} evenly") + + self.q_lin = nn.Linear(in_features=config.dim, out_features=config.dim) + self.k_lin = nn.Linear(in_features=config.dim, out_features=config.dim) + self.v_lin = nn.Linear(in_features=config.dim, out_features=config.dim) + self.out_lin = nn.Linear(in_features=config.dim, out_features=config.dim) + + self.pruned_heads: Set[int] = set() + self.attention_head_size = self.dim // self.n_heads + + def prune_heads(self, heads: List[int]): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.n_heads, self.attention_head_size, self.pruned_heads + ) + # Prune linear layers + self.q_lin = prune_linear_layer(self.q_lin, index) + self.k_lin = prune_linear_layer(self.k_lin, index) + self.v_lin = prune_linear_layer(self.v_lin, index) + self.out_lin = prune_linear_layer(self.out_lin, index, dim=1) + # Update hyper params + self.n_heads = self.n_heads - len(heads) + self.dim = self.attention_head_size * self.n_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + mask: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, ...]: + """ + Parameters: + query: torch.tensor(bs, seq_length, dim) + key: torch.tensor(bs, seq_length, dim) + value: torch.tensor(bs, seq_length, dim) + mask: torch.tensor(bs, seq_length) + + Returns: + weights: torch.tensor(bs, n_heads, seq_length, seq_length) Attention weights context: torch.tensor(bs, + seq_length, dim) Contextualized layer. Optional: only if `output_attentions=True` + """ + bs, q_length, dim = query.size() + k_length = key.size(1) + # assert dim == self.dim, f'Dimensions do not match: {dim} input vs {self.dim} configured' + # assert key.size() == value.size() + + dim_per_head = self.dim // self.n_heads + + mask_reshp = (bs, 1, 1, k_length) + + def shape(x: torch.Tensor) -> torch.Tensor: + """separate heads""" + return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2) + + def unshape(x: torch.Tensor) -> torch.Tensor: + """group heads""" + return x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head) + + q = shape(self.q_lin(query)) # (bs, n_heads, q_length, dim_per_head) + k = shape(self.k_lin(key)) # (bs, n_heads, k_length, dim_per_head) + v = shape(self.v_lin(value)) # (bs, n_heads, k_length, dim_per_head) + + q = q / math.sqrt(dim_per_head) # (bs, n_heads, q_length, dim_per_head) + scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length) + mask = (mask == 0).view(mask_reshp).expand_as(scores) # (bs, n_heads, q_length, k_length) + scores = scores.masked_fill( + mask, torch.tensor(torch.finfo(scores.dtype).min) + ) # (bs, n_heads, q_length, k_length) + + weights = nn.functional.softmax(scores, dim=-1) # (bs, n_heads, q_length, k_length) + weights = self.dropout(weights) # (bs, n_heads, q_length, k_length) + + # Mask heads if we want to + if head_mask is not None: + weights = weights * head_mask + + context = torch.matmul(weights, v) # (bs, n_heads, q_length, dim_per_head) + context = unshape(context) # (bs, q_length, dim) + context = self.out_lin(context) # (bs, q_length, dim) + + if output_attentions: + return (context, weights) + else: + return (context,) + + +class DistilBertFlashAttention2(MultiHeadSelfAttention): + """ + DistilBert flash attention module. This module inherits from `MultiHeadSelfAttention` as the weights of the module + stays untouched. The only required change would be on the forward pass where it needs to correctly call the public + API of flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + mask: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, ...]: + """ + Parameters: + query: torch.tensor(bs, seq_length, dim) + key: torch.tensor(bs, seq_length, dim) + value: torch.tensor(bs, seq_length, dim) + mask: torch.tensor(bs, seq_length) + + Returns: + weights: torch.tensor(bs, n_heads, seq_length, seq_length) Attention weights context: torch.tensor(bs, + seq_length, dim) Contextualized layer. Optional: only if `output_attentions=True` + """ + batch_size, q_length, dim = query.size() + + dim_per_head = self.dim // self.n_heads + + def reshape(x: torch.Tensor) -> torch.Tensor: + """separate heads""" + return x.view(batch_size, -1, self.n_heads, dim_per_head) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + query_states = reshape(self.q_lin(query)) + key_states = reshape(self.k_lin(key)) + value_states = reshape(self.v_lin(value)) + + attn_dropout = self.config.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + if query_states.dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_lin.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_weights = _flash_attention_forward( + query_states, + key_states, + value_states, + mask, + q_length, + dropout=attn_dropout, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + ) + + attn_weights_reshaped = attn_weights.reshape(batch_size, q_length, self.n_heads * dim_per_head) + attn_output = self.out_lin(attn_weights_reshaped) + + if output_attentions: + return (attn_output, attn_weights) + else: + return (attn_output,) + + +class DistilBertSdpaAttention(MultiHeadSelfAttention): + def __init__(self, config: PretrainedConfig): + super().__init__(config=config) + self.dropout_prob = config.attention_dropout + self.require_contiguous_qkv = not is_torch_greater_or_equal_than_2_2 + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + mask: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, ...]: + """ + Parameters: + query: torch.tensor(bs, seq_length, dim) + key: torch.tensor(bs, seq_length, dim) + value: torch.tensor(bs, seq_length, dim) + mask: torch.tensor(bs, seq_length) + + Returns: + weights: torch.tensor(bs, n_heads, seq_length, seq_length) Attention weights context: torch.tensor(bs, + seq_length, dim) Contextualized layer. Optional: only if `output_attentions=True` + """ + if output_attentions or head_mask is not None: + logger.warning_once( + "DistilBertSdpaAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support" + " `output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but specifying" + " the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be" + ' removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + query, + key, + value, + mask, + head_mask, + output_attentions, + ) + + batch_size, _, _ = query.size() + dim_per_head = self.dim // self.n_heads + + def shape(x: torch.Tensor) -> torch.Tensor: + """separate heads""" + return x.view(batch_size, -1, self.n_heads, dim_per_head).transpose(1, 2) + + def unshape(x: torch.Tensor) -> torch.Tensor: + """group heads""" + return x.transpose(1, 2).contiguous().view(batch_size, -1, self.n_heads * dim_per_head) + + q = shape(self.q_lin(query)) # (bs, n_heads, q_length, dim_per_head) + k = shape(self.k_lin(key)) # (bs, n_heads, k_length, dim_per_head) + v = shape(self.v_lin(value)) # (bs, n_heads, k_length, dim_per_head) + + # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom + # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0. + # Reference: https://github.com/pytorch/pytorch/issues/112577 + if self.require_contiguous_qkv and q.device.type == "cuda" and mask is not None: + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + + attn_output = torch.nn.functional.scaled_dot_product_attention( + q, + k, + v, + attn_mask=mask, + dropout_p=self.dropout_prob if self.training else 0.0, + is_causal=False, + ) + + attn_output = unshape(attn_output) + attn_output = self.out_lin(attn_output) + + return (attn_output,) + + +class FFN(nn.Module): + def __init__(self, config: PretrainedConfig): + super().__init__() + self.dropout = nn.Dropout(p=config.dropout) + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.lin1 = nn.Linear(in_features=config.dim, out_features=config.hidden_dim) + self.lin2 = nn.Linear(in_features=config.hidden_dim, out_features=config.dim) + self.activation = get_activation(config.activation) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return apply_chunking_to_forward(self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, input) + + def ff_chunk(self, input: torch.Tensor) -> torch.Tensor: + x = self.lin1(input) + x = self.activation(x) + x = self.lin2(x) + x = self.dropout(x) + return x + + +DISTILBERT_ATTENTION_CLASSES = { + "eager": MultiHeadSelfAttention, + "flash_attention_2": DistilBertFlashAttention2, + "sdpa": DistilBertSdpaAttention, +} + + +class TransformerBlock(nn.Module): + def __init__(self, config: PretrainedConfig): + super().__init__() + + # Have an even number of Configure multi-heads + if config.dim % config.n_heads != 0: + raise ValueError(f"config.n_heads {config.n_heads} must divide config.dim {config.dim} evenly") + + self.attention = DISTILBERT_ATTENTION_CLASSES[config._attn_implementation](config) + self.sa_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12) + + self.ffn = FFN(config) + self.output_layer_norm = nn.LayerNorm(normalized_shape=config.dim, eps=1e-12) + + def forward( + self, + x: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, ...]: + """ + Parameters: + x: torch.tensor(bs, seq_length, dim) + attn_mask: torch.tensor(bs, seq_length) + + Returns: + sa_weights: torch.tensor(bs, n_heads, seq_length, seq_length) The attention weights ffn_output: + torch.tensor(bs, seq_length, dim) The output of the transformer block contextualization. + """ + # Self-Attention + sa_output = self.attention( + query=x, + key=x, + value=x, + mask=attn_mask, + head_mask=head_mask, + output_attentions=output_attentions, + ) + if output_attentions: + sa_output, sa_weights = sa_output # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length) + else: # To handle these `output_attentions` or `output_hidden_states` cases returning tuples + if type(sa_output) is not tuple: + raise TypeError(f"sa_output must be a tuple but it is {type(sa_output)} type") + + sa_output = sa_output[0] + sa_output = self.sa_layer_norm(sa_output + x) # (bs, seq_length, dim) + + # Feed Forward Network + ffn_output = self.ffn(sa_output) # (bs, seq_length, dim) + ffn_output: torch.Tensor = self.output_layer_norm(ffn_output + sa_output) # (bs, seq_length, dim) + + output = (ffn_output,) + if output_attentions: + output = (sa_weights,) + output + return output + + +class Transformer(nn.Module): + def __init__(self, config: PretrainedConfig): + super().__init__() + self.n_layers = config.n_layers + self.layer = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + x: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: Optional[bool] = None, + ) -> Union[BaseModelOutput, Tuple[torch.Tensor, ...]]: # docstyle-ignore + """ + Parameters: + x: torch.tensor(bs, seq_length, dim) Input sequence embedded. + attn_mask: torch.tensor(bs, seq_length) Attention mask on the sequence. + + Returns: + hidden_state: torch.tensor(bs, seq_length, dim) Sequence of hidden states in the last (top) + layer all_hidden_states: Tuple[torch.tensor(bs, seq_length, dim)] + Tuple of length n_layers with the hidden states from each layer. + Optional: only if output_hidden_states=True + all_attentions: Tuple[torch.tensor(bs, n_heads, seq_length, seq_length)] + Tuple of length n_layers with the attention weights from each layer + Optional: only if output_attentions=True + """ + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_state = x + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_state,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.__call__, + hidden_state, + attn_mask, + head_mask[i], + output_attentions, + ) + else: + layer_outputs = layer_module( + hidden_state, + attn_mask, + head_mask[i], + output_attentions, + ) + + hidden_state = layer_outputs[-1] + + if output_attentions: + if len(layer_outputs) != 2: + raise ValueError(f"The length of the layer_outputs should be 2, but it is {len(layer_outputs)}") + + attentions = layer_outputs[0] + all_attentions = all_attentions + (attentions,) + else: + if len(layer_outputs) != 1: + raise ValueError(f"The length of the layer_outputs should be 1, but it is {len(layer_outputs)}") + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_state,) + + if not return_dict: + return tuple(v for v in [hidden_state, all_hidden_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_state, hidden_states=all_hidden_states, attentions=all_attentions + ) + + +# INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL # +class DistilBertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = DistilBertConfig + load_tf_weights = None + base_model_prefix = "distilbert" + supports_gradient_checkpointing = True + _supports_flash_attn_2 = True + _supports_sdpa = True + + def _init_weights(self, module: nn.Module): + """Initialize the weights.""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, Embeddings) and self.config.sinusoidal_pos_embds: + create_sinusoidal_embeddings( + self.config.max_position_embeddings, self.config.dim, module.position_embeddings.weight + ) + + +DISTILBERT_START_DOCSTRING = r""" + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`DistilBertConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +DISTILBERT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare DistilBERT encoder/transformer outputting raw hidden-states without any specific head on top.", + DISTILBERT_START_DOCSTRING, +) +class DistilBertModel(DistilBertPreTrainedModel): + def __init__(self, config: PretrainedConfig): + super().__init__(config) + + self.embeddings = Embeddings(config) # Embeddings + self.transformer = Transformer(config) # Encoder + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_sdpa = config._attn_implementation == "sdpa" + + # Initialize weights and apply final processing + self.post_init() + + def get_position_embeddings(self) -> nn.Embedding: + """ + Returns the position embeddings + """ + return self.embeddings.position_embeddings + + def resize_position_embeddings(self, new_num_position_embeddings: int): + """ + Resizes position embeddings of the model if `new_num_position_embeddings != config.max_position_embeddings`. + + Arguments: + new_num_position_embeddings (`int`): + The number of new position embedding matrix. If position embeddings are learned, increasing the size + will add newly initialized vectors at the end, whereas reducing the size will remove vectors from the + end. If position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the + size will add correct vectors at the end following the position encoding algorithm, whereas reducing + the size will remove vectors from the end. + """ + num_position_embeds_diff = new_num_position_embeddings - self.config.max_position_embeddings + + # no resizing needs to be done if the length stays the same + if num_position_embeds_diff == 0: + return + + logger.info(f"Setting `config.max_position_embeddings={new_num_position_embeddings}`...") + self.config.max_position_embeddings = new_num_position_embeddings + + old_position_embeddings_weight = self.embeddings.position_embeddings.weight.clone() + + self.embeddings.position_embeddings = nn.Embedding(self.config.max_position_embeddings, self.config.dim) + + if self.config.sinusoidal_pos_embds: + create_sinusoidal_embeddings( + n_pos=self.config.max_position_embeddings, dim=self.config.dim, out=self.position_embeddings.weight + ) + else: + with torch.no_grad(): + if num_position_embeds_diff > 0: + self.embeddings.position_embeddings.weight[:-num_position_embeds_diff] = nn.Parameter( + old_position_embeddings_weight + ) + else: + self.embeddings.position_embeddings.weight = nn.Parameter( + old_position_embeddings_weight[:num_position_embeds_diff] + ) + # move position_embeddings to correct device + self.embeddings.position_embeddings.to(self.device) + + def get_input_embeddings(self) -> nn.Embedding: + return self.embeddings.word_embeddings + + def set_input_embeddings(self, new_embeddings: nn.Embedding): + self.embeddings.word_embeddings = new_embeddings + + def _prune_heads(self, heads_to_prune: Dict[int, List[List[int]]]): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.transformer.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, num_choices")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[BaseModelOutput, Tuple[torch.Tensor, ...]]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + head_mask_is_none = head_mask is None + # Prepare head mask if needed + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embeddings = self.embeddings(input_ids, inputs_embeds) # (bs, seq_length, dim) + + if self._use_flash_attention_2: + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + else: + if attention_mask is None: + attention_mask = torch.ones(input_shape, device=device) # (bs, seq_length) + + if self._use_sdpa and head_mask_is_none and not output_attentions: + attention_mask = _prepare_4d_attention_mask_for_sdpa( + attention_mask, embeddings.dtype, tgt_len=input_shape[1] + ) + + return self.transformer( + x=embeddings, + attn_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +@add_start_docstrings( + """DistilBert Model with a `masked language modeling` head on top.""", + DISTILBERT_START_DOCSTRING, +) +class DistilBertForMaskedLM(DistilBertPreTrainedModel): + _tied_weights_keys = ["vocab_projector.weight"] + + def __init__(self, config: PretrainedConfig): + super().__init__(config) + + self.activation = get_activation(config.activation) + + self.distilbert = DistilBertModel(config) + self.vocab_transform = nn.Linear(config.dim, config.dim) + self.vocab_layer_norm = nn.LayerNorm(config.dim, eps=1e-12) + self.vocab_projector = nn.Linear(config.dim, config.vocab_size) + + # Initialize weights and apply final processing + self.post_init() + + self.mlm_loss_fct = nn.CrossEntropyLoss() + + def get_position_embeddings(self) -> nn.Embedding: + """ + Returns the position embeddings + """ + return self.distilbert.get_position_embeddings() + + def resize_position_embeddings(self, new_num_position_embeddings: int): + """ + Resizes position embeddings of the model if `new_num_position_embeddings != config.max_position_embeddings`. + + Arguments: + new_num_position_embeddings (`int`): + The number of new position embedding matrix. If position embeddings are learned, increasing the size + will add newly initialized vectors at the end, whereas reducing the size will remove vectors from the + end. If position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the + size will add correct vectors at the end following the position encoding algorithm, whereas reducing + the size will remove vectors from the end. + """ + self.distilbert.resize_position_embeddings(new_num_position_embeddings) + + def get_output_embeddings(self) -> nn.Module: + return self.vocab_projector + + def set_output_embeddings(self, new_embeddings: nn.Module): + self.vocab_projector = new_embeddings + + @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, num_choices")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MaskedLMOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[MaskedLMOutput, Tuple[torch.Tensor, ...]]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + dlbrt_output = self.distilbert( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = dlbrt_output[0] # (bs, seq_length, dim) + prediction_logits = self.vocab_transform(hidden_states) # (bs, seq_length, dim) + prediction_logits = self.activation(prediction_logits) # (bs, seq_length, dim) + prediction_logits = self.vocab_layer_norm(prediction_logits) # (bs, seq_length, dim) + prediction_logits = self.vocab_projector(prediction_logits) # (bs, seq_length, vocab_size) + + mlm_loss = None + if labels is not None: + mlm_loss = self.mlm_loss_fct(prediction_logits.view(-1, prediction_logits.size(-1)), labels.view(-1)) + + if not return_dict: + output = (prediction_logits,) + dlbrt_output[1:] + return ((mlm_loss,) + output) if mlm_loss is not None else output + + return MaskedLMOutput( + loss=mlm_loss, + logits=prediction_logits, + hidden_states=dlbrt_output.hidden_states, + attentions=dlbrt_output.attentions, + ) + + +@add_start_docstrings( + """ + DistilBert Model transformer with a sequence classification/regression head on top (a linear layer on top of the + pooled output) e.g. for GLUE tasks. + """, + DISTILBERT_START_DOCSTRING, +) +class DistilBertForSequenceClassification(DistilBertPreTrainedModel): + def __init__(self, config: PretrainedConfig): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + + self.distilbert = DistilBertModel(config) + self.pre_classifier = nn.Linear(config.dim, config.dim) + self.classifier = nn.Linear(config.dim, config.num_labels) + self.dropout = nn.Dropout(config.seq_classif_dropout) + + # Initialize weights and apply final processing + self.post_init() + + def get_position_embeddings(self) -> nn.Embedding: + """ + Returns the position embeddings + """ + return self.distilbert.get_position_embeddings() + + def resize_position_embeddings(self, new_num_position_embeddings: int): + """ + Resizes position embeddings of the model if `new_num_position_embeddings != config.max_position_embeddings`. + + Arguments: + new_num_position_embeddings (`int`): + The number of new position embedding matrix. If position embeddings are learned, increasing the size + will add newly initialized vectors at the end, whereas reducing the size will remove vectors from the + end. If position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the + size will add correct vectors at the end following the position encoding algorithm, whereas reducing + the size will remove vectors from the end. + """ + self.distilbert.resize_position_embeddings(new_num_position_embeddings) + + @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=SequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[SequenceClassifierOutput, Tuple[torch.Tensor, ...]]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + distilbert_output = self.distilbert( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_state = distilbert_output[0] # (bs, seq_len, dim) + pooled_output = hidden_state[:, 0] # (bs, dim) + pooled_output = self.pre_classifier(pooled_output) # (bs, dim) + pooled_output = nn.ReLU()(pooled_output) # (bs, dim) + pooled_output = self.dropout(pooled_output) # (bs, dim) + logits = self.classifier(pooled_output) # (bs, num_labels) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + distilbert_output[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=distilbert_output.hidden_states, + attentions=distilbert_output.attentions, + ) + + +@add_start_docstrings( + """ + DistilBert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a + linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + DISTILBERT_START_DOCSTRING, +) +class DistilBertForQuestionAnswering(DistilBertPreTrainedModel): + def __init__(self, config: PretrainedConfig): + super().__init__(config) + + self.distilbert = DistilBertModel(config) + self.qa_outputs = nn.Linear(config.dim, config.num_labels) + if config.num_labels != 2: + raise ValueError(f"config.num_labels should be 2, but it is {config.num_labels}") + + self.dropout = nn.Dropout(config.qa_dropout) + + # Initialize weights and apply final processing + self.post_init() + + def get_position_embeddings(self) -> nn.Embedding: + """ + Returns the position embeddings + """ + return self.distilbert.get_position_embeddings() + + def resize_position_embeddings(self, new_num_position_embeddings: int): + """ + Resizes position embeddings of the model if `new_num_position_embeddings != config.max_position_embeddings`. + + Arguments: + new_num_position_embeddings (`int`): + The number of new position embedding matrix. If position embeddings are learned, increasing the size + will add newly initialized vectors at the end, whereas reducing the size will remove vectors from the + end. If position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the + size will add correct vectors at the end following the position encoding algorithm, whereas reducing + the size will remove vectors from the end. + """ + self.distilbert.resize_position_embeddings(new_num_position_embeddings) + + @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, num_choices")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=QuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + start_positions: Optional[torch.Tensor] = None, + end_positions: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[QuestionAnsweringModelOutput, Tuple[torch.Tensor, ...]]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + distilbert_output = self.distilbert( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = distilbert_output[0] # (bs, max_query_len, dim) + + hidden_states = self.dropout(hidden_states) # (bs, max_query_len, dim) + logits = self.qa_outputs(hidden_states) # (bs, max_query_len, 2) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() # (bs, max_query_len) + end_logits = end_logits.squeeze(-1).contiguous() # (bs, max_query_len) + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + distilbert_output[1:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=distilbert_output.hidden_states, + attentions=distilbert_output.attentions, + ) + + +@add_start_docstrings( + """ + DistilBert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. + for Named-Entity-Recognition (NER) tasks. + """, + DISTILBERT_START_DOCSTRING, +) +class DistilBertForTokenClassification(DistilBertPreTrainedModel): + def __init__(self, config: PretrainedConfig): + super().__init__(config) + self.num_labels = config.num_labels + + self.distilbert = DistilBertModel(config) + self.dropout = nn.Dropout(config.dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_position_embeddings(self) -> nn.Embedding: + """ + Returns the position embeddings + """ + return self.distilbert.get_position_embeddings() + + def resize_position_embeddings(self, new_num_position_embeddings: int): + """ + Resizes position embeddings of the model if `new_num_position_embeddings != config.max_position_embeddings`. + + Arguments: + new_num_position_embeddings (`int`): + The number of new position embedding matrix. If position embeddings are learned, increasing the size + will add newly initialized vectors at the end, whereas reducing the size will remove vectors from the + end. If position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the + size will add correct vectors at the end following the position encoding algorithm, whereas reducing + the size will remove vectors from the end. + """ + self.distilbert.resize_position_embeddings(new_num_position_embeddings) + + @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[TokenClassifierOutput, Tuple[torch.Tensor, ...]]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.distilbert( + input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + DistilBert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and + a softmax) e.g. for RocStories/SWAG tasks. + """, + DISTILBERT_START_DOCSTRING, +) +class DistilBertForMultipleChoice(DistilBertPreTrainedModel): + def __init__(self, config: PretrainedConfig): + super().__init__(config) + + self.distilbert = DistilBertModel(config) + self.pre_classifier = nn.Linear(config.dim, config.dim) + self.classifier = nn.Linear(config.dim, 1) + self.dropout = nn.Dropout(config.seq_classif_dropout) + + # Initialize weights and apply final processing + self.post_init() + + def get_position_embeddings(self) -> nn.Embedding: + """ + Returns the position embeddings + """ + return self.distilbert.get_position_embeddings() + + def resize_position_embeddings(self, new_num_position_embeddings: int): + """ + Resizes position embeddings of the model if `new_num_position_embeddings != config.max_position_embeddings`. + + Arguments: + new_num_position_embeddings (`int`) + The number of new position embeddings. If position embeddings are learned, increasing the size will add + newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If + position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will + add correct vectors at the end following the position encoding algorithm, whereas reducing the size + will remove vectors from the end. + """ + self.distilbert.resize_position_embeddings(new_num_position_embeddings) + + @add_start_docstrings_to_model_forward( + DISTILBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length") + ) + @replace_return_docstrings(output_type=MultipleChoiceModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[MultipleChoiceModelOutput, Tuple[torch.Tensor, ...]]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., + num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See + `input_ids` above) + + Returns: + + Examples: + + ```python + >>> from transformers import AutoTokenizer, DistilBertForMultipleChoice + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("distilbert-base-cased") + >>> model = DistilBertForMultipleChoice.from_pretrained("distilbert-base-cased") + + >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." + >>> choice0 = "It is eaten with a fork and a knife." + >>> choice1 = "It is eaten while held in the hand." + >>> labels = torch.tensor(0).unsqueeze(0) # choice0 is correct (according to Wikipedia ;)), batch size 1 + + >>> encoding = tokenizer([[prompt, choice0], [prompt, choice1]], return_tensors="pt", padding=True) + >>> outputs = model(**{k: v.unsqueeze(0) for k, v in encoding.items()}, labels=labels) # batch size is 1 + + >>> # the linear classifier still needs to be trained + >>> loss = outputs.loss + >>> logits = outputs.logits + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] + + input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) + + outputs = self.distilbert( + input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_state = outputs[0] # (bs * num_choices, seq_len, dim) + pooled_output = hidden_state[:, 0] # (bs * num_choices, dim) + pooled_output = self.pre_classifier(pooled_output) # (bs * num_choices, dim) + pooled_output = nn.ReLU()(pooled_output) # (bs * num_choices, dim) + pooled_output = self.dropout(pooled_output) # (bs * num_choices, dim) + logits = self.classifier(pooled_output) # (bs * num_choices, 1) + + reshaped_logits = logits.view(-1, num_choices) # (bs, num_choices) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + + if not return_dict: + output = (reshaped_logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return MultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "DistilBertForMaskedLM", + "DistilBertForMultipleChoice", + "DistilBertForQuestionAnswering", + "DistilBertForSequenceClassification", + "DistilBertForTokenClassification", + "DistilBertModel", + "DistilBertPreTrainedModel", +] diff --git a/docs/transformers/build/lib/transformers/models/distilbert/modeling_flax_distilbert.py b/docs/transformers/build/lib/transformers/models/distilbert/modeling_flax_distilbert.py new file mode 100644 index 0000000000000000000000000000000000000000..1f2b6ac96ab63590dba11b9535042d05edb58400 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/distilbert/modeling_flax_distilbert.py @@ -0,0 +1,906 @@ +# coding=utf-8 +# Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Callable, Optional, Tuple + +import flax.linen as nn +import jax +import jax.numpy as jnp +import numpy as np +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze +from flax.traverse_util import flatten_dict, unflatten_dict +from jax import lax + +from ...modeling_flax_outputs import ( + FlaxBaseModelOutput, + FlaxMaskedLMOutput, + FlaxMultipleChoiceModelOutput, + FlaxQuestionAnsweringModelOutput, + FlaxSequenceClassifierOutput, + FlaxTokenClassifierOutput, +) +from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring, overwrite_call_docstring +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging +from .configuration_distilbert import DistilBertConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "distilbert-base-uncased" +_CONFIG_FOR_DOC = "DistilBertConfig" + + +FLAX_DISTILBERT_START_DOCSTRING = r""" + + This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading, saving and converting weights from PyTorch models) + + This model is also a + [flax.linen.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) subclass. Use it as + a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and + behavior. + + Finally, this model supports inherent JAX features such as: + + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + + Parameters: + config ([`DistilBertConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +DISTILBERT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`numpy.ndarray` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`numpy.ndarray` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +def get_angles(pos, i, d_model): + angle_rates = 1 / np.power(10000, (2 * (i // 2)) / np.float32(d_model)) + return pos * angle_rates + + +def positional_encoding(position, d_model): + # create the sinusoidal pattern for the positional encoding + angle_rads = get_angles(np.arange(position)[:, np.newaxis], np.arange(d_model)[np.newaxis, :], d_model) + + # apply sin to even indices in the array; 2i + angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2]) + + # apply cos to odd indices in the array; 2i+1 + angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2]) + + pos_encoding = angle_rads[np.newaxis, ...] + + return jnp.array(pos_encoding) + + +class FlaxEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + config: DistilBertConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.word_embeddings = nn.Embed( + self.config.vocab_size, + self.config.dim, + embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), + ) + if not self.config.sinusoidal_pos_embds: + self.position_embeddings = nn.Embed( + self.config.max_position_embeddings, + self.config.dim, + embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), + ) + else: + self.pos_encoding = positional_encoding(self.config.max_position_embeddings, self.config.dim) + self.LayerNorm = nn.LayerNorm(epsilon=1e-12, dtype=self.dtype) + self.dropout = nn.Dropout(rate=self.config.dropout) + + def __call__(self, input_ids, deterministic: bool = True): + # Embed + batch_size, seq_length = input_ids.shape + inputs_embeds = self.word_embeddings(input_ids.astype("i4")) + if not self.config.sinusoidal_pos_embds: + position_ids = jnp.arange(seq_length).astype("i4") + position_ids = jnp.broadcast_to(position_ids, shape=(batch_size, seq_length)) + position_embeds = self.position_embeddings(position_ids.astype("i4")) + else: + position_embeds = self.pos_encoding[:, :seq_length, :] + # explicitly cast the positions here, since self.embed_positions are not registered as parameters + position_embeds = position_embeds.astype(inputs_embeds.dtype) + + # Sum all embeddings + hidden_states = inputs_embeds + position_embeds + + # Layer Norm + hidden_states = self.LayerNorm(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + return hidden_states + + +class FlaxMultiHeadSelfAttention(nn.Module): + config: DistilBertConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.n_heads = self.config.n_heads + self.dim = self.config.dim + self.dropout = nn.Dropout(rate=self.config.attention_dropout) + + if not (self.dim % self.n_heads == 0): + raise ValueError(f"Hidden size {self.dim} not dividable by number of heads {self.n_heads}") + + self.q_lin = nn.Dense( + self.dim, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), + ) + self.k_lin = nn.Dense( + self.dim, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), + ) + self.v_lin = nn.Dense( + self.dim, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), + ) + self.out_lin = nn.Dense( + self.dim, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), + ) + + def __call__( + self, + query, + key, + value, + mask, + deterministic: bool = True, + output_attentions: bool = False, + ): + bs, q_len, dim = query.shape + k_len = key.shape[1] + # assert dim == self.dim, f'Dimensions do not match: {dim} input vs {self.dim} configured' + # assert key.size() == value.size() + + dim_per_head = self.dim // self.n_heads + + mask_reshp = (bs, 1, 1, k_len) + + def shape(x): + """separate heads""" + return x.reshape(bs, -1, self.n_heads, dim_per_head).transpose(0, 2, 1, 3) + + def unshape(x): + """group heads""" + return x.transpose(0, 2, 1, 3).reshape(bs, -1, self.n_heads * dim_per_head) + + q = shape(self.q_lin(query)) # (bs, n_heads, q_len, dim_per_head) + k = shape(self.k_lin(key)) # (bs, n_heads, k_len, dim_per_head) + v = shape(self.v_lin(value)) # (bs, n_heads, k_len, dim_per_head) + + q = q / math.sqrt(dim_per_head) # (bs, n_heads, q_len, dim_per_head) + scores = jnp.matmul(q, k.transpose(0, 1, 3, 2)) # (bs, n_heads, q_len, k_len) + mask = jnp.reshape(mask, mask_reshp) + + mask = mask.astype(scores.dtype) + scores = scores - 1e30 * (1.0 - mask) + + weights = nn.softmax(scores, axis=-1) # (bs, n_heads, q_len, k_len) + weights = self.dropout(weights, deterministic=deterministic) + + context = jnp.matmul(weights, v) # (bs, n_heads, q_len, dim_per_head) + context = unshape(context) # (bs, q_len, dim) + context = self.out_lin(context) # (bs, q_len, dim) + + if output_attentions: + return (context, weights) + else: + return (context,) + + +class FlaxFFN(nn.Module): + config: DistilBertConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dropout = nn.Dropout(rate=self.config.dropout) + self.chunk_size_feed_forward = self.config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.lin1 = nn.Dense( + self.config.hidden_dim, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), + ) + self.lin2 = nn.Dense( + self.config.dim, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), + ) + + self.activation = ACT2FN[self.config.activation] + + def __call__(self, hidden_states, deterministic: bool = True): + hidden_states = self.lin1(hidden_states) + hidden_states = self.activation(hidden_states) + hidden_states = self.lin2(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + return hidden_states + + +class FlaxTransformerBlock(nn.Module): + config: DistilBertConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + assert self.config.dim % self.config.n_heads == 0, ( + f"Hidden size {self.config.dim} not dividable by number of heads {self.config.n_heads}" + ) + + self.attention = FlaxMultiHeadSelfAttention(self.config, dtype=self.dtype) + self.sa_layer_norm = nn.LayerNorm(epsilon=1e-12, dtype=self.dtype) + + self.ffn = FlaxFFN(self.config, dtype=self.dtype) + self.output_layer_norm = nn.LayerNorm(epsilon=1e-12, dtype=self.dtype) + + def __call__( + self, + hidden_states, + attn_mask, + output_attentions: bool = False, + deterministic: bool = True, + ): + # Self-Attention + sa_output = self.attention( + query=hidden_states, + key=hidden_states, + value=hidden_states, + mask=attn_mask, + output_attentions=output_attentions, + deterministic=deterministic, + ) + if output_attentions: + sa_output, sa_weights = sa_output + else: + assert type(sa_output) is tuple + sa_output = sa_output[0] + sa_output = self.sa_layer_norm(sa_output + hidden_states) + + # Feed Forward Network + ffn_output = self.ffn(sa_output, deterministic=deterministic) + ffn_output = self.output_layer_norm(ffn_output + sa_output) + output = (ffn_output,) + if output_attentions: + output = (sa_weights,) + output + return output + + +class FlaxTransformer(nn.Module): + config: DistilBertConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.layers = [ + FlaxTransformerBlock(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.n_layers) + ] + + def __call__( + self, + hidden_states, + attention_mask, + output_attentions: bool = False, + output_hidden_states: bool = False, + deterministic: bool = True, + return_dict: bool = False, + ): + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + for layer_module in self.layers: + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_outputs = layer_module( + hidden_states=hidden_states, + attn_mask=attention_mask, + output_attentions=output_attentions, + deterministic=deterministic, + ) + hidden_states = layer_outputs[-1] + + if output_attentions: + assert len(layer_outputs) == 2 + attentions = layer_outputs[0] + all_attentions = all_attentions + (attentions,) + else: + assert len(layer_outputs) == 1 + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_attentions, all_hidden_states] if v is not None) + return FlaxBaseModelOutput( + last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions + ) + + +class FlaxTransformerEncoder(nn.Module): + config: DistilBertConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.layer = FlaxTransformer(self.config, dtype=self.dtype) + + def __call__( + self, + hidden_states, + attention_mask, + output_attentions: bool = False, + output_hidden_states: bool = False, + deterministic: bool = True, + return_dict: bool = False, + ): + return self.layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + deterministic=deterministic, + return_dict=return_dict, + ) + + +class FlaxDistilBertLMDecoder(nn.Module): + config: DistilBertConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros + + def setup(self): + self.bias = self.param("bias", self.bias_init, (self.config.vocab_size,)) + + def __call__(self, inputs, kernel): + inputs = jnp.asarray(inputs, self.dtype) + kernel = jnp.asarray(kernel, self.dtype) + y = lax.dot_general(inputs, kernel, (((inputs.ndim - 1,), (0,)), ((), ()))) + bias = jnp.asarray(self.bias, self.dtype) + y = y + bias + return y + + +class FlaxDistilBertPreTrainedModel(FlaxPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = DistilBertConfig + base_model_prefix = "distilbert" + module_class: nn.Module = None + + def __init__( + self, + config: DistilBertConfig, + input_shape: Tuple = (1, 1), + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, + **kwargs, + ): + module = self.module_class(config=config, dtype=dtype, **kwargs) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: + # init input tensors + input_ids = jnp.zeros(input_shape, dtype="i4") + attention_mask = jnp.ones_like(input_ids) + + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + random_params = self.module.init(rngs, input_ids, attention_mask, return_dict=False)["params"] + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + + @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + def __call__( + self, + input_ids, + attention_mask=None, + head_mask=None, + params: dict = None, + dropout_rng: jax.random.PRNGKey = None, + train: bool = False, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + return self.module.apply( + {"params": params or self.params}, + jnp.array(input_ids, dtype="i4"), + jnp.array(attention_mask, dtype="i4"), + not train, + output_attentions, + output_hidden_states, + return_dict, + rngs=rngs, + ) + + +class FlaxDistilBertModule(nn.Module): + config: DistilBertConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.embeddings = FlaxEmbeddings(self.config, dtype=self.dtype) + self.transformer = FlaxTransformerEncoder(self.config, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + input_embeds = self.embeddings(input_ids, deterministic=deterministic) + return self.transformer( + hidden_states=input_embeds, + attention_mask=attention_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +@add_start_docstrings( + "The bare DistilBert Model transformer outputting raw hidden-states without any specific head on top.", + FLAX_DISTILBERT_START_DOCSTRING, +) +class FlaxDistilBertModel(FlaxDistilBertPreTrainedModel): + module_class = FlaxDistilBertModule + + +append_call_sample_docstring(FlaxDistilBertModel, _CHECKPOINT_FOR_DOC, None, _CONFIG_FOR_DOC) + + +class FlaxDistilBertForMaskedLMModule(nn.Module): + config: DistilBertConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.distilbert = FlaxDistilBertModule(self.config, dtype=self.dtype) + self.vocab_transform = nn.Dense( + self.config.dim, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), + ) + self.vocab_layer_norm = nn.LayerNorm(epsilon=1e-12, dtype=self.dtype) + if self.config.tie_word_embeddings: + self.vocab_projector = FlaxDistilBertLMDecoder( + self.config, + dtype=self.dtype, + ) + else: + self.vocab_projector = nn.Dense( + self.config.vocab_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), + ) + + def __call__( + self, + input_ids, + attention_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + dlbrt_output = self.distilbert( + input_ids=input_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + deterministic=deterministic, + return_dict=return_dict, + ) + hidden_states = dlbrt_output[0] + prediction_logits = self.vocab_transform(hidden_states) + prediction_logits = ACT2FN[self.config.activation](prediction_logits) + prediction_logits = self.vocab_layer_norm(prediction_logits) + + if self.config.tie_word_embeddings: + shared_embedding = self.distilbert.variables["params"]["embeddings"]["word_embeddings"]["embedding"] + prediction_logits = self.vocab_projector(prediction_logits, shared_embedding.T) + else: + prediction_logits = self.vocab_projector(prediction_logits) + + if not return_dict: + output = (prediction_logits,) + dlbrt_output[1:] + return output + + return FlaxMaskedLMOutput( + logits=prediction_logits, + hidden_states=dlbrt_output.hidden_states, + attentions=dlbrt_output.attentions, + ) + + +@add_start_docstrings("""DistilBert Model with a `language modeling` head on top.""", FLAX_DISTILBERT_START_DOCSTRING) +class FlaxDistilBertForMaskedLM(FlaxDistilBertPreTrainedModel): + module_class = FlaxDistilBertForMaskedLMModule + + +append_call_sample_docstring(FlaxDistilBertForMaskedLM, _CHECKPOINT_FOR_DOC, FlaxMaskedLMOutput, _CONFIG_FOR_DOC) + + +class FlaxDistilBertForSequenceClassificationModule(nn.Module): + config: DistilBertConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.distilbert = FlaxDistilBertModule(config=self.config, dtype=self.dtype) + self.pre_classifier = nn.Dense( + self.config.dim, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), + ) + self.dropout = nn.Dropout(rate=self.config.seq_classif_dropout) + self.classifier = nn.Dense( + self.config.num_labels, + dtype=self.dtype, + ) + + def __call__( + self, + input_ids, + attention_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + # Model + distilbert_output = self.distilbert( + input_ids, + attention_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_state = distilbert_output[0] # (bs, seq_len, dim) + pooled_output = hidden_state[:, 0] # (bs, dim) + pooled_output = self.pre_classifier(pooled_output) # (bs, dim) + pooled_output = ACT2FN["relu"](pooled_output) + pooled_output = self.dropout(pooled_output, deterministic=deterministic) + logits = self.classifier(pooled_output) # (bs, dim) + + if not return_dict: + return (logits,) + distilbert_output[1:] + + return FlaxSequenceClassifierOutput( + logits=logits, + hidden_states=distilbert_output.hidden_states, + attentions=distilbert_output.attentions, + ) + + +@add_start_docstrings( + """ + DistilBert Model transformer with a sequence classification/regression head on top (a linear layer on top of the + pooled output) e.g. for GLUE tasks. + """, + FLAX_DISTILBERT_START_DOCSTRING, +) +class FlaxDistilBertForSequenceClassification(FlaxDistilBertPreTrainedModel): + module_class = FlaxDistilBertForSequenceClassificationModule + + +append_call_sample_docstring( + FlaxDistilBertForSequenceClassification, + _CHECKPOINT_FOR_DOC, + FlaxSequenceClassifierOutput, + _CONFIG_FOR_DOC, +) + + +class FlaxDistilBertForMultipleChoiceModule(nn.Module): + config: DistilBertConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.distilbert = FlaxDistilBertModule(config=self.config, dtype=self.dtype) + self.pre_classifier = nn.Dense( + self.config.dim, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), + ) + self.dropout = nn.Dropout(rate=self.config.seq_classif_dropout) + self.classifier = nn.Dense( + 1, + dtype=self.dtype, + ) + + def __call__( + self, + input_ids, + attention_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + num_choices = input_ids.shape[1] + input_ids = input_ids.reshape(-1, input_ids.shape[-1]) if input_ids is not None else None + attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1]) if attention_mask is not None else None + + # Model + outputs = self.distilbert( + input_ids, + attention_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_state = outputs[0] + pooled_output = hidden_state[:, 0] + pooled_output = self.pre_classifier(pooled_output) + pooled_output = ACT2FN["relu"](pooled_output) + pooled_output = self.dropout(pooled_output, deterministic=deterministic) + logits = self.classifier(pooled_output) + + reshaped_logits = logits.reshape(-1, num_choices) + + if not return_dict: + return (reshaped_logits,) + outputs[2:] + + return FlaxMultipleChoiceModelOutput( + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + DistilBert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and + a softmax) e.g. for RocStories/SWAG tasks. + """, + FLAX_DISTILBERT_START_DOCSTRING, +) +class FlaxDistilBertForMultipleChoice(FlaxDistilBertPreTrainedModel): + module_class = FlaxDistilBertForMultipleChoiceModule + + +overwrite_call_docstring( + FlaxDistilBertForMultipleChoice, DISTILBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length") +) +append_call_sample_docstring( + FlaxDistilBertForMultipleChoice, + _CHECKPOINT_FOR_DOC, + FlaxMultipleChoiceModelOutput, + _CONFIG_FOR_DOC, +) + + +class FlaxDistilBertForTokenClassificationModule(nn.Module): + config: DistilBertConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.distilbert = FlaxDistilBertModule(config=self.config, dtype=self.dtype) + self.dropout = nn.Dropout(rate=self.config.dropout) + self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + # Model + outputs = self.distilbert( + input_ids, + attention_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + logits = self.classifier(hidden_states) + + if not return_dict: + return (logits,) + outputs[1:] + + return FlaxTokenClassifierOutput( + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + DistilBert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. + for Named-Entity-Recognition (NER) tasks. + """, + FLAX_DISTILBERT_START_DOCSTRING, +) +class FlaxDistilBertForTokenClassification(FlaxDistilBertPreTrainedModel): + module_class = FlaxDistilBertForTokenClassificationModule + + +append_call_sample_docstring( + FlaxDistilBertForTokenClassification, + _CHECKPOINT_FOR_DOC, + FlaxTokenClassifierOutput, + _CONFIG_FOR_DOC, +) + + +class FlaxDistilBertForQuestionAnsweringModule(nn.Module): + config: DistilBertConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.distilbert = FlaxDistilBertModule(config=self.config, dtype=self.dtype) + self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype) + assert self.config.num_labels == 2 + self.dropout = nn.Dropout(rate=self.config.qa_dropout) + + def __call__( + self, + input_ids, + attention_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Model + distilbert_output = self.distilbert( + input_ids, + attention_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = distilbert_output[0] + + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + logits = self.qa_outputs(hidden_states) + start_logits, end_logits = jnp.split(logits, self.config.num_labels, axis=-1) + start_logits = start_logits.squeeze(-1) + end_logits = end_logits.squeeze(-1) + + if not return_dict: + return (start_logits, end_logits) + distilbert_output[1:] + + return FlaxQuestionAnsweringModelOutput( + start_logits=start_logits, + end_logits=end_logits, + hidden_states=distilbert_output.hidden_states, + attentions=distilbert_output.attentions, + ) + + +@add_start_docstrings( + """ + DistilBert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a + linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + FLAX_DISTILBERT_START_DOCSTRING, +) +class FlaxDistilBertForQuestionAnswering(FlaxDistilBertPreTrainedModel): + module_class = FlaxDistilBertForQuestionAnsweringModule + + +append_call_sample_docstring( + FlaxDistilBertForQuestionAnswering, + _CHECKPOINT_FOR_DOC, + FlaxQuestionAnsweringModelOutput, + _CONFIG_FOR_DOC, +) + + +__all__ = [ + "FlaxDistilBertForMaskedLM", + "FlaxDistilBertForMultipleChoice", + "FlaxDistilBertForQuestionAnswering", + "FlaxDistilBertForSequenceClassification", + "FlaxDistilBertForTokenClassification", + "FlaxDistilBertModel", + "FlaxDistilBertPreTrainedModel", +] diff --git a/docs/transformers/build/lib/transformers/models/distilbert/modeling_tf_distilbert.py b/docs/transformers/build/lib/transformers/models/distilbert/modeling_tf_distilbert.py new file mode 100644 index 0000000000000000000000000000000000000000..d0ee2f84835d41585f6133148a4596a1253e8604 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/distilbert/modeling_tf_distilbert.py @@ -0,0 +1,1147 @@ +# coding=utf-8 +# Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +TF 2.0 DistilBERT model +""" + +from __future__ import annotations + +import warnings +from typing import Optional, Tuple, Union + +import numpy as np +import tensorflow as tf + +from ...activations_tf import get_tf_activation +from ...modeling_tf_outputs import ( + TFBaseModelOutput, + TFMaskedLMOutput, + TFMultipleChoiceModelOutput, + TFQuestionAnsweringModelOutput, + TFSequenceClassifierOutput, + TFTokenClassifierOutput, +) +from ...modeling_tf_utils import ( + TFMaskedLanguageModelingLoss, + TFModelInputType, + TFMultipleChoiceLoss, + TFPreTrainedModel, + TFQuestionAnsweringLoss, + TFSequenceClassificationLoss, + TFTokenClassificationLoss, + get_initializer, + keras, + keras_serializable, + unpack_inputs, +) +from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, +) +from .configuration_distilbert import DistilBertConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "distilbert-base-uncased" +_CONFIG_FOR_DOC = "DistilBertConfig" + + +class TFEmbeddings(keras.layers.Layer): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + self.config = config + self.dim = config.dim + self.initializer_range = config.initializer_range + self.max_position_embeddings = config.max_position_embeddings + self.LayerNorm = keras.layers.LayerNormalization(epsilon=1e-12, name="LayerNorm") + self.dropout = keras.layers.Dropout(rate=config.dropout) + + def build(self, input_shape=None): + with tf.name_scope("word_embeddings"): + self.weight = self.add_weight( + name="weight", + shape=[self.config.vocab_size, self.dim], + initializer=get_initializer(initializer_range=self.initializer_range), + ) + + with tf.name_scope("position_embeddings"): + self.position_embeddings = self.add_weight( + name="embeddings", + shape=[self.max_position_embeddings, self.dim], + initializer=get_initializer(initializer_range=self.initializer_range), + ) + + if self.built: + return + self.built = True + if getattr(self, "LayerNorm", None) is not None: + with tf.name_scope(self.LayerNorm.name): + self.LayerNorm.build([None, None, self.config.dim]) + + def call(self, input_ids=None, position_ids=None, inputs_embeds=None, training=False): + """ + Applies embedding based on inputs tensor. + + Returns: + final_embeddings (`tf.Tensor`): output embedding tensor. + """ + assert not (input_ids is None and inputs_embeds is None) + + if input_ids is not None: + check_embeddings_within_bounds(input_ids, self.config.vocab_size) + inputs_embeds = tf.gather(params=self.weight, indices=input_ids) + + input_shape = shape_list(inputs_embeds)[:-1] + + if position_ids is None: + position_ids = tf.expand_dims(tf.range(start=0, limit=input_shape[-1]), axis=0) + + position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids) + final_embeddings = inputs_embeds + position_embeds + final_embeddings = self.LayerNorm(inputs=final_embeddings) + final_embeddings = self.dropout(inputs=final_embeddings, training=training) + + return final_embeddings + + +class TFMultiHeadSelfAttention(keras.layers.Layer): + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + + self.n_heads = config.n_heads + self.dim = config.dim + self.dropout = keras.layers.Dropout(config.attention_dropout) + self.output_attentions = config.output_attentions + + assert self.dim % self.n_heads == 0, f"Hidden size {self.dim} not dividable by number of heads {self.n_heads}" + + self.q_lin = keras.layers.Dense( + config.dim, kernel_initializer=get_initializer(config.initializer_range), name="q_lin" + ) + self.k_lin = keras.layers.Dense( + config.dim, kernel_initializer=get_initializer(config.initializer_range), name="k_lin" + ) + self.v_lin = keras.layers.Dense( + config.dim, kernel_initializer=get_initializer(config.initializer_range), name="v_lin" + ) + self.out_lin = keras.layers.Dense( + config.dim, kernel_initializer=get_initializer(config.initializer_range), name="out_lin" + ) + + self.pruned_heads = set() + self.config = config + + def prune_heads(self, heads): + raise NotImplementedError + + def call(self, query, key, value, mask, head_mask, output_attentions, training=False): + """ + Parameters: + query: tf.Tensor(bs, seq_length, dim) + key: tf.Tensor(bs, seq_length, dim) + value: tf.Tensor(bs, seq_length, dim) + mask: tf.Tensor(bs, seq_length) + + Returns: + weights: tf.Tensor(bs, n_heads, seq_length, seq_length) Attention weights context: tf.Tensor(bs, + seq_length, dim) Contextualized layer. Optional: only if `output_attentions=True` + """ + bs, q_length, dim = shape_list(query) + k_length = shape_list(key)[1] + # assert dim == self.dim, f'Dimensions do not match: {dim} input vs {self.dim} configured' + # assert key.size() == value.size() + dim_per_head = int(self.dim / self.n_heads) + dim_per_head = tf.cast(dim_per_head, dtype=tf.int32) + mask_reshape = [bs, 1, 1, k_length] + + def shape(x): + """separate heads""" + return tf.transpose(tf.reshape(x, (bs, -1, self.n_heads, dim_per_head)), perm=(0, 2, 1, 3)) + + def unshape(x): + """group heads""" + return tf.reshape(tf.transpose(x, perm=(0, 2, 1, 3)), (bs, -1, self.n_heads * dim_per_head)) + + q = shape(self.q_lin(query)) # (bs, n_heads, q_length, dim_per_head) + k = shape(self.k_lin(key)) # (bs, n_heads, k_length, dim_per_head) + v = shape(self.v_lin(value)) # (bs, n_heads, k_length, dim_per_head) + q = tf.cast(q, dtype=tf.float32) + q = tf.multiply(q, tf.math.rsqrt(tf.cast(dim_per_head, dtype=tf.float32))) + k = tf.cast(k, dtype=q.dtype) + scores = tf.matmul(q, k, transpose_b=True) # (bs, n_heads, q_length, k_length) + mask = tf.reshape(mask, mask_reshape) # (bs, n_heads, qlen, klen) + # scores.masked_fill_(mask, -float('inf')) # (bs, n_heads, q_length, k_length) + + mask = tf.cast(mask, dtype=scores.dtype) + scores = scores - 1e30 * (1.0 - mask) + weights = stable_softmax(scores, axis=-1) # (bs, n_heads, qlen, klen) + weights = self.dropout(weights, training=training) # (bs, n_heads, qlen, klen) + + # Mask heads if we want to + if head_mask is not None: + weights = weights * head_mask + + context = tf.matmul(weights, v) # (bs, n_heads, qlen, dim_per_head) + context = unshape(context) # (bs, q_length, dim) + context = self.out_lin(context) # (bs, q_length, dim) + + if output_attentions: + return (context, weights) + else: + return (context,) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "q_lin", None) is not None: + with tf.name_scope(self.q_lin.name): + self.q_lin.build([None, None, self.config.dim]) + if getattr(self, "k_lin", None) is not None: + with tf.name_scope(self.k_lin.name): + self.k_lin.build([None, None, self.config.dim]) + if getattr(self, "v_lin", None) is not None: + with tf.name_scope(self.v_lin.name): + self.v_lin.build([None, None, self.config.dim]) + if getattr(self, "out_lin", None) is not None: + with tf.name_scope(self.out_lin.name): + self.out_lin.build([None, None, self.config.dim]) + + +class TFFFN(keras.layers.Layer): + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + self.dropout = keras.layers.Dropout(config.dropout) + self.lin1 = keras.layers.Dense( + config.hidden_dim, kernel_initializer=get_initializer(config.initializer_range), name="lin1" + ) + self.lin2 = keras.layers.Dense( + config.dim, kernel_initializer=get_initializer(config.initializer_range), name="lin2" + ) + self.activation = get_tf_activation(config.activation) + self.config = config + + def call(self, input, training=False): + x = self.lin1(input) + x = self.activation(x) + x = self.lin2(x) + x = self.dropout(x, training=training) + return x + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "lin1", None) is not None: + with tf.name_scope(self.lin1.name): + self.lin1.build([None, None, self.config.dim]) + if getattr(self, "lin2", None) is not None: + with tf.name_scope(self.lin2.name): + self.lin2.build([None, None, self.config.hidden_dim]) + + +class TFTransformerBlock(keras.layers.Layer): + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + + self.n_heads = config.n_heads + self.dim = config.dim + self.hidden_dim = config.hidden_dim + self.dropout = keras.layers.Dropout(config.dropout) + self.activation = config.activation + self.output_attentions = config.output_attentions + + assert config.dim % config.n_heads == 0, ( + f"Hidden size {config.dim} not dividable by number of heads {config.n_heads}" + ) + + self.attention = TFMultiHeadSelfAttention(config, name="attention") + self.sa_layer_norm = keras.layers.LayerNormalization(epsilon=1e-12, name="sa_layer_norm") + + self.ffn = TFFFN(config, name="ffn") + self.output_layer_norm = keras.layers.LayerNormalization(epsilon=1e-12, name="output_layer_norm") + self.config = config + + def call(self, x, attn_mask, head_mask, output_attentions, training=False): # removed: src_enc=None, src_len=None + """ + Parameters: + x: tf.Tensor(bs, seq_length, dim) + attn_mask: tf.Tensor(bs, seq_length) + + Outputs: sa_weights: tf.Tensor(bs, n_heads, seq_length, seq_length) The attention weights ffn_output: + tf.Tensor(bs, seq_length, dim) The output of the transformer block contextualization. + """ + # Self-Attention + sa_output = self.attention(x, x, x, attn_mask, head_mask, output_attentions, training=training) + if output_attentions: + sa_output, sa_weights = sa_output # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length) + else: # To handle these `output_attentions` or `output_hidden_states` cases returning tuples + # assert type(sa_output) == tuple + sa_output = sa_output[0] + sa_output = self.sa_layer_norm(sa_output + x) # (bs, seq_length, dim) + + # Feed Forward Network + ffn_output = self.ffn(sa_output, training=training) # (bs, seq_length, dim) + ffn_output = self.output_layer_norm(ffn_output + sa_output) # (bs, seq_length, dim) + + output = (ffn_output,) + if output_attentions: + output = (sa_weights,) + output + return output + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "attention", None) is not None: + with tf.name_scope(self.attention.name): + self.attention.build(None) + if getattr(self, "sa_layer_norm", None) is not None: + with tf.name_scope(self.sa_layer_norm.name): + self.sa_layer_norm.build([None, None, self.config.dim]) + if getattr(self, "ffn", None) is not None: + with tf.name_scope(self.ffn.name): + self.ffn.build(None) + if getattr(self, "output_layer_norm", None) is not None: + with tf.name_scope(self.output_layer_norm.name): + self.output_layer_norm.build([None, None, self.config.dim]) + + +class TFTransformer(keras.layers.Layer): + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + self.n_layers = config.n_layers + self.output_hidden_states = config.output_hidden_states + self.output_attentions = config.output_attentions + + self.layer = [TFTransformerBlock(config, name=f"layer_._{i}") for i in range(config.n_layers)] + + def call(self, x, attn_mask, head_mask, output_attentions, output_hidden_states, return_dict, training=False): + # docstyle-ignore + """ + Parameters: + x: tf.Tensor(bs, seq_length, dim) Input sequence embedded. + attn_mask: tf.Tensor(bs, seq_length) Attention mask on the sequence. + + Returns: + hidden_state: tf.Tensor(bs, seq_length, dim) + Sequence of hidden states in the last (top) layer + all_hidden_states: Tuple[tf.Tensor(bs, seq_length, dim)] + Tuple of length n_layers with the hidden states from each layer. + Optional: only if output_hidden_states=True + all_attentions: Tuple[tf.Tensor(bs, n_heads, seq_length, seq_length)] + Tuple of length n_layers with the attention weights from each layer + Optional: only if output_attentions=True + """ + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_state = x + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_state,) + + layer_outputs = layer_module(hidden_state, attn_mask, head_mask[i], output_attentions, training=training) + hidden_state = layer_outputs[-1] + + if output_attentions: + assert len(layer_outputs) == 2 + attentions = layer_outputs[0] + all_attentions = all_attentions + (attentions,) + else: + assert len(layer_outputs) == 1, f"Incorrect number of outputs {len(layer_outputs)} instead of 1" + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_state,) + + if not return_dict: + return tuple(v for v in [hidden_state, all_hidden_states, all_attentions] if v is not None) + return TFBaseModelOutput( + last_hidden_state=hidden_state, hidden_states=all_hidden_states, attentions=all_attentions + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "layer", None) is not None: + for layer in self.layer: + with tf.name_scope(layer.name): + layer.build(None) + + +@keras_serializable +class TFDistilBertMainLayer(keras.layers.Layer): + config_class = DistilBertConfig + + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + + self.config = config + self.num_hidden_layers = config.num_hidden_layers + self.output_attentions = config.output_attentions + self.output_hidden_states = config.output_hidden_states + self.return_dict = config.use_return_dict + + self.embeddings = TFEmbeddings(config, name="embeddings") # Embeddings + self.transformer = TFTransformer(config, name="transformer") # Encoder + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, value): + self.embeddings.weight = value + self.embeddings.vocab_size = value.shape[0] + + def _prune_heads(self, heads_to_prune): + raise NotImplementedError + + @unpack_inputs + def call( + self, + input_ids=None, + attention_mask=None, + head_mask=None, + inputs_embeds=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + training=False, + ): + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = shape_list(input_ids) + elif inputs_embeds is not None: + input_shape = shape_list(inputs_embeds)[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if attention_mask is None: + attention_mask = tf.ones(input_shape) # (bs, seq_length) + + attention_mask = tf.cast(attention_mask, dtype=tf.float32) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + if head_mask is not None: + raise NotImplementedError + else: + head_mask = [None] * self.num_hidden_layers + + embedding_output = self.embeddings(input_ids, inputs_embeds=inputs_embeds) # (bs, seq_length, dim) + tfmr_output = self.transformer( + embedding_output, + attention_mask, + head_mask, + output_attentions, + output_hidden_states, + return_dict, + training=training, + ) + + return tfmr_output # last-layer hidden-state, (all hidden_states), (all attentions) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "embeddings", None) is not None: + with tf.name_scope(self.embeddings.name): + self.embeddings.build(None) + if getattr(self, "transformer", None) is not None: + with tf.name_scope(self.transformer.name): + self.transformer.build(None) + + +# INTERFACE FOR ENCODER AND TASK SPECIFIC MODEL # +class TFDistilBertPreTrainedModel(TFPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = DistilBertConfig + base_model_prefix = "distilbert" + + +DISTILBERT_START_DOCSTRING = r""" + + This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it + as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and + behavior. + + + + TensorFlow models and layers in `transformers` accept two formats as input: + + - having all inputs as keyword arguments (like PyTorch models), or + - having all inputs as a list, tuple or dict in the first positional argument. + + The reason the second format is supported is that Keras methods prefer this format when passing inputs to models + and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just + pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second + format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with + the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first + positional argument: + + - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` + - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: + `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` + - a dictionary with one or several input Tensors associated to the input names given in the docstring: + `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` + + Note that when creating models and layers with + [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry + about any of this, as you can just pass inputs like you would to any other Python function! + + + + Parameters: + config ([`DistilBertConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +DISTILBERT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`Numpy array` or `tf.Tensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and + [`PreTrainedTokenizer.encode`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`Numpy array` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`tf.Tensor` of shape `({0}, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the + config will be used instead. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. This argument can be used only in eager mode, in graph mode the value in the config will be + used instead. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in + eager mode, in graph mode the value will always be set to True. + training (`bool`, *optional*, defaults to `False`): + Whether or not to use the model in training mode (some modules like dropout modules have different + behaviors between training and evaluation). +""" + + +@add_start_docstrings( + "The bare DistilBERT encoder/transformer outputting raw hidden-states without any specific head on top.", + DISTILBERT_START_DOCSTRING, +) +class TFDistilBertModel(TFDistilBertPreTrainedModel): + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.distilbert = TFDistilBertMainLayer(config, name="distilbert") # Embeddings + + @unpack_inputs + @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFBaseModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: Optional[bool] = False, + ) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]: + outputs = self.distilbert( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "distilbert", None) is not None: + with tf.name_scope(self.distilbert.name): + self.distilbert.build(None) + + +class TFDistilBertLMHead(keras.layers.Layer): + def __init__(self, config, input_embeddings, **kwargs): + super().__init__(**kwargs) + + self.config = config + self.dim = config.dim + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.input_embeddings = input_embeddings + + def build(self, input_shape): + self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias") + + super().build(input_shape) + + def get_output_embeddings(self): + return self.input_embeddings + + def set_output_embeddings(self, value): + self.input_embeddings.weight = value + self.input_embeddings.vocab_size = shape_list(value)[0] + + def get_bias(self): + return {"bias": self.bias} + + def set_bias(self, value): + self.bias = value["bias"] + self.config.vocab_size = shape_list(value["bias"])[0] + + def call(self, hidden_states): + seq_length = shape_list(tensor=hidden_states)[1] + hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.dim]) + hidden_states = tf.matmul(a=hidden_states, b=self.input_embeddings.weight, transpose_b=True) + hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.config.vocab_size]) + hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.bias) + + return hidden_states + + +@add_start_docstrings( + """DistilBert Model with a `masked language modeling` head on top.""", + DISTILBERT_START_DOCSTRING, +) +class TFDistilBertForMaskedLM(TFDistilBertPreTrainedModel, TFMaskedLanguageModelingLoss): + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.config = config + + self.distilbert = TFDistilBertMainLayer(config, name="distilbert") + self.vocab_transform = keras.layers.Dense( + config.dim, kernel_initializer=get_initializer(config.initializer_range), name="vocab_transform" + ) + self.act = get_tf_activation(config.activation) + self.vocab_layer_norm = keras.layers.LayerNormalization(epsilon=1e-12, name="vocab_layer_norm") + self.vocab_projector = TFDistilBertLMHead(config, self.distilbert.embeddings, name="vocab_projector") + + def get_lm_head(self): + return self.vocab_projector + + def get_prefix_bias_name(self): + warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) + return self.name + "/" + self.vocab_projector.name + + @unpack_inputs + @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFMaskedLMOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: np.ndarray | tf.Tensor | None = None, + training: Optional[bool] = False, + ) -> Union[TFMaskedLMOutput, Tuple[tf.Tensor]]: + r""" + labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + """ + distilbert_output = self.distilbert( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + hidden_states = distilbert_output[0] # (bs, seq_length, dim) + prediction_logits = self.vocab_transform(hidden_states) # (bs, seq_length, dim) + prediction_logits = self.act(prediction_logits) # (bs, seq_length, dim) + prediction_logits = self.vocab_layer_norm(prediction_logits) # (bs, seq_length, dim) + prediction_logits = self.vocab_projector(prediction_logits) + + loss = None if labels is None else self.hf_compute_loss(labels, prediction_logits) + + if not return_dict: + output = (prediction_logits,) + distilbert_output[1:] + return ((loss,) + output) if loss is not None else output + + return TFMaskedLMOutput( + loss=loss, + logits=prediction_logits, + hidden_states=distilbert_output.hidden_states, + attentions=distilbert_output.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "distilbert", None) is not None: + with tf.name_scope(self.distilbert.name): + self.distilbert.build(None) + if getattr(self, "vocab_transform", None) is not None: + with tf.name_scope(self.vocab_transform.name): + self.vocab_transform.build([None, None, self.config.dim]) + if getattr(self, "vocab_layer_norm", None) is not None: + with tf.name_scope(self.vocab_layer_norm.name): + self.vocab_layer_norm.build([None, None, self.config.dim]) + if getattr(self, "vocab_projector", None) is not None: + with tf.name_scope(self.vocab_projector.name): + self.vocab_projector.build(None) + + +@add_start_docstrings( + """ + DistilBert Model transformer with a sequence classification/regression head on top (a linear layer on top of the + pooled output) e.g. for GLUE tasks. + """, + DISTILBERT_START_DOCSTRING, +) +class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel, TFSequenceClassificationLoss): + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.num_labels = config.num_labels + + self.distilbert = TFDistilBertMainLayer(config, name="distilbert") + self.pre_classifier = keras.layers.Dense( + config.dim, + kernel_initializer=get_initializer(config.initializer_range), + activation="relu", + name="pre_classifier", + ) + self.classifier = keras.layers.Dense( + config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" + ) + self.dropout = keras.layers.Dropout(config.seq_classif_dropout) + self.config = config + + @unpack_inputs + @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFSequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: np.ndarray | tf.Tensor | None = None, + training: Optional[bool] = False, + ) -> Union[TFSequenceClassifierOutput, Tuple[tf.Tensor]]: + r""" + labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + distilbert_output = self.distilbert( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + hidden_state = distilbert_output[0] # (bs, seq_len, dim) + pooled_output = hidden_state[:, 0] # (bs, dim) + pooled_output = self.pre_classifier(pooled_output) # (bs, dim) + pooled_output = self.dropout(pooled_output, training=training) # (bs, dim) + logits = self.classifier(pooled_output) # (bs, dim) + + loss = None if labels is None else self.hf_compute_loss(labels, logits) + + if not return_dict: + output = (logits,) + distilbert_output[1:] + return ((loss,) + output) if loss is not None else output + + return TFSequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=distilbert_output.hidden_states, + attentions=distilbert_output.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "distilbert", None) is not None: + with tf.name_scope(self.distilbert.name): + self.distilbert.build(None) + if getattr(self, "pre_classifier", None) is not None: + with tf.name_scope(self.pre_classifier.name): + self.pre_classifier.build([None, None, self.config.dim]) + if getattr(self, "classifier", None) is not None: + with tf.name_scope(self.classifier.name): + self.classifier.build([None, None, self.config.dim]) + + +@add_start_docstrings( + """ + DistilBert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. + for Named-Entity-Recognition (NER) tasks. + """, + DISTILBERT_START_DOCSTRING, +) +class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel, TFTokenClassificationLoss): + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.num_labels = config.num_labels + + self.distilbert = TFDistilBertMainLayer(config, name="distilbert") + self.dropout = keras.layers.Dropout(config.dropout) + self.classifier = keras.layers.Dense( + config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" + ) + self.config = config + + @unpack_inputs + @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFTokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: np.ndarray | tf.Tensor | None = None, + training: Optional[bool] = False, + ) -> Union[TFTokenClassifierOutput, Tuple[tf.Tensor]]: + r""" + labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + outputs = self.distilbert( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output, training=training) + logits = self.classifier(sequence_output) + loss = None if labels is None else self.hf_compute_loss(labels, logits) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return TFTokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "distilbert", None) is not None: + with tf.name_scope(self.distilbert.name): + self.distilbert.build(None) + if getattr(self, "classifier", None) is not None: + with tf.name_scope(self.classifier.name): + self.classifier.build([None, None, self.config.hidden_size]) + + +@add_start_docstrings( + """ + DistilBert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and + a softmax) e.g. for RocStories/SWAG tasks. + """, + DISTILBERT_START_DOCSTRING, +) +class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoiceLoss): + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.distilbert = TFDistilBertMainLayer(config, name="distilbert") + self.dropout = keras.layers.Dropout(config.seq_classif_dropout) + self.pre_classifier = keras.layers.Dense( + config.dim, + kernel_initializer=get_initializer(config.initializer_range), + activation="relu", + name="pre_classifier", + ) + self.classifier = keras.layers.Dense( + 1, kernel_initializer=get_initializer(config.initializer_range), name="classifier" + ) + self.config = config + + @unpack_inputs + @add_start_docstrings_to_model_forward( + DISTILBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length") + ) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFMultipleChoiceModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: np.ndarray | tf.Tensor | None = None, + training: Optional[bool] = False, + ) -> Union[TFMultipleChoiceModelOutput, Tuple[tf.Tensor]]: + r""" + labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` + where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above) + """ + if input_ids is not None: + num_choices = shape_list(input_ids)[1] + seq_length = shape_list(input_ids)[2] + else: + num_choices = shape_list(inputs_embeds)[1] + seq_length = shape_list(inputs_embeds)[2] + + flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None + flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None + flat_inputs_embeds = ( + tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3])) + if inputs_embeds is not None + else None + ) + distilbert_output = self.distilbert( + flat_input_ids, + flat_attention_mask, + head_mask, + flat_inputs_embeds, + output_attentions, + output_hidden_states, + return_dict=return_dict, + training=training, + ) + hidden_state = distilbert_output[0] # (bs, seq_len, dim) + pooled_output = hidden_state[:, 0] # (bs, dim) + pooled_output = self.pre_classifier(pooled_output) # (bs, dim) + pooled_output = self.dropout(pooled_output, training=training) # (bs, dim) + logits = self.classifier(pooled_output) + reshaped_logits = tf.reshape(logits, (-1, num_choices)) + + loss = None if labels is None else self.hf_compute_loss(labels, reshaped_logits) + + if not return_dict: + output = (reshaped_logits,) + distilbert_output[1:] + return ((loss,) + output) if loss is not None else output + + return TFMultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=distilbert_output.hidden_states, + attentions=distilbert_output.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "distilbert", None) is not None: + with tf.name_scope(self.distilbert.name): + self.distilbert.build(None) + if getattr(self, "pre_classifier", None) is not None: + with tf.name_scope(self.pre_classifier.name): + self.pre_classifier.build([None, None, self.config.dim]) + if getattr(self, "classifier", None) is not None: + with tf.name_scope(self.classifier.name): + self.classifier.build([None, None, self.config.dim]) + + +@add_start_docstrings( + """ + DistilBert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a + linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + DISTILBERT_START_DOCSTRING, +) +class TFDistilBertForQuestionAnswering(TFDistilBertPreTrainedModel, TFQuestionAnsweringLoss): + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.distilbert = TFDistilBertMainLayer(config, name="distilbert") + self.qa_outputs = keras.layers.Dense( + config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs" + ) + assert config.num_labels == 2, f"Incorrect number of labels {config.num_labels} instead of 2" + self.dropout = keras.layers.Dropout(config.qa_dropout) + self.config = config + + @unpack_inputs + @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFQuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + start_positions: np.ndarray | tf.Tensor | None = None, + end_positions: np.ndarray | tf.Tensor | None = None, + training: Optional[bool] = False, + ) -> Union[TFQuestionAnsweringModelOutput, Tuple[tf.Tensor]]: + r""" + start_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + distilbert_output = self.distilbert( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + hidden_states = distilbert_output[0] # (bs, max_query_len, dim) + hidden_states = self.dropout(hidden_states, training=training) # (bs, max_query_len, dim) + logits = self.qa_outputs(hidden_states) # (bs, max_query_len, 2) + start_logits, end_logits = tf.split(logits, 2, axis=-1) + start_logits = tf.squeeze(start_logits, axis=-1) + end_logits = tf.squeeze(end_logits, axis=-1) + + loss = None + if start_positions is not None and end_positions is not None: + labels = {"start_position": start_positions} + labels["end_position"] = end_positions + loss = self.hf_compute_loss(labels, (start_logits, end_logits)) + + if not return_dict: + output = (start_logits, end_logits) + distilbert_output[1:] + return ((loss,) + output) if loss is not None else output + + return TFQuestionAnsweringModelOutput( + loss=loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=distilbert_output.hidden_states, + attentions=distilbert_output.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "distilbert", None) is not None: + with tf.name_scope(self.distilbert.name): + self.distilbert.build(None) + if getattr(self, "qa_outputs", None) is not None: + with tf.name_scope(self.qa_outputs.name): + self.qa_outputs.build([None, None, self.config.dim]) + + +__all__ = [ + "TFDistilBertForMaskedLM", + "TFDistilBertForMultipleChoice", + "TFDistilBertForQuestionAnswering", + "TFDistilBertForSequenceClassification", + "TFDistilBertForTokenClassification", + "TFDistilBertMainLayer", + "TFDistilBertModel", + "TFDistilBertPreTrainedModel", +] diff --git a/docs/transformers/build/lib/transformers/models/distilbert/tokenization_distilbert.py b/docs/transformers/build/lib/transformers/models/distilbert/tokenization_distilbert.py new file mode 100644 index 0000000000000000000000000000000000000000..c894211a2e0acf2bd82858186e88f7e3e99f672e --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/distilbert/tokenization_distilbert.py @@ -0,0 +1,522 @@ +# coding=utf-8 +# Copyright 2018 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for DistilBERT.""" + +import collections +import os +import unicodedata +from typing import List, Optional, Tuple + +from ...tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace +from ...utils import logging + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} + + +# Copied from transformers.models.bert.tokenization_bert.load_vocab +def load_vocab(vocab_file): + """Loads a vocabulary file into a dictionary.""" + vocab = collections.OrderedDict() + with open(vocab_file, "r", encoding="utf-8") as reader: + tokens = reader.readlines() + for index, token in enumerate(tokens): + token = token.rstrip("\n") + vocab[token] = index + return vocab + + +# Copied from transformers.models.bert.tokenization_bert.whitespace_tokenize +def whitespace_tokenize(text): + """Runs basic whitespace cleaning and splitting on a piece of text.""" + text = text.strip() + if not text: + return [] + tokens = text.split() + return tokens + + +class DistilBertTokenizer(PreTrainedTokenizer): + r""" + Construct a DistilBERT tokenizer. Based on WordPiece. + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + File containing the vocabulary. + do_lower_case (`bool`, *optional*, defaults to `True`): + Whether or not to lowercase the input when tokenizing. + do_basic_tokenize (`bool`, *optional*, defaults to `True`): + Whether or not to do basic tokenization before WordPiece. + never_split (`Iterable`, *optional*): + Collection of tokens which will never be split during tokenization. Only has an effect when + `do_basic_tokenize=True` + unk_token (`str`, *optional*, defaults to `"[UNK]"`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + sep_token (`str`, *optional*, defaults to `"[SEP]"`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + pad_token (`str`, *optional*, defaults to `"[PAD]"`): + The token used for padding, for example when batching sequences of different lengths. + cls_token (`str`, *optional*, defaults to `"[CLS]"`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + mask_token (`str`, *optional*, defaults to `"[MASK]"`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + tokenize_chinese_chars (`bool`, *optional*, defaults to `True`): + Whether or not to tokenize Chinese characters. + + This should likely be deactivated for Japanese (see this + [issue](https://github.com/huggingface/transformers/issues/328)). + strip_accents (`bool`, *optional*): + Whether or not to strip all accents. If this option is not specified, then it will be determined by the + value for `lowercase` (as in the original BERT). + clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`): + Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like + extra spaces. + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + do_lower_case=True, + do_basic_tokenize=True, + never_split=None, + unk_token="[UNK]", + sep_token="[SEP]", + pad_token="[PAD]", + cls_token="[CLS]", + mask_token="[MASK]", + tokenize_chinese_chars=True, + strip_accents=None, + clean_up_tokenization_spaces=True, + **kwargs, + ): + if not os.path.isfile(vocab_file): + raise ValueError( + f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained" + " model use `tokenizer = DistilBertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`" + ) + self.vocab = load_vocab(vocab_file) + self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()]) + self.do_basic_tokenize = do_basic_tokenize + if do_basic_tokenize: + self.basic_tokenizer = BasicTokenizer( + do_lower_case=do_lower_case, + never_split=never_split, + tokenize_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, + ) + self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=str(unk_token)) + + super().__init__( + do_lower_case=do_lower_case, + do_basic_tokenize=do_basic_tokenize, + never_split=never_split, + unk_token=unk_token, + sep_token=sep_token, + pad_token=pad_token, + cls_token=cls_token, + mask_token=mask_token, + tokenize_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + **kwargs, + ) + + @property + # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.do_lower_case + def do_lower_case(self): + return self.basic_tokenizer.do_lower_case + + @property + # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.vocab_size + def vocab_size(self): + return len(self.vocab) + + # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.get_vocab + def get_vocab(self): + return dict(self.vocab, **self.added_tokens_encoder) + + # Copied from transformers.models.bert.tokenization_bert.BertTokenizer._tokenize + def _tokenize(self, text, split_special_tokens=False): + split_tokens = [] + if self.do_basic_tokenize: + for token in self.basic_tokenizer.tokenize( + text, never_split=self.all_special_tokens if not split_special_tokens else None + ): + # If the token is part of the never_split set + if token in self.basic_tokenizer.never_split: + split_tokens.append(token) + else: + split_tokens += self.wordpiece_tokenizer.tokenize(token) + else: + split_tokens = self.wordpiece_tokenizer.tokenize(text) + return split_tokens + + # Copied from transformers.models.bert.tokenization_bert.BertTokenizer._convert_token_to_id + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.vocab.get(token, self.vocab.get(self.unk_token)) + + # Copied from transformers.models.bert.tokenization_bert.BertTokenizer._convert_id_to_token + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.ids_to_tokens.get(index, self.unk_token) + + # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.convert_tokens_to_string + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + out_string = " ".join(tokens).replace(" ##", "").strip() + return out_string + + # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.build_inputs_with_special_tokens + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A BERT sequence has the following format: + + - single sequence: `[CLS] X [SEP]` + - pair of sequences: `[CLS] A [SEP] B [SEP]` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + cls = [self.cls_token_id] + sep = [self.sep_token_id] + return cls + token_ids_0 + sep + token_ids_1 + sep + + # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.get_special_tokens_mask + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + if token_ids_1 is not None: + return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1] + + # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.create_token_type_ids_from_sequences + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence + pair mask has the following format: + + ``` + 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 + | first sequence | second sequence | + ``` + + If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s). + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s). + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] + + # Copied from transformers.models.bert.tokenization_bert.BertTokenizer.save_vocabulary + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + index = 0 + if os.path.isdir(save_directory): + vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + else: + vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory + with open(vocab_file, "w", encoding="utf-8") as writer: + for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning( + f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive." + " Please check that the vocabulary is not corrupted!" + ) + index = token_index + writer.write(token + "\n") + index += 1 + return (vocab_file,) + + +# Copied from transformers.models.bert.tokenization_bert.BasicTokenizer +class BasicTokenizer: + """ + Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.). + + Args: + do_lower_case (`bool`, *optional*, defaults to `True`): + Whether or not to lowercase the input when tokenizing. + never_split (`Iterable`, *optional*): + Collection of tokens which will never be split during tokenization. Only has an effect when + `do_basic_tokenize=True` + tokenize_chinese_chars (`bool`, *optional*, defaults to `True`): + Whether or not to tokenize Chinese characters. + + This should likely be deactivated for Japanese (see this + [issue](https://github.com/huggingface/transformers/issues/328)). + strip_accents (`bool`, *optional*): + Whether or not to strip all accents. If this option is not specified, then it will be determined by the + value for `lowercase` (as in the original BERT). + do_split_on_punc (`bool`, *optional*, defaults to `True`): + In some instances we want to skip the basic punctuation splitting so that later tokenization can capture + the full context of the words, such as contractions. + """ + + def __init__( + self, + do_lower_case=True, + never_split=None, + tokenize_chinese_chars=True, + strip_accents=None, + do_split_on_punc=True, + ): + if never_split is None: + never_split = [] + self.do_lower_case = do_lower_case + self.never_split = set(never_split) + self.tokenize_chinese_chars = tokenize_chinese_chars + self.strip_accents = strip_accents + self.do_split_on_punc = do_split_on_punc + + def tokenize(self, text, never_split=None): + """ + Basic Tokenization of a piece of text. For sub-word tokenization, see WordPieceTokenizer. + + Args: + never_split (`List[str]`, *optional*) + Kept for backward compatibility purposes. Now implemented directly at the base class level (see + [`PreTrainedTokenizer.tokenize`]) List of token not to split. + """ + # union() returns a new set by concatenating the two sets. + never_split = self.never_split.union(set(never_split)) if never_split else self.never_split + text = self._clean_text(text) + + # This was added on November 1st, 2018 for the multilingual and Chinese + # models. This is also applied to the English models now, but it doesn't + # matter since the English models were not trained on any Chinese data + # and generally don't have any Chinese data in them (there are Chinese + # characters in the vocabulary because Wikipedia does have some Chinese + # words in the English Wikipedia.). + if self.tokenize_chinese_chars: + text = self._tokenize_chinese_chars(text) + # prevents treating the same character with different unicode codepoints as different characters + unicode_normalized_text = unicodedata.normalize("NFC", text) + orig_tokens = whitespace_tokenize(unicode_normalized_text) + split_tokens = [] + for token in orig_tokens: + if token not in never_split: + if self.do_lower_case: + token = token.lower() + if self.strip_accents is not False: + token = self._run_strip_accents(token) + elif self.strip_accents: + token = self._run_strip_accents(token) + split_tokens.extend(self._run_split_on_punc(token, never_split)) + + output_tokens = whitespace_tokenize(" ".join(split_tokens)) + return output_tokens + + def _run_strip_accents(self, text): + """Strips accents from a piece of text.""" + text = unicodedata.normalize("NFD", text) + output = [] + for char in text: + cat = unicodedata.category(char) + if cat == "Mn": + continue + output.append(char) + return "".join(output) + + def _run_split_on_punc(self, text, never_split=None): + """Splits punctuation on a piece of text.""" + if not self.do_split_on_punc or (never_split is not None and text in never_split): + return [text] + chars = list(text) + i = 0 + start_new_word = True + output = [] + while i < len(chars): + char = chars[i] + if _is_punctuation(char): + output.append([char]) + start_new_word = True + else: + if start_new_word: + output.append([]) + start_new_word = False + output[-1].append(char) + i += 1 + + return ["".join(x) for x in output] + + def _tokenize_chinese_chars(self, text): + """Adds whitespace around any CJK character.""" + output = [] + for char in text: + cp = ord(char) + if self._is_chinese_char(cp): + output.append(" ") + output.append(char) + output.append(" ") + else: + output.append(char) + return "".join(output) + + def _is_chinese_char(self, cp): + """Checks whether CP is the codepoint of a CJK character.""" + # This defines a "chinese character" as anything in the CJK Unicode block: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + # + # Note that the CJK Unicode block is NOT all Japanese and Korean characters, + # despite its name. The modern Korean Hangul alphabet is a different block, + # as is Japanese Hiragana and Katakana. Those alphabets are used to write + # space-separated words, so they are not treated specially and handled + # like the all of the other languages. + if ( + (cp >= 0x4E00 and cp <= 0x9FFF) + or (cp >= 0x3400 and cp <= 0x4DBF) # + or (cp >= 0x20000 and cp <= 0x2A6DF) # + or (cp >= 0x2A700 and cp <= 0x2B73F) # + or (cp >= 0x2B740 and cp <= 0x2B81F) # + or (cp >= 0x2B820 and cp <= 0x2CEAF) # + or (cp >= 0xF900 and cp <= 0xFAFF) + or (cp >= 0x2F800 and cp <= 0x2FA1F) # + ): # + return True + + return False + + def _clean_text(self, text): + """Performs invalid character removal and whitespace cleanup on text.""" + output = [] + for char in text: + cp = ord(char) + if cp == 0 or cp == 0xFFFD or _is_control(char): + continue + if _is_whitespace(char): + output.append(" ") + else: + output.append(char) + return "".join(output) + + +# Copied from transformers.models.bert.tokenization_bert.WordpieceTokenizer +class WordpieceTokenizer: + """Runs WordPiece tokenization.""" + + def __init__(self, vocab, unk_token, max_input_chars_per_word=100): + self.vocab = vocab + self.unk_token = unk_token + self.max_input_chars_per_word = max_input_chars_per_word + + def tokenize(self, text): + """ + Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform + tokenization using the given vocabulary. + + For example, `input = "unaffable"` wil return as output `["un", "##aff", "##able"]`. + + Args: + text: A single token or whitespace separated tokens. This should have + already been passed through *BasicTokenizer*. + + Returns: + A list of wordpiece tokens. + """ + + output_tokens = [] + for token in whitespace_tokenize(text): + chars = list(token) + if len(chars) > self.max_input_chars_per_word: + output_tokens.append(self.unk_token) + continue + + is_bad = False + start = 0 + sub_tokens = [] + while start < len(chars): + end = len(chars) + cur_substr = None + while start < end: + substr = "".join(chars[start:end]) + if start > 0: + substr = "##" + substr + if substr in self.vocab: + cur_substr = substr + break + end -= 1 + if cur_substr is None: + is_bad = True + break + sub_tokens.append(cur_substr) + start = end + + if is_bad: + output_tokens.append(self.unk_token) + else: + output_tokens.extend(sub_tokens) + return output_tokens + + +__all__ = ["DistilBertTokenizer"] diff --git a/docs/transformers/build/lib/transformers/models/distilbert/tokenization_distilbert_fast.py b/docs/transformers/build/lib/transformers/models/distilbert/tokenization_distilbert_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..d3829763d5e7ab8e2a338c53a0f7dd50c3e4b737 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/distilbert/tokenization_distilbert_fast.py @@ -0,0 +1,179 @@ +# coding=utf-8 +# Copyright 2018 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for DistilBERT.""" + +import json +from typing import List, Optional, Tuple + +from tokenizers import normalizers + +from ...tokenization_utils_fast import PreTrainedTokenizerFast +from ...utils import logging +from .tokenization_distilbert import DistilBertTokenizer + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.json"} + + +class DistilBertTokenizerFast(PreTrainedTokenizerFast): + r""" + Construct a "fast" DistilBERT tokenizer (backed by HuggingFace's *tokenizers* library). Based on WordPiece. + + This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should + refer to this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + File containing the vocabulary. + do_lower_case (`bool`, *optional*, defaults to `True`): + Whether or not to lowercase the input when tokenizing. + unk_token (`str`, *optional*, defaults to `"[UNK]"`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + sep_token (`str`, *optional*, defaults to `"[SEP]"`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + pad_token (`str`, *optional*, defaults to `"[PAD]"`): + The token used for padding, for example when batching sequences of different lengths. + cls_token (`str`, *optional*, defaults to `"[CLS]"`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + mask_token (`str`, *optional*, defaults to `"[MASK]"`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + clean_text (`bool`, *optional*, defaults to `True`): + Whether or not to clean the text before tokenization by removing any control characters and replacing all + whitespaces by the classic one. + tokenize_chinese_chars (`bool`, *optional*, defaults to `True`): + Whether or not to tokenize Chinese characters. This should likely be deactivated for Japanese (see [this + issue](https://github.com/huggingface/transformers/issues/328)). + strip_accents (`bool`, *optional*): + Whether or not to strip all accents. If this option is not specified, then it will be determined by the + value for `lowercase` (as in the original BERT). + wordpieces_prefix (`str`, *optional*, defaults to `"##"`): + The prefix for subwords. + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + slow_tokenizer_class = DistilBertTokenizer + + def __init__( + self, + vocab_file=None, + tokenizer_file=None, + do_lower_case=True, + unk_token="[UNK]", + sep_token="[SEP]", + pad_token="[PAD]", + cls_token="[CLS]", + mask_token="[MASK]", + tokenize_chinese_chars=True, + strip_accents=None, + **kwargs, + ): + super().__init__( + vocab_file, + tokenizer_file=tokenizer_file, + do_lower_case=do_lower_case, + unk_token=unk_token, + sep_token=sep_token, + pad_token=pad_token, + cls_token=cls_token, + mask_token=mask_token, + tokenize_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, + **kwargs, + ) + + normalizer_state = json.loads(self.backend_tokenizer.normalizer.__getstate__()) + if ( + normalizer_state.get("lowercase", do_lower_case) != do_lower_case + or normalizer_state.get("strip_accents", strip_accents) != strip_accents + or normalizer_state.get("handle_chinese_chars", tokenize_chinese_chars) != tokenize_chinese_chars + ): + normalizer_class = getattr(normalizers, normalizer_state.pop("type")) + normalizer_state["lowercase"] = do_lower_case + normalizer_state["strip_accents"] = strip_accents + normalizer_state["handle_chinese_chars"] = tokenize_chinese_chars + self.backend_tokenizer.normalizer = normalizer_class(**normalizer_state) + + self.do_lower_case = do_lower_case + + # Copied from transformers.models.bert.tokenization_bert_fast.BertTokenizerFast.build_inputs_with_special_tokens + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A BERT sequence has the following format: + + - single sequence: `[CLS] X [SEP]` + - pair of sequences: `[CLS] A [SEP] B [SEP]` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + + if token_ids_1 is not None: + output += token_ids_1 + [self.sep_token_id] + + return output + + # Copied from transformers.models.bert.tokenization_bert_fast.BertTokenizerFast.create_token_type_ids_from_sequences + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence + pair mask has the following format: + + ``` + 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 + | first sequence | second sequence | + ``` + + If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s). + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s). + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] + + # Copied from transformers.models.bert.tokenization_bert_fast.BertTokenizerFast.save_vocabulary + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + files = self._tokenizer.model.save(save_directory, name=filename_prefix) + return tuple(files) + + +__all__ = ["DistilBertTokenizerFast"] diff --git a/docs/transformers/build/lib/transformers/models/dit/__init__.py b/docs/transformers/build/lib/transformers/models/dit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/docs/transformers/build/lib/transformers/models/dit/convert_dit_unilm_to_pytorch.py b/docs/transformers/build/lib/transformers/models/dit/convert_dit_unilm_to_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..40c5b22e3b9a2dd2037660902febd8069ca41a7d --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/dit/convert_dit_unilm_to_pytorch.py @@ -0,0 +1,230 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert DiT checkpoints from the unilm repository.""" + +import argparse +import json +from pathlib import Path + +import requests +import torch +from huggingface_hub import hf_hub_download +from PIL import Image + +from transformers import BeitConfig, BeitForImageClassification, BeitForMaskedImageModeling, BeitImageProcessor +from transformers.image_utils import PILImageResampling +from transformers.utils import logging + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + +# here we list all keys to be renamed (original name on the left, our name on the right) +def create_rename_keys(config, has_lm_head=False, is_semantic=False): + prefix = "backbone." if is_semantic else "" + + rename_keys = [] + for i in range(config.num_hidden_layers): + # encoder layers: output projection, 2 feedforward neural networks and 2 layernorms + rename_keys.append((f"{prefix}blocks.{i}.norm1.weight", f"beit.encoder.layer.{i}.layernorm_before.weight")) + rename_keys.append((f"{prefix}blocks.{i}.norm1.bias", f"beit.encoder.layer.{i}.layernorm_before.bias")) + rename_keys.append( + (f"{prefix}blocks.{i}.attn.proj.weight", f"beit.encoder.layer.{i}.attention.output.dense.weight") + ) + rename_keys.append( + (f"{prefix}blocks.{i}.attn.proj.bias", f"beit.encoder.layer.{i}.attention.output.dense.bias") + ) + rename_keys.append((f"{prefix}blocks.{i}.norm2.weight", f"beit.encoder.layer.{i}.layernorm_after.weight")) + rename_keys.append((f"{prefix}blocks.{i}.norm2.bias", f"beit.encoder.layer.{i}.layernorm_after.bias")) + rename_keys.append((f"{prefix}blocks.{i}.mlp.fc1.weight", f"beit.encoder.layer.{i}.intermediate.dense.weight")) + rename_keys.append((f"{prefix}blocks.{i}.mlp.fc1.bias", f"beit.encoder.layer.{i}.intermediate.dense.bias")) + rename_keys.append((f"{prefix}blocks.{i}.mlp.fc2.weight", f"beit.encoder.layer.{i}.output.dense.weight")) + rename_keys.append((f"{prefix}blocks.{i}.mlp.fc2.bias", f"beit.encoder.layer.{i}.output.dense.bias")) + + # projection layer + position embeddings + rename_keys.extend( + [ + (f"{prefix}cls_token", "beit.embeddings.cls_token"), + (f"{prefix}patch_embed.proj.weight", "beit.embeddings.patch_embeddings.projection.weight"), + (f"{prefix}patch_embed.proj.bias", "beit.embeddings.patch_embeddings.projection.bias"), + (f"{prefix}pos_embed", "beit.embeddings.position_embeddings"), + ] + ) + + if has_lm_head: + # mask token + layernorm + rename_keys.extend( + [ + ("mask_token", "beit.embeddings.mask_token"), + ("norm.weight", "layernorm.weight"), + ("norm.bias", "layernorm.bias"), + ] + ) + else: + # layernorm + classification head + rename_keys.extend( + [ + ("fc_norm.weight", "beit.pooler.layernorm.weight"), + ("fc_norm.bias", "beit.pooler.layernorm.bias"), + ("head.weight", "classifier.weight"), + ("head.bias", "classifier.bias"), + ] + ) + + return rename_keys + + +# we split up the matrix of each encoder layer into queries, keys and values +def read_in_q_k_v(state_dict, config, has_lm_head=False, is_semantic=False): + for i in range(config.num_hidden_layers): + prefix = "backbone." if is_semantic else "" + # queries, keys and values + in_proj_weight = state_dict.pop(f"{prefix}blocks.{i}.attn.qkv.weight") + q_bias = state_dict.pop(f"{prefix}blocks.{i}.attn.q_bias") + v_bias = state_dict.pop(f"{prefix}blocks.{i}.attn.v_bias") + + state_dict[f"beit.encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[ + : config.hidden_size, : + ] + state_dict[f"beit.encoder.layer.{i}.attention.attention.query.bias"] = q_bias + state_dict[f"beit.encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[ + config.hidden_size : config.hidden_size * 2, : + ] + state_dict[f"beit.encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[ + -config.hidden_size :, : + ] + state_dict[f"beit.encoder.layer.{i}.attention.attention.value.bias"] = v_bias + + # gamma_1 and gamma_2 + # we call them lambda because otherwise they are renamed when using .from_pretrained + gamma_1 = state_dict.pop(f"{prefix}blocks.{i}.gamma_1") + gamma_2 = state_dict.pop(f"{prefix}blocks.{i}.gamma_2") + + state_dict[f"beit.encoder.layer.{i}.lambda_1"] = gamma_1 + state_dict[f"beit.encoder.layer.{i}.lambda_2"] = gamma_2 + + +def rename_key(dct, old, new): + val = dct.pop(old) + dct[new] = val + + +# We will verify our results on an image of cute cats +def prepare_img(): + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + im = Image.open(requests.get(url, stream=True).raw) + return im + + +@torch.no_grad() +def convert_dit_checkpoint(checkpoint_url, pytorch_dump_folder_path, push_to_hub=False): + """ + Copy/paste/tweak model's weights to our BEiT structure. + """ + + # define default BEiT configuration + has_lm_head = False if "rvlcdip" in checkpoint_url else True + config = BeitConfig(use_absolute_position_embeddings=True, use_mask_token=has_lm_head) + + # size of the architecture + if "large" in checkpoint_url or "dit-l" in checkpoint_url: + config.hidden_size = 1024 + config.intermediate_size = 4096 + config.num_hidden_layers = 24 + config.num_attention_heads = 16 + + # labels + if "rvlcdip" in checkpoint_url: + config.num_labels = 16 + repo_id = "huggingface/label-files" + filename = "rvlcdip-id2label.json" + id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r")) + id2label = {int(k): v for k, v in id2label.items()} + config.id2label = id2label + config.label2id = {v: k for k, v in id2label.items()} + + # load state_dict of original model, remove and rename some keys + state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu")["model"] + + rename_keys = create_rename_keys(config, has_lm_head=has_lm_head) + for src, dest in rename_keys: + rename_key(state_dict, src, dest) + read_in_q_k_v(state_dict, config, has_lm_head=has_lm_head) + + # load HuggingFace model + model = BeitForMaskedImageModeling(config) if has_lm_head else BeitForImageClassification(config) + model.eval() + model.load_state_dict(state_dict) + + # Check outputs on an image + image_processor = BeitImageProcessor( + size=config.image_size, resample=PILImageResampling.BILINEAR, do_center_crop=False + ) + image = prepare_img() + + encoding = image_processor(images=image, return_tensors="pt") + pixel_values = encoding["pixel_values"] + + outputs = model(pixel_values) + logits = outputs.logits + + # verify logits + expected_shape = [1, 16] if "rvlcdip" in checkpoint_url else [1, 196, 8192] + assert logits.shape == torch.Size(expected_shape), "Shape of logits not as expected" + + Path(pytorch_dump_folder_path).mkdir(exist_ok=True) + print(f"Saving model to {pytorch_dump_folder_path}") + model.save_pretrained(pytorch_dump_folder_path) + print(f"Saving image processor to {pytorch_dump_folder_path}") + image_processor.save_pretrained(pytorch_dump_folder_path) + + if push_to_hub: + if has_lm_head: + model_name = "dit-base" if "base" in checkpoint_url else "dit-large" + else: + model_name = "dit-base-finetuned-rvlcdip" if "dit-b" in checkpoint_url else "dit-large-finetuned-rvlcdip" + image_processor.push_to_hub( + repo_path_or_name=Path(pytorch_dump_folder_path, model_name), + organization="nielsr", + commit_message="Add image processor", + use_temp_dir=True, + ) + model.push_to_hub( + repo_path_or_name=Path(pytorch_dump_folder_path, model_name), + organization="nielsr", + commit_message="Add model", + use_temp_dir=True, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--checkpoint_url", + default="https://layoutlm.blob.core.windows.net/dit/dit-pts/dit-base-224-p16-500k-62d53a.pth", + type=str, + help="URL to the original PyTorch checkpoint (.pth file).", + ) + parser.add_argument( + "--pytorch_dump_folder_path", default=None, type=str, help="Path to the folder to output PyTorch model." + ) + parser.add_argument( + "--push_to_hub", + action="store_true", + ) + args = parser.parse_args() + convert_dit_checkpoint(args.checkpoint_url, args.pytorch_dump_folder_path, args.push_to_hub) diff --git a/docs/transformers/build/lib/transformers/models/donut/__init__.py b/docs/transformers/build/lib/transformers/models/donut/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..834c451f78fa0d4c5fe91f59719b6505c4c4e4e5 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/donut/__init__.py @@ -0,0 +1,31 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_donut_swin import * + from .feature_extraction_donut import * + from .image_processing_donut import * + from .image_processing_donut_fast import * + from .modeling_donut_swin import * + from .processing_donut import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/build/lib/transformers/models/donut/configuration_donut_swin.py b/docs/transformers/build/lib/transformers/models/donut/configuration_donut_swin.py new file mode 100644 index 0000000000000000000000000000000000000000..9aac07dace7688273be0bdc57da0a12663c2fb5b --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/donut/configuration_donut_swin.py @@ -0,0 +1,135 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Donut Swin Transformer model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class DonutSwinConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`DonutSwinModel`]. It is used to instantiate a + Donut model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the Donut + [naver-clova-ix/donut-base](https://huggingface.co/naver-clova-ix/donut-base) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 4): + The size (resolution) of each patch. + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + embed_dim (`int`, *optional*, defaults to 96): + Dimensionality of patch embedding. + depths (`list(int)`, *optional*, defaults to `[2, 2, 6, 2]`): + Depth of each layer in the Transformer encoder. + num_heads (`list(int)`, *optional*, defaults to `[3, 6, 12, 24]`): + Number of attention heads in each layer of the Transformer encoder. + window_size (`int`, *optional*, defaults to 7): + Size of windows. + mlp_ratio (`float`, *optional*, defaults to 4.0): + Ratio of MLP hidden dimensionality to embedding dimensionality. + qkv_bias (`bool`, *optional*, defaults to `True`): + Whether or not a learnable bias should be added to the queries, keys and values. + hidden_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings and encoder. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + drop_path_rate (`float`, *optional*, defaults to 0.1): + Stochastic depth rate. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder. If string, `"gelu"`, `"relu"`, + `"selu"` and `"gelu_new"` are supported. + use_absolute_embeddings (`bool`, *optional*, defaults to `False`): + Whether or not to add absolute position embeddings to the patch embeddings. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the layer normalization layers. + + Example: + + ```python + >>> from transformers import DonutSwinConfig, DonutSwinModel + + >>> # Initializing a Donut naver-clova-ix/donut-base style configuration + >>> configuration = DonutSwinConfig() + + >>> # Randomly initializing a model from the naver-clova-ix/donut-base style configuration + >>> model = DonutSwinModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "donut-swin" + + attribute_map = { + "num_attention_heads": "num_heads", + "num_hidden_layers": "num_layers", + } + + def __init__( + self, + image_size=224, + patch_size=4, + num_channels=3, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + mlp_ratio=4.0, + qkv_bias=True, + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + drop_path_rate=0.1, + hidden_act="gelu", + use_absolute_embeddings=False, + initializer_range=0.02, + layer_norm_eps=1e-5, + **kwargs, + ): + super().__init__(**kwargs) + + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.embed_dim = embed_dim + self.depths = depths + self.num_layers = len(depths) + self.num_heads = num_heads + self.window_size = window_size + self.mlp_ratio = mlp_ratio + self.qkv_bias = qkv_bias + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.drop_path_rate = drop_path_rate + self.hidden_act = hidden_act + self.use_absolute_embeddings = use_absolute_embeddings + self.layer_norm_eps = layer_norm_eps + self.initializer_range = initializer_range + # we set the hidden_size attribute in order to make Swin work with VisionEncoderDecoderModel + # this indicates the channel dimension after the last stage of the model + self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1)) + + +__all__ = ["DonutSwinConfig"] diff --git a/docs/transformers/build/lib/transformers/models/donut/convert_donut_to_pytorch.py b/docs/transformers/build/lib/transformers/models/donut/convert_donut_to_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..f6f14f6d08e31037389f448815242b388545fd15 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/donut/convert_donut_to_pytorch.py @@ -0,0 +1,234 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert Donut checkpoints using the original `donut-python` library. URL: https://github.com/clovaai/donut""" + +import argparse + +import torch +from datasets import load_dataset +from donut import DonutModel + +from transformers import ( + DonutImageProcessor, + DonutProcessor, + DonutSwinConfig, + DonutSwinModel, + MBartConfig, + MBartForCausalLM, + VisionEncoderDecoderModel, + XLMRobertaTokenizerFast, +) + + +def get_configs(model): + original_config = model.config + + encoder_config = DonutSwinConfig( + image_size=original_config.input_size, + patch_size=4, + depths=original_config.encoder_layer, + num_heads=[4, 8, 16, 32], + window_size=original_config.window_size, + embed_dim=128, + ) + decoder_config = MBartConfig( + is_decoder=True, + is_encoder_decoder=False, + add_cross_attention=True, + decoder_layers=original_config.decoder_layer, + max_position_embeddings=original_config.max_position_embeddings, + vocab_size=len( + model.decoder.tokenizer + ), # several special tokens are added to the vocab of XLMRobertaTokenizer, see repo on the hub (added_tokens.json) + scale_embedding=True, + add_final_layer_norm=True, + ) + + return encoder_config, decoder_config + + +def rename_key(name): + if "encoder.model" in name: + name = name.replace("encoder.model", "encoder") + if "decoder.model" in name: + name = name.replace("decoder.model", "decoder") + if "patch_embed.proj" in name: + name = name.replace("patch_embed.proj", "embeddings.patch_embeddings.projection") + if "patch_embed.norm" in name: + name = name.replace("patch_embed.norm", "embeddings.norm") + if name.startswith("encoder"): + if "layers" in name: + name = "encoder." + name + if "attn.proj" in name: + name = name.replace("attn.proj", "attention.output.dense") + if "attn" in name and "mask" not in name: + name = name.replace("attn", "attention.self") + if "norm1" in name: + name = name.replace("norm1", "layernorm_before") + if "norm2" in name: + name = name.replace("norm2", "layernorm_after") + if "mlp.fc1" in name: + name = name.replace("mlp.fc1", "intermediate.dense") + if "mlp.fc2" in name: + name = name.replace("mlp.fc2", "output.dense") + + if name == "encoder.norm.weight": + name = "encoder.layernorm.weight" + if name == "encoder.norm.bias": + name = "encoder.layernorm.bias" + + return name + + +def convert_state_dict(orig_state_dict, model): + for key in orig_state_dict.copy().keys(): + val = orig_state_dict.pop(key) + + if "qkv" in key: + key_split = key.split(".") + layer_num = int(key_split[3]) + block_num = int(key_split[5]) + dim = model.encoder.encoder.layers[layer_num].blocks[block_num].attention.self.all_head_size + + if "weight" in key: + orig_state_dict[ + f"encoder.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.query.weight" + ] = val[:dim, :] + orig_state_dict[f"encoder.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.key.weight"] = ( + val[dim : dim * 2, :] + ) + orig_state_dict[ + f"encoder.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.value.weight" + ] = val[-dim:, :] + else: + orig_state_dict[f"encoder.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.query.bias"] = ( + val[:dim] + ) + orig_state_dict[f"encoder.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.key.bias"] = ( + val[dim : dim * 2] + ) + orig_state_dict[f"encoder.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.value.bias"] = ( + val[-dim:] + ) + elif "attn_mask" in key or key in ["encoder.model.norm.weight", "encoder.model.norm.bias"]: + # HuggingFace implementation doesn't use attn_mask buffer + # and model doesn't use final LayerNorms for the encoder + pass + else: + orig_state_dict[rename_key(key)] = val + + return orig_state_dict + + +def convert_donut_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_hub=False): + # load original model + original_model = DonutModel.from_pretrained(model_name).eval() + + # load HuggingFace model + encoder_config, decoder_config = get_configs(original_model) + encoder = DonutSwinModel(encoder_config) + decoder = MBartForCausalLM(decoder_config) + model = VisionEncoderDecoderModel(encoder=encoder, decoder=decoder) + model.eval() + + state_dict = original_model.state_dict() + new_state_dict = convert_state_dict(state_dict, model) + model.load_state_dict(new_state_dict) + + # verify results on scanned document + dataset = load_dataset("hf-internal-testing/example-documents") # no-script + image = dataset["test"][0]["image"].convert("RGB") + + tokenizer = XLMRobertaTokenizerFast.from_pretrained(model_name, from_slow=True) + image_processor = DonutImageProcessor( + do_align_long_axis=original_model.config.align_long_axis, size=original_model.config.input_size[::-1] + ) + processor = DonutProcessor(image_processor, tokenizer) + pixel_values = processor(image, return_tensors="pt").pixel_values + + if model_name == "naver-clova-ix/donut-base-finetuned-docvqa": + task_prompt = "{user_input}" + question = "When is the coffee break?" + task_prompt = task_prompt.replace("{user_input}", question) + elif model_name == "naver-clova-ix/donut-base-finetuned-rvlcdip": + task_prompt = "" + elif model_name in [ + "naver-clova-ix/donut-base-finetuned-cord-v1", + "naver-clova-ix/donut-base-finetuned-cord-v1-2560", + ]: + task_prompt = "" + elif model_name == "naver-clova-ix/donut-base-finetuned-cord-v2": + task_prompt = "s_cord-v2>" + elif model_name == "naver-clova-ix/donut-base-finetuned-zhtrainticket": + task_prompt = "" + elif model_name in ["naver-clova-ix/donut-proto", "naver-clova-ix/donut-base"]: + # use a random prompt + task_prompt = "hello world" + else: + raise ValueError("Model name not supported") + prompt_tensors = original_model.decoder.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt")[ + "input_ids" + ] + + original_patch_embed = original_model.encoder.model.patch_embed(pixel_values) + patch_embeddings, _ = model.encoder.embeddings(pixel_values) + assert torch.allclose(original_patch_embed, patch_embeddings, atol=1e-3) + + # verify encoder hidden states + original_last_hidden_state = original_model.encoder(pixel_values) + last_hidden_state = model.encoder(pixel_values).last_hidden_state + assert torch.allclose(original_last_hidden_state, last_hidden_state, atol=1e-2) + + # verify decoder hidden states + original_logits = original_model(pixel_values, prompt_tensors, None).logits + logits = model(pixel_values, decoder_input_ids=prompt_tensors).logits + assert torch.allclose(original_logits, logits, atol=1e-3) + print("Looks ok!") + + if pytorch_dump_folder_path is not None: + print(f"Saving model and processor to {pytorch_dump_folder_path}") + model.save_pretrained(pytorch_dump_folder_path) + processor.save_pretrained(pytorch_dump_folder_path) + + if push_to_hub: + model.push_to_hub("nielsr/" + model_name.split("/")[-1], commit_message="Update model") + processor.push_to_hub("nielsr/" + model_name.split("/")[-1], commit_message="Update model") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--model_name", + default="naver-clova-ix/donut-base-finetuned-docvqa", + required=False, + type=str, + help="Name of the original model you'd like to convert.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", + default=None, + required=False, + type=str, + help="Path to the output PyTorch model directory.", + ) + parser.add_argument( + "--push_to_hub", + action="store_true", + help="Whether or not to push the converted model and processor to the 🤗 hub.", + ) + + args = parser.parse_args() + convert_donut_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub) diff --git a/docs/transformers/build/lib/transformers/models/donut/feature_extraction_donut.py b/docs/transformers/build/lib/transformers/models/donut/feature_extraction_donut.py new file mode 100644 index 0000000000000000000000000000000000000000..e37a58ddd3055e040c6c29cbd5f5cc3c34270cbe --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/donut/feature_extraction_donut.py @@ -0,0 +1,38 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Feature extractor class for Donut.""" + +import warnings + +from ...utils import logging +from ...utils.import_utils import requires +from .image_processing_donut import DonutImageProcessor + + +logger = logging.get_logger(__name__) + + +@requires(backends=("vision",)) +class DonutFeatureExtractor(DonutImageProcessor): + def __init__(self, *args, **kwargs) -> None: + warnings.warn( + "The class DonutFeatureExtractor is deprecated and will be removed in version 5 of Transformers. Please" + " use DonutImageProcessor instead.", + FutureWarning, + ) + super().__init__(*args, **kwargs) + + +__all__ = ["DonutFeatureExtractor"] diff --git a/docs/transformers/build/lib/transformers/models/donut/image_processing_donut.py b/docs/transformers/build/lib/transformers/models/donut/image_processing_donut.py new file mode 100644 index 0000000000000000000000000000000000000000..72d051859a70d2cd83c1fc5c538361534aa82e7d --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/donut/image_processing_donut.py @@ -0,0 +1,477 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for Donut.""" + +from typing import Dict, List, Optional, Union + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import ( + convert_to_rgb, + get_resize_output_image_size, + pad, + resize, + to_channel_dimension_format, +) +from ...image_utils import ( + IMAGENET_STANDARD_MEAN, + IMAGENET_STANDARD_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + get_image_size, + infer_channel_dimension_format, + is_scaled_image, + make_list_of_images, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...utils import TensorType, filter_out_non_signature_kwargs, logging +from ...utils.import_utils import is_vision_available, requires + + +logger = logging.get_logger(__name__) + + +if is_vision_available(): + import PIL + + +@requires(backends=("vision",)) +class DonutImageProcessor(BaseImageProcessor): + r""" + Constructs a Donut image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by + `do_resize` in the `preprocess` method. + size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 224}`): + Size of the image after resizing. The shortest edge of the image is resized to size["shortest_edge"], with + the longest edge resized to keep the input aspect ratio. Can be overridden by `size` in the `preprocess` + method. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`): + Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method. + do_thumbnail (`bool`, *optional*, defaults to `True`): + Whether to resize the image using thumbnail method. + do_align_long_axis (`bool`, *optional*, defaults to `False`): + Whether to align the long axis of the image with the long axis of `size` by rotating by 90 degrees. + do_pad (`bool`, *optional*, defaults to `True`): + Whether to pad the image. If `random_padding` is set to `True` in `preprocess`, each image is padded with a + random amont of padding on each size, up to the largest image size in the batch. Otherwise, all images are + padded to the largest image size in the batch. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in + the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess` + method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method. + image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`): + Image standard deviation. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize: bool = True, + size: Dict[str, int] = None, + resample: PILImageResampling = PILImageResampling.BILINEAR, + do_thumbnail: bool = True, + do_align_long_axis: bool = False, + do_pad: bool = True, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + size = size if size is not None else {"height": 2560, "width": 1920} + if isinstance(size, (tuple, list)): + # The previous feature extractor size parameter was in (width, height) format + size = size[::-1] + size = get_size_dict(size) + + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_thumbnail = do_thumbnail + self.do_align_long_axis = do_align_long_axis + self.do_pad = do_pad + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN + self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD + + def align_long_axis( + self, + image: np.ndarray, + size: Dict[str, int], + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> np.ndarray: + """ + Align the long axis of the image to the longest axis of the specified size. + + Args: + image (`np.ndarray`): + The image to be aligned. + size (`Dict[str, int]`): + The size `{"height": h, "width": w}` to align the long axis to. + data_format (`str` or `ChannelDimension`, *optional*): + The data format of the output image. If unset, the same format as the input image is used. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + + Returns: + `np.ndarray`: The aligned image. + """ + input_height, input_width = get_image_size(image, channel_dim=input_data_format) + output_height, output_width = size["height"], size["width"] + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(image) + + if input_data_format == ChannelDimension.LAST: + rot_axes = (0, 1) + elif input_data_format == ChannelDimension.FIRST: + rot_axes = (1, 2) + else: + raise ValueError(f"Unsupported data format: {input_data_format}") + + if (output_width < output_height and input_width > input_height) or ( + output_width > output_height and input_width < input_height + ): + image = np.rot90(image, 3, axes=rot_axes) + + if data_format is not None: + image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + + return image + + def pad_image( + self, + image: np.ndarray, + size: Dict[str, int], + random_padding: bool = False, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> np.ndarray: + """ + Pad the image to the specified size. + + Args: + image (`np.ndarray`): + The image to be padded. + size (`Dict[str, int]`): + The size `{"height": h, "width": w}` to pad the image to. + random_padding (`bool`, *optional*, defaults to `False`): + Whether to use random padding or not. + data_format (`str` or `ChannelDimension`, *optional*): + The data format of the output image. If unset, the same format as the input image is used. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + output_height, output_width = size["height"], size["width"] + input_height, input_width = get_image_size(image, channel_dim=input_data_format) + + delta_width = output_width - input_width + delta_height = output_height - input_height + + if random_padding: + pad_top = np.random.randint(low=0, high=delta_height + 1) + pad_left = np.random.randint(low=0, high=delta_width + 1) + else: + pad_top = delta_height // 2 + pad_left = delta_width // 2 + + pad_bottom = delta_height - pad_top + pad_right = delta_width - pad_left + + padding = ((pad_top, pad_bottom), (pad_left, pad_right)) + return pad(image, padding, data_format=data_format, input_data_format=input_data_format) + + def pad(self, *args, **kwargs): + logger.info("pad is deprecated and will be removed in version 4.27. Please use pad_image instead.") + return self.pad_image(*args, **kwargs) + + def thumbnail( + self, + image: np.ndarray, + size: Dict[str, int], + resample: PILImageResampling = PILImageResampling.BICUBIC, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize the image to make a thumbnail. The image is resized so that no dimension is larger than any + corresponding dimension of the specified size. + + Args: + image (`np.ndarray`): + The image to be resized. + size (`Dict[str, int]`): + The size `{"height": h, "width": w}` to resize the image to. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + The resampling filter to use. + data_format (`Optional[Union[str, ChannelDimension]]`, *optional*): + The data format of the output image. If unset, the same format as the input image is used. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + input_height, input_width = get_image_size(image, channel_dim=input_data_format) + output_height, output_width = size["height"], size["width"] + + # We always resize to the smallest of either the input or output size. + height = min(input_height, output_height) + width = min(input_width, output_width) + + if height == input_height and width == input_width: + return image + + if input_height > input_width: + width = int(input_width * height / input_height) + elif input_width > input_height: + height = int(input_height * width / input_width) + + return resize( + image, + size=(height, width), + resample=resample, + reducing_gap=2.0, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + def resize( + self, + image: np.ndarray, + size: Dict[str, int], + resample: PILImageResampling = PILImageResampling.BICUBIC, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resizes `image` to `(height, width)` specified by `size` using the PIL library. + + Args: + image (`np.ndarray`): + Image to resize. + size (`Dict[str, int]`): + Size of the output image. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + Resampling filter to use when resiizing the image. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + size = get_size_dict(size) + shortest_edge = min(size["height"], size["width"]) + output_size = get_resize_output_image_size( + image, size=shortest_edge, default_to_square=False, input_data_format=input_data_format + ) + resized_image = resize( + image, + size=output_size, + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + return resized_image + + @filter_out_non_signature_kwargs() + def preprocess( + self, + images: ImageInput, + do_resize: Optional[bool] = None, + size: Dict[str, int] = None, + resample: PILImageResampling = None, + do_thumbnail: Optional[bool] = None, + do_align_long_axis: Optional[bool] = None, + do_pad: Optional[bool] = None, + random_padding: bool = False, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> PIL.Image.Image: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image after resizing. Shortest edge of the image is resized to min(size["height"], + size["width"]) with the longest edge resized to keep the input aspect ratio. + resample (`int`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only + has an effect if `do_resize` is set to `True`. + do_thumbnail (`bool`, *optional*, defaults to `self.do_thumbnail`): + Whether to resize the image using thumbnail method. + do_align_long_axis (`bool`, *optional*, defaults to `self.do_align_long_axis`): + Whether to align the long axis of the image with the long axis of `size` by rotating by 90 degrees. + do_pad (`bool`, *optional*, defaults to `self.do_pad`): + Whether to pad the image. If `random_padding` is set to `True`, each image is padded with a random + amont of padding on each size, up to the largest image size in the batch. Otherwise, all images are + padded to the largest image size in the batch. + random_padding (`bool`, *optional*, defaults to `self.random_padding`): + Whether to use random padding when padding the image. If `True`, each image in the batch with be padded + with a random amount of padding on each side up to the size of the largest image in the batch. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image pixel values. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use for normalization. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use for normalization. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: defaults to the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + do_resize = do_resize if do_resize is not None else self.do_resize + size = size if size is not None else self.size + if isinstance(size, (tuple, list)): + # Previous feature extractor had size in (width, height) format + size = size[::-1] + size = get_size_dict(size) + resample = resample if resample is not None else self.resample + do_thumbnail = do_thumbnail if do_thumbnail is not None else self.do_thumbnail + do_align_long_axis = do_align_long_axis if do_align_long_axis is not None else self.do_align_long_axis + do_pad = do_pad if do_pad is not None else self.do_pad + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + + images = make_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_pad=do_pad, + size_divisibility=size, # There is no pad divisibility in this processor, but pad requires the size arg. + do_resize=do_resize, + size=size, + resample=resample, + ) + + images = [convert_to_rgb(image) for image in images] + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if do_rescale and is_scaled_image(images[0]): + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + if do_align_long_axis: + images = [self.align_long_axis(image, size=size, input_data_format=input_data_format) for image in images] + + if do_resize: + images = [ + self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format) + for image in images + ] + + if do_thumbnail: + images = [self.thumbnail(image=image, size=size, input_data_format=input_data_format) for image in images] + + if do_pad: + images = [ + self.pad_image( + image=image, size=size, random_padding=random_padding, input_data_format=input_data_format + ) + for image in images + ] + + if do_rescale: + images = [ + self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + for image in images + ] + + if do_normalize: + images = [ + self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) + for image in images + ] + + images = [ + to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images + ] + + data = {"pixel_values": images} + return BatchFeature(data=data, tensor_type=return_tensors) + + +__all__ = ["DonutImageProcessor"] diff --git a/docs/transformers/build/lib/transformers/models/donut/image_processing_donut_fast.py b/docs/transformers/build/lib/transformers/models/donut/image_processing_donut_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..be83b5cc5c6bd1b83d23ff1df9a853fbe023a74e --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/donut/image_processing_donut_fast.py @@ -0,0 +1,289 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for Donut.""" + +from typing import Optional, Union + +from ...image_processing_utils_fast import ( + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS, + BaseImageProcessorFast, + BatchFeature, + DefaultFastImageProcessorKwargs, +) +from ...image_transforms import group_images_by_shape, reorder_images +from ...image_utils import ( + IMAGENET_STANDARD_MEAN, + IMAGENET_STANDARD_STD, + ImageInput, + PILImageResampling, + SizeDict, +) +from ...processing_utils import Unpack +from ...utils import ( + TensorType, + add_start_docstrings, + is_torch_available, + is_torchvision_available, + is_torchvision_v2_available, + logging, +) + + +logger = logging.get_logger(__name__) + +if is_torch_available(): + import torch + +if is_torchvision_available(): + if is_torchvision_v2_available(): + from torchvision.transforms.v2 import functional as F + else: + from torchvision.transforms import functional as F + + +class DonutFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): + do_thumbnail: Optional[bool] + do_align_long_axis: Optional[bool] + do_pad: Optional[bool] + + +@add_start_docstrings( + "Constructs a fast Donut image processor.", + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, + """ + do_thumbnail (`bool`, *optional*, defaults to `self.do_thumbnail`): + Whether to resize the image using thumbnail method. + do_align_long_axis (`bool`, *optional*, defaults to `self.do_align_long_axis`): + Whether to align the long axis of the image with the long axis of `size` by rotating by 90 degrees. + do_pad (`bool`, *optional*, defaults to `self.do_pad`): + Whether to pad the image. If `random_padding` is set to `True`, each image is padded with a random + amount of padding on each size, up to the largest image size in the batch. Otherwise, all images are + padded to the largest image size in the batch. + """, +) +class DonutImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.BILINEAR + image_mean = IMAGENET_STANDARD_MEAN + image_std = IMAGENET_STANDARD_STD + size = {"height": 2560, "width": 1920} + do_resize = True + do_rescale = True + do_normalize = True + do_thumbnail = True + do_align_long_axis = False + do_pad = True + valid_kwargs = DonutFastImageProcessorKwargs + + def __init__(self, **kwargs: Unpack[DonutFastImageProcessorKwargs]): + size = kwargs.pop("size", None) + if isinstance(size, (tuple, list)): + size = size[::-1] + kwargs["size"] = size + super().__init__(**kwargs) + + @add_start_docstrings( + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS, + """ + do_thumbnail (`bool`, *optional*, defaults to `self.do_thumbnail`): + Whether to resize the image using thumbnail method. + do_align_long_axis (`bool`, *optional*, defaults to `self.do_align_long_axis`): + Whether to align the long axis of the image with the long axis of `size` by rotating by 90 degrees. + do_pad (`bool`, *optional*, defaults to `self.do_pad`): + Whether to pad the image. If `random_padding` is set to `True`, each image is padded with a random + amount of padding on each size, up to the largest image size in the batch. Otherwise, all images are + padded to the largest image size in the batch. + """, + ) + def preprocess(self, images: ImageInput, **kwargs: Unpack[DonutFastImageProcessorKwargs]) -> BatchFeature: + if "size" in kwargs: + size = kwargs.pop("size") + if isinstance(size, (tuple, list)): + size = size[::-1] + kwargs["size"] = size + return super().preprocess(images, **kwargs) + + def align_long_axis( + self, + image: "torch.Tensor", + size: SizeDict, + ) -> "torch.Tensor": + """ + Align the long axis of the image to the longest axis of the specified size. + + Args: + image (`torch.Tensor`): + The image to be aligned. + size (`Dict[str, int]`): + The size `{"height": h, "width": w}` to align the long axis to. + + Returns: + `torch.Tensor`: The aligned image. + """ + input_height, input_width = image.shape[-2:] + output_height, output_width = size.height, size.width + + if (output_width < output_height and input_width > input_height) or ( + output_width > output_height and input_width < input_height + ): + height_dim, width_dim = image.dim() - 2, image.dim() - 1 + image = torch.rot90(image, 3, dims=[height_dim, width_dim]) + + return image + + def pad_image( + self, + image: "torch.Tensor", + size: SizeDict, + random_padding: bool = False, + ) -> "torch.Tensor": + """ + Pad the image to the specified size. + + Args: + image (`torch.Tensor`): + The image to be padded. + size (`Dict[str, int]`): + The size `{"height": h, "width": w}` to pad the image to. + random_padding (`bool`, *optional*, defaults to `False`): + Whether to use random padding or not. + data_format (`str` or `ChannelDimension`, *optional*): + The data format of the output image. If unset, the same format as the input image is used. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + output_height, output_width = size.height, size.width + input_height, input_width = image.shape[-2:] + + delta_width = output_width - input_width + delta_height = output_height - input_height + + if random_padding: + pad_top = torch.random.randint(low=0, high=delta_height + 1) + pad_left = torch.random.randint(low=0, high=delta_width + 1) + else: + pad_top = delta_height // 2 + pad_left = delta_width // 2 + + pad_bottom = delta_height - pad_top + pad_right = delta_width - pad_left + + padding = (pad_left, pad_top, pad_right, pad_bottom) + return F.pad(image, padding) + + def pad(self, *args, **kwargs): + logger.info("pad is deprecated and will be removed in version 4.27. Please use pad_image instead.") + return self.pad_image(*args, **kwargs) + + def thumbnail( + self, + image: "torch.Tensor", + size: SizeDict, + ) -> "torch.Tensor": + """ + Resize the image to make a thumbnail. The image is resized so that no dimension is larger than any + corresponding dimension of the specified size. + + Args: + image (`torch.Tensor`): + The image to be resized. + size (`Dict[str, int]`): + The size `{"height": h, "width": w}` to resize the image to. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + The resampling filter to use. + data_format (`Optional[Union[str, ChannelDimension]]`, *optional*): + The data format of the output image. If unset, the same format as the input image is used. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + input_height, input_width = image.shape[-2:] + output_height, output_width = size.height, size.width + + # We always resize to the smallest of either the input or output size. + height = min(input_height, output_height) + width = min(input_width, output_width) + + if height == input_height and width == input_width: + return image + + if input_height > input_width: + width = int(input_width * height / input_height) + elif input_width > input_height: + height = int(input_height * width / input_width) + + return self.resize( + image, + size=SizeDict(width=width, height=height), + interpolation=F.InterpolationMode.BICUBIC, + ) + + def _preprocess( + self, + images: list["torch.Tensor"], + do_resize: bool, + do_thumbnail: bool, + do_align_long_axis: bool, + do_pad: bool, + size: SizeDict, + interpolation: Optional["F.InterpolationMode"], + do_center_crop: bool, + crop_size: SizeDict, + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + image_mean: Optional[Union[float, list[float]]], + image_std: Optional[Union[float, list[float]]], + return_tensors: Optional[Union[str, TensorType]], + **kwargs, + ) -> BatchFeature: + # Group images by size for batched resizing + grouped_images, grouped_images_index = group_images_by_shape(images) + resized_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_align_long_axis: + stacked_images = self.align_long_axis(image=stacked_images, size=size) + if do_resize: + shortest_edge = min(size.height, size.width) + stacked_images = self.resize( + image=stacked_images, size=SizeDict(shortest_edge=shortest_edge), interpolation=interpolation + ) + if do_thumbnail: + stacked_images = self.thumbnail(image=stacked_images, size=size) + if do_pad: + stacked_images = self.pad_image(image=stacked_images, size=size, random_padding=False) + + resized_images_grouped[shape] = stacked_images + resized_images = reorder_images(resized_images_grouped, grouped_images_index) + + # Group images by size for further processing + # Needed in case do_resize is False, or resize returns images with different sizes + grouped_images, grouped_images_index = group_images_by_shape(resized_images) + processed_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_center_crop: + stacked_images = self.center_crop(stacked_images, crop_size) + # Fused rescale and normalize + stacked_images = self.rescale_and_normalize( + stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std + ) + processed_images_grouped[shape] = stacked_images + + processed_images = reorder_images(processed_images_grouped, grouped_images_index) + processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images + + return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) + + +__all__ = ["DonutImageProcessorFast"] diff --git a/docs/transformers/build/lib/transformers/models/donut/modeling_donut_swin.py b/docs/transformers/build/lib/transformers/models/donut/modeling_donut_swin.py new file mode 100644 index 0000000000000000000000000000000000000000..43d8f3f10798be368f95c285d932d937c7e39178 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/donut/modeling_donut_swin.py @@ -0,0 +1,1145 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Donut Swin Transformer model. + +This implementation is identical to a regular Swin Transformer, without final layer norm on top of the final hidden +states.""" + +import collections.abc +import math +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn + +from ...activations import ACT2FN +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer +from ...utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + torch_int, +) +from .configuration_donut_swin import DonutSwinConfig + + +logger = logging.get_logger(__name__) + +# General docstring +_CONFIG_FOR_DOC = "DonutSwinConfig" + +# Base docstring +_CHECKPOINT_FOR_DOC = "https://huggingface.co/naver-clova-ix/donut-base" +_EXPECTED_OUTPUT_SHAPE = [1, 49, 768] + +# Image classification docstring +_IMAGE_CLASS_CHECKPOINT = "eljandoubi/donut-base-encoder" +_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat" + + +@dataclass +# Copied from transformers.models.swin.modeling_swin.SwinEncoderOutput with Swin->DonutSwin +class DonutSwinEncoderOutput(ModelOutput): + """ + DonutSwin encoder's outputs, with potential hidden states and attentions. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, hidden_size, height, width)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to + include the spatial dimensions. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +# Copied from transformers.models.swin.modeling_swin.SwinModelOutput with Swin->DonutSwin +class DonutSwinModelOutput(ModelOutput): + """ + DonutSwin model's outputs that also contains a pooling of the last hidden states. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*, returned when `add_pooling_layer=True` is passed): + Average pooling of the last layer hidden-state. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, hidden_size, height, width)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to + include the spatial dimensions. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + pooler_output: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +# Copied from transformers.models.swin.modeling_swin.SwinImageClassifierOutput with Swin->DonutSwin +class DonutSwinImageClassifierOutput(ModelOutput): + """ + DonutSwin outputs for image classification. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, hidden_size, height, width)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to + include the spatial dimensions. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + + +# Copied from transformers.models.swin.modeling_swin.window_partition +def window_partition(input_feature, window_size): + """ + Partitions the given input into windows. + """ + batch_size, height, width, num_channels = input_feature.shape + input_feature = input_feature.view( + batch_size, height // window_size, window_size, width // window_size, window_size, num_channels + ) + windows = input_feature.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels) + return windows + + +# Copied from transformers.models.swin.modeling_swin.window_reverse +def window_reverse(windows, window_size, height, width): + """ + Merges windows to produce higher resolution features. + """ + num_channels = windows.shape[-1] + windows = windows.view(-1, height // window_size, width // window_size, window_size, window_size, num_channels) + windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, height, width, num_channels) + return windows + + +# Copied from transformers.models.swin.modeling_swin.SwinEmbeddings with Swin->DonutSwin +class DonutSwinEmbeddings(nn.Module): + """ + Construct the patch and position embeddings. Optionally, also the mask token. + """ + + def __init__(self, config, use_mask_token=False): + super().__init__() + + self.patch_embeddings = DonutSwinPatchEmbeddings(config) + num_patches = self.patch_embeddings.num_patches + self.patch_grid = self.patch_embeddings.grid_size + self.mask_token = nn.Parameter(torch.zeros(1, 1, config.embed_dim)) if use_mask_token else None + + if config.use_absolute_embeddings: + self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.embed_dim)) + else: + self.position_embeddings = None + + self.norm = nn.LayerNorm(config.embed_dim) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.patch_size = config.patch_size + self.config = config + + # Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing. + + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 + """ + + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embeddings.shape[1] - 1 + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: + return self.position_embeddings + + class_pos_embed = self.position_embeddings[:, :1] + patch_pos_embed = self.position_embeddings[:, 1:] + + dim = embeddings.shape[-1] + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(new_height, new_width), + mode="bicubic", + align_corners=False, + ) + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) + + def forward( + self, + pixel_values: Optional[torch.FloatTensor], + bool_masked_pos: Optional[torch.BoolTensor] = None, + interpolate_pos_encoding: bool = False, + ) -> Tuple[torch.Tensor]: + _, num_channels, height, width = pixel_values.shape + embeddings, output_dimensions = self.patch_embeddings(pixel_values) + embeddings = self.norm(embeddings) + batch_size, seq_len, _ = embeddings.size() + + if bool_masked_pos is not None: + mask_tokens = self.mask_token.expand(batch_size, seq_len, -1) + # replace the masked visual tokens by mask_tokens + mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens) + embeddings = embeddings * (1.0 - mask) + mask_tokens * mask + + if self.position_embeddings is not None: + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embeddings + + embeddings = self.dropout(embeddings) + + return embeddings, output_dimensions + + +# Copied from transformers.models.swin.modeling_swin.SwinPatchEmbeddings with Swin->DonutSwin +class DonutSwinPatchEmbeddings(nn.Module): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__(self, config): + super().__init__() + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.embed_dim + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + self.grid_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1]) + + self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) + + def maybe_pad(self, pixel_values, height, width): + if width % self.patch_size[1] != 0: + pad_values = (0, self.patch_size[1] - width % self.patch_size[1]) + pixel_values = nn.functional.pad(pixel_values, pad_values) + if height % self.patch_size[0] != 0: + pad_values = (0, 0, 0, self.patch_size[0] - height % self.patch_size[0]) + pixel_values = nn.functional.pad(pixel_values, pad_values) + return pixel_values + + def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]: + _, num_channels, height, width = pixel_values.shape + # pad the input to be divisible by self.patch_size, if needed + pixel_values = self.maybe_pad(pixel_values, height, width) + embeddings = self.projection(pixel_values) + _, _, height, width = embeddings.shape + output_dimensions = (height, width) + embeddings = embeddings.flatten(2).transpose(1, 2) + + return embeddings, output_dimensions + + +# Copied from transformers.models.swin.modeling_swin.SwinPatchMerging +class DonutSwinPatchMerging(nn.Module): + """ + Patch Merging Layer. + + Args: + input_resolution (`Tuple[int]`): + Resolution of input feature. + dim (`int`): + Number of input channels. + norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`): + Normalization layer class. + """ + + def __init__(self, input_resolution: Tuple[int], dim: int, norm_layer: nn.Module = nn.LayerNorm) -> None: + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def maybe_pad(self, input_feature, height, width): + should_pad = (height % 2 == 1) or (width % 2 == 1) + if should_pad: + pad_values = (0, 0, 0, width % 2, 0, height % 2) + input_feature = nn.functional.pad(input_feature, pad_values) + + return input_feature + + def forward(self, input_feature: torch.Tensor, input_dimensions: Tuple[int, int]) -> torch.Tensor: + height, width = input_dimensions + # `dim` is height * width + batch_size, dim, num_channels = input_feature.shape + + input_feature = input_feature.view(batch_size, height, width, num_channels) + # pad input to be disible by width and height, if needed + input_feature = self.maybe_pad(input_feature, height, width) + # [batch_size, height/2, width/2, num_channels] + input_feature_0 = input_feature[:, 0::2, 0::2, :] + # [batch_size, height/2, width/2, num_channels] + input_feature_1 = input_feature[:, 1::2, 0::2, :] + # [batch_size, height/2, width/2, num_channels] + input_feature_2 = input_feature[:, 0::2, 1::2, :] + # [batch_size, height/2, width/2, num_channels] + input_feature_3 = input_feature[:, 1::2, 1::2, :] + # batch_size height/2 width/2 4*num_channels + input_feature = torch.cat([input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1) + input_feature = input_feature.view(batch_size, -1, 4 * num_channels) # batch_size height/2*width/2 4*C + + input_feature = self.norm(input_feature) + input_feature = self.reduction(input_feature) + + return input_feature + + +# Copied from transformers.models.beit.modeling_beit.drop_path +def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, + however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the + layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the + argument. + """ + if drop_prob == 0.0 or not training: + return input + keep_prob = 1 - drop_prob + shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) + random_tensor.floor_() # binarize + output = input.div(keep_prob) * random_tensor + return output + + +# Copied from transformers.models.swin.modeling_swin.SwinDropPath +class DonutSwinDropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: Optional[float] = None) -> None: + super().__init__() + self.drop_prob = drop_prob + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return drop_path(hidden_states, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return "p={}".format(self.drop_prob) + + +# Copied from transformers.models.swin.modeling_swin.SwinSelfAttention with Swin->DonutSwin +class DonutSwinSelfAttention(nn.Module): + def __init__(self, config, dim, num_heads, window_size): + super().__init__() + if dim % num_heads != 0: + raise ValueError( + f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})" + ) + + self.num_attention_heads = num_heads + self.attention_head_size = int(dim / num_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.window_size = ( + window_size if isinstance(window_size, collections.abc.Iterable) else (window_size, window_size) + ) + + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads) + ) + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij")) + coords_flatten = torch.flatten(coords, 1) + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + relative_coords[:, :, 0] += self.window_size[0] - 1 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) + self.register_buffer("relative_position_index", relative_position_index) + + self.query = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias) + self.key = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias) + self.value = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + batch_size, dim, num_channels = hidden_states.shape + mixed_query_layer = self.query(hidden_states) + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)] + relative_position_bias = relative_position_bias.view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1 + ) + + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() + attention_scores = attention_scores + relative_position_bias.unsqueeze(0) + + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in DonutSwinModel forward() function) + mask_shape = attention_mask.shape[0] + attention_scores = attention_scores.view( + batch_size // mask_shape, mask_shape, self.num_attention_heads, dim, dim + ) + attention_scores = attention_scores + attention_mask.unsqueeze(1).unsqueeze(0) + attention_scores = attention_scores.view(-1, self.num_attention_heads, dim, dim) + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +# Copied from transformers.models.swin.modeling_swin.SwinSelfOutput +class DonutSwinSelfOutput(nn.Module): + def __init__(self, config, dim): + super().__init__() + self.dense = nn.Linear(dim, dim) + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + return hidden_states + + +# Copied from transformers.models.swin.modeling_swin.SwinAttention with Swin->DonutSwin +class DonutSwinAttention(nn.Module): + def __init__(self, config, dim, num_heads, window_size): + super().__init__() + self.self = DonutSwinSelfAttention(config, dim, num_heads, window_size) + self.output = DonutSwinSelfOutput(config, dim) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + self_outputs = self.self(hidden_states, attention_mask, head_mask, output_attentions) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.swin.modeling_swin.SwinIntermediate +class DonutSwinIntermediate(nn.Module): + def __init__(self, config, dim): + super().__init__() + self.dense = nn.Linear(dim, int(config.mlp_ratio * dim)) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +# Copied from transformers.models.swin.modeling_swin.SwinOutput +class DonutSwinOutput(nn.Module): + def __init__(self, config, dim): + super().__init__() + self.dense = nn.Linear(int(config.mlp_ratio * dim), dim) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +# Copied from transformers.models.swin.modeling_swin.SwinLayer with Swin->DonutSwin +class DonutSwinLayer(nn.Module): + def __init__(self, config, dim, input_resolution, num_heads, drop_path_rate=0.0, shift_size=0): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.shift_size = shift_size + self.window_size = config.window_size + self.input_resolution = input_resolution + self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps) + self.attention = DonutSwinAttention(config, dim, num_heads, window_size=self.window_size) + self.drop_path = DonutSwinDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() + self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps) + self.intermediate = DonutSwinIntermediate(config, dim) + self.output = DonutSwinOutput(config, dim) + + def set_shift_and_window_size(self, input_resolution): + if min(input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = torch_int(0) + self.window_size = ( + torch.min(torch.tensor(input_resolution)) if torch.jit.is_tracing() else min(input_resolution) + ) + + def get_attn_mask(self, height, width, dtype, device): + if self.shift_size > 0: + # calculate attention mask for SW-MSA + img_mask = torch.zeros((1, height, width, 1), dtype=dtype, device=device) + height_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + width_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + count = 0 + for height_slice in height_slices: + for width_slice in width_slices: + img_mask[:, height_slice, width_slice, :] = count + count += 1 + + mask_windows = window_partition(img_mask, self.window_size) + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + return attn_mask + + def maybe_pad(self, hidden_states, height, width): + pad_right = (self.window_size - width % self.window_size) % self.window_size + pad_bottom = (self.window_size - height % self.window_size) % self.window_size + pad_values = (0, 0, 0, pad_right, 0, pad_bottom) + hidden_states = nn.functional.pad(hidden_states, pad_values) + return hidden_states, pad_values + + def forward( + self, + hidden_states: torch.Tensor, + input_dimensions: Tuple[int, int], + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + always_partition: Optional[bool] = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if not always_partition: + self.set_shift_and_window_size(input_dimensions) + else: + pass + height, width = input_dimensions + batch_size, _, channels = hidden_states.size() + shortcut = hidden_states + + hidden_states = self.layernorm_before(hidden_states) + + hidden_states = hidden_states.view(batch_size, height, width, channels) + + # pad hidden_states to multiples of window size + hidden_states, pad_values = self.maybe_pad(hidden_states, height, width) + + _, height_pad, width_pad, _ = hidden_states.shape + # cyclic shift + if self.shift_size > 0: + shifted_hidden_states = torch.roll(hidden_states, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_hidden_states = hidden_states + + # partition windows + hidden_states_windows = window_partition(shifted_hidden_states, self.window_size) + hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels) + attn_mask = self.get_attn_mask( + height_pad, width_pad, dtype=hidden_states.dtype, device=hidden_states_windows.device + ) + + attention_outputs = self.attention( + hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions + ) + + attention_output = attention_outputs[0] + + attention_windows = attention_output.view(-1, self.window_size, self.window_size, channels) + shifted_windows = window_reverse(attention_windows, self.window_size, height_pad, width_pad) + + # reverse cyclic shift + if self.shift_size > 0: + attention_windows = torch.roll(shifted_windows, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + attention_windows = shifted_windows + + was_padded = pad_values[3] > 0 or pad_values[5] > 0 + if was_padded: + attention_windows = attention_windows[:, :height, :width, :].contiguous() + + attention_windows = attention_windows.view(batch_size, height * width, channels) + + hidden_states = shortcut + self.drop_path(attention_windows) + + layer_output = self.layernorm_after(hidden_states) + layer_output = self.intermediate(layer_output) + layer_output = hidden_states + self.output(layer_output) + + layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,) + return layer_outputs + + +# Copied from transformers.models.swin.modeling_swin.SwinStage with Swin->DonutSwin +class DonutSwinStage(nn.Module): + def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, downsample): + super().__init__() + self.config = config + self.dim = dim + self.blocks = nn.ModuleList( + [ + DonutSwinLayer( + config=config, + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + drop_path_rate=drop_path[i], + shift_size=0 if (i % 2 == 0) else config.window_size // 2, + ) + for i in range(depth) + ] + ) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=nn.LayerNorm) + else: + self.downsample = None + + self.pointing = False + + def forward( + self, + hidden_states: torch.Tensor, + input_dimensions: Tuple[int, int], + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + always_partition: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + height, width = input_dimensions + for i, layer_module in enumerate(self.blocks): + layer_head_mask = head_mask[i] if head_mask is not None else None + + layer_outputs = layer_module( + hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition + ) + + hidden_states = layer_outputs[0] + + hidden_states_before_downsampling = hidden_states + if self.downsample is not None: + height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2 + output_dimensions = (height, width, height_downsampled, width_downsampled) + hidden_states = self.downsample(hidden_states_before_downsampling, input_dimensions) + else: + output_dimensions = (height, width, height, width) + + stage_outputs = (hidden_states, hidden_states_before_downsampling, output_dimensions) + + if output_attentions: + stage_outputs += layer_outputs[1:] + return stage_outputs + + +# Copied from transformers.models.swin.modeling_swin.SwinEncoder with Swin->DonutSwin +class DonutSwinEncoder(nn.Module): + def __init__(self, config, grid_size): + super().__init__() + self.num_layers = len(config.depths) + self.config = config + dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")] + self.layers = nn.ModuleList( + [ + DonutSwinStage( + config=config, + dim=int(config.embed_dim * 2**i_layer), + input_resolution=(grid_size[0] // (2**i_layer), grid_size[1] // (2**i_layer)), + depth=config.depths[i_layer], + num_heads=config.num_heads[i_layer], + drop_path=dpr[sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])], + downsample=DonutSwinPatchMerging if (i_layer < self.num_layers - 1) else None, + ) + for i_layer in range(self.num_layers) + ] + ) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + input_dimensions: Tuple[int, int], + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + output_hidden_states_before_downsampling: Optional[bool] = False, + always_partition: Optional[bool] = False, + return_dict: Optional[bool] = True, + ) -> Union[Tuple, DonutSwinEncoderOutput]: + all_hidden_states = () if output_hidden_states else None + all_reshaped_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + if output_hidden_states: + batch_size, _, hidden_size = hidden_states.shape + # rearrange b (h w) c -> b c h w + reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size) + reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2) + all_hidden_states += (hidden_states,) + all_reshaped_hidden_states += (reshaped_hidden_state,) + + for i, layer_module in enumerate(self.layers): + layer_head_mask = head_mask[i] if head_mask is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + input_dimensions, + layer_head_mask, + output_attentions, + always_partition, + ) + else: + layer_outputs = layer_module( + hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition + ) + + hidden_states = layer_outputs[0] + hidden_states_before_downsampling = layer_outputs[1] + output_dimensions = layer_outputs[2] + + input_dimensions = (output_dimensions[-2], output_dimensions[-1]) + + if output_hidden_states and output_hidden_states_before_downsampling: + batch_size, _, hidden_size = hidden_states_before_downsampling.shape + # rearrange b (h w) c -> b c h w + # here we use the original (not downsampled) height and width + reshaped_hidden_state = hidden_states_before_downsampling.view( + batch_size, *(output_dimensions[0], output_dimensions[1]), hidden_size + ) + reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2) + all_hidden_states += (hidden_states_before_downsampling,) + all_reshaped_hidden_states += (reshaped_hidden_state,) + elif output_hidden_states and not output_hidden_states_before_downsampling: + batch_size, _, hidden_size = hidden_states.shape + # rearrange b (h w) c -> b c h w + reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size) + reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2) + all_hidden_states += (hidden_states,) + all_reshaped_hidden_states += (reshaped_hidden_state,) + + if output_attentions: + all_self_attentions += layer_outputs[3:] + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + + return DonutSwinEncoderOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + reshaped_hidden_states=all_reshaped_hidden_states, + ) + + +# Copied from transformers.models.swin.modeling_swin.SwinPreTrainedModel with Swin->DonutSwin,swin->donut +class DonutSwinPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = DonutSwinConfig + base_model_prefix = "donut" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + _no_split_modules = ["DonutSwinStage"] + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, DonutSwinEmbeddings): + if module.mask_token is not None: + module.mask_token.data.zero_() + if module.position_embeddings is not None: + module.position_embeddings.data.zero_() + elif isinstance(module, DonutSwinSelfAttention): + module.relative_position_bias_table.data.zero_() + + +SWIN_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use + it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`DonutSwinConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +SWIN_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See + [`DonutImageProcessor.__call__`] for details. + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): + Whether to interpolate the pre-trained position encodings. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare Donut Swin Model transformer outputting raw hidden-states without any specific head on top.", + SWIN_START_DOCSTRING, +) +class DonutSwinModel(DonutSwinPreTrainedModel): + def __init__(self, config, add_pooling_layer=True, use_mask_token=False): + super().__init__(config) + self.config = config + self.num_layers = len(config.depths) + self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1)) + + self.embeddings = DonutSwinEmbeddings(config, use_mask_token=use_mask_token) + self.encoder = DonutSwinEncoder(config, self.embeddings.patch_grid) + + self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.patch_embeddings + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=DonutSwinModelOutput, + config_class=_CONFIG_FOR_DOC, + modality="vision", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, DonutSwinModelOutput]: + r""" + bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`): + Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, len(self.config.depths)) + + embedding_output, input_dimensions = self.embeddings( + pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding + ) + + encoder_outputs = self.encoder( + embedding_output, + input_dimensions, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = encoder_outputs[0] + + pooled_output = None + if self.pooler is not None: + pooled_output = self.pooler(sequence_output.transpose(1, 2)) + pooled_output = torch.flatten(pooled_output, 1) + + if not return_dict: + output = (sequence_output, pooled_output) + encoder_outputs[1:] + + return output + + return DonutSwinModelOutput( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + reshaped_hidden_states=encoder_outputs.reshaped_hidden_states, + ) + + +@add_start_docstrings( + """ + DonutSwin Model transformer with an image classification head on top (a linear layer on top of the final hidden state of + the [CLS] token) e.g. for ImageNet. + + + + Note that it's possible to fine-tune DonutSwin on higher resolution images than the ones it has been trained on, by + setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained + position embeddings to the higher resolution. + + + """, + SWIN_START_DOCSTRING, +) +# Copied from transformers.models.swin.modeling_swin.SwinForImageClassification with Swin->DonutSwin,swin->donut +class DonutSwinForImageClassification(DonutSwinPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.num_labels = config.num_labels + self.donut = DonutSwinModel(config) + + # Classifier head + self.classifier = ( + nn.Linear(self.donut.num_features, config.num_labels) if config.num_labels > 0 else nn.Identity() + ) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_IMAGE_CLASS_CHECKPOINT, + output_type=DonutSwinImageClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, + ) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, DonutSwinImageClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.donut( + pixel_values, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, pooled_logits=logits, config=self.config) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return DonutSwinImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + reshaped_hidden_states=outputs.reshaped_hidden_states, + ) + + +__all__ = ["DonutSwinModel", "DonutSwinPreTrainedModel", "DonutSwinForImageClassification"] diff --git a/docs/transformers/build/lib/transformers/models/donut/processing_donut.py b/docs/transformers/build/lib/transformers/models/donut/processing_donut.py new file mode 100644 index 0000000000000000000000000000000000000000..689aa5122f8feca2862f71f02eaad4519d6d7cbf --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/donut/processing_donut.py @@ -0,0 +1,220 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Processor class for Donut. +""" + +import re +import warnings +from contextlib import contextmanager +from typing import List, Optional, Union + +from ...image_utils import ImageInput +from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack +from ...tokenization_utils_base import PreTokenizedInput, TextInput +from ...utils import logging + + +class DonutProcessorKwargs(ProcessingKwargs, total=False): + _defaults = {} + + +logger = logging.get_logger(__name__) + + +class DonutProcessor(ProcessorMixin): + r""" + Constructs a Donut processor which wraps a Donut image processor and an XLMRoBERTa tokenizer into a single + processor. + + [`DonutProcessor`] offers all the functionalities of [`DonutImageProcessor`] and + [`XLMRobertaTokenizer`/`XLMRobertaTokenizerFast`]. See the [`~DonutProcessor.__call__`] and + [`~DonutProcessor.decode`] for more information. + + Args: + image_processor ([`DonutImageProcessor`], *optional*): + An instance of [`DonutImageProcessor`]. The image processor is a required input. + tokenizer ([`XLMRobertaTokenizer`/`XLMRobertaTokenizerFast`], *optional*): + An instance of [`XLMRobertaTokenizer`/`XLMRobertaTokenizerFast`]. The tokenizer is a required input. + """ + + attributes = ["image_processor", "tokenizer"] + image_processor_class = "AutoImageProcessor" + tokenizer_class = "AutoTokenizer" + + def __init__(self, image_processor=None, tokenizer=None, **kwargs): + feature_extractor = None + if "feature_extractor" in kwargs: + warnings.warn( + "The `feature_extractor` argument is deprecated and will be removed in v5, use `image_processor`" + " instead.", + FutureWarning, + ) + feature_extractor = kwargs.pop("feature_extractor") + + image_processor = image_processor if image_processor is not None else feature_extractor + if image_processor is None: + raise ValueError("You need to specify an `image_processor`.") + if tokenizer is None: + raise ValueError("You need to specify a `tokenizer`.") + + super().__init__(image_processor, tokenizer) + self.current_processor = self.image_processor + self._in_target_context_manager = False + + def __call__( + self, + images: ImageInput = None, + text: Optional[Union[str, List[str], TextInput, PreTokenizedInput]] = None, + audio=None, + videos=None, + **kwargs: Unpack[DonutProcessorKwargs], + ): + """ + When used in normal mode, this method forwards all its arguments to AutoImageProcessor's + [`~AutoImageProcessor.__call__`] and returns its output. If used in the context + [`~DonutProcessor.as_target_processor`] this method forwards all its arguments to DonutTokenizer's + [`~DonutTokenizer.__call__`]. Please refer to the docstring of the above two methods for more information. + """ + if self._in_target_context_manager: + return self.current_processor(images, text, **kwargs) + + if images is None and text is None: + raise ValueError("You need to specify either an `images` or `text` input to process.") + + output_kwargs = self._merge_kwargs( + DonutProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + + if images is not None: + inputs = self.image_processor(images, **output_kwargs["images_kwargs"]) + if text is not None: + if images is not None: + output_kwargs["text_kwargs"].setdefault("add_special_tokens", False) + encodings = self.tokenizer(text, **output_kwargs["text_kwargs"]) + + if text is None: + return inputs + elif images is None: + return encodings + else: + inputs["labels"] = encodings["input_ids"] # for BC + inputs["input_ids"] = encodings["input_ids"] + return inputs + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to DonutTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please refer + to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to DonutTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to the + docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + @contextmanager + def as_target_processor(self): + """ + Temporarily sets the tokenizer for processing the input. Useful for encoding the labels when fine-tuning TrOCR. + """ + warnings.warn( + "`as_target_processor` is deprecated and will be removed in v5 of Transformers. You can process your " + "labels by using the argument `text` of the regular `__call__` method (either in the same call as " + "your images inputs, or in a separate call." + ) + self._in_target_context_manager = True + self.current_processor = self.tokenizer + yield + self.current_processor = self.image_processor + self._in_target_context_manager = False + + def token2json(self, tokens, is_inner_value=False, added_vocab=None): + """ + Convert a (generated) token sequence into an ordered JSON format. + """ + if added_vocab is None: + added_vocab = self.tokenizer.get_added_vocab() + + output = {} + + while tokens: + start_token = re.search(r"", tokens, re.IGNORECASE) + if start_token is None: + break + key = start_token.group(1) + key_escaped = re.escape(key) + + end_token = re.search(rf"", tokens, re.IGNORECASE) + start_token = start_token.group() + if end_token is None: + tokens = tokens.replace(start_token, "") + else: + end_token = end_token.group() + start_token_escaped = re.escape(start_token) + end_token_escaped = re.escape(end_token) + content = re.search( + f"{start_token_escaped}(.*?){end_token_escaped}", tokens, re.IGNORECASE | re.DOTALL + ) + if content is not None: + content = content.group(1).strip() + if r""): + leaf = leaf.strip() + if leaf in added_vocab and leaf[0] == "<" and leaf[-2:] == "/>": + leaf = leaf[1:-2] # for categorical special tokens + output[key].append(leaf) + if len(output[key]) == 1: + output[key] = output[key][0] + + tokens = tokens[tokens.find(end_token) + len(end_token) :].strip() + if tokens[:6] == r"": # non-leaf nodes + return [output] + self.token2json(tokens[6:], is_inner_value=True, added_vocab=added_vocab) + + if len(output): + return [output] if is_inner_value else output + else: + return [] if is_inner_value else {"text_sequence": tokens} + + @property + def feature_extractor_class(self): + warnings.warn( + "`feature_extractor_class` is deprecated and will be removed in v5. Use `image_processor_class` instead.", + FutureWarning, + ) + return self.image_processor_class + + @property + def feature_extractor(self): + warnings.warn( + "`feature_extractor` is deprecated and will be removed in v5. Use `image_processor` instead.", + FutureWarning, + ) + return self.image_processor + + +__all__ = ["DonutProcessor"] diff --git a/docs/transformers/build/lib/transformers/models/dpr/__init__.py b/docs/transformers/build/lib/transformers/models/dpr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9aeadbeaf416575570c280a3e15a52422a007103 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/dpr/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_dpr import * + from .modeling_dpr import * + from .modeling_tf_dpr import * + from .tokenization_dpr import * + from .tokenization_dpr_fast import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/build/lib/transformers/models/dpr/configuration_dpr.py b/docs/transformers/build/lib/transformers/models/dpr/configuration_dpr.py new file mode 100644 index 0000000000000000000000000000000000000000..7e4b97c97a4f7f1acedb0f2e23be4fb5dd770a99 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/dpr/configuration_dpr.py @@ -0,0 +1,131 @@ +# coding=utf-8 +# Copyright 2010, DPR authors, The Hugging Face Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""DPR model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class DPRConfig(PretrainedConfig): + r""" + [`DPRConfig`] is the configuration class to store the configuration of a *DPRModel*. + + This is the configuration class to store the configuration of a [`DPRContextEncoder`], [`DPRQuestionEncoder`], or a + [`DPRReader`]. It is used to instantiate the components of the DPR model according to the specified arguments, + defining the model component architectures. Instantiating a configuration with the defaults will yield a similar + configuration to that of the DPRContextEncoder + [facebook/dpr-ctx_encoder-single-nq-base](https://huggingface.co/facebook/dpr-ctx_encoder-single-nq-base) + architecture. + + This class is a subclass of [`BertConfig`]. Please check the superclass for the documentation of all kwargs. + + Args: + vocab_size (`int`, *optional*, defaults to 30522): + Vocabulary size of the DPR model. Defines the different tokens that can be represented by the *inputs_ids* + passed to the forward method of [`BertModel`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + type_vocab_size (`int`, *optional*, defaults to 2): + The vocabulary size of the *token_type_ids* passed into [`BertModel`]. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + pad_token_id (`int`, *optional*, defaults to 0): + Padding token id. + position_embedding_type (`str`, *optional*, defaults to `"absolute"`): + Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For + positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to + [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155). + For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models + with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658). + projection_dim (`int`, *optional*, defaults to 0): + Dimension of the projection for the context and question encoders. If it is set to zero (default), then no + projection is done. + + Example: + + ```python + >>> from transformers import DPRConfig, DPRContextEncoder + + >>> # Initializing a DPR facebook/dpr-ctx_encoder-single-nq-base style configuration + >>> configuration = DPRConfig() + + >>> # Initializing a model (with random weights) from the facebook/dpr-ctx_encoder-single-nq-base style configuration + >>> model = DPRContextEncoder(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "dpr" + + def __init__( + self, + vocab_size=30522, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + layer_norm_eps=1e-12, + pad_token_id=0, + position_embedding_type="absolute", + projection_dim: int = 0, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.projection_dim = projection_dim + self.position_embedding_type = position_embedding_type + + +__all__ = ["DPRConfig"] diff --git a/docs/transformers/build/lib/transformers/models/dpr/convert_dpr_original_checkpoint_to_pytorch.py b/docs/transformers/build/lib/transformers/models/dpr/convert_dpr_original_checkpoint_to_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..5151c0972a7ed72c47d125400b918aba3a0d3c0d --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/dpr/convert_dpr_original_checkpoint_to_pytorch.py @@ -0,0 +1,145 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import collections +from pathlib import Path + +import torch +from torch.serialization import default_restore_location + +from transformers import BertConfig, DPRConfig, DPRContextEncoder, DPRQuestionEncoder, DPRReader + + +CheckpointState = collections.namedtuple( + "CheckpointState", ["model_dict", "optimizer_dict", "scheduler_dict", "offset", "epoch", "encoder_params"] +) + + +def load_states_from_checkpoint(model_file: str) -> CheckpointState: + print(f"Reading saved model from {model_file}") + state_dict = torch.load( + model_file, map_location=lambda s, l: default_restore_location(s, "cpu"), weights_only=True + ) + return CheckpointState(**state_dict) + + +class DPRState: + def __init__(self, src_file: Path): + self.src_file = src_file + + def load_dpr_model(self): + raise NotImplementedError + + @staticmethod + def from_type(comp_type: str, *args, **kwargs) -> "DPRState": + if comp_type.startswith("c"): + return DPRContextEncoderState(*args, **kwargs) + if comp_type.startswith("q"): + return DPRQuestionEncoderState(*args, **kwargs) + if comp_type.startswith("r"): + return DPRReaderState(*args, **kwargs) + else: + raise ValueError("Component type must be either 'ctx_encoder', 'question_encoder' or 'reader'.") + + +class DPRContextEncoderState(DPRState): + def load_dpr_model(self): + model = DPRContextEncoder(DPRConfig(**BertConfig.get_config_dict("google-bert/bert-base-uncased")[0])) + print(f"Loading DPR biencoder from {self.src_file}") + saved_state = load_states_from_checkpoint(self.src_file) + encoder, prefix = model.ctx_encoder, "ctx_model." + # Fix changes from https://github.com/huggingface/transformers/commit/614fef1691edb806de976756d4948ecbcd0c0ca3 + state_dict = {"bert_model.embeddings.position_ids": model.ctx_encoder.bert_model.embeddings.position_ids} + for key, value in saved_state.model_dict.items(): + if key.startswith(prefix): + key = key[len(prefix) :] + if not key.startswith("encode_proj."): + key = "bert_model." + key + state_dict[key] = value + encoder.load_state_dict(state_dict) + return model + + +class DPRQuestionEncoderState(DPRState): + def load_dpr_model(self): + model = DPRQuestionEncoder(DPRConfig(**BertConfig.get_config_dict("google-bert/bert-base-uncased")[0])) + print(f"Loading DPR biencoder from {self.src_file}") + saved_state = load_states_from_checkpoint(self.src_file) + encoder, prefix = model.question_encoder, "question_model." + # Fix changes from https://github.com/huggingface/transformers/commit/614fef1691edb806de976756d4948ecbcd0c0ca3 + state_dict = {"bert_model.embeddings.position_ids": model.question_encoder.bert_model.embeddings.position_ids} + for key, value in saved_state.model_dict.items(): + if key.startswith(prefix): + key = key[len(prefix) :] + if not key.startswith("encode_proj."): + key = "bert_model." + key + state_dict[key] = value + encoder.load_state_dict(state_dict) + return model + + +class DPRReaderState(DPRState): + def load_dpr_model(self): + model = DPRReader(DPRConfig(**BertConfig.get_config_dict("google-bert/bert-base-uncased")[0])) + print(f"Loading DPR reader from {self.src_file}") + saved_state = load_states_from_checkpoint(self.src_file) + # Fix changes from https://github.com/huggingface/transformers/commit/614fef1691edb806de976756d4948ecbcd0c0ca3 + state_dict = { + "encoder.bert_model.embeddings.position_ids": model.span_predictor.encoder.bert_model.embeddings.position_ids + } + for key, value in saved_state.model_dict.items(): + if key.startswith("encoder.") and not key.startswith("encoder.encode_proj"): + key = "encoder.bert_model." + key[len("encoder.") :] + state_dict[key] = value + model.span_predictor.load_state_dict(state_dict) + return model + + +def convert(comp_type: str, src_file: Path, dest_dir: Path): + dest_dir = Path(dest_dir) + dest_dir.mkdir(exist_ok=True) + + dpr_state = DPRState.from_type(comp_type, src_file=src_file) + model = dpr_state.load_dpr_model() + model.save_pretrained(dest_dir) + model.from_pretrained(dest_dir) # sanity check + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--type", type=str, help="Type of the component to convert: 'ctx_encoder', 'question_encoder' or 'reader'." + ) + parser.add_argument( + "--src", + type=str, + help=( + "Path to the dpr checkpoint file. They can be downloaded from the official DPR repo" + " https://github.com/facebookresearch/DPR. Note that in the official repo, both encoders are stored in the" + " 'retriever' checkpoints." + ), + ) + parser.add_argument("--dest", type=str, default=None, help="Path to the output PyTorch model directory.") + args = parser.parse_args() + + src_file = Path(args.src) + dest_dir = f"converted-{src_file.name}" if args.dest is None else args.dest + dest_dir = Path(dest_dir) + assert src_file.exists() + assert args.type is not None, ( + "Please specify the component type of the DPR model to convert: 'ctx_encoder', 'question_encoder' or 'reader'." + ) + convert(args.type, src_file, dest_dir) diff --git a/docs/transformers/build/lib/transformers/models/dpr/modeling_dpr.py b/docs/transformers/build/lib/transformers/models/dpr/modeling_dpr.py new file mode 100644 index 0000000000000000000000000000000000000000..3ff4aa11523e730dee6dc014aef58479f0c2775a --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/dpr/modeling_dpr.py @@ -0,0 +1,668 @@ +# coding=utf-8 +# Copyright 2018 DPR Authors, The Hugging Face Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch DPR model for Open Domain Question Answering.""" + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +from torch import Tensor, nn + +from ...modeling_outputs import BaseModelOutputWithPooling +from ...modeling_utils import PreTrainedModel +from ...utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from ..bert.modeling_bert import BertModel +from .configuration_dpr import DPRConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "DPRConfig" +_CHECKPOINT_FOR_DOC = "facebook/dpr-ctx_encoder-single-nq-base" + + +########## +# Outputs +########## + + +@dataclass +class DPRContextEncoderOutput(ModelOutput): + """ + Class for outputs of [`DPRQuestionEncoder`]. + + Args: + pooler_output (`torch.FloatTensor` of shape `(batch_size, embeddings_size)`): + The DPR encoder outputs the *pooler_output* that corresponds to the context representation. Last layer + hidden-state of the first token of the sequence (classification token) further processed by a Linear layer. + This output is to be used to embed contexts for nearest neighbors queries with questions embeddings. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + pooler_output: torch.FloatTensor + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class DPRQuestionEncoderOutput(ModelOutput): + """ + Class for outputs of [`DPRQuestionEncoder`]. + + Args: + pooler_output (`torch.FloatTensor` of shape `(batch_size, embeddings_size)`): + The DPR encoder outputs the *pooler_output* that corresponds to the question representation. Last layer + hidden-state of the first token of the sequence (classification token) further processed by a Linear layer. + This output is to be used to embed questions for nearest neighbors queries with context embeddings. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + pooler_output: torch.FloatTensor + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class DPRReaderOutput(ModelOutput): + """ + Class for outputs of [`DPRQuestionEncoder`]. + + Args: + start_logits (`torch.FloatTensor` of shape `(n_passages, sequence_length)`): + Logits of the start index of the span for each passage. + end_logits (`torch.FloatTensor` of shape `(n_passages, sequence_length)`): + Logits of the end index of the span for each passage. + relevance_logits (`torch.FloatTensor` of shape `(n_passages, )`): + Outputs of the QA classifier of the DPRReader that corresponds to the scores of each passage to answer the + question, compared to all the other passages. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + start_logits: torch.FloatTensor + end_logits: Optional[torch.FloatTensor] = None + relevance_logits: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +class DPRPreTrainedModel(PreTrainedModel): + _supports_sdpa = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +class DPREncoder(DPRPreTrainedModel): + base_model_prefix = "bert_model" + + def __init__(self, config: DPRConfig): + super().__init__(config) + self.bert_model = BertModel(config, add_pooling_layer=False) + if self.bert_model.config.hidden_size <= 0: + raise ValueError("Encoder hidden_size can't be zero") + self.projection_dim = config.projection_dim + if self.projection_dim > 0: + self.encode_proj = nn.Linear(self.bert_model.config.hidden_size, config.projection_dim) + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids: Tensor, + attention_mask: Optional[Tensor] = None, + token_type_ids: Optional[Tensor] = None, + inputs_embeds: Optional[Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = False, + ) -> Union[BaseModelOutputWithPooling, Tuple[Tensor, ...]]: + outputs = self.bert_model( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + pooled_output = sequence_output[:, 0, :] + + if self.projection_dim > 0: + pooled_output = self.encode_proj(pooled_output) + + if not return_dict: + return (sequence_output, pooled_output) + outputs[2:] + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + @property + def embeddings_size(self) -> int: + if self.projection_dim > 0: + return self.encode_proj.out_features + return self.bert_model.config.hidden_size + + +class DPRSpanPredictor(DPRPreTrainedModel): + base_model_prefix = "encoder" + + def __init__(self, config: DPRConfig): + super().__init__(config) + self.encoder = DPREncoder(config) + self.qa_outputs = nn.Linear(self.encoder.embeddings_size, 2) + self.qa_classifier = nn.Linear(self.encoder.embeddings_size, 1) + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids: Tensor, + attention_mask: Tensor, + inputs_embeds: Optional[Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = False, + ) -> Union[DPRReaderOutput, Tuple[Tensor, ...]]: + # notations: N - number of questions in a batch, M - number of passages per questions, L - sequence length + n_passages, sequence_length = input_ids.size() if input_ids is not None else inputs_embeds.size()[:2] + # feed encoder + outputs = self.encoder( + input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + + # compute logits + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + relevance_logits = self.qa_classifier(sequence_output[:, 0, :]) + + # resize + start_logits = start_logits.view(n_passages, sequence_length) + end_logits = end_logits.view(n_passages, sequence_length) + relevance_logits = relevance_logits.view(n_passages) + + if not return_dict: + return (start_logits, end_logits, relevance_logits) + outputs[2:] + + return DPRReaderOutput( + start_logits=start_logits, + end_logits=end_logits, + relevance_logits=relevance_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +################## +# PreTrainedModel +################## + + +class DPRPretrainedContextEncoder(DPRPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = DPRConfig + load_tf_weights = None + base_model_prefix = "ctx_encoder" + + +class DPRPretrainedQuestionEncoder(DPRPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = DPRConfig + load_tf_weights = None + base_model_prefix = "question_encoder" + + +class DPRPretrainedReader(DPRPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = DPRConfig + load_tf_weights = None + base_model_prefix = "span_predictor" + + +############### +# Actual Models +############### + + +DPR_START_DOCSTRING = r""" + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`DPRConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +DPR_ENCODERS_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. To match pretraining, DPR input sequence should be + formatted with [CLS] and [SEP] tokens as follows: + + (a) For sequence pairs (for a pair title+text for example): + + ``` + tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] + token_type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 + ``` + + (b) For single sequences (for a question for example): + + ``` + tokens: [CLS] the dog is hairy . [SEP] + token_type_ids: 0 0 0 0 0 0 0 + ``` + + DPR is a model with absolute position embeddings so it's usually advised to pad the inputs on the right + rather than the left. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +DPR_READER_INPUTS_DOCSTRING = r""" + Args: + input_ids (`Tuple[torch.LongTensor]` of shapes `(n_passages, sequence_length)`): + Indices of input sequence tokens in the vocabulary. It has to be a sequence triplet with 1) the question + and 2) the passages titles and 3) the passages texts To match pretraining, DPR `input_ids` sequence should + be formatted with [CLS] and [SEP] with the format: + + `[CLS] [SEP] [SEP] ` + + DPR is a model with absolute position embeddings so it's usually advised to pad the inputs on the right + rather than the left. + + Indices can be obtained using [`DPRReaderTokenizer`]. See this class documentation for more details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.FloatTensor` of shape `(n_passages, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + inputs_embeds (`torch.FloatTensor` of shape `(n_passages, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare DPRContextEncoder transformer outputting pooler outputs as context representations.", + DPR_START_DOCSTRING, +) +class DPRContextEncoder(DPRPretrainedContextEncoder): + def __init__(self, config: DPRConfig): + super().__init__(config) + self.config = config + self.ctx_encoder = DPREncoder(config) + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(DPR_ENCODERS_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=DPRContextEncoderOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[Tensor] = None, + attention_mask: Optional[Tensor] = None, + token_type_ids: Optional[Tensor] = None, + inputs_embeds: Optional[Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[DPRContextEncoderOutput, Tuple[Tensor, ...]]: + r""" + Return: + + Examples: + + ```python + >>> from transformers import DPRContextEncoder, DPRContextEncoderTokenizer + + >>> tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base") + >>> model = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base") + >>> input_ids = tokenizer("Hello, is my dog cute ?", return_tensors="pt")["input_ids"] + >>> embeddings = model(input_ids).pooler_output + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if attention_mask is None: + attention_mask = ( + torch.ones(input_shape, device=device) + if input_ids is None + else (input_ids != self.config.pad_token_id) + ) + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + outputs = self.ctx_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + return outputs[1:] + return DPRContextEncoderOutput( + pooler_output=outputs.pooler_output, hidden_states=outputs.hidden_states, attentions=outputs.attentions + ) + + +@add_start_docstrings( + "The bare DPRQuestionEncoder transformer outputting pooler outputs as question representations.", + DPR_START_DOCSTRING, +) +class DPRQuestionEncoder(DPRPretrainedQuestionEncoder): + def __init__(self, config: DPRConfig): + super().__init__(config) + self.config = config + self.question_encoder = DPREncoder(config) + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(DPR_ENCODERS_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=DPRQuestionEncoderOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[Tensor] = None, + attention_mask: Optional[Tensor] = None, + token_type_ids: Optional[Tensor] = None, + inputs_embeds: Optional[Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[DPRQuestionEncoderOutput, Tuple[Tensor, ...]]: + r""" + Return: + + Examples: + + ```python + >>> from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer + + >>> tokenizer = DPRQuestionEncoderTokenizer.from_pretrained("facebook/dpr-question_encoder-single-nq-base") + >>> model = DPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-single-nq-base") + >>> input_ids = tokenizer("Hello, is my dog cute ?", return_tensors="pt")["input_ids"] + >>> embeddings = model(input_ids).pooler_output + ``` + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if attention_mask is None: + attention_mask = ( + torch.ones(input_shape, device=device) + if input_ids is None + else (input_ids != self.config.pad_token_id) + ) + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + outputs = self.question_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + return outputs[1:] + return DPRQuestionEncoderOutput( + pooler_output=outputs.pooler_output, hidden_states=outputs.hidden_states, attentions=outputs.attentions + ) + + +@add_start_docstrings( + "The bare DPRReader transformer outputting span predictions.", + DPR_START_DOCSTRING, +) +class DPRReader(DPRPretrainedReader): + def __init__(self, config: DPRConfig): + super().__init__(config) + self.config = config + self.span_predictor = DPRSpanPredictor(config) + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(DPR_READER_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=DPRReaderOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[Tensor] = None, + attention_mask: Optional[Tensor] = None, + inputs_embeds: Optional[Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[DPRReaderOutput, Tuple[Tensor, ...]]: + r""" + Return: + + Examples: + + ```python + >>> from transformers import DPRReader, DPRReaderTokenizer + + >>> tokenizer = DPRReaderTokenizer.from_pretrained("facebook/dpr-reader-single-nq-base") + >>> model = DPRReader.from_pretrained("facebook/dpr-reader-single-nq-base") + >>> encoded_inputs = tokenizer( + ... questions=["What is love ?"], + ... titles=["Haddaway"], + ... texts=["'What Is Love' is a song recorded by the artist Haddaway"], + ... return_tensors="pt", + ... ) + >>> outputs = model(**encoded_inputs) + >>> start_logits = outputs.start_logits + >>> end_logits = outputs.end_logits + >>> relevance_logits = outputs.relevance_logits + ``` + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if attention_mask is None: + attention_mask = torch.ones(input_shape, device=device) + + return self.span_predictor( + input_ids, + attention_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +__all__ = [ + "DPRContextEncoder", + "DPRPretrainedContextEncoder", + "DPRPreTrainedModel", + "DPRPretrainedQuestionEncoder", + "DPRPretrainedReader", + "DPRQuestionEncoder", + "DPRReader", +] diff --git a/docs/transformers/build/lib/transformers/models/dpr/modeling_tf_dpr.py b/docs/transformers/build/lib/transformers/models/dpr/modeling_tf_dpr.py new file mode 100644 index 0000000000000000000000000000000000000000..303b03ec244d61c8698d8115cf8237cdbe1f960b --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/dpr/modeling_tf_dpr.py @@ -0,0 +1,800 @@ +# coding=utf-8 +# Copyright 2018 DPR Authors, The Hugging Face Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""TensorFlow DPR model for Open Domain Question Answering.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import tensorflow as tf + +from ...modeling_tf_outputs import TFBaseModelOutputWithPooling +from ...modeling_tf_utils import TFModelInputType, TFPreTrainedModel, get_initializer, keras, shape_list, unpack_inputs +from ...utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from ..bert.modeling_tf_bert import TFBertMainLayer +from .configuration_dpr import DPRConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "DPRConfig" + + +########## +# Outputs +########## + + +@dataclass +class TFDPRContextEncoderOutput(ModelOutput): + r""" + Class for outputs of [`TFDPRContextEncoder`]. + + Args: + pooler_output (`tf.Tensor` of shape `(batch_size, embeddings_size)`): + The DPR encoder outputs the *pooler_output* that corresponds to the context representation. Last layer + hidden-state of the first token of the sequence (classification token) further processed by a Linear layer. + This output is to be used to embed contexts for nearest neighbors queries with questions embeddings. + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + pooler_output: Optional[tf.Tensor] = None + hidden_states: Tuple[tf.Tensor, ...] | None = None + attentions: Tuple[tf.Tensor, ...] | None = None + + +@dataclass +class TFDPRQuestionEncoderOutput(ModelOutput): + """ + Class for outputs of [`TFDPRQuestionEncoder`]. + + Args: + pooler_output (`tf.Tensor` of shape `(batch_size, embeddings_size)`): + The DPR encoder outputs the *pooler_output* that corresponds to the question representation. Last layer + hidden-state of the first token of the sequence (classification token) further processed by a Linear layer. + This output is to be used to embed questions for nearest neighbors queries with context embeddings. + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + pooler_output: Optional[tf.Tensor] = None + hidden_states: Tuple[tf.Tensor, ...] | None = None + attentions: Tuple[tf.Tensor, ...] | None = None + + +@dataclass +class TFDPRReaderOutput(ModelOutput): + """ + Class for outputs of [`TFDPRReaderEncoder`]. + + Args: + start_logits (`tf.Tensor` of shape `(n_passages, sequence_length)`): + Logits of the start index of the span for each passage. + end_logits (`tf.Tensor` of shape `(n_passages, sequence_length)`): + Logits of the end index of the span for each passage. + relevance_logits (`tf.Tensor` of shape `(n_passages, )`): + Outputs of the QA classifier of the DPRReader that corresponds to the scores of each passage to answer the + question, compared to all the other passages. + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + start_logits: Optional[tf.Tensor] = None + end_logits: Optional[tf.Tensor] = None + relevance_logits: Optional[tf.Tensor] = None + hidden_states: Tuple[tf.Tensor, ...] | None = None + attentions: Tuple[tf.Tensor, ...] | None = None + + +class TFDPREncoderLayer(keras.layers.Layer): + base_model_prefix = "bert_model" + + def __init__(self, config: DPRConfig, **kwargs): + super().__init__(**kwargs) + + # resolve name conflict with TFBertMainLayer instead of TFBertModel + self.bert_model = TFBertMainLayer(config, add_pooling_layer=False, name="bert_model") + self.config = config + + if self.config.hidden_size <= 0: + raise ValueError("Encoder hidden_size can't be zero") + self.projection_dim = config.projection_dim + if self.projection_dim > 0: + self.encode_proj = keras.layers.Dense( + config.projection_dim, kernel_initializer=get_initializer(config.initializer_range), name="encode_proj" + ) + + @unpack_inputs + def call( + self, + input_ids: Optional[tf.Tensor] = None, + attention_mask: tf.Tensor | None = None, + token_type_ids: tf.Tensor | None = None, + inputs_embeds: tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: bool = False, + ) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor, ...]]: + outputs = self.bert_model( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + sequence_output = outputs[0] + pooled_output = sequence_output[:, 0, :] + if self.projection_dim > 0: + pooled_output = self.encode_proj(pooled_output) + + if not return_dict: + return (sequence_output, pooled_output) + outputs[1:] + + return TFBaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + @property + def embeddings_size(self) -> int: + if self.projection_dim > 0: + return self.projection_dim + return self.bert_model.config.hidden_size + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "bert_model", None) is not None: + with tf.name_scope(self.bert_model.name): + self.bert_model.build(None) + if getattr(self, "encode_proj", None) is not None: + with tf.name_scope(self.encode_proj.name): + self.encode_proj.build(None) + + +class TFDPRSpanPredictorLayer(keras.layers.Layer): + base_model_prefix = "encoder" + + def __init__(self, config: DPRConfig, **kwargs): + super().__init__(**kwargs) + self.config = config + self.encoder = TFDPREncoderLayer(config, name="encoder") + + self.qa_outputs = keras.layers.Dense( + 2, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs" + ) + self.qa_classifier = keras.layers.Dense( + 1, kernel_initializer=get_initializer(config.initializer_range), name="qa_classifier" + ) + + @unpack_inputs + def call( + self, + input_ids: Optional[tf.Tensor] = None, + attention_mask: tf.Tensor | None = None, + inputs_embeds: tf.Tensor | None = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = False, + training: bool = False, + ) -> Union[TFDPRReaderOutput, Tuple[tf.Tensor, ...]]: + # notations: N - number of questions in a batch, M - number of passages per questions, L - sequence length + n_passages, sequence_length = shape_list(input_ids) if input_ids is not None else shape_list(inputs_embeds)[:2] + # feed encoder + outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + sequence_output = outputs[0] + + # compute logits + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = tf.split(logits, 2, axis=-1) + start_logits = tf.squeeze(start_logits, axis=-1) + end_logits = tf.squeeze(end_logits, axis=-1) + relevance_logits = self.qa_classifier(sequence_output[:, 0, :]) + + # resize + start_logits = tf.reshape(start_logits, [n_passages, sequence_length]) + end_logits = tf.reshape(end_logits, [n_passages, sequence_length]) + relevance_logits = tf.reshape(relevance_logits, [n_passages]) + + if not return_dict: + return (start_logits, end_logits, relevance_logits) + outputs[2:] + + return TFDPRReaderOutput( + start_logits=start_logits, + end_logits=end_logits, + relevance_logits=relevance_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "encoder", None) is not None: + with tf.name_scope(self.encoder.name): + self.encoder.build(None) + if getattr(self, "qa_outputs", None) is not None: + with tf.name_scope(self.qa_outputs.name): + self.qa_outputs.build([None, None, self.encoder.embeddings_size]) + if getattr(self, "qa_classifier", None) is not None: + with tf.name_scope(self.qa_classifier.name): + self.qa_classifier.build([None, None, self.encoder.embeddings_size]) + + +class TFDPRSpanPredictor(TFPreTrainedModel): + base_model_prefix = "encoder" + + def __init__(self, config: DPRConfig, **kwargs): + super().__init__(config, **kwargs) + self.encoder = TFDPRSpanPredictorLayer(config) + + @unpack_inputs + def call( + self, + input_ids: Optional[tf.Tensor] = None, + attention_mask: tf.Tensor | None = None, + token_type_ids: tf.Tensor | None = None, + inputs_embeds: tf.Tensor | None = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = False, + training: bool = False, + ) -> Union[TFDPRReaderOutput, Tuple[tf.Tensor, ...]]: + outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + return outputs + + +class TFDPREncoder(TFPreTrainedModel): + base_model_prefix = "encoder" + + def __init__(self, config: DPRConfig, **kwargs): + super().__init__(config, **kwargs) + + self.encoder = TFDPREncoderLayer(config) + + @unpack_inputs + def call( + self, + input_ids: Optional[tf.Tensor] = None, + attention_mask: tf.Tensor | None = None, + token_type_ids: tf.Tensor | None = None, + inputs_embeds: tf.Tensor | None = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = False, + training: bool = False, + ) -> Union[TFDPRReaderOutput, Tuple[tf.Tensor, ...]]: + outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + return outputs + + +################## +# PreTrainedModel +################## + + +class TFDPRPretrainedContextEncoder(TFPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = DPRConfig + base_model_prefix = "ctx_encoder" + + +class TFDPRPretrainedQuestionEncoder(TFPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = DPRConfig + base_model_prefix = "question_encoder" + + +class TFDPRPretrainedReader(TFPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = DPRConfig + base_model_prefix = "reader" + + +############### +# Actual Models +############### + + +TF_DPR_START_DOCSTRING = r""" + + This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a Tensorflow [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) + subclass. Use it as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to + general usage and behavior. + + + + TensorFlow models and layers in `transformers` accept two formats as input: + + - having all inputs as keyword arguments (like PyTorch models), or + - having all inputs as a list, tuple or dict in the first positional argument. + + The reason the second format is supported is that Keras methods prefer this format when passing inputs to models + and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just + pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second + format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with + the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first + positional argument: + + - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` + - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: + `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` + - a dictionary with one or several input Tensors associated to the input names given in the docstring: + `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` + + Note that when creating models and layers with + [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry + about any of this, as you can just pass inputs like you would to any other Python function! + + + + Parameters: + config ([`DPRConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. +""" + +TF_DPR_ENCODERS_INPUTS_DOCSTRING = r""" + Args: + input_ids (`Numpy array` or `tf.Tensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. To match pretraining, DPR input sequence should be + formatted with [CLS] and [SEP] tokens as follows: + + (a) For sequence pairs (for a pair title+text for example): + + ``` + tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] + token_type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 + ``` + + (b) For single sequences (for a question for example): + + ``` + tokens: [CLS] the dog is hairy . [SEP] + token_type_ids: 0 0 0 0 0 0 0 + ``` + + DPR is a model with absolute position embeddings so it's usually advised to pad the inputs on the right + rather than the left. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`Numpy array` or `tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`Numpy array` or `tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + inputs_embeds (`Numpy array` or `tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the + config will be used instead. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. This argument can be used only in eager mode, in graph mode the value in the config will be + used instead. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in + eager mode, in graph mode the value will always be set to True. + training (`bool`, *optional*, defaults to `False`): + Whether or not to use the model in training mode (some modules like dropout modules have different + behaviors between training and evaluation). +""" + +TF_DPR_READER_INPUTS_DOCSTRING = r""" + Args: + input_ids (`Numpy array` or `tf.Tensor` of shapes `(n_passages, sequence_length)`): + Indices of input sequence tokens in the vocabulary. It has to be a sequence triplet with 1) the question + and 2) the passages titles and 3) the passages texts To match pretraining, DPR `input_ids` sequence should + be formatted with [CLS] and [SEP] with the format: + + `[CLS] [SEP] [SEP] ` + + DPR is a model with absolute position embeddings so it's usually advised to pad the inputs on the right + rather than the left. + + Indices can be obtained using [`DPRReaderTokenizer`]. See this class documentation for more details. + attention_mask (`Numpy array` or `tf.Tensor` of shape `(n_passages, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + inputs_embeds (`Numpy array` or `tf.Tensor` of shape `(n_passages, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. This argument can be used only in eager mode, in graph mode the value in the config will be + used instead. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in + eager mode, in graph mode the value will always be set to True. + training (`bool`, *optional*, defaults to `False`): + Whether or not to use the model in training mode (some modules like dropout modules have different + behaviors between training and evaluation). +""" + + +@add_start_docstrings( + "The bare DPRContextEncoder transformer outputting pooler outputs as context representations.", + TF_DPR_START_DOCSTRING, +) +class TFDPRContextEncoder(TFDPRPretrainedContextEncoder): + def __init__(self, config: DPRConfig, *args, **kwargs): + super().__init__(config, *args, **kwargs) + self.ctx_encoder = TFDPREncoderLayer(config, name="ctx_encoder") + + def get_input_embeddings(self): + try: + return self.ctx_encoder.bert_model.get_input_embeddings() + except AttributeError: + self.build() + return self.ctx_encoder.bert_model.get_input_embeddings() + + @unpack_inputs + @add_start_docstrings_to_model_forward(TF_DPR_ENCODERS_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TFDPRContextEncoderOutput, config_class=_CONFIG_FOR_DOC) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: tf.Tensor | None = None, + token_type_ids: tf.Tensor | None = None, + inputs_embeds: tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + training: bool = False, + ) -> TFDPRContextEncoderOutput | Tuple[tf.Tensor, ...]: + r""" + Return: + + Examples: + + ```python + >>> from transformers import TFDPRContextEncoder, DPRContextEncoderTokenizer + + >>> tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base") + >>> model = TFDPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base", from_pt=True) + >>> input_ids = tokenizer("Hello, is my dog cute ?", return_tensors="tf")["input_ids"] + >>> embeddings = model(input_ids).pooler_output + ``` + """ + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = shape_list(input_ids) + elif inputs_embeds is not None: + input_shape = shape_list(inputs_embeds)[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if attention_mask is None: + attention_mask = ( + tf.ones(input_shape, dtype=tf.dtypes.int32) + if input_ids is None + else (input_ids != self.config.pad_token_id) + ) + if token_type_ids is None: + token_type_ids = tf.zeros(input_shape, dtype=tf.dtypes.int32) + + outputs = self.ctx_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + if not return_dict: + return outputs[1:] + + return TFDPRContextEncoderOutput( + pooler_output=outputs.pooler_output, hidden_states=outputs.hidden_states, attentions=outputs.attentions + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "ctx_encoder", None) is not None: + with tf.name_scope(self.ctx_encoder.name): + self.ctx_encoder.build(None) + + +@add_start_docstrings( + "The bare DPRQuestionEncoder transformer outputting pooler outputs as question representations.", + TF_DPR_START_DOCSTRING, +) +class TFDPRQuestionEncoder(TFDPRPretrainedQuestionEncoder): + def __init__(self, config: DPRConfig, *args, **kwargs): + super().__init__(config, *args, **kwargs) + self.question_encoder = TFDPREncoderLayer(config, name="question_encoder") + + def get_input_embeddings(self): + try: + return self.question_encoder.bert_model.get_input_embeddings() + except AttributeError: + self.build() + return self.question_encoder.bert_model.get_input_embeddings() + + @unpack_inputs + @add_start_docstrings_to_model_forward(TF_DPR_ENCODERS_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TFDPRQuestionEncoderOutput, config_class=_CONFIG_FOR_DOC) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: tf.Tensor | None = None, + token_type_ids: tf.Tensor | None = None, + inputs_embeds: tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + training: bool = False, + ) -> TFDPRQuestionEncoderOutput | Tuple[tf.Tensor, ...]: + r""" + Return: + + Examples: + + ```python + >>> from transformers import TFDPRQuestionEncoder, DPRQuestionEncoderTokenizer + + >>> tokenizer = DPRQuestionEncoderTokenizer.from_pretrained("facebook/dpr-question_encoder-single-nq-base") + >>> model = TFDPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-single-nq-base", from_pt=True) + >>> input_ids = tokenizer("Hello, is my dog cute ?", return_tensors="tf")["input_ids"] + >>> embeddings = model(input_ids).pooler_output + ``` + """ + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = shape_list(input_ids) + elif inputs_embeds is not None: + input_shape = shape_list(inputs_embeds)[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if attention_mask is None: + attention_mask = ( + tf.ones(input_shape, dtype=tf.dtypes.int32) + if input_ids is None + else (input_ids != self.config.pad_token_id) + ) + if token_type_ids is None: + token_type_ids = tf.zeros(input_shape, dtype=tf.dtypes.int32) + + outputs = self.question_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + if not return_dict: + return outputs[1:] + return TFDPRQuestionEncoderOutput( + pooler_output=outputs.pooler_output, hidden_states=outputs.hidden_states, attentions=outputs.attentions + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "question_encoder", None) is not None: + with tf.name_scope(self.question_encoder.name): + self.question_encoder.build(None) + + +@add_start_docstrings( + "The bare DPRReader transformer outputting span predictions.", + TF_DPR_START_DOCSTRING, +) +class TFDPRReader(TFDPRPretrainedReader): + def __init__(self, config: DPRConfig, *args, **kwargs): + super().__init__(config, *args, **kwargs) + self.span_predictor = TFDPRSpanPredictorLayer(config, name="span_predictor") + + def get_input_embeddings(self): + try: + return self.span_predictor.encoder.bert_model.get_input_embeddings() + except AttributeError: + self.build() + return self.span_predictor.encoder.bert_model.get_input_embeddings() + + @unpack_inputs + @add_start_docstrings_to_model_forward(TF_DPR_READER_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TFDPRReaderOutput, config_class=_CONFIG_FOR_DOC) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: tf.Tensor | None = None, + inputs_embeds: tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + training: bool = False, + ) -> TFDPRReaderOutput | Tuple[tf.Tensor, ...]: + r""" + Return: + + Examples: + + ```python + >>> from transformers import TFDPRReader, DPRReaderTokenizer + + >>> tokenizer = DPRReaderTokenizer.from_pretrained("facebook/dpr-reader-single-nq-base") + >>> model = TFDPRReader.from_pretrained("facebook/dpr-reader-single-nq-base", from_pt=True) + >>> encoded_inputs = tokenizer( + ... questions=["What is love ?"], + ... titles=["Haddaway"], + ... texts=["'What Is Love' is a song recorded by the artist Haddaway"], + ... return_tensors="tf", + ... ) + >>> outputs = model(encoded_inputs) + >>> start_logits = outputs.start_logits + >>> end_logits = outputs.end_logits + >>> relevance_logits = outputs.relevance_logits + ``` + """ + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = shape_list(input_ids) + elif inputs_embeds is not None: + input_shape = shape_list(inputs_embeds)[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if attention_mask is None: + attention_mask = tf.ones(input_shape, dtype=tf.dtypes.int32) + + return self.span_predictor( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "span_predictor", None) is not None: + with tf.name_scope(self.span_predictor.name): + self.span_predictor.build(None) + + +__all__ = [ + "TFDPRContextEncoder", + "TFDPRPretrainedContextEncoder", + "TFDPRPretrainedQuestionEncoder", + "TFDPRPretrainedReader", + "TFDPRQuestionEncoder", + "TFDPRReader", +] diff --git a/docs/transformers/build/lib/transformers/models/dpr/tokenization_dpr.py b/docs/transformers/build/lib/transformers/models/dpr/tokenization_dpr.py new file mode 100644 index 0000000000000000000000000000000000000000..00b8dedfa7e4b7d190ad4fa3b9f68597996b038f --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/dpr/tokenization_dpr.py @@ -0,0 +1,321 @@ +# coding=utf-8 +# Copyright 2018 The HuggingFace Inc. team, The Hugging Face Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for DPR.""" + +import collections +from typing import List, Optional, Union + +from ...tokenization_utils_base import BatchEncoding +from ...utils import TensorType, add_end_docstrings, add_start_docstrings, logging +from ..bert.tokenization_bert import BertTokenizer + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.json"} + + +class DPRContextEncoderTokenizer(BertTokenizer): + r""" + Construct a DPRContextEncoder tokenizer. + + [`DPRContextEncoderTokenizer`] is identical to [`BertTokenizer`] and runs end-to-end tokenization: punctuation + splitting and wordpiece. + + Refer to superclass [`BertTokenizer`] for usage examples and documentation concerning parameters. + """ + + vocab_files_names = VOCAB_FILES_NAMES + + +class DPRQuestionEncoderTokenizer(BertTokenizer): + r""" + Constructs a DPRQuestionEncoder tokenizer. + + [`DPRQuestionEncoderTokenizer`] is identical to [`BertTokenizer`] and runs end-to-end tokenization: punctuation + splitting and wordpiece. + + Refer to superclass [`BertTokenizer`] for usage examples and documentation concerning parameters. + """ + + vocab_files_names = VOCAB_FILES_NAMES + + +DPRSpanPrediction = collections.namedtuple( + "DPRSpanPrediction", ["span_score", "relevance_score", "doc_id", "start_index", "end_index", "text"] +) + +DPRReaderOutput = collections.namedtuple("DPRReaderOutput", ["start_logits", "end_logits", "relevance_logits"]) + + +CUSTOM_DPR_READER_DOCSTRING = r""" + Return a dictionary with the token ids of the input strings and other information to give to `.decode_best_spans`. + It converts the strings of a question and different passages (title and text) in a sequence of IDs (integers), + using the tokenizer and vocabulary. The resulting `input_ids` is a matrix of size `(n_passages, sequence_length)` + with the format: + + ``` + [CLS] [SEP] [SEP] + ``` + + Args: + questions (`str` or `List[str]`): + The questions to be encoded. You can specify one question for many passages. In this case, the question + will be duplicated like `[questions] * n_passages`. Otherwise you have to specify as many questions as in + `titles` or `texts`. + titles (`str` or `List[str]`): + The passages titles to be encoded. This can be a string or a list of strings if there are several passages. + texts (`str` or `List[str]`): + The passages texts to be encoded. This can be a string or a list of strings if there are several passages. + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`): + Activates and controls padding. Accepts the following values: + + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single sequence + if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + truncation (`bool`, `str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*, defaults to `False`): + Activates and controls truncation. Accepts the following values: + + - `True` or `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or to + the maximum acceptable input length for the model if that argument is not provided. This will truncate + token by token, removing a token from the longest sequence in the pair if a pair of sequences (or a batch + of pairs) is provided. + - `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. This will only truncate the first + sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + - `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. This will only truncate the + second sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + - `False` or `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths + greater than the model maximum admissible input size). + max_length (`int`, *optional*): + Controls the maximum length to use by one of the truncation/padding parameters. + + If left unset or set to `None`, this will use the predefined model maximum length if a maximum length + is required by one of the truncation/padding parameters. If the model has no specific maximum input + length (like XLNet) truncation/padding to a maximum length will be deactivated. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors instead of list of python integers. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return Numpy `np.ndarray` objects. + return_attention_mask (`bool`, *optional*): + Whether or not to return the attention mask. If not set, will return the attention mask according to the + specific tokenizer's default, defined by the `return_outputs` attribute. + + [What are attention masks?](../glossary#attention-mask) + + Returns: + `Dict[str, List[List[int]]]`: A dictionary with the following keys: + + - `input_ids`: List of token ids to be fed to a model. + - `attention_mask`: List of indices specifying which tokens should be attended to by the model. + """ + + +@add_start_docstrings(CUSTOM_DPR_READER_DOCSTRING) +class CustomDPRReaderTokenizerMixin: + def __call__( + self, + questions, + titles: Optional[str] = None, + texts: Optional[str] = None, + padding: Union[bool, str] = False, + truncation: Union[bool, str] = False, + max_length: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_attention_mask: Optional[bool] = None, + **kwargs, + ) -> BatchEncoding: + if titles is None and texts is None: + return super().__call__( + questions, + padding=padding, + truncation=truncation, + max_length=max_length, + return_tensors=return_tensors, + return_attention_mask=return_attention_mask, + **kwargs, + ) + elif titles is None or texts is None: + text_pair = titles if texts is None else texts + return super().__call__( + questions, + text_pair, + padding=padding, + truncation=truncation, + max_length=max_length, + return_tensors=return_tensors, + return_attention_mask=return_attention_mask, + **kwargs, + ) + titles = titles if not isinstance(titles, str) else [titles] + texts = texts if not isinstance(texts, str) else [texts] + n_passages = len(titles) + questions = questions if not isinstance(questions, str) else [questions] * n_passages + if len(titles) != len(texts): + raise ValueError( + f"There should be as many titles than texts but got {len(titles)} titles and {len(texts)} texts." + ) + encoded_question_and_titles = super().__call__(questions, titles, padding=False, truncation=False)["input_ids"] + encoded_texts = super().__call__(texts, add_special_tokens=False, padding=False, truncation=False)["input_ids"] + encoded_inputs = { + "input_ids": [ + (encoded_question_and_title + encoded_text)[:max_length] + if max_length is not None and truncation + else encoded_question_and_title + encoded_text + for encoded_question_and_title, encoded_text in zip(encoded_question_and_titles, encoded_texts) + ] + } + if return_attention_mask is not False: + attention_mask = [] + for input_ids in encoded_inputs["input_ids"]: + attention_mask.append([int(input_id != self.pad_token_id) for input_id in input_ids]) + encoded_inputs["attention_mask"] = attention_mask + return self.pad(encoded_inputs, padding=padding, max_length=max_length, return_tensors=return_tensors) + + def decode_best_spans( + self, + reader_input: BatchEncoding, + reader_output: DPRReaderOutput, + num_spans: int = 16, + max_answer_length: int = 64, + num_spans_per_passage: int = 4, + ) -> List[DPRSpanPrediction]: + """ + Get the span predictions for the extractive Q&A model. + + Returns: *List* of *DPRReaderOutput* sorted by descending *(relevance_score, span_score)*. Each + *DPRReaderOutput* is a *Tuple* with: + + - **span_score**: `float` that corresponds to the score given by the reader for this span compared to other + spans in the same passage. It corresponds to the sum of the start and end logits of the span. + - **relevance_score**: `float` that corresponds to the score of the each passage to answer the question, + compared to all the other passages. It corresponds to the output of the QA classifier of the DPRReader. + - **doc_id**: `int` the id of the passage. - **start_index**: `int` the start index of the span + (inclusive). - **end_index**: `int` the end index of the span (inclusive). + + Examples: + + ```python + >>> from transformers import DPRReader, DPRReaderTokenizer + + >>> tokenizer = DPRReaderTokenizer.from_pretrained("facebook/dpr-reader-single-nq-base") + >>> model = DPRReader.from_pretrained("facebook/dpr-reader-single-nq-base") + >>> encoded_inputs = tokenizer( + ... questions=["What is love ?"], + ... titles=["Haddaway"], + ... texts=["'What Is Love' is a song recorded by the artist Haddaway"], + ... return_tensors="pt", + ... ) + >>> outputs = model(**encoded_inputs) + >>> predicted_spans = tokenizer.decode_best_spans(encoded_inputs, outputs) + >>> print(predicted_spans[0].text) # best span + a song + ```""" + input_ids = reader_input["input_ids"] + start_logits, end_logits, relevance_logits = reader_output[:3] + n_passages = len(relevance_logits) + sorted_docs = sorted(range(n_passages), reverse=True, key=relevance_logits.__getitem__) + nbest_spans_predictions: List[DPRReaderOutput] = [] + for doc_id in sorted_docs: + sequence_ids = list(input_ids[doc_id]) + # assuming question & title information is at the beginning of the sequence + passage_offset = sequence_ids.index(self.sep_token_id, 2) + 1 # second sep id + if sequence_ids[-1] == self.pad_token_id: + sequence_len = sequence_ids.index(self.pad_token_id) + else: + sequence_len = len(sequence_ids) + + best_spans = self._get_best_spans( + start_logits=start_logits[doc_id][passage_offset:sequence_len], + end_logits=end_logits[doc_id][passage_offset:sequence_len], + max_answer_length=max_answer_length, + top_spans=num_spans_per_passage, + ) + for start_index, end_index in best_spans: + start_index += passage_offset + end_index += passage_offset + nbest_spans_predictions.append( + DPRSpanPrediction( + span_score=start_logits[doc_id][start_index] + end_logits[doc_id][end_index], + relevance_score=relevance_logits[doc_id], + doc_id=doc_id, + start_index=start_index, + end_index=end_index, + text=self.decode(sequence_ids[start_index : end_index + 1]), + ) + ) + if len(nbest_spans_predictions) >= num_spans: + break + return nbest_spans_predictions[:num_spans] + + def _get_best_spans( + self, + start_logits: List[int], + end_logits: List[int], + max_answer_length: int, + top_spans: int, + ) -> List[DPRSpanPrediction]: + """ + Finds the best answer span for the extractive Q&A model for one passage. It returns the best span by descending + `span_score` order and keeping max `top_spans` spans. Spans longer that `max_answer_length` are ignored. + """ + scores = [] + for start_index, start_score in enumerate(start_logits): + for answer_length, end_score in enumerate(end_logits[start_index : start_index + max_answer_length]): + scores.append(((start_index, start_index + answer_length), start_score + end_score)) + scores = sorted(scores, key=lambda x: x[1], reverse=True) + chosen_span_intervals = [] + for (start_index, end_index), score in scores: + if start_index > end_index: + raise ValueError(f"Wrong span indices: [{start_index}:{end_index}]") + length = end_index - start_index + 1 + if length > max_answer_length: + raise ValueError(f"Span is too long: {length} > {max_answer_length}") + if any( + start_index <= prev_start_index <= prev_end_index <= end_index + or prev_start_index <= start_index <= end_index <= prev_end_index + for (prev_start_index, prev_end_index) in chosen_span_intervals + ): + continue + chosen_span_intervals.append((start_index, end_index)) + + if len(chosen_span_intervals) == top_spans: + break + return chosen_span_intervals + + +@add_end_docstrings(CUSTOM_DPR_READER_DOCSTRING) +class DPRReaderTokenizer(CustomDPRReaderTokenizerMixin, BertTokenizer): + r""" + Construct a DPRReader tokenizer. + + [`DPRReaderTokenizer`] is almost identical to [`BertTokenizer`] and runs end-to-end tokenization: punctuation + splitting and wordpiece. The difference is that is has three inputs strings: question, titles and texts that are + combined to be fed to the [`DPRReader`] model. + + Refer to superclass [`BertTokenizer`] for usage examples and documentation concerning parameters. + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + + +__all__ = ["DPRContextEncoderTokenizer", "DPRQuestionEncoderTokenizer", "DPRReaderOutput", "DPRReaderTokenizer"] diff --git a/docs/transformers/build/lib/transformers/models/dpr/tokenization_dpr_fast.py b/docs/transformers/build/lib/transformers/models/dpr/tokenization_dpr_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..f4e7c0fdcdbf72aebb97216927dea52577b8bc4c --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/dpr/tokenization_dpr_fast.py @@ -0,0 +1,321 @@ +# coding=utf-8 +# Copyright 2018 The HuggingFace Inc. team, The Hugging Face Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for DPR.""" + +import collections +from typing import List, Optional, Union + +from ...tokenization_utils_base import BatchEncoding +from ...utils import TensorType, add_end_docstrings, add_start_docstrings, logging +from ..bert.tokenization_bert_fast import BertTokenizerFast +from .tokenization_dpr import DPRContextEncoderTokenizer, DPRQuestionEncoderTokenizer, DPRReaderTokenizer + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.json"} + + +class DPRContextEncoderTokenizerFast(BertTokenizerFast): + r""" + Construct a "fast" DPRContextEncoder tokenizer (backed by HuggingFace's *tokenizers* library). + + [`DPRContextEncoderTokenizerFast`] is identical to [`BertTokenizerFast`] and runs end-to-end tokenization: + punctuation splitting and wordpiece. + + Refer to superclass [`BertTokenizerFast`] for usage examples and documentation concerning parameters. + """ + + vocab_files_names = VOCAB_FILES_NAMES + slow_tokenizer_class = DPRContextEncoderTokenizer + + +class DPRQuestionEncoderTokenizerFast(BertTokenizerFast): + r""" + Constructs a "fast" DPRQuestionEncoder tokenizer (backed by HuggingFace's *tokenizers* library). + + [`DPRQuestionEncoderTokenizerFast`] is identical to [`BertTokenizerFast`] and runs end-to-end tokenization: + punctuation splitting and wordpiece. + + Refer to superclass [`BertTokenizerFast`] for usage examples and documentation concerning parameters. + """ + + vocab_files_names = VOCAB_FILES_NAMES + slow_tokenizer_class = DPRQuestionEncoderTokenizer + + +DPRSpanPrediction = collections.namedtuple( + "DPRSpanPrediction", ["span_score", "relevance_score", "doc_id", "start_index", "end_index", "text"] +) + +DPRReaderOutput = collections.namedtuple("DPRReaderOutput", ["start_logits", "end_logits", "relevance_logits"]) + + +CUSTOM_DPR_READER_DOCSTRING = r""" + Return a dictionary with the token ids of the input strings and other information to give to `.decode_best_spans`. + It converts the strings of a question and different passages (title and text) in a sequence of IDs (integers), + using the tokenizer and vocabulary. The resulting `input_ids` is a matrix of size `(n_passages, sequence_length)` + with the format: + + [CLS] [SEP] [SEP] + + Args: + questions (`str` or `List[str]`): + The questions to be encoded. You can specify one question for many passages. In this case, the question + will be duplicated like `[questions] * n_passages`. Otherwise you have to specify as many questions as in + `titles` or `texts`. + titles (`str` or `List[str]`): + The passages titles to be encoded. This can be a string or a list of strings if there are several passages. + texts (`str` or `List[str]`): + The passages texts to be encoded. This can be a string or a list of strings if there are several passages. + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`): + Activates and controls padding. Accepts the following values: + + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single sequence + if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + truncation (`bool`, `str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*, defaults to `False`): + Activates and controls truncation. Accepts the following values: + + - `True` or `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or to + the maximum acceptable input length for the model if that argument is not provided. This will truncate + token by token, removing a token from the longest sequence in the pair if a pair of sequences (or a batch + of pairs) is provided. + - `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. This will only truncate the first + sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + - `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. This will only truncate the + second sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + - `False` or `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths + greater than the model maximum admissible input size). + max_length (`int`, *optional*): + Controls the maximum length to use by one of the truncation/padding parameters. + + If left unset or set to `None`, this will use the predefined model maximum length if a maximum length + is required by one of the truncation/padding parameters. If the model has no specific maximum input + length (like XLNet) truncation/padding to a maximum length will be deactivated. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors instead of list of python integers. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return Numpy `np.ndarray` objects. + return_attention_mask (`bool`, *optional*): + Whether or not to return the attention mask. If not set, will return the attention mask according to the + specific tokenizer's default, defined by the `return_outputs` attribute. + + [What are attention masks?](../glossary#attention-mask) + + Return: + `Dict[str, List[List[int]]]`: A dictionary with the following keys: + + - `input_ids`: List of token ids to be fed to a model. + - `attention_mask`: List of indices specifying which tokens should be attended to by the model. + """ + + +@add_start_docstrings(CUSTOM_DPR_READER_DOCSTRING) +class CustomDPRReaderTokenizerMixin: + def __call__( + self, + questions, + titles: Optional[str] = None, + texts: Optional[str] = None, + padding: Union[bool, str] = False, + truncation: Union[bool, str] = False, + max_length: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_attention_mask: Optional[bool] = None, + **kwargs, + ) -> BatchEncoding: + if titles is None and texts is None: + return super().__call__( + questions, + padding=padding, + truncation=truncation, + max_length=max_length, + return_tensors=return_tensors, + return_attention_mask=return_attention_mask, + **kwargs, + ) + elif titles is None or texts is None: + text_pair = titles if texts is None else texts + return super().__call__( + questions, + text_pair, + padding=padding, + truncation=truncation, + max_length=max_length, + return_tensors=return_tensors, + return_attention_mask=return_attention_mask, + **kwargs, + ) + titles = titles if not isinstance(titles, str) else [titles] + texts = texts if not isinstance(texts, str) else [texts] + n_passages = len(titles) + questions = questions if not isinstance(questions, str) else [questions] * n_passages + assert len(titles) == len(texts), ( + f"There should be as many titles than texts but got {len(titles)} titles and {len(texts)} texts." + ) + encoded_question_and_titles = super().__call__(questions, titles, padding=False, truncation=False)["input_ids"] + encoded_texts = super().__call__(texts, add_special_tokens=False, padding=False, truncation=False)["input_ids"] + encoded_inputs = { + "input_ids": [ + (encoded_question_and_title + encoded_text)[:max_length] + if max_length is not None and truncation + else encoded_question_and_title + encoded_text + for encoded_question_and_title, encoded_text in zip(encoded_question_and_titles, encoded_texts) + ] + } + if return_attention_mask is not False: + attention_mask = [] + for input_ids in encoded_inputs["input_ids"]: + attention_mask.append([int(input_id != self.pad_token_id) for input_id in input_ids]) + encoded_inputs["attention_mask"] = attention_mask + return self.pad(encoded_inputs, padding=padding, max_length=max_length, return_tensors=return_tensors) + + def decode_best_spans( + self, + reader_input: BatchEncoding, + reader_output: DPRReaderOutput, + num_spans: int = 16, + max_answer_length: int = 64, + num_spans_per_passage: int = 4, + ) -> List[DPRSpanPrediction]: + """ + Get the span predictions for the extractive Q&A model. + + Returns: *List* of *DPRReaderOutput* sorted by descending *(relevance_score, span_score)*. Each + *DPRReaderOutput* is a *Tuple* with: + + - **span_score**: `float` that corresponds to the score given by the reader for this span compared to other + spans in the same passage. It corresponds to the sum of the start and end logits of the span. + - **relevance_score**: `float` that corresponds to the score of the each passage to answer the question, + compared to all the other passages. It corresponds to the output of the QA classifier of the DPRReader. + - **doc_id**: `int` the id of the passage. - ***start_index**: `int` the start index of the span + (inclusive). - **end_index**: `int` the end index of the span (inclusive). + + Examples: + + ```python + >>> from transformers import DPRReader, DPRReaderTokenizer + + >>> tokenizer = DPRReaderTokenizer.from_pretrained("facebook/dpr-reader-single-nq-base") + >>> model = DPRReader.from_pretrained("facebook/dpr-reader-single-nq-base") + >>> encoded_inputs = tokenizer( + ... questions=["What is love ?"], + ... titles=["Haddaway"], + ... texts=["'What Is Love' is a song recorded by the artist Haddaway"], + ... return_tensors="pt", + ... ) + >>> outputs = model(**encoded_inputs) + >>> predicted_spans = tokenizer.decode_best_spans(encoded_inputs, outputs) + >>> print(predicted_spans[0].text) # best span + a song + ```""" + input_ids = reader_input["input_ids"] + start_logits, end_logits, relevance_logits = reader_output[:3] + n_passages = len(relevance_logits) + sorted_docs = sorted(range(n_passages), reverse=True, key=relevance_logits.__getitem__) + nbest_spans_predictions: List[DPRReaderOutput] = [] + for doc_id in sorted_docs: + sequence_ids = list(input_ids[doc_id]) + # assuming question & title information is at the beginning of the sequence + passage_offset = sequence_ids.index(self.sep_token_id, 2) + 1 # second sep id + if sequence_ids[-1] == self.pad_token_id: + sequence_len = sequence_ids.index(self.pad_token_id) + else: + sequence_len = len(sequence_ids) + + best_spans = self._get_best_spans( + start_logits=start_logits[doc_id][passage_offset:sequence_len], + end_logits=end_logits[doc_id][passage_offset:sequence_len], + max_answer_length=max_answer_length, + top_spans=num_spans_per_passage, + ) + for start_index, end_index in best_spans: + start_index += passage_offset + end_index += passage_offset + nbest_spans_predictions.append( + DPRSpanPrediction( + span_score=start_logits[doc_id][start_index] + end_logits[doc_id][end_index], + relevance_score=relevance_logits[doc_id], + doc_id=doc_id, + start_index=start_index, + end_index=end_index, + text=self.decode(sequence_ids[start_index : end_index + 1]), + ) + ) + if len(nbest_spans_predictions) >= num_spans: + break + return nbest_spans_predictions[:num_spans] + + def _get_best_spans( + self, + start_logits: List[int], + end_logits: List[int], + max_answer_length: int, + top_spans: int, + ) -> List[DPRSpanPrediction]: + """ + Finds the best answer span for the extractive Q&A model for one passage. It returns the best span by descending + `span_score` order and keeping max `top_spans` spans. Spans longer that `max_answer_length` are ignored. + """ + scores = [] + for start_index, start_score in enumerate(start_logits): + for answer_length, end_score in enumerate(end_logits[start_index : start_index + max_answer_length]): + scores.append(((start_index, start_index + answer_length), start_score + end_score)) + scores = sorted(scores, key=lambda x: x[1], reverse=True) + chosen_span_intervals = [] + for (start_index, end_index), score in scores: + assert start_index <= end_index, f"Wrong span indices: [{start_index}:{end_index}]" + length = end_index - start_index + 1 + assert length <= max_answer_length, f"Span is too long: {length} > {max_answer_length}" + if any( + start_index <= prev_start_index <= prev_end_index <= end_index + or prev_start_index <= start_index <= end_index <= prev_end_index + for (prev_start_index, prev_end_index) in chosen_span_intervals + ): + continue + chosen_span_intervals.append((start_index, end_index)) + + if len(chosen_span_intervals) == top_spans: + break + return chosen_span_intervals + + +@add_end_docstrings(CUSTOM_DPR_READER_DOCSTRING) +class DPRReaderTokenizerFast(CustomDPRReaderTokenizerMixin, BertTokenizerFast): + r""" + Constructs a "fast" DPRReader tokenizer (backed by HuggingFace's *tokenizers* library). + + [`DPRReaderTokenizerFast`] is almost identical to [`BertTokenizerFast`] and runs end-to-end tokenization: + punctuation splitting and wordpiece. The difference is that is has three inputs strings: question, titles and texts + that are combined to be fed to the [`DPRReader`] model. + + Refer to superclass [`BertTokenizerFast`] for usage examples and documentation concerning parameters. + + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + slow_tokenizer_class = DPRReaderTokenizer + + +__all__ = ["DPRContextEncoderTokenizerFast", "DPRQuestionEncoderTokenizerFast", "DPRReaderTokenizerFast"] diff --git a/docs/transformers/build/lib/transformers/models/dpt/__init__.py b/docs/transformers/build/lib/transformers/models/dpt/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..086750423dbd93ecd6f23bf50adb8e0955c1771d --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/dpt/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_dpt import * + from .feature_extraction_dpt import * + from .image_processing_dpt import * + from .modeling_dpt import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/build/lib/transformers/models/dpt/configuration_dpt.py b/docs/transformers/build/lib/transformers/models/dpt/configuration_dpt.py new file mode 100644 index 0000000000000000000000000000000000000000..22f25e18423f142cd9847a3441db723b3e005025 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/dpt/configuration_dpt.py @@ -0,0 +1,300 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""DPT model configuration""" + +import copy + +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ...utils.backbone_utils import verify_backbone_config_arguments +from ..auto.configuration_auto import CONFIG_MAPPING +from ..bit import BitConfig + + +logger = logging.get_logger(__name__) + + +class DPTConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`DPTModel`]. It is used to instantiate an DPT + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the DPT + [Intel/dpt-large](https://huggingface.co/Intel/dpt-large) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + image_size (`int`, *optional*, defaults to 384): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch. + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + is_hybrid (`bool`, *optional*, defaults to `False`): + Whether to use a hybrid backbone. Useful in the context of loading DPT-Hybrid models. + qkv_bias (`bool`, *optional*, defaults to `True`): + Whether to add a bias to the queries, keys and values. + backbone_out_indices (`List[int]`, *optional*, defaults to `[2, 5, 8, 11]`): + Indices of the intermediate hidden states to use from backbone. + readout_type (`str`, *optional*, defaults to `"project"`): + The readout type to use when processing the readout token (CLS token) of the intermediate hidden states of + the ViT backbone. Can be one of [`"ignore"`, `"add"`, `"project"`]. + + - "ignore" simply ignores the CLS token. + - "add" passes the information from the CLS token to all other tokens by adding the representations. + - "project" passes information to the other tokens by concatenating the readout to all other tokens before + projecting the + representation to the original feature dimension D using a linear layer followed by a GELU non-linearity. + reassemble_factors (`List[int]`, *optional*, defaults to `[4, 2, 1, 0.5]`): + The up/downsampling factors of the reassemble layers. + neck_hidden_sizes (`List[str]`, *optional*, defaults to `[96, 192, 384, 768]`): + The hidden sizes to project to for the feature maps of the backbone. + fusion_hidden_size (`int`, *optional*, defaults to 256): + The number of channels before fusion. + head_in_index (`int`, *optional*, defaults to -1): + The index of the features to use in the heads. + use_batch_norm_in_fusion_residual (`bool`, *optional*, defaults to `False`): + Whether to use batch normalization in the pre-activate residual units of the fusion blocks. + use_bias_in_fusion_residual (`bool`, *optional*, defaults to `True`): + Whether to use bias in the pre-activate residual units of the fusion blocks. + add_projection (`bool`, *optional*, defaults to `False`): + Whether to add a projection layer before the depth estimation head. + use_auxiliary_head (`bool`, *optional*, defaults to `True`): + Whether to use an auxiliary head during training. + auxiliary_loss_weight (`float`, *optional*, defaults to 0.4): + Weight of the cross-entropy loss of the auxiliary head. + semantic_loss_ignore_index (`int`, *optional*, defaults to 255): + The index that is ignored by the loss function of the semantic segmentation model. + semantic_classifier_dropout (`float`, *optional*, defaults to 0.1): + The dropout ratio for the semantic classification head. + backbone_featmap_shape (`List[int]`, *optional*, defaults to `[1, 1024, 24, 24]`): + Used only for the `hybrid` embedding type. The shape of the feature maps of the backbone. + neck_ignore_stages (`List[int]`, *optional*, defaults to `[0, 1]`): + Used only for the `hybrid` embedding type. The stages of the readout layers to ignore. + backbone_config (`Union[Dict[str, Any], PretrainedConfig]`, *optional*): + The configuration of the backbone model. Only used in case `is_hybrid` is `True` or in case you want to + leverage the [`AutoBackbone`] API. + backbone (`str`, *optional*): + Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this + will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone` + is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights. + use_pretrained_backbone (`bool`, *optional*, defaults to `False`): + Whether to use pretrained weights for the backbone. + use_timm_backbone (`bool`, *optional*, defaults to `False`): + Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers + library. + backbone_kwargs (`dict`, *optional*): + Keyword arguments to be passed to AutoBackbone when loading from a checkpoint + e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set. + pooler_output_size (`int`, *optional*): + Dimensionality of the pooler layer. If None, defaults to `hidden_size`. + pooler_act (`str`, *optional*, defaults to `"tanh"`): + The activation function to be used by the pooler. Keys of ACT2FN are supported for Flax and + Pytorch, and elements of https://www.tensorflow.org/api_docs/python/tf/keras/activations are + supported for Tensorflow. + + Example: + + ```python + >>> from transformers import DPTModel, DPTConfig + + >>> # Initializing a DPT dpt-large style configuration + >>> configuration = DPTConfig() + + >>> # Initializing a model from the dpt-large style configuration + >>> model = DPTModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "dpt" + + def __init__( + self, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + initializer_range=0.02, + layer_norm_eps=1e-12, + image_size=384, + patch_size=16, + num_channels=3, + is_hybrid=False, + qkv_bias=True, + backbone_out_indices=[2, 5, 8, 11], + readout_type="project", + reassemble_factors=[4, 2, 1, 0.5], + neck_hidden_sizes=[96, 192, 384, 768], + fusion_hidden_size=256, + head_in_index=-1, + use_batch_norm_in_fusion_residual=False, + use_bias_in_fusion_residual=None, + add_projection=False, + use_auxiliary_head=True, + auxiliary_loss_weight=0.4, + semantic_loss_ignore_index=255, + semantic_classifier_dropout=0.1, + backbone_featmap_shape=[1, 1024, 24, 24], + neck_ignore_stages=[0, 1], + backbone_config=None, + backbone=None, + use_pretrained_backbone=False, + use_timm_backbone=False, + backbone_kwargs=None, + pooler_output_size=None, + pooler_act="tanh", + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.is_hybrid = is_hybrid + + use_autobackbone = False + if self.is_hybrid: + if backbone_config is None: + backbone_config = { + "global_padding": "same", + "layer_type": "bottleneck", + "depths": [3, 4, 9], + "out_features": ["stage1", "stage2", "stage3"], + "embedding_dynamic_padding": True, + } + + if isinstance(backbone_config, dict): + logger.info("Initializing the config with a `BiT` backbone.") + backbone_config = BitConfig(**backbone_config) + elif isinstance(backbone_config, PretrainedConfig): + backbone_config = backbone_config + else: + raise ValueError( + f"backbone_config must be a dictionary or a `PretrainedConfig`, got {backbone_config.__class__}." + ) + self.backbone_config = backbone_config + self.backbone_featmap_shape = backbone_featmap_shape + self.neck_ignore_stages = neck_ignore_stages + + if readout_type != "project": + raise ValueError("Readout type must be 'project' when using `DPT-hybrid` mode.") + + elif backbone is not None or backbone_config is not None: + use_autobackbone = True + if isinstance(backbone_config, dict): + backbone_model_type = backbone_config.get("model_type") + config_class = CONFIG_MAPPING[backbone_model_type] + backbone_config = config_class.from_dict(backbone_config) + + self.backbone_config = backbone_config + self.backbone_featmap_shape = None + self.neck_ignore_stages = [] + + # We only use load_backbone when config.is_hydrid is False + verify_backbone_config_arguments( + use_timm_backbone=use_timm_backbone, + use_pretrained_backbone=use_pretrained_backbone, + backbone=backbone, + backbone_config=backbone_config, + backbone_kwargs=backbone_kwargs, + ) + else: + self.backbone_config = None + self.backbone_featmap_shape = None + self.neck_ignore_stages = [] + + self.backbone = backbone + self.use_pretrained_backbone = use_pretrained_backbone + self.use_timm_backbone = use_timm_backbone + self.backbone_kwargs = backbone_kwargs + + # ViT parameters used if not using a hybrid backbone + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.layer_norm_eps = layer_norm_eps + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.qkv_bias = qkv_bias + self.use_autobackbone = use_autobackbone + self.backbone_out_indices = None if use_autobackbone else backbone_out_indices + + if readout_type not in ["ignore", "add", "project"]: + raise ValueError("Readout_type must be one of ['ignore', 'add', 'project']") + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.readout_type = readout_type + self.reassemble_factors = reassemble_factors + self.neck_hidden_sizes = neck_hidden_sizes + self.fusion_hidden_size = fusion_hidden_size + self.head_in_index = head_in_index + self.use_batch_norm_in_fusion_residual = use_batch_norm_in_fusion_residual + self.use_bias_in_fusion_residual = use_bias_in_fusion_residual + self.add_projection = add_projection + + # auxiliary head attributes (semantic segmentation) + self.use_auxiliary_head = use_auxiliary_head + self.auxiliary_loss_weight = auxiliary_loss_weight + self.semantic_loss_ignore_index = semantic_loss_ignore_index + self.semantic_classifier_dropout = semantic_classifier_dropout + self.pooler_output_size = pooler_output_size if pooler_output_size else hidden_size + self.pooler_act = pooler_act + + def to_dict(self): + """ + Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. Returns: + `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, + """ + output = copy.deepcopy(self.__dict__) + + if output["backbone_config"] is not None: + output["backbone_config"] = self.backbone_config.to_dict() + + output["model_type"] = self.__class__.model_type + return output + + @property + def sub_configs(self): + return {"backbone_config": type(self.backbone_config)} if self.backbone_config is not None else {} + + +__all__ = ["DPTConfig"] diff --git a/docs/transformers/build/lib/transformers/models/dpt/convert_dinov2_depth_to_hf.py b/docs/transformers/build/lib/transformers/models/dpt/convert_dinov2_depth_to_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..367aff7f90e18bab33dc81bfc82148550d73b944 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/dpt/convert_dinov2_depth_to_hf.py @@ -0,0 +1,383 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert DINOv2 + DPT checkpoints from the original repository. URL: +https://github.com/facebookresearch/dinov2/tree/main""" + +import argparse +import itertools +import math +from pathlib import Path + +import requests +import torch +from PIL import Image +from torchvision import transforms + +from transformers import Dinov2Config, DPTConfig, DPTForDepthEstimation, DPTImageProcessor +from transformers.utils import logging + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + +def get_dpt_config(model_name): + if "small" in model_name: + # equivalent to stage 3, stage 6, stage 9, stage 12 + backbone_config = Dinov2Config.from_pretrained( + "facebook/dinov2-small", out_indices=[3, 6, 9, 12], apply_layernorm=False, reshape_hidden_states=False + ) + neck_hidden_sizes = [48, 96, 192, 384] + elif "base" in model_name: + backbone_config = Dinov2Config.from_pretrained( + "facebook/dinov2-base", out_indices=[3, 6, 9, 12], apply_layernorm=False, reshape_hidden_states=False + ) + neck_hidden_sizes = [96, 192, 384, 768] + elif "large" in model_name: + backbone_config = Dinov2Config.from_pretrained( + "facebook/dinov2-large", out_indices=[5, 12, 18, 24], apply_layernorm=False, reshape_hidden_states=False + ) + neck_hidden_sizes = [128, 256, 512, 1024] + elif "giant" in model_name: + backbone_config = Dinov2Config.from_pretrained( + "facebook/dinov2-giant", out_indices=[10, 20, 30, 40], apply_layernorm=False, reshape_hidden_states=False + ) + neck_hidden_sizes = [192, 384, 768, 1536] + else: + raise NotImplementedError("To do") + + config = DPTConfig( + backbone_config=backbone_config, + neck_hidden_sizes=neck_hidden_sizes, + use_bias_in_fusion_residual=False, + add_projection=True, + ) + + return config + + +# here we list all DPT keys to be renamed (original name on the left, our name on the right) +def create_rename_keys_dpt(config): + rename_keys = [] + + # fmt: off + # activation postprocessing (projections, readout projections + resize blocks) + for i in range(4): + rename_keys.append((f"decode_head.reassemble_blocks.projects.{i}.conv.weight", f"neck.reassemble_stage.layers.{i}.projection.weight")) + rename_keys.append((f"decode_head.reassemble_blocks.projects.{i}.conv.bias", f"neck.reassemble_stage.layers.{i}.projection.bias")) + + rename_keys.append((f"decode_head.reassemble_blocks.readout_projects.{i}.0.weight", f"neck.reassemble_stage.readout_projects.{i}.0.weight")) + rename_keys.append((f"decode_head.reassemble_blocks.readout_projects.{i}.0.bias", f"neck.reassemble_stage.readout_projects.{i}.0.bias")) + + if i != 2: + rename_keys.append((f"decode_head.reassemble_blocks.resize_layers.{i}.weight", f"neck.reassemble_stage.layers.{i}.resize.weight")) + rename_keys.append((f"decode_head.reassemble_blocks.resize_layers.{i}.bias", f"neck.reassemble_stage.layers.{i}.resize.bias")) + + # fusion layers + for i in range(4): + rename_keys.append((f"decode_head.fusion_blocks.{i}.project.conv.weight", f"neck.fusion_stage.layers.{i}.projection.weight")) + rename_keys.append((f"decode_head.fusion_blocks.{i}.project.conv.bias", f"neck.fusion_stage.layers.{i}.projection.bias")) + if i != 0: + rename_keys.append((f"decode_head.fusion_blocks.{i}.res_conv_unit1.conv1.conv.weight", f"neck.fusion_stage.layers.{i}.residual_layer1.convolution1.weight")) + rename_keys.append((f"decode_head.fusion_blocks.{i}.res_conv_unit1.conv2.conv.weight", f"neck.fusion_stage.layers.{i}.residual_layer1.convolution2.weight")) + rename_keys.append((f"decode_head.fusion_blocks.{i}.res_conv_unit2.conv1.conv.weight", f"neck.fusion_stage.layers.{i}.residual_layer2.convolution1.weight")) + rename_keys.append((f"decode_head.fusion_blocks.{i}.res_conv_unit2.conv2.conv.weight", f"neck.fusion_stage.layers.{i}.residual_layer2.convolution2.weight")) + + # neck convolutions + for i in range(4): + rename_keys.append((f"decode_head.convs.{i}.conv.weight", f"neck.convs.{i}.weight")) + + # head + rename_keys.append(("decode_head.project.conv.weight", "head.projection.weight")) + rename_keys.append(("decode_head.project.conv.bias", "head.projection.bias")) + + for i in range(0, 5, 2): + rename_keys.append((f"decode_head.conv_depth.head.{i}.weight", f"head.head.{i}.weight")) + rename_keys.append((f"decode_head.conv_depth.head.{i}.bias", f"head.head.{i}.bias")) + # fmt: on + + return rename_keys + + +# here we list all backbone keys to be renamed (original name on the left, our name on the right) +def create_rename_keys_backbone(config): + rename_keys = [] + + # fmt: off + # patch embedding layer + rename_keys.append(("cls_token", "backbone.embeddings.cls_token")) + rename_keys.append(("mask_token", "backbone.embeddings.mask_token")) + rename_keys.append(("pos_embed", "backbone.embeddings.position_embeddings")) + rename_keys.append(("patch_embed.proj.weight", "backbone.embeddings.patch_embeddings.projection.weight")) + rename_keys.append(("patch_embed.proj.bias", "backbone.embeddings.patch_embeddings.projection.bias")) + + # Transfomer encoder + for i in range(config.backbone_config.num_hidden_layers): + # layernorms + rename_keys.append((f"blocks.{i}.norm1.weight", f"backbone.encoder.layer.{i}.norm1.weight")) + rename_keys.append((f"blocks.{i}.norm1.bias", f"backbone.encoder.layer.{i}.norm1.bias")) + rename_keys.append((f"blocks.{i}.norm2.weight", f"backbone.encoder.layer.{i}.norm2.weight")) + rename_keys.append((f"blocks.{i}.norm2.bias", f"backbone.encoder.layer.{i}.norm2.bias")) + # MLP + if config.backbone_config.use_swiglu_ffn: + rename_keys.append((f"blocks.{i}.mlp.w12.weight", f"backbone.encoder.layer.{i}.mlp.w12.weight")) + rename_keys.append((f"blocks.{i}.mlp.w12.bias", f"backbone.encoder.layer.{i}.mlp.w12.bias")) + rename_keys.append((f"blocks.{i}.mlp.w3.weight", f"backbone.encoder.layer.{i}.mlp.w3.weight")) + rename_keys.append((f"blocks.{i}.mlp.w3.bias", f"backbone.encoder.layer.{i}.mlp.w3.bias")) + else: + rename_keys.append((f"blocks.{i}.mlp.fc1.weight", f"backbone.encoder.layer.{i}.mlp.fc1.weight")) + rename_keys.append((f"blocks.{i}.mlp.fc1.bias", f"backbone.encoder.layer.{i}.mlp.fc1.bias")) + rename_keys.append((f"blocks.{i}.mlp.fc2.weight", f"backbone.encoder.layer.{i}.mlp.fc2.weight")) + rename_keys.append((f"blocks.{i}.mlp.fc2.bias", f"backbone.encoder.layer.{i}.mlp.fc2.bias")) + # layerscale + rename_keys.append((f"blocks.{i}.ls1.gamma", f"backbone.encoder.layer.{i}.layer_scale1.lambda1")) + rename_keys.append((f"blocks.{i}.ls2.gamma", f"backbone.encoder.layer.{i}.layer_scale2.lambda1")) + # attention projection layer + rename_keys.append((f"blocks.{i}.attn.proj.weight", f"backbone.encoder.layer.{i}.attention.output.dense.weight")) + rename_keys.append((f"blocks.{i}.attn.proj.bias", f"backbone.encoder.layer.{i}.attention.output.dense.bias")) + # fmt: on + + rename_keys.append(("norm.weight", "backbone.layernorm.weight")) + rename_keys.append(("norm.bias", "backbone.layernorm.bias")) + + return rename_keys + + +# we split up the matrix of each encoder layer into queries, keys and values +def read_in_q_k_v(state_dict, config): + for i in range(config.backbone_config.num_hidden_layers): + # read in weights + bias of input projection layer (in timm, this is a single matrix + bias) + in_proj_weight = state_dict.pop(f"blocks.{i}.attn.qkv.weight") + in_proj_bias = state_dict.pop(f"blocks.{i}.attn.qkv.bias") + hidden_size = config.backbone_config.hidden_size + # next, add query, keys and values (in that order) to the state dict + state_dict[f"backbone.encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[:hidden_size, :] + state_dict[f"backbone.encoder.layer.{i}.attention.attention.query.bias"] = in_proj_bias[:hidden_size] + state_dict[f"backbone.encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[ + hidden_size : hidden_size * 2, : + ] + state_dict[f"backbone.encoder.layer.{i}.attention.attention.key.bias"] = in_proj_bias[ + hidden_size : hidden_size * 2 + ] + state_dict[f"backbone.encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[-hidden_size:, :] + state_dict[f"backbone.encoder.layer.{i}.attention.attention.value.bias"] = in_proj_bias[-hidden_size:] + + +def rename_key(dct, old, new): + val = dct.pop(old) + dct[new] = val + + +# We will verify our results on an image of cute cats +def prepare_img(): + url = "https://dl.fbaipublicfiles.com/dinov2/images/example.jpg" + im = Image.open(requests.get(url, stream=True).raw) + return im + + +name_to_url = { + "dpt-dinov2-small-nyu": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_nyu_dpt_head.pth", + "dpt-dinov2-small-kitti": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_kitti_dpt_head.pth", + "dpt-dinov2-base-nyu": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_nyu_dpt_head.pth", + "dpt-dinov2-base-kitti": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_kitti_dpt_head.pth", + "dpt-dinov2-large-nyu": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_nyu_dpt_head.pth", + "dpt-dinov2-large-kitti": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_kitti_dpt_head.pth", + "dpt-dinov2-giant-nyu": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_nyu_dpt_head.pth", + "dpt-dinov2-giant-kitti": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_kitti_dpt_head.pth", +} + + +def get_original_pixel_values(image): + class CenterPadding: + def __init__(self, multiple): + super().__init__() + self.multiple = multiple + + def _get_pad(self, size): + new_size = math.ceil(size / self.multiple) * self.multiple + pad_size = new_size - size + pad_size_left = pad_size // 2 + pad_size_right = pad_size - pad_size_left + return pad_size_left, pad_size_right + + def __call__(self, img): + pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in img.shape[-2:][::-1])) + output = torch.nn.functional.pad(img, pads) + return output + + def __repr__(self): + return self.__class__.__name__ + "()" + + def make_depth_transform() -> transforms.Compose: + return transforms.Compose( + [ + transforms.ToTensor(), + lambda x: 255.0 * x[:3], # Discard alpha component and scale by 255 + transforms.Normalize( + mean=(123.675, 116.28, 103.53), + std=(58.395, 57.12, 57.375), + ), + CenterPadding(multiple=14), + ] + ) + + transform = make_depth_transform() + original_pixel_values = transform(image).unsqueeze(0) + + return original_pixel_values + + +@torch.no_grad() +def convert_dpt_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub, verify_logits): + """ + Copy/paste/tweak model's weights to our DPT structure. + """ + + # define DPT configuration based on URL + checkpoint_url = name_to_url[model_name] + config = get_dpt_config(model_name) + + # load original DPT state_dict from URL + print("URL:", checkpoint_url) + dpt_state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu")["state_dict"] + # rename keys + rename_keys = create_rename_keys_dpt(config) + for src, dest in rename_keys: + rename_key(dpt_state_dict, src, dest) + + # load original backbone state_dict from URL + if "small" in model_name: + original_model = torch.hub.load("facebookresearch/dinov2", "dinov2_vits14") + elif "base" in model_name: + original_model = torch.hub.load("facebookresearch/dinov2", "dinov2_vitb14") + elif "large" in model_name: + original_model = torch.hub.load("facebookresearch/dinov2", "dinov2_vitl14") + elif "giant" in model_name: + original_model = torch.hub.load("facebookresearch/dinov2", "dinov2_vitg14") + else: + raise NotImplementedError("To do") + original_model.eval() + backbone_state_dict = original_model.state_dict() + + # rename keys + rename_keys = create_rename_keys_backbone(config) + for src, dest in rename_keys: + rename_key(backbone_state_dict, src, dest) + + # read in qkv matrices + read_in_q_k_v(backbone_state_dict, config) + + for key, val in backbone_state_dict.copy().items(): + val = backbone_state_dict.pop(key) + if "w12" in key: + key = key.replace("w12", "weights_in") + if "w3" in key: + key = key.replace("w3", "weights_out") + backbone_state_dict[key] = val + + # merge state_dicts + state_dict = {**backbone_state_dict, **dpt_state_dict} + + # load HuggingFace model + model = DPTForDepthEstimation(config) + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) + print("Missing keys:", missing_keys) + print("Unexpected keys:", unexpected_keys) + assert missing_keys == [ + "neck.fusion_stage.layers.0.residual_layer1.convolution1.weight", + "neck.fusion_stage.layers.0.residual_layer1.convolution2.weight", + ] + model.eval() + + # Verify image processor + processor = DPTImageProcessor( + do_resize=False, + do_rescale=False, + do_pad=True, + size_divisor=14, + do_normalize=True, + image_mean=(123.675, 116.28, 103.53), + image_std=(58.395, 57.12, 57.375), + ) + + image = prepare_img() + pixel_values = processor(image, return_tensors="pt").pixel_values.float() + original_pixel_values = get_original_pixel_values(image) + + assert torch.allclose(pixel_values, original_pixel_values) + + # Verify forward pass + with torch.no_grad(): + outputs = model(pixel_values) + + predicted_depth = outputs.predicted_depth + + print("Shape of predicted depth:", predicted_depth.shape) + print("First values of predicted depth:", predicted_depth[0, :3, :3]) + + # assert logits + if verify_logits: + if model_name == "dpt-dinov2-small-nyu": + expected_shape = torch.Size([1, 576, 736]) + expected_slice = torch.tensor( + [[3.3576, 3.4741, 3.4345], [3.4324, 3.5012, 3.2775], [3.2560, 3.3563, 3.2354]] + ) + + assert predicted_depth.shape == torch.Size(expected_shape) + assert torch.allclose(predicted_depth[0, :3, :3], expected_slice, atol=1e-5) + print("Looks ok!") + + if pytorch_dump_folder_path is not None: + Path(pytorch_dump_folder_path).mkdir(exist_ok=True) + print(f"Saving model and processor to {pytorch_dump_folder_path}") + model.save_pretrained(pytorch_dump_folder_path) + processor.save_pretrained(pytorch_dump_folder_path) + + if push_to_hub: + print("Pushing model and processor to hub...") + model.push_to_hub(repo_id=f"facebook/{model_name}") + processor.push_to_hub(repo_id=f"facebook/{model_name}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--model_name", + default="dpt-dinov2-small-nyu", + type=str, + choices=name_to_url.keys(), + help="Name of the model you'd like to convert.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", + default=None, + type=str, + help="Path to the output PyTorch model directory.", + ) + parser.add_argument( + "--push_to_hub", + action="store_true", + help="Whether to push the model to the hub after conversion.", + ) + parser.add_argument( + "--verify_logits", + action="store_true", + required=False, + help="Path to the output PyTorch model directory.", + ) + + args = parser.parse_args() + convert_dpt_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub, args.verify_logits) diff --git a/docs/transformers/build/lib/transformers/models/dpt/convert_dpt_beit_to_hf.py b/docs/transformers/build/lib/transformers/models/dpt/convert_dpt_beit_to_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..3a576d772f577b0f690fb3300c3dd75203f700a6 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/dpt/convert_dpt_beit_to_hf.py @@ -0,0 +1,305 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert DPT 3.1 checkpoints from the MiDaS repository. URL: https://github.com/isl-org/MiDaS""" + +import argparse +from pathlib import Path + +import requests +import torch +from PIL import Image + +from transformers import BeitConfig, DPTConfig, DPTForDepthEstimation, DPTImageProcessor +from transformers.utils import logging + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + +def get_dpt_config(model_name): + hidden_size = 768 + num_hidden_layers = 12 + num_attention_heads = 12 + intermediate_size = 3072 + out_features = ["stage3", "stage6", "stage9", "stage12"] # beit-base-384 uses [2, 5, 8, 11] + + if "large" in model_name: + hidden_size = 1024 + num_hidden_layers = 24 + num_attention_heads = 16 + intermediate_size = 4096 + out_features = ["stage6", "stage12", "stage18", "stage24"] # beit-large-512 uses [5, 11, 17, 23] + + if "512" in model_name: + image_size = 512 + elif "384" in model_name: + image_size = 384 + else: + raise ValueError("Model not supported") + + backbone_config = BeitConfig( + image_size=image_size, + num_hidden_layers=num_hidden_layers, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_attention_heads=num_attention_heads, + use_relative_position_bias=True, + reshape_hidden_states=False, + out_features=out_features, + ) + + neck_hidden_sizes = [256, 512, 1024, 1024] if "large" in model_name else [96, 192, 384, 768] + config = DPTConfig(backbone_config=backbone_config, neck_hidden_sizes=neck_hidden_sizes) + + return config, image_size + + +# here we list all keys to be renamed (original name on the left, our name on the right) +def create_rename_keys(config): + rename_keys = [] + + # fmt: off + # stem + rename_keys.append(("pretrained.model.cls_token", "backbone.embeddings.cls_token")) + rename_keys.append(("pretrained.model.patch_embed.proj.weight", "backbone.embeddings.patch_embeddings.projection.weight")) + rename_keys.append(("pretrained.model.patch_embed.proj.bias", "backbone.embeddings.patch_embeddings.projection.bias")) + + # Transfomer encoder + for i in range(config.backbone_config.num_hidden_layers): + rename_keys.append((f"pretrained.model.blocks.{i}.gamma_1", f"backbone.encoder.layer.{i}.lambda_1")) + rename_keys.append((f"pretrained.model.blocks.{i}.gamma_2", f"backbone.encoder.layer.{i}.lambda_2")) + rename_keys.append((f"pretrained.model.blocks.{i}.norm1.weight", f"backbone.encoder.layer.{i}.layernorm_before.weight")) + rename_keys.append((f"pretrained.model.blocks.{i}.norm1.bias", f"backbone.encoder.layer.{i}.layernorm_before.bias")) + rename_keys.append((f"pretrained.model.blocks.{i}.norm2.weight", f"backbone.encoder.layer.{i}.layernorm_after.weight")) + rename_keys.append((f"pretrained.model.blocks.{i}.norm2.bias", f"backbone.encoder.layer.{i}.layernorm_after.bias")) + rename_keys.append((f"pretrained.model.blocks.{i}.mlp.fc1.weight", f"backbone.encoder.layer.{i}.intermediate.dense.weight")) + rename_keys.append((f"pretrained.model.blocks.{i}.mlp.fc1.bias", f"backbone.encoder.layer.{i}.intermediate.dense.bias")) + rename_keys.append((f"pretrained.model.blocks.{i}.mlp.fc2.weight", f"backbone.encoder.layer.{i}.output.dense.weight")) + rename_keys.append((f"pretrained.model.blocks.{i}.mlp.fc2.bias", f"backbone.encoder.layer.{i}.output.dense.bias")) + rename_keys.append((f"pretrained.model.blocks.{i}.attn.proj.weight", f"backbone.encoder.layer.{i}.attention.output.dense.weight")) + rename_keys.append((f"pretrained.model.blocks.{i}.attn.proj.bias", f"backbone.encoder.layer.{i}.attention.output.dense.bias")) + rename_keys.append((f"pretrained.model.blocks.{i}.attn.relative_position_bias_table", f"backbone.encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_bias_table")) + rename_keys.append((f"pretrained.model.blocks.{i}.attn.relative_position_index", f"backbone.encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_index")) + + # activation postprocessing (readout projections + resize blocks) + for i in range(4): + rename_keys.append((f"pretrained.act_postprocess{i+1}.0.project.0.weight", f"neck.reassemble_stage.readout_projects.{i}.0.weight")) + rename_keys.append((f"pretrained.act_postprocess{i+1}.0.project.0.bias", f"neck.reassemble_stage.readout_projects.{i}.0.bias")) + + rename_keys.append((f"pretrained.act_postprocess{i+1}.3.weight", f"neck.reassemble_stage.layers.{i}.projection.weight")) + rename_keys.append((f"pretrained.act_postprocess{i+1}.3.bias", f"neck.reassemble_stage.layers.{i}.projection.bias")) + + if i != 2: + rename_keys.append((f"pretrained.act_postprocess{i+1}.4.weight", f"neck.reassemble_stage.layers.{i}.resize.weight")) + rename_keys.append((f"pretrained.act_postprocess{i+1}.4.bias", f"neck.reassemble_stage.layers.{i}.resize.bias")) + + # refinenet (tricky here) + mapping = {1:3, 2:2, 3:1, 4:0} + + for i in range(1, 5): + j = mapping[i] + rename_keys.append((f"scratch.refinenet{i}.out_conv.weight", f"neck.fusion_stage.layers.{j}.projection.weight")) + rename_keys.append((f"scratch.refinenet{i}.out_conv.bias", f"neck.fusion_stage.layers.{j}.projection.bias")) + rename_keys.append((f"scratch.refinenet{i}.resConfUnit1.conv1.weight", f"neck.fusion_stage.layers.{j}.residual_layer1.convolution1.weight")) + rename_keys.append((f"scratch.refinenet{i}.resConfUnit1.conv1.bias", f"neck.fusion_stage.layers.{j}.residual_layer1.convolution1.bias")) + rename_keys.append((f"scratch.refinenet{i}.resConfUnit1.conv2.weight", f"neck.fusion_stage.layers.{j}.residual_layer1.convolution2.weight")) + rename_keys.append((f"scratch.refinenet{i}.resConfUnit1.conv2.bias", f"neck.fusion_stage.layers.{j}.residual_layer1.convolution2.bias")) + rename_keys.append((f"scratch.refinenet{i}.resConfUnit2.conv1.weight", f"neck.fusion_stage.layers.{j}.residual_layer2.convolution1.weight")) + rename_keys.append((f"scratch.refinenet{i}.resConfUnit2.conv1.bias", f"neck.fusion_stage.layers.{j}.residual_layer2.convolution1.bias")) + rename_keys.append((f"scratch.refinenet{i}.resConfUnit2.conv2.weight", f"neck.fusion_stage.layers.{j}.residual_layer2.convolution2.weight")) + rename_keys.append((f"scratch.refinenet{i}.resConfUnit2.conv2.bias", f"neck.fusion_stage.layers.{j}.residual_layer2.convolution2.bias")) + + # scratch convolutions + for i in range(4): + rename_keys.append((f"scratch.layer{i+1}_rn.weight", f"neck.convs.{i}.weight")) + + # head + for i in range(0, 5, 2): + rename_keys.append((f"scratch.output_conv.{i}.weight", f"head.head.{i}.weight")) + rename_keys.append((f"scratch.output_conv.{i}.bias", f"head.head.{i}.bias")) + + return rename_keys + + +def remove_ignore_keys_(state_dict): + ignore_keys = ["pretrained.model.head.weight", "pretrained.model.head.bias"] + for k in ignore_keys: + state_dict.pop(k, None) + + +# we split up the matrix of each encoder layer into queries, keys and values +def read_in_q_k_v(state_dict, config): + hidden_size = config.backbone_config.hidden_size + for i in range(config.backbone_config.num_hidden_layers): + # read in weights + bias of input projection layer (in original implementation, this is a single matrix + bias) + in_proj_weight = state_dict.pop(f"pretrained.model.blocks.{i}.attn.qkv.weight") + q_bias = state_dict.pop(f"pretrained.model.blocks.{i}.attn.q_bias") + v_bias = state_dict.pop(f"pretrained.model.blocks.{i}.attn.v_bias") + # next, add query, keys and values (in that order) to the state dict + state_dict[f"backbone.encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[:hidden_size, :] + state_dict[f"backbone.encoder.layer.{i}.attention.attention.query.bias"] = q_bias + state_dict[f"backbone.encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[ + hidden_size : hidden_size * 2, : + ] + state_dict[f"backbone.encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[-hidden_size:, :] + state_dict[f"backbone.encoder.layer.{i}.attention.attention.value.bias"] = v_bias + + +def rename_key(dct, old, new): + val = dct.pop(old) + dct[new] = val + + +# We will verify our results on an image of cute cats +def prepare_img(): + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + im = Image.open(requests.get(url, stream=True).raw) + return im + + +@torch.no_grad() +def convert_dpt_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub): + """ + Copy/paste/tweak model's weights to our DPT structure. + """ + + name_to_url = { + "dpt-beit-large-512": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_512.pt", + "dpt-beit-large-384": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_384.pt", + "dpt-beit-base-384": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_base_384.pt", + } + + # define DPT configuration based on URL + checkpoint_url = name_to_url[model_name] + config, image_size = get_dpt_config(model_name) + # load original state_dict from URL + state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu") + # remove certain keys + remove_ignore_keys_(state_dict) + # rename keys + rename_keys = create_rename_keys(config) + for src, dest in rename_keys: + rename_key(state_dict, src, dest) + # read in qkv matrices + read_in_q_k_v(state_dict, config) + + # load HuggingFace model + model = DPTForDepthEstimation(config) + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) + print("Missing keys:", missing_keys) + print("Unexpected keys:", unexpected_keys) + assert missing_keys == [] + # assert unexpected_keys == ["pretrained.model.fc_norm.weight", "pretrained.model.fc_norm.bias"] + model.eval() + + # Check outputs on an image + # We set `keep_aspect_ratio=False` as our current BEiT does not support arbitrary window sizes + processor = DPTImageProcessor( + size={"height": image_size, "width": image_size}, keep_aspect_ratio=False, ensure_multiple_of=32 + ) + + image = prepare_img() + pixel_values = processor(image, return_tensors="pt").pixel_values + + print("First values of pixel values:", pixel_values[0, 0, :3, :3]) + print("Mean of pixel values:", pixel_values.mean().item()) + print("Shape of pixel values:", pixel_values.shape) + + import requests + from PIL import Image + from torchvision import transforms + + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + image = Image.open(requests.get(url, stream=True).raw) + + transforms = transforms.Compose( + [ + transforms.Resize((image_size, image_size)), + transforms.ToTensor(), + ] + ) + pixel_values = transforms(image).unsqueeze(0) + + # forward pass + with torch.no_grad(): + outputs = model(pixel_values) + + predicted_depth = outputs.predicted_depth + + print("Shape of predicted depth:", predicted_depth.shape) + print("First values of predicted depth:", predicted_depth[0, :3, :3]) + + # assert logits + # TODO there's still a small difference with the original logits + if model_name == "dpt-beit-large-512": + # OK, checked + expected_shape = torch.Size([1, 512, 512]) + expected_slice = torch.tensor( + [[2804.6260, 2792.5708, 2812.9263], [2772.0288, 2780.1118, 2796.2529], [2748.1094, 2766.6558, 2766.9834]] + ) + elif model_name == "dpt-beit-large-384": + # OK, checked + expected_shape = torch.Size([1, 384, 384]) + expected_slice = torch.tensor( + [[1783.2273, 1780.5729, 1792.6453], [1759.9817, 1765.5359, 1778.5002], [1739.1633, 1754.7903, 1757.1990]], + ) + elif model_name == "dpt-beit-base-384": + # OK, checked + expected_shape = torch.Size([1, 384, 384]) + expected_slice = torch.tensor( + [[2898.4482, 2891.3750, 2904.8079], [2858.6685, 2877.2615, 2894.4507], [2842.1235, 2854.1023, 2861.6328]], + ) + + assert predicted_depth.shape == torch.Size(expected_shape) + assert torch.allclose(predicted_depth[0, :3, :3], expected_slice) + print("Looks ok!") + + if pytorch_dump_folder_path is not None: + Path(pytorch_dump_folder_path).mkdir(exist_ok=True) + print(f"Saving model and processor to {pytorch_dump_folder_path}") + model.save_pretrained(pytorch_dump_folder_path) + processor.save_pretrained(pytorch_dump_folder_path) + + if push_to_hub: + print("Pushing model and processor to hub...") + model.push_to_hub(repo_id=f"nielsr/{model_name}") + processor.push_to_hub(repo_id=f"nielsr/{model_name}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--model_name", + default="dpt-beit-large-512", + type=str, + choices=["dpt-beit-large-512", "dpt-beit-large-384", "dpt-beit-base-384"], + help="Name of the model you'd like to convert.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", + default=None, + type=str, + help="Path to the output PyTorch model directory.", + ) + parser.add_argument( + "--push_to_hub", + action="store_true", + help="Whether to push the model to the hub after conversion.", + ) + + args = parser.parse_args() + convert_dpt_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub) diff --git a/docs/transformers/build/lib/transformers/models/dpt/convert_dpt_hybrid_to_pytorch.py b/docs/transformers/build/lib/transformers/models/dpt/convert_dpt_hybrid_to_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..ceae9b84711463261f20e70b7cd174437a1666cd --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/dpt/convert_dpt_hybrid_to_pytorch.py @@ -0,0 +1,315 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert DPT checkpoints from the original repository. URL: https://github.com/isl-org/DPT""" + +import argparse +import json +from pathlib import Path + +import requests +import torch +from huggingface_hub import hf_hub_download +from PIL import Image + +from transformers import DPTConfig, DPTForDepthEstimation, DPTForSemanticSegmentation, DPTImageProcessor +from transformers.utils import logging + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + +def get_dpt_config(checkpoint_url): + config = DPTConfig(embedding_type="hybrid") + + if "large" in checkpoint_url: + config.hidden_size = 1024 + config.intermediate_size = 4096 + config.num_hidden_layers = 24 + config.num_attention_heads = 16 + config.backbone_out_indices = [5, 11, 17, 23] + config.neck_hidden_sizes = [256, 512, 1024, 1024] + expected_shape = (1, 384, 384) + + if "nyu" in checkpoint_url or "midas" in checkpoint_url: + config.hidden_size = 768 + config.reassemble_factors = [1, 1, 1, 0.5] + config.neck_hidden_sizes = [256, 512, 768, 768] + config.num_labels = 150 + config.patch_size = 16 + expected_shape = (1, 384, 384) + config.use_batch_norm_in_fusion_residual = False + config.readout_type = "project" + + if "ade" in checkpoint_url: + config.use_batch_norm_in_fusion_residual = True + config.hidden_size = 768 + config.reassemble_stage = [1, 1, 1, 0.5] + config.num_labels = 150 + config.patch_size = 16 + repo_id = "huggingface/label-files" + filename = "ade20k-id2label.json" + id2label = json.loads(Path(hf_hub_download(repo_id, filename, repo_type="dataset")).read_text()) + id2label = {int(k): v for k, v in id2label.items()} + config.id2label = id2label + config.label2id = {v: k for k, v in id2label.items()} + expected_shape = [1, 150, 480, 480] + + return config, expected_shape + + +def remove_ignore_keys_(state_dict): + ignore_keys = ["pretrained.model.head.weight", "pretrained.model.head.bias"] + for k in ignore_keys: + state_dict.pop(k, None) + + +def rename_key(name): + if ( + "pretrained.model" in name + and "cls_token" not in name + and "pos_embed" not in name + and "patch_embed" not in name + ): + name = name.replace("pretrained.model", "dpt.encoder") + if "pretrained.model" in name: + name = name.replace("pretrained.model", "dpt.embeddings") + if "patch_embed" in name: + name = name.replace("patch_embed", "") + if "pos_embed" in name: + name = name.replace("pos_embed", "position_embeddings") + if "attn.proj" in name: + name = name.replace("attn.proj", "attention.output.dense") + if "proj" in name and "project" not in name: + name = name.replace("proj", "projection") + if "blocks" in name: + name = name.replace("blocks", "layer") + if "mlp.fc1" in name: + name = name.replace("mlp.fc1", "intermediate.dense") + if "mlp.fc2" in name: + name = name.replace("mlp.fc2", "output.dense") + if "norm1" in name and "backbone" not in name: + name = name.replace("norm1", "layernorm_before") + if "norm2" in name and "backbone" not in name: + name = name.replace("norm2", "layernorm_after") + if "scratch.output_conv" in name: + name = name.replace("scratch.output_conv", "head") + if "scratch" in name: + name = name.replace("scratch", "neck") + if "layer1_rn" in name: + name = name.replace("layer1_rn", "convs.0") + if "layer2_rn" in name: + name = name.replace("layer2_rn", "convs.1") + if "layer3_rn" in name: + name = name.replace("layer3_rn", "convs.2") + if "layer4_rn" in name: + name = name.replace("layer4_rn", "convs.3") + if "refinenet" in name: + layer_idx = int(name[len("neck.refinenet") : len("neck.refinenet") + 1]) + # tricky here: we need to map 4 to 0, 3 to 1, 2 to 2 and 1 to 3 + name = name.replace(f"refinenet{layer_idx}", f"fusion_stage.layers.{abs(layer_idx - 4)}") + if "out_conv" in name: + name = name.replace("out_conv", "projection") + if "resConfUnit1" in name: + name = name.replace("resConfUnit1", "residual_layer1") + if "resConfUnit2" in name: + name = name.replace("resConfUnit2", "residual_layer2") + if "conv1" in name: + name = name.replace("conv1", "convolution1") + if "conv2" in name: + name = name.replace("conv2", "convolution2") + # readout blocks + if "pretrained.act_postprocess1.0.project.0" in name: + name = name.replace("pretrained.act_postprocess1.0.project.0", "neck.reassemble_stage.readout_projects.0.0") + if "pretrained.act_postprocess2.0.project.0" in name: + name = name.replace("pretrained.act_postprocess2.0.project.0", "neck.reassemble_stage.readout_projects.1.0") + if "pretrained.act_postprocess3.0.project.0" in name: + name = name.replace("pretrained.act_postprocess3.0.project.0", "neck.reassemble_stage.readout_projects.2.0") + if "pretrained.act_postprocess4.0.project.0" in name: + name = name.replace("pretrained.act_postprocess4.0.project.0", "neck.reassemble_stage.readout_projects.3.0") + + # resize blocks + if "pretrained.act_postprocess1.3" in name: + name = name.replace("pretrained.act_postprocess1.3", "neck.reassemble_stage.layers.0.projection") + if "pretrained.act_postprocess1.4" in name: + name = name.replace("pretrained.act_postprocess1.4", "neck.reassemble_stage.layers.0.resize") + if "pretrained.act_postprocess2.3" in name: + name = name.replace("pretrained.act_postprocess2.3", "neck.reassemble_stage.layers.1.projection") + if "pretrained.act_postprocess2.4" in name: + name = name.replace("pretrained.act_postprocess2.4", "neck.reassemble_stage.layers.1.resize") + if "pretrained.act_postprocess3.3" in name: + name = name.replace("pretrained.act_postprocess3.3", "neck.reassemble_stage.layers.2.projection") + if "pretrained.act_postprocess4.3" in name: + name = name.replace("pretrained.act_postprocess4.3", "neck.reassemble_stage.layers.3.projection") + if "pretrained.act_postprocess4.4" in name: + name = name.replace("pretrained.act_postprocess4.4", "neck.reassemble_stage.layers.3.resize") + if "pretrained" in name: + name = name.replace("pretrained", "dpt") + if "bn" in name: + name = name.replace("bn", "batch_norm") + if "head" in name: + name = name.replace("head", "head.head") + if "encoder.norm" in name: + name = name.replace("encoder.norm", "layernorm") + if "auxlayer" in name: + name = name.replace("auxlayer", "auxiliary_head.head") + if "backbone" in name: + name = name.replace("backbone", "backbone.bit.encoder") + + if ".." in name: + name = name.replace("..", ".") + + if "stem.conv" in name: + name = name.replace("stem.conv", "bit.embedder.convolution") + if "blocks" in name: + name = name.replace("blocks", "layers") + if "convolution" in name and "backbone" in name: + name = name.replace("convolution", "conv") + if "layer" in name and "backbone" in name: + name = name.replace("layer", "layers") + if "backbone.bit.encoder.bit" in name: + name = name.replace("backbone.bit.encoder.bit", "backbone.bit") + if "embedder.conv" in name: + name = name.replace("embedder.conv", "embedder.convolution") + if "backbone.bit.encoder.stem.norm" in name: + name = name.replace("backbone.bit.encoder.stem.norm", "backbone.bit.embedder.norm") + return name + + +# we split up the matrix of each encoder layer into queries, keys and values +def read_in_q_k_v(state_dict, config): + for i in range(config.num_hidden_layers): + # read in weights + bias of input projection layer (in timm, this is a single matrix + bias) + in_proj_weight = state_dict.pop(f"dpt.encoder.layer.{i}.attn.qkv.weight") + in_proj_bias = state_dict.pop(f"dpt.encoder.layer.{i}.attn.qkv.bias") + # next, add query, keys and values (in that order) to the state dict + state_dict[f"dpt.encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[: config.hidden_size, :] + state_dict[f"dpt.encoder.layer.{i}.attention.attention.query.bias"] = in_proj_bias[: config.hidden_size] + state_dict[f"dpt.encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[ + config.hidden_size : config.hidden_size * 2, : + ] + state_dict[f"dpt.encoder.layer.{i}.attention.attention.key.bias"] = in_proj_bias[ + config.hidden_size : config.hidden_size * 2 + ] + state_dict[f"dpt.encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[ + -config.hidden_size :, : + ] + state_dict[f"dpt.encoder.layer.{i}.attention.attention.value.bias"] = in_proj_bias[-config.hidden_size :] + + +# We will verify our results on an image of cute cats +def prepare_img(): + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + im = Image.open(requests.get(url, stream=True).raw) + return im + + +@torch.no_grad() +def convert_dpt_checkpoint(checkpoint_url, pytorch_dump_folder_path, push_to_hub, model_name, show_prediction): + """ + Copy/paste/tweak model's weights to our DPT structure. + """ + + # define DPT configuration based on URL + config, expected_shape = get_dpt_config(checkpoint_url) + # load original state_dict from URL + # state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu") + state_dict = torch.load(checkpoint_url, map_location="cpu", weights_only=True) + # remove certain keys + remove_ignore_keys_(state_dict) + # rename keys + for key in state_dict.copy().keys(): + val = state_dict.pop(key) + state_dict[rename_key(key)] = val + # read in qkv matrices + read_in_q_k_v(state_dict, config) + + # load HuggingFace model + model = DPTForSemanticSegmentation(config) if "ade" in checkpoint_url else DPTForDepthEstimation(config) + model.load_state_dict(state_dict) + model.eval() + + # Check outputs on an image + size = 480 if "ade" in checkpoint_url else 384 + image_processor = DPTImageProcessor(size=size) + + image = prepare_img() + encoding = image_processor(image, return_tensors="pt") + + # forward pass + outputs = model(**encoding).logits if "ade" in checkpoint_url else model(**encoding).predicted_depth + + if show_prediction: + prediction = ( + torch.nn.functional.interpolate( + outputs.unsqueeze(1), + size=(image.size[1], image.size[0]), + mode="bicubic", + align_corners=False, + ) + .squeeze() + .cpu() + .numpy() + ) + + Image.fromarray((prediction / prediction.max()) * 255).show() + + if pytorch_dump_folder_path is not None: + Path(pytorch_dump_folder_path).mkdir(exist_ok=True) + print(f"Saving model to {pytorch_dump_folder_path}") + model.save_pretrained(pytorch_dump_folder_path) + print(f"Saving image processor to {pytorch_dump_folder_path}") + image_processor.save_pretrained(pytorch_dump_folder_path) + + if push_to_hub: + model.push_to_hub("ybelkada/dpt-hybrid-midas") + image_processor.push_to_hub("ybelkada/dpt-hybrid-midas") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--checkpoint_url", + default="https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt", + type=str, + help="URL of the original DPT checkpoint you'd like to convert.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", + default=None, + type=str, + required=False, + help="Path to the output PyTorch model directory.", + ) + parser.add_argument( + "--push_to_hub", + action="store_true", + ) + parser.add_argument( + "--model_name", + default="dpt-large", + type=str, + help="Name of the model, in case you're pushing to the hub.", + ) + parser.add_argument( + "--show_prediction", + action="store_true", + ) + + args = parser.parse_args() + convert_dpt_checkpoint( + args.checkpoint_url, args.pytorch_dump_folder_path, args.push_to_hub, args.model_name, args.show_prediction + ) diff --git a/docs/transformers/build/lib/transformers/models/dpt/convert_dpt_swinv2_to_hf.py b/docs/transformers/build/lib/transformers/models/dpt/convert_dpt_swinv2_to_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..0feebe72d47419b1bce08a25f85fda81c3822210 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/dpt/convert_dpt_swinv2_to_hf.py @@ -0,0 +1,321 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert DPT 3.1 checkpoints from the MiDaS repository. URL: https://github.com/isl-org/MiDaS""" + +import argparse +from pathlib import Path + +import requests +import torch +from PIL import Image + +from transformers import DPTConfig, DPTForDepthEstimation, DPTImageProcessor, Swinv2Config +from transformers.utils import logging + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + +def get_dpt_config(model_name): + if "tiny" in model_name: + embed_dim = 96 + depths = (2, 2, 6, 2) + num_heads = (3, 6, 12, 24) + window_size = 16 + # note: for Swinv2-tiny authors used the window_size = 16 variant + # as seen here: https://github.com/isl-org/MiDaS/blob/bdc4ed64c095e026dc0a2f17cabb14d58263decb/midas/backbones/swin2.py#L26 + pretrained_window_sizes = (0, 0, 0, 0) + elif "base" in model_name: + embed_dim = 128 + depths = (2, 2, 18, 2) + num_heads = (4, 8, 16, 32) + window_size = 24 + pretrained_window_sizes = (12, 12, 12, 6) + elif "large" in model_name: + embed_dim = 192 + depths = (2, 2, 18, 2) + num_heads = (6, 12, 24, 48) + window_size = 24 + pretrained_window_sizes = (12, 12, 12, 6) + + if "384" in model_name: + image_size = 384 + elif "256" in model_name: + image_size = 256 + else: + raise ValueError("Model not supported, to do") + + backbone_config = Swinv2Config( + image_size=image_size, + embed_dim=embed_dim, + depths=depths, + window_size=window_size, + pretrained_window_sizes=pretrained_window_sizes, + num_heads=num_heads, + out_features=["stage1", "stage2", "stage3", "stage4"], + ) + + if model_name == "dpt-swinv2-tiny-256": + neck_hidden_sizes = [96, 192, 384, 768] + elif model_name == "dpt-swinv2-base-384": + neck_hidden_sizes = [128, 256, 512, 1024] + elif model_name == "dpt-swinv2-large-384": + neck_hidden_sizes = [192, 384, 768, 1536] + + config = DPTConfig(backbone_config=backbone_config, neck_hidden_sizes=neck_hidden_sizes) + + return config, image_size + + +# here we list all keys to be renamed (original name on the left, our name on the right) +def create_rename_keys(config): + rename_keys = [] + + # fmt: off + # stem + rename_keys.append(("pretrained.model.patch_embed.proj.weight", "backbone.embeddings.patch_embeddings.projection.weight")) + rename_keys.append(("pretrained.model.patch_embed.proj.bias", "backbone.embeddings.patch_embeddings.projection.bias")) + rename_keys.append(("pretrained.model.patch_embed.norm.weight", "backbone.embeddings.norm.weight")) + rename_keys.append(("pretrained.model.patch_embed.norm.bias", "backbone.embeddings.norm.bias")) + + # transformer encoder + for i in range(len(config.backbone_config.depths)): + for j in range(config.backbone_config.depths[i]): + rename_keys.append((f"pretrained.model.layers.{i}.blocks.{j}.attn.logit_scale", f"backbone.encoder.layers.{i}.blocks.{j}.attention.self.logit_scale")) + rename_keys.append((f"pretrained.model.layers.{i}.blocks.{j}.attn.cpb_mlp.0.weight", f"backbone.encoder.layers.{i}.blocks.{j}.attention.self.continuous_position_bias_mlp.0.weight")) + rename_keys.append((f"pretrained.model.layers.{i}.blocks.{j}.attn.cpb_mlp.0.bias", f"backbone.encoder.layers.{i}.blocks.{j}.attention.self.continuous_position_bias_mlp.0.bias")) + rename_keys.append((f"pretrained.model.layers.{i}.blocks.{j}.attn.cpb_mlp.2.weight", f"backbone.encoder.layers.{i}.blocks.{j}.attention.self.continuous_position_bias_mlp.2.weight")) + rename_keys.append((f"pretrained.model.layers.{i}.blocks.{j}.attn.q_bias", f"backbone.encoder.layers.{i}.blocks.{j}.attention.self.query.bias")) + rename_keys.append((f"pretrained.model.layers.{i}.blocks.{j}.attn.v_bias", f"backbone.encoder.layers.{i}.blocks.{j}.attention.self.value.bias")) + rename_keys.append((f"pretrained.model.layers.{i}.blocks.{j}.attn.proj.weight", f"backbone.encoder.layers.{i}.blocks.{j}.attention.output.dense.weight")) + rename_keys.append((f"pretrained.model.layers.{i}.blocks.{j}.attn.proj.bias", f"backbone.encoder.layers.{i}.blocks.{j}.attention.output.dense.bias")) + rename_keys.append((f"pretrained.model.layers.{i}.blocks.{j}.norm1.weight", f"backbone.encoder.layers.{i}.blocks.{j}.layernorm_before.weight")) + rename_keys.append((f"pretrained.model.layers.{i}.blocks.{j}.norm1.bias", f"backbone.encoder.layers.{i}.blocks.{j}.layernorm_before.bias")) + rename_keys.append((f"pretrained.model.layers.{i}.blocks.{j}.mlp.fc1.weight", f"backbone.encoder.layers.{i}.blocks.{j}.intermediate.dense.weight")) + rename_keys.append((f"pretrained.model.layers.{i}.blocks.{j}.mlp.fc1.bias", f"backbone.encoder.layers.{i}.blocks.{j}.intermediate.dense.bias")) + rename_keys.append((f"pretrained.model.layers.{i}.blocks.{j}.mlp.fc2.weight", f"backbone.encoder.layers.{i}.blocks.{j}.output.dense.weight")) + rename_keys.append((f"pretrained.model.layers.{i}.blocks.{j}.mlp.fc2.bias", f"backbone.encoder.layers.{i}.blocks.{j}.output.dense.bias")) + rename_keys.append((f"pretrained.model.layers.{i}.blocks.{j}.norm2.weight", f"backbone.encoder.layers.{i}.blocks.{j}.layernorm_after.weight")) + rename_keys.append((f"pretrained.model.layers.{i}.blocks.{j}.norm2.bias", f"backbone.encoder.layers.{i}.blocks.{j}.layernorm_after.bias")) + + # downsample parameters + if i in [0,1,2]: + rename_keys.append((f"pretrained.model.layers.{i}.downsample.reduction.weight", f"backbone.encoder.layers.{i}.downsample.reduction.weight")) + rename_keys.append((f"pretrained.model.layers.{i}.downsample.norm.weight", f"backbone.encoder.layers.{i}.downsample.norm.weight")) + rename_keys.append((f"pretrained.model.layers.{i}.downsample.norm.bias", f"backbone.encoder.layers.{i}.downsample.norm.bias")) + + # note: non-Transformer backbones like Swinv2, LeViT et al don't require activation postprocessing (readout projections + resize blocks) + + # refinenet (tricky here) + mapping = {1:3, 2:2, 3:1, 4:0} + + for i in range(1, 5): + j = mapping[i] + rename_keys.append((f"scratch.refinenet{i}.out_conv.weight", f"neck.fusion_stage.layers.{j}.projection.weight")) + rename_keys.append((f"scratch.refinenet{i}.out_conv.bias", f"neck.fusion_stage.layers.{j}.projection.bias")) + rename_keys.append((f"scratch.refinenet{i}.resConfUnit1.conv1.weight", f"neck.fusion_stage.layers.{j}.residual_layer1.convolution1.weight")) + rename_keys.append((f"scratch.refinenet{i}.resConfUnit1.conv1.bias", f"neck.fusion_stage.layers.{j}.residual_layer1.convolution1.bias")) + rename_keys.append((f"scratch.refinenet{i}.resConfUnit1.conv2.weight", f"neck.fusion_stage.layers.{j}.residual_layer1.convolution2.weight")) + rename_keys.append((f"scratch.refinenet{i}.resConfUnit1.conv2.bias", f"neck.fusion_stage.layers.{j}.residual_layer1.convolution2.bias")) + rename_keys.append((f"scratch.refinenet{i}.resConfUnit2.conv1.weight", f"neck.fusion_stage.layers.{j}.residual_layer2.convolution1.weight")) + rename_keys.append((f"scratch.refinenet{i}.resConfUnit2.conv1.bias", f"neck.fusion_stage.layers.{j}.residual_layer2.convolution1.bias")) + rename_keys.append((f"scratch.refinenet{i}.resConfUnit2.conv2.weight", f"neck.fusion_stage.layers.{j}.residual_layer2.convolution2.weight")) + rename_keys.append((f"scratch.refinenet{i}.resConfUnit2.conv2.bias", f"neck.fusion_stage.layers.{j}.residual_layer2.convolution2.bias")) + + # scratch convolutions + for i in range(4): + rename_keys.append((f"scratch.layer{i+1}_rn.weight", f"neck.convs.{i}.weight")) + + # head + for i in range(0, 5, 2): + rename_keys.append((f"scratch.output_conv.{i}.weight", f"head.head.{i}.weight")) + rename_keys.append((f"scratch.output_conv.{i}.bias", f"head.head.{i}.bias")) + + return rename_keys + + +def remove_ignore_keys_(state_dict): + ignore_keys = ["pretrained.model.head.weight", "pretrained.model.head.bias"] + for k in ignore_keys: + state_dict.pop(k, None) + + +# we split up the matrix of each encoder layer into queries, keys and values +def read_in_q_k_v(state_dict, config, model): + for i in range(len(config.backbone_config.depths)): + for j in range(config.backbone_config.depths[i]): + dim = model.backbone.encoder.layers[i].blocks[j].attention.self.all_head_size + # read in weights + bias of input projection layer (in original implementation, this is a single matrix + bias) + in_proj_weight = state_dict.pop(f"pretrained.model.layers.{i}.blocks.{j}.attn.qkv.weight") + # next, add query, keys and values (in that order) to the state dict + state_dict[f"backbone.encoder.layers.{i}.blocks.{j}.attention.self.query.weight"] = in_proj_weight[:dim, :] + state_dict[f"backbone.encoder.layers.{i}.blocks.{j}.attention.self.key.weight"] = in_proj_weight[ + dim : dim * 2, : + ] + state_dict[f"backbone.encoder.layers.{i}.blocks.{j}.attention.self.value.weight"] = in_proj_weight[ + -dim:, : + ] + + +def rename_key(dct, old, new): + val = dct.pop(old) + dct[new] = val + + +# We will verify our results on an image of cute cats +def prepare_img(): + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + im = Image.open(requests.get(url, stream=True).raw) + return im + + +@torch.no_grad() +def convert_dpt_checkpoint(model_name, pytorch_dump_folder_path, verify_logits, push_to_hub): + """ + Copy/paste/tweak model's weights to our DPT structure. + """ + + name_to_url = { + "dpt-swinv2-tiny-256": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_tiny_256.pt", + "dpt-swinv2-base-384": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_base_384.pt", + "dpt-swinv2-large-384": "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_large_384.pt", + } + + # define DPT configuration based on URL + checkpoint_url = name_to_url[model_name] + config, image_size = get_dpt_config(model_name) + # load original state_dict from URL + state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu") + + # load HuggingFace model + model = DPTForDepthEstimation(config) + + # remove certain keys + remove_ignore_keys_(state_dict) + # rename keys + rename_keys = create_rename_keys(config) + for src, dest in rename_keys: + rename_key(state_dict, src, dest) + # read in qkv matrices + read_in_q_k_v(state_dict, config, model) + + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) + print("Missing keys:", missing_keys) + print("Unexpected keys:", unexpected_keys) + model.eval() + + # Check outputs on an image + processor = DPTImageProcessor(size={"height": image_size, "width": image_size}) + + image = prepare_img() + processor(image, return_tensors="pt") + + if verify_logits: + from torchvision import transforms + + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + image = Image.open(requests.get(url, stream=True).raw) + + transforms = transforms.Compose( + [ + transforms.Resize((image_size, image_size)), + transforms.ToTensor(), + ] + ) + pixel_values = transforms(image).unsqueeze(0) + + # forward pass + with torch.no_grad(): + outputs = model(pixel_values) + + predicted_depth = outputs.predicted_depth + + print("Shape of predicted depth:", predicted_depth.shape) + print("First values of predicted depth:", predicted_depth[0, :3, :3]) + + # assert logits + if model_name == "dpt-swinv2-base-384": + # OK, checked + expected_shape = torch.Size([1, 384, 384]) + expected_slice = torch.tensor( + [ + [1998.5575, 1997.3887, 2009.2981], + [1952.8607, 1979.6488, 2001.0854], + [1953.7697, 1961.7711, 1968.8904], + ], + ) + elif model_name == "dpt-swinv2-tiny-256": + # OK, checked + expected_shape = torch.Size([1, 256, 256]) + expected_slice = torch.tensor( + [[978.9163, 976.5215, 978.5349], [974.1859, 971.7249, 975.8046], [971.3419, 970.3118, 971.6830]], + ) + elif model_name == "dpt-swinv2-large-384": + # OK, checked + expected_shape = torch.Size([1, 384, 384]) + expected_slice = torch.tensor( + [ + [1203.7206, 1200.1495, 1197.8234], + [1196.2484, 1183.5033, 1186.4640], + [1178.8131, 1182.3260, 1174.3975], + ], + ) + + assert predicted_depth.shape == torch.Size(expected_shape) + assert torch.allclose(predicted_depth[0, :3, :3], expected_slice) + print("Looks ok!") + + if pytorch_dump_folder_path is not None: + Path(pytorch_dump_folder_path).mkdir(exist_ok=True) + print(f"Saving model and processor to {pytorch_dump_folder_path}") + model.save_pretrained(pytorch_dump_folder_path) + processor.save_pretrained(pytorch_dump_folder_path) + + if push_to_hub: + print("Pushing model and processor to hub...") + model.push_to_hub(repo_id=f"Intel/{model_name}") + processor.push_to_hub(repo_id=f"Intel/{model_name}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--model_name", + default="dpt-swinv2-base-384", + type=str, + choices=["dpt-swinv2-tiny-256", "dpt-swinv2-base-384", "dpt-swinv2-large-384"], + help="Name of the model you'd like to convert.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", + default=None, + type=str, + help="Path to the output PyTorch model directory.", + ) + parser.add_argument( + "--verify_logits", + action="store_true", + help="Whether to verify logits after conversion.", + ) + parser.add_argument( + "--push_to_hub", + action="store_true", + help="Whether to push the model to the hub after conversion.", + ) + + args = parser.parse_args() + convert_dpt_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.verify_logits, args.push_to_hub) diff --git a/docs/transformers/build/lib/transformers/models/dpt/convert_dpt_to_pytorch.py b/docs/transformers/build/lib/transformers/models/dpt/convert_dpt_to_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..55e0a444e857b0d386e685845bb4842c9f8794d7 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/dpt/convert_dpt_to_pytorch.py @@ -0,0 +1,285 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert DPT checkpoints from the original repository. URL: https://github.com/isl-org/DPT""" + +import argparse +import json +from pathlib import Path + +import requests +import torch +from huggingface_hub import hf_hub_download +from PIL import Image + +from transformers import DPTConfig, DPTForDepthEstimation, DPTForSemanticSegmentation, DPTImageProcessor +from transformers.utils import logging + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + +def get_dpt_config(checkpoint_url): + config = DPTConfig() + + if "large" in checkpoint_url: + config.hidden_size = 1024 + config.intermediate_size = 4096 + config.num_hidden_layers = 24 + config.num_attention_heads = 16 + config.backbone_out_indices = [5, 11, 17, 23] + config.neck_hidden_sizes = [256, 512, 1024, 1024] + expected_shape = (1, 384, 384) + + if "ade" in checkpoint_url: + config.use_batch_norm_in_fusion_residual = True + + config.num_labels = 150 + repo_id = "huggingface/label-files" + filename = "ade20k-id2label.json" + id2label = json.loads(Path(hf_hub_download(repo_id, filename, repo_type="dataset")).read_text()) + id2label = {int(k): v for k, v in id2label.items()} + config.id2label = id2label + config.label2id = {v: k for k, v in id2label.items()} + expected_shape = [1, 150, 480, 480] + + return config, expected_shape + + +def remove_ignore_keys_(state_dict): + ignore_keys = ["pretrained.model.head.weight", "pretrained.model.head.bias"] + for k in ignore_keys: + state_dict.pop(k, None) + + +def rename_key(name): + if ( + "pretrained.model" in name + and "cls_token" not in name + and "pos_embed" not in name + and "patch_embed" not in name + ): + name = name.replace("pretrained.model", "dpt.encoder") + if "pretrained.model" in name: + name = name.replace("pretrained.model", "dpt.embeddings") + if "patch_embed" in name: + name = name.replace("patch_embed", "patch_embeddings") + if "pos_embed" in name: + name = name.replace("pos_embed", "position_embeddings") + if "attn.proj" in name: + name = name.replace("attn.proj", "attention.output.dense") + if "proj" in name and "project" not in name: + name = name.replace("proj", "projection") + if "blocks" in name: + name = name.replace("blocks", "layer") + if "mlp.fc1" in name: + name = name.replace("mlp.fc1", "intermediate.dense") + if "mlp.fc2" in name: + name = name.replace("mlp.fc2", "output.dense") + if "norm1" in name: + name = name.replace("norm1", "layernorm_before") + if "norm2" in name: + name = name.replace("norm2", "layernorm_after") + if "scratch.output_conv" in name: + name = name.replace("scratch.output_conv", "head") + if "scratch" in name: + name = name.replace("scratch", "neck") + if "layer1_rn" in name: + name = name.replace("layer1_rn", "convs.0") + if "layer2_rn" in name: + name = name.replace("layer2_rn", "convs.1") + if "layer3_rn" in name: + name = name.replace("layer3_rn", "convs.2") + if "layer4_rn" in name: + name = name.replace("layer4_rn", "convs.3") + if "refinenet" in name: + layer_idx = int(name[len("neck.refinenet") : len("neck.refinenet") + 1]) + # tricky here: we need to map 4 to 0, 3 to 1, 2 to 2 and 1 to 3 + name = name.replace(f"refinenet{layer_idx}", f"fusion_stage.layers.{abs(layer_idx - 4)}") + if "out_conv" in name: + name = name.replace("out_conv", "projection") + if "resConfUnit1" in name: + name = name.replace("resConfUnit1", "residual_layer1") + if "resConfUnit2" in name: + name = name.replace("resConfUnit2", "residual_layer2") + if "conv1" in name: + name = name.replace("conv1", "convolution1") + if "conv2" in name: + name = name.replace("conv2", "convolution2") + # readout blocks + if "pretrained.act_postprocess1.0.project.0" in name: + name = name.replace("pretrained.act_postprocess1.0.project.0", "neck.reassemble_stage.readout_projects.0.0") + if "pretrained.act_postprocess2.0.project.0" in name: + name = name.replace("pretrained.act_postprocess2.0.project.0", "neck.reassemble_stage.readout_projects.1.0") + if "pretrained.act_postprocess3.0.project.0" in name: + name = name.replace("pretrained.act_postprocess3.0.project.0", "neck.reassemble_stage.readout_projects.2.0") + if "pretrained.act_postprocess4.0.project.0" in name: + name = name.replace("pretrained.act_postprocess4.0.project.0", "neck.reassemble_stage.readout_projects.3.0") + # resize blocks + if "pretrained.act_postprocess1.3" in name: + name = name.replace("pretrained.act_postprocess1.3", "neck.reassemble_stage.layers.0.projection") + if "pretrained.act_postprocess1.4" in name: + name = name.replace("pretrained.act_postprocess1.4", "neck.reassemble_stage.layers.0.resize") + if "pretrained.act_postprocess2.3" in name: + name = name.replace("pretrained.act_postprocess2.3", "neck.reassemble_stage.layers.1.projection") + if "pretrained.act_postprocess2.4" in name: + name = name.replace("pretrained.act_postprocess2.4", "neck.reassemble_stage.layers.1.resize") + if "pretrained.act_postprocess3.3" in name: + name = name.replace("pretrained.act_postprocess3.3", "neck.reassemble_stage.layers.2.projection") + if "pretrained.act_postprocess4.3" in name: + name = name.replace("pretrained.act_postprocess4.3", "neck.reassemble_stage.layers.3.projection") + if "pretrained.act_postprocess4.4" in name: + name = name.replace("pretrained.act_postprocess4.4", "neck.reassemble_stage.layers.3.resize") + if "pretrained" in name: + name = name.replace("pretrained", "dpt") + if "bn" in name: + name = name.replace("bn", "batch_norm") + if "head" in name: + name = name.replace("head", "head.head") + if "encoder.norm" in name: + name = name.replace("encoder.norm", "layernorm") + if "auxlayer" in name: + name = name.replace("auxlayer", "auxiliary_head.head") + + return name + + +# we split up the matrix of each encoder layer into queries, keys and values +def read_in_q_k_v(state_dict, config): + for i in range(config.num_hidden_layers): + # read in weights + bias of input projection layer (in timm, this is a single matrix + bias) + in_proj_weight = state_dict.pop(f"dpt.encoder.layer.{i}.attn.qkv.weight") + in_proj_bias = state_dict.pop(f"dpt.encoder.layer.{i}.attn.qkv.bias") + # next, add query, keys and values (in that order) to the state dict + state_dict[f"dpt.encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[: config.hidden_size, :] + state_dict[f"dpt.encoder.layer.{i}.attention.attention.query.bias"] = in_proj_bias[: config.hidden_size] + state_dict[f"dpt.encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[ + config.hidden_size : config.hidden_size * 2, : + ] + state_dict[f"dpt.encoder.layer.{i}.attention.attention.key.bias"] = in_proj_bias[ + config.hidden_size : config.hidden_size * 2 + ] + state_dict[f"dpt.encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[ + -config.hidden_size :, : + ] + state_dict[f"dpt.encoder.layer.{i}.attention.attention.value.bias"] = in_proj_bias[-config.hidden_size :] + + +# We will verify our results on an image of cute cats +def prepare_img(): + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + im = Image.open(requests.get(url, stream=True).raw) + return im + + +@torch.no_grad() +def convert_dpt_checkpoint(checkpoint_url, pytorch_dump_folder_path, push_to_hub, model_name): + """ + Copy/paste/tweak model's weights to our DPT structure. + """ + + # define DPT configuration based on URL + config, expected_shape = get_dpt_config(checkpoint_url) + # load original state_dict from URL + state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu") + # remove certain keys + remove_ignore_keys_(state_dict) + # rename keys + for key in state_dict.copy().keys(): + val = state_dict.pop(key) + state_dict[rename_key(key)] = val + # read in qkv matrices + read_in_q_k_v(state_dict, config) + + # load HuggingFace model + model = DPTForSemanticSegmentation(config) if "ade" in checkpoint_url else DPTForDepthEstimation(config) + model.load_state_dict(state_dict) + model.eval() + + # Check outputs on an image + size = 480 if "ade" in checkpoint_url else 384 + image_processor = DPTImageProcessor(size=size) + + image = prepare_img() + encoding = image_processor(image, return_tensors="pt") + + # forward pass + outputs = model(**encoding).logits if "ade" in checkpoint_url else model(**encoding).predicted_depth + + # Assert logits + expected_slice = torch.tensor([[6.3199, 6.3629, 6.4148], [6.3850, 6.3615, 6.4166], [6.3519, 6.3176, 6.3575]]) + if "ade" in checkpoint_url: + expected_slice = torch.tensor([[4.0480, 4.2420, 4.4360], [4.3124, 4.5693, 4.8261], [4.5768, 4.8965, 5.2163]]) + assert outputs.shape == torch.Size(expected_shape) + assert ( + torch.allclose(outputs[0, 0, :3, :3], expected_slice, atol=1e-4) + if "ade" in checkpoint_url + else torch.allclose(outputs[0, :3, :3], expected_slice) + ) + print("Looks ok!") + + if pytorch_dump_folder_path is not None: + Path(pytorch_dump_folder_path).mkdir(exist_ok=True) + print(f"Saving model to {pytorch_dump_folder_path}") + model.save_pretrained(pytorch_dump_folder_path) + print(f"Saving image processor to {pytorch_dump_folder_path}") + image_processor.save_pretrained(pytorch_dump_folder_path) + + if push_to_hub: + print("Pushing model to hub...") + model.push_to_hub( + repo_path_or_name=Path(pytorch_dump_folder_path, model_name), + organization="nielsr", + commit_message="Add model", + use_temp_dir=True, + ) + image_processor.push_to_hub( + repo_path_or_name=Path(pytorch_dump_folder_path, model_name), + organization="nielsr", + commit_message="Add image processor", + use_temp_dir=True, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--checkpoint_url", + default="https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt", + type=str, + help="URL of the original DPT checkpoint you'd like to convert.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", + default=None, + type=str, + required=False, + help="Path to the output PyTorch model directory.", + ) + parser.add_argument( + "--push_to_hub", + action="store_true", + ) + parser.add_argument( + "--model_name", + default="dpt-large", + type=str, + required=False, + help="Name of the model, in case you're pushing to the hub.", + ) + + args = parser.parse_args() + convert_dpt_checkpoint(args.checkpoint_url, args.pytorch_dump_folder_path, args.push_to_hub, args.model_name) diff --git a/docs/transformers/build/lib/transformers/models/dpt/feature_extraction_dpt.py b/docs/transformers/build/lib/transformers/models/dpt/feature_extraction_dpt.py new file mode 100644 index 0000000000000000000000000000000000000000..b6ab8ccbed8d33b1e5b15d429b6cb057ff781f78 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/dpt/feature_extraction_dpt.py @@ -0,0 +1,38 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Feature extractor class for DPT.""" + +import warnings + +from ...utils import logging +from ...utils.import_utils import requires +from .image_processing_dpt import DPTImageProcessor + + +logger = logging.get_logger(__name__) + + +@requires(backends=("vision",)) +class DPTFeatureExtractor(DPTImageProcessor): + def __init__(self, *args, **kwargs) -> None: + warnings.warn( + "The class DPTFeatureExtractor is deprecated and will be removed in version 5 of Transformers. Please" + " use DPTImageProcessor instead.", + FutureWarning, + ) + super().__init__(*args, **kwargs) + + +__all__ = ["DPTFeatureExtractor"] diff --git a/docs/transformers/build/lib/transformers/models/dpt/image_processing_dpt.py b/docs/transformers/build/lib/transformers/models/dpt/image_processing_dpt.py new file mode 100644 index 0000000000000000000000000000000000000000..095cd1a48b4f08843c80b48dbff9768f98d17bb7 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/dpt/image_processing_dpt.py @@ -0,0 +1,680 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for DPT.""" + +import math +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union + +from ...utils.import_utils import requires + + +if TYPE_CHECKING: + from ...modeling_outputs import DepthEstimatorOutput + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import pad, resize, to_channel_dimension_format +from ...image_utils import ( + IMAGENET_STANDARD_MEAN, + IMAGENET_STANDARD_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + get_image_size, + infer_channel_dimension_format, + is_scaled_image, + is_torch_available, + is_torch_tensor, + make_list_of_images, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...utils import ( + TensorType, + filter_out_non_signature_kwargs, + is_vision_available, + logging, + requires_backends, +) + + +if is_torch_available(): + import torch + +if is_vision_available(): + import PIL + + +logger = logging.get_logger(__name__) + + +def get_resize_output_image_size( + input_image: np.ndarray, + output_size: Union[int, Iterable[int]], + keep_aspect_ratio: bool, + multiple: int, + input_data_format: Optional[Union[str, ChannelDimension]] = None, +) -> Tuple[int, int]: + def constrain_to_multiple_of(val, multiple, min_val=0, max_val=None): + x = round(val / multiple) * multiple + + if max_val is not None and x > max_val: + x = math.floor(val / multiple) * multiple + + if x < min_val: + x = math.ceil(val / multiple) * multiple + + return x + + output_size = (output_size, output_size) if isinstance(output_size, int) else output_size + + input_height, input_width = get_image_size(input_image, input_data_format) + output_height, output_width = output_size + + # determine new height and width + scale_height = output_height / input_height + scale_width = output_width / input_width + + if keep_aspect_ratio: + # scale as little as possible + if abs(1 - scale_width) < abs(1 - scale_height): + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + + new_height = constrain_to_multiple_of(scale_height * input_height, multiple=multiple) + new_width = constrain_to_multiple_of(scale_width * input_width, multiple=multiple) + + return (new_height, new_width) + + +@requires(backends=("vision",)) +class DPTImageProcessor(BaseImageProcessor): + r""" + Constructs a DPT image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions. Can be overidden by `do_resize` in `preprocess`. + size (`Dict[str, int]` *optional*, defaults to `{"height": 384, "width": 384}`): + Size of the image after resizing. Can be overidden by `size` in `preprocess`. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): + Defines the resampling filter to use if resizing the image. Can be overidden by `resample` in `preprocess`. + keep_aspect_ratio (`bool`, *optional*, defaults to `False`): + If `True`, the image is resized to the largest possible size such that the aspect ratio is preserved. Can + be overidden by `keep_aspect_ratio` in `preprocess`. + ensure_multiple_of (`int`, *optional*, defaults to 1): + If `do_resize` is `True`, the image is resized to a size that is a multiple of this value. Can be overidden + by `ensure_multiple_of` in `preprocess`. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overidden by `do_rescale` in + `preprocess`. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overidden by `rescale_factor` in `preprocess`. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess` + method. + image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + do_pad (`bool`, *optional*, defaults to `False`): + Whether to apply center padding. This was introduced in the DINOv2 paper, which uses the model in + combination with DPT. + size_divisor (`int`, *optional*): + If `do_pad` is `True`, pads the image dimensions to be divisible by this value. This was introduced in the + DINOv2 paper, which uses the model in combination with DPT. + do_reduce_labels (`bool`, *optional*, defaults to `False`): + Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 is + used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k). The + background label will be replaced by 255. Can be overridden by the `do_reduce_labels` parameter in the + `preprocess` method. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize: bool = True, + size: Dict[str, int] = None, + resample: PILImageResampling = PILImageResampling.BICUBIC, + keep_aspect_ratio: bool = False, + ensure_multiple_of: int = 1, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_pad: bool = False, + size_divisor: Optional[int] = None, + do_reduce_labels: bool = False, + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"height": 384, "width": 384} + size = get_size_dict(size) + self.do_resize = do_resize + self.size = size + self.keep_aspect_ratio = keep_aspect_ratio + self.ensure_multiple_of = ensure_multiple_of + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN + self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD + self.do_pad = do_pad + self.size_divisor = size_divisor + self.do_reduce_labels = do_reduce_labels + + def resize( + self, + image: np.ndarray, + size: Dict[str, int], + keep_aspect_ratio: bool = False, + ensure_multiple_of: int = 1, + resample: PILImageResampling = PILImageResampling.BICUBIC, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize an image to target size `(size["height"], size["width"])`. If `keep_aspect_ratio` is `True`, the image + is resized to the largest possible size such that the aspect ratio is preserved. If `ensure_multiple_of` is + set, the image is resized to a size that is a multiple of this value. + + Args: + image (`np.ndarray`): + Image to resize. + size (`Dict[str, int]`): + Target size of the output image. + keep_aspect_ratio (`bool`, *optional*, defaults to `False`): + If `True`, the image is resized to the largest possible size such that the aspect ratio is preserved. + ensure_multiple_of (`int`, *optional*, defaults to 1): + The image is resized to a size that is a multiple of this value. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + Defines the resampling filter to use if resizing the image. Otherwise, the image is resized to size + specified in `size`. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + Resampling filter to use when resiizing the image. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + input_data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + size = get_size_dict(size) + if "height" not in size or "width" not in size: + raise ValueError(f"The size dictionary must contain the keys 'height' and 'width'. Got {size.keys()}") + + output_size = get_resize_output_image_size( + image, + output_size=(size["height"], size["width"]), + keep_aspect_ratio=keep_aspect_ratio, + multiple=ensure_multiple_of, + input_data_format=input_data_format, + ) + return resize( + image, + size=output_size, + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + def pad_image( + self, + image: np.array, + size_divisor: int, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + """ + Center pad an image to be a multiple of `multiple`. + + Args: + image (`np.ndarray`): + Image to pad. + size_divisor (`int`): + The width and height of the image will be padded to a multiple of this number. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + + def _get_pad(size, size_divisor): + new_size = math.ceil(size / size_divisor) * size_divisor + pad_size = new_size - size + pad_size_left = pad_size // 2 + pad_size_right = pad_size - pad_size_left + return pad_size_left, pad_size_right + + if input_data_format is None: + input_data_format = infer_channel_dimension_format(image) + + height, width = get_image_size(image, input_data_format) + + pad_size_left, pad_size_right = _get_pad(height, size_divisor) + pad_size_top, pad_size_bottom = _get_pad(width, size_divisor) + + return pad(image, ((pad_size_left, pad_size_right), (pad_size_top, pad_size_bottom)), data_format=data_format) + + # Copied from transformers.models.beit.image_processing_beit.BeitImageProcessor.reduce_label + def reduce_label(self, label: ImageInput) -> np.ndarray: + label = to_numpy_array(label) + # Avoid using underflow conversion + label[label == 0] = 255 + label = label - 1 + label[label == 254] = 255 + return label + + def _preprocess( + self, + image: ImageInput, + do_reduce_labels: Optional[bool] = None, + do_resize: Optional[bool] = None, + size: Dict[str, int] = None, + resample: PILImageResampling = None, + keep_aspect_ratio: Optional[bool] = None, + ensure_multiple_of: Optional[int] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_pad: Optional[bool] = None, + size_divisor: Optional[int] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + if do_reduce_labels: + image = self.reduce_label(image) + + if do_resize: + image = self.resize( + image=image, + size=size, + resample=resample, + keep_aspect_ratio=keep_aspect_ratio, + ensure_multiple_of=ensure_multiple_of, + input_data_format=input_data_format, + ) + + if do_rescale: + image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + + if do_normalize: + image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) + + if do_pad: + image = self.pad_image(image=image, size_divisor=size_divisor, input_data_format=input_data_format) + + return image + + def _preprocess_image( + self, + image: ImageInput, + do_resize: Optional[bool] = None, + size: Dict[str, int] = None, + resample: PILImageResampling = None, + keep_aspect_ratio: Optional[bool] = None, + ensure_multiple_of: Optional[int] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_pad: Optional[bool] = None, + size_divisor: Optional[int] = None, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> np.ndarray: + """Preprocesses a single image.""" + # All transformations expect numpy arrays. + image = to_numpy_array(image) + if do_rescale and is_scaled_image(image): + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(image) + + image = self._preprocess( + image, + do_reduce_labels=False, + do_resize=do_resize, + size=size, + resample=resample, + keep_aspect_ratio=keep_aspect_ratio, + ensure_multiple_of=ensure_multiple_of, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_pad=do_pad, + size_divisor=size_divisor, + input_data_format=input_data_format, + ) + if data_format is not None: + image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + return image + + def _preprocess_segmentation_map( + self, + segmentation_map: ImageInput, + do_resize: Optional[bool] = None, + size: Dict[str, int] = None, + resample: PILImageResampling = None, + keep_aspect_ratio: Optional[bool] = None, + ensure_multiple_of: Optional[int] = None, + do_reduce_labels: Optional[bool] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + """Preprocesses a single segmentation map.""" + # All transformations expect numpy arrays. + segmentation_map = to_numpy_array(segmentation_map) + # Add an axis to the segmentation maps for transformations. + if segmentation_map.ndim == 2: + segmentation_map = segmentation_map[None, ...] + added_dimension = True + input_data_format = ChannelDimension.FIRST + else: + added_dimension = False + if input_data_format is None: + input_data_format = infer_channel_dimension_format(segmentation_map, num_channels=1) + segmentation_map = self._preprocess( + image=segmentation_map, + do_reduce_labels=do_reduce_labels, + do_resize=do_resize, + size=size, + resample=resample, + keep_aspect_ratio=keep_aspect_ratio, + ensure_multiple_of=ensure_multiple_of, + do_normalize=False, + do_rescale=False, + input_data_format=input_data_format, + ) + # Remove extra axis if added + if added_dimension: + segmentation_map = np.squeeze(segmentation_map, axis=0) + segmentation_map = segmentation_map.astype(np.int64) + return segmentation_map + + # Copied from transformers.models.beit.image_processing_beit.BeitImageProcessor.__call__ + def __call__(self, images, segmentation_maps=None, **kwargs): + # Overrides the `__call__` method of the `Preprocessor` class such that the images and segmentation maps can both + # be passed in as positional arguments. + return super().__call__(images, segmentation_maps=segmentation_maps, **kwargs) + + @filter_out_non_signature_kwargs() + def preprocess( + self, + images: ImageInput, + segmentation_maps: Optional[ImageInput] = None, + do_resize: Optional[bool] = None, + size: Optional[int] = None, + keep_aspect_ratio: Optional[bool] = None, + ensure_multiple_of: Optional[int] = None, + resample: PILImageResampling = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_pad: Optional[bool] = None, + size_divisor: Optional[int] = None, + do_reduce_labels: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: ChannelDimension = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> PIL.Image.Image: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + segmentation_maps (`ImageInput`, *optional*): + Segmentation map to preprocess. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image after reszing. If `keep_aspect_ratio` is `True`, the image is resized to the largest + possible size such that the aspect ratio is preserved. If `ensure_multiple_of` is set, the image is + resized to a size that is a multiple of this value. + keep_aspect_ratio (`bool`, *optional*, defaults to `self.keep_aspect_ratio`): + Whether to keep the aspect ratio of the image. If False, the image will be resized to (size, size). If + True, the image will be resized to keep the aspect ratio and the size will be the maximum possible. + ensure_multiple_of (`int`, *optional*, defaults to `self.ensure_multiple_of`): + Ensure that the image size is a multiple of this value. + resample (`int`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only + has an effect if `do_resize` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image values between [0 - 1]. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation. + do_reduce_labels (`bool`, *optional*, defaults to `self.do_reduce_labels`): + Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 + is used for background, and background itself is not included in all classes of a dataset (e.g. + ADE20k). The background label will be replaced by 255. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `ChannelDimension.LAST`: image in (height, width, num_channels) format. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + do_resize = do_resize if do_resize is not None else self.do_resize + size = size if size is not None else self.size + size = get_size_dict(size) + keep_aspect_ratio = keep_aspect_ratio if keep_aspect_ratio is not None else self.keep_aspect_ratio + ensure_multiple_of = ensure_multiple_of if ensure_multiple_of is not None else self.ensure_multiple_of + resample = resample if resample is not None else self.resample + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + do_pad = do_pad if do_pad is not None else self.do_pad + size_divisor = size_divisor if size_divisor is not None else self.size_divisor + do_reduce_labels = do_reduce_labels if do_reduce_labels is not None else self.do_reduce_labels + + images = make_list_of_images(images) + + if segmentation_maps is not None: + segmentation_maps = make_list_of_images(segmentation_maps, expected_ndims=2) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_pad=do_pad, + size_divisibility=size_divisor, + do_resize=do_resize, + size=size, + resample=resample, + ) + + images = [ + self._preprocess_image( + image=img, + do_resize=do_resize, + do_rescale=do_rescale, + do_normalize=do_normalize, + do_pad=do_pad, + size=size, + resample=resample, + keep_aspect_ratio=keep_aspect_ratio, + ensure_multiple_of=ensure_multiple_of, + rescale_factor=rescale_factor, + image_mean=image_mean, + image_std=image_std, + size_divisor=size_divisor, + data_format=data_format, + input_data_format=input_data_format, + ) + for img in images + ] + + data = {"pixel_values": images} + + if segmentation_maps is not None: + segmentation_maps = [ + self._preprocess_segmentation_map( + segmentation_map=segmentation_map, + do_reduce_labels=do_reduce_labels, + do_resize=do_resize, + size=size, + resample=resample, + keep_aspect_ratio=keep_aspect_ratio, + ensure_multiple_of=ensure_multiple_of, + input_data_format=input_data_format, + ) + for segmentation_map in segmentation_maps + ] + + data["labels"] = segmentation_maps + + return BatchFeature(data=data, tensor_type=return_tensors) + + # Copied from transformers.models.beit.image_processing_beit.BeitImageProcessor.post_process_semantic_segmentation with Beit->DPT + def post_process_semantic_segmentation(self, outputs, target_sizes: List[Tuple] = None): + """ + Converts the output of [`DPTForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch. + + Args: + outputs ([`DPTForSemanticSegmentation`]): + Raw outputs of the model. + target_sizes (`List[Tuple]` of length `batch_size`, *optional*): + List of tuples corresponding to the requested final size (height, width) of each prediction. If unset, + predictions will not be resized. + + Returns: + semantic_segmentation: `List[torch.Tensor]` of length `batch_size`, where each item is a semantic + segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is + specified). Each entry of each `torch.Tensor` correspond to a semantic class id. + """ + # TODO: add support for other frameworks + logits = outputs.logits + + # Resize logits and compute semantic segmentation maps + if target_sizes is not None: + if len(logits) != len(target_sizes): + raise ValueError( + "Make sure that you pass in as many target sizes as the batch dimension of the logits" + ) + + if is_torch_tensor(target_sizes): + target_sizes = target_sizes.numpy() + + semantic_segmentation = [] + + for idx in range(len(logits)): + resized_logits = torch.nn.functional.interpolate( + logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False + ) + semantic_map = resized_logits[0].argmax(dim=0) + semantic_segmentation.append(semantic_map) + else: + semantic_segmentation = logits.argmax(dim=1) + semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])] + + return semantic_segmentation + + def post_process_depth_estimation( + self, + outputs: "DepthEstimatorOutput", + target_sizes: Optional[Union[TensorType, List[Tuple[int, int]], None]] = None, + ) -> List[Dict[str, TensorType]]: + """ + Converts the raw output of [`DepthEstimatorOutput`] into final depth predictions and depth PIL images. + Only supports PyTorch. + + Args: + outputs ([`DepthEstimatorOutput`]): + Raw outputs of the model. + target_sizes (`TensorType` or `List[Tuple[int, int]]`, *optional*): + Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size + (height, width) of each image in the batch. If left to None, predictions will not be resized. + + Returns: + `List[Dict[str, TensorType]]`: A list of dictionaries of tensors representing the processed depth + predictions. + """ + requires_backends(self, "torch") + + predicted_depth = outputs.predicted_depth + + if (target_sizes is not None) and (len(predicted_depth) != len(target_sizes)): + raise ValueError( + "Make sure that you pass in as many target sizes as the batch dimension of the predicted depth" + ) + + results = [] + target_sizes = [None] * len(predicted_depth) if target_sizes is None else target_sizes + for depth, target_size in zip(predicted_depth, target_sizes): + if target_size is not None: + depth = torch.nn.functional.interpolate( + depth.unsqueeze(0).unsqueeze(1), size=target_size, mode="bicubic", align_corners=False + ).squeeze() + + results.append({"predicted_depth": depth}) + + return results + + +__all__ = ["DPTImageProcessor"] diff --git a/docs/transformers/build/lib/transformers/models/dpt/modeling_dpt.py b/docs/transformers/build/lib/transformers/models/dpt/modeling_dpt.py new file mode 100644 index 0000000000000000000000000000000000000000..0dc1dcf2f87e072ed72b0892f0d88a9175203bd0 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/dpt/modeling_dpt.py @@ -0,0 +1,1413 @@ +# coding=utf-8 +# Copyright 2022 Intel Labs, OpenMMLab and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch DPT (Dense Prediction Transformers) model. + +This implementation is heavily inspired by OpenMMLab's implementation, found here: +https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/decode_heads/dpt_head.py. + +""" + +import collections.abc +from dataclasses import dataclass +from typing import Callable, List, Optional, Set, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss + +from ...activations import ACT2FN +from ...file_utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + replace_return_docstrings, +) +from ...modeling_outputs import BaseModelOutput, DepthEstimatorOutput, SemanticSegmenterOutput +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import ModelOutput, logging, torch_int +from ...utils.backbone_utils import load_backbone +from .configuration_dpt import DPTConfig + + +logger = logging.get_logger(__name__) + +# General docstring +_CONFIG_FOR_DOC = "DPTConfig" + +# Base docstring +_CHECKPOINT_FOR_DOC = "Intel/dpt-large" +_EXPECTED_OUTPUT_SHAPE = [1, 577, 1024] + + +@dataclass +class BaseModelOutputWithIntermediateActivations(ModelOutput): + """ + Base class for model's outputs that also contains intermediate activations that can be used at later stages. Useful + in the context of Vision models.: + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + intermediate_activations (`tuple(torch.FloatTensor)`, *optional*): + Intermediate activations that can be used to compute hidden states of the model at various layers. + """ + + last_hidden_states: Optional[torch.FloatTensor] = None + intermediate_activations: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class BaseModelOutputWithPoolingAndIntermediateActivations(ModelOutput): + """ + Base class for model's outputs that also contains a pooling of the last hidden states as well as intermediate + activations that can be used by the model at later stages. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`): + Last layer hidden-state of the first token of the sequence (classification token) after further processing + through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns + the classification token after processing through a linear layer and a tanh activation function. The linear + layer weights are trained from the next sentence prediction (classification) objective during pretraining. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + intermediate_activations (`tuple(torch.FloatTensor)`, *optional*): + Intermediate activations that can be used to compute hidden states of the model at various layers. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + pooler_output: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + intermediate_activations: Optional[Tuple[torch.FloatTensor, ...]] = None + + +class DPTViTHybridEmbeddings(nn.Module): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__(self, config, feature_size=None): + super().__init__() + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.hidden_size + + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + + self.backbone = load_backbone(config) + feature_dim = self.backbone.channels[-1] + if len(self.backbone.channels) != 3: + raise ValueError(f"Expected backbone to have 3 output features, got {len(self.backbone.channels)}") + self.residual_feature_map_index = [0, 1] # Always take the output of the first and second backbone stage + + if feature_size is None: + feat_map_shape = config.backbone_featmap_shape + feature_size = feat_map_shape[-2:] + feature_dim = feat_map_shape[1] + else: + feature_size = ( + feature_size if isinstance(feature_size, collections.abc.Iterable) else (feature_size, feature_size) + ) + feature_dim = self.backbone.channels[-1] + + self.image_size = image_size + self.patch_size = patch_size[0] + self.num_channels = num_channels + + self.projection = nn.Conv2d(feature_dim, hidden_size, kernel_size=1) + + self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) + self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size)) + + def _resize_pos_embed(self, posemb, grid_size_height, grid_size_width, start_index=1): + posemb_tok = posemb[:, :start_index] + posemb_grid = posemb[0, start_index:] + + old_grid_size = torch_int(len(posemb_grid) ** 0.5) + + posemb_grid = posemb_grid.reshape(1, old_grid_size, old_grid_size, -1).permute(0, 3, 1, 2) + posemb_grid = nn.functional.interpolate(posemb_grid, size=(grid_size_height, grid_size_width), mode="bilinear") + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, grid_size_height * grid_size_width, -1) + + posemb = torch.cat([posemb_tok, posemb_grid], dim=1) + + return posemb + + def forward( + self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False, return_dict: bool = False + ) -> torch.Tensor: + batch_size, num_channels, height, width = pixel_values.shape + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + if not interpolate_pos_encoding: + if height != self.image_size[0] or width != self.image_size[1]: + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model" + f" ({self.image_size[0]}*{self.image_size[1]})." + ) + + position_embeddings = self._resize_pos_embed( + self.position_embeddings, height // self.patch_size, width // self.patch_size + ) + + backbone_output = self.backbone(pixel_values) + + features = backbone_output.feature_maps[-1] + + # Retrieve also the intermediate activations to use them at later stages + output_hidden_states = [backbone_output.feature_maps[index] for index in self.residual_feature_map_index] + + embeddings = self.projection(features).flatten(2).transpose(1, 2) + + cls_tokens = self.cls_token.expand(batch_size, -1, -1) + embeddings = torch.cat((cls_tokens, embeddings), dim=1) + + # add positional encoding to each token + embeddings = embeddings + position_embeddings + + if not return_dict: + return (embeddings, output_hidden_states) + + # Return hidden states and intermediate activations + return BaseModelOutputWithIntermediateActivations( + last_hidden_states=embeddings, + intermediate_activations=output_hidden_states, + ) + + +class DPTViTEmbeddings(nn.Module): + """ + Construct the CLS token, position and patch embeddings. + + """ + + def __init__(self, config): + super().__init__() + + self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) + self.patch_embeddings = DPTViTPatchEmbeddings(config) + num_patches = self.patch_embeddings.num_patches + self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size)) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.config = config + + def _resize_pos_embed(self, posemb, grid_size_height, grid_size_width, start_index=1): + posemb_tok = posemb[:, :start_index] + posemb_grid = posemb[0, start_index:] + + old_grid_size = torch_int(posemb_grid.size(0) ** 0.5) + + posemb_grid = posemb_grid.reshape(1, old_grid_size, old_grid_size, -1).permute(0, 3, 1, 2) + posemb_grid = nn.functional.interpolate(posemb_grid, size=(grid_size_height, grid_size_width), mode="bilinear") + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, grid_size_height * grid_size_width, -1) + + posemb = torch.cat([posemb_tok, posemb_grid], dim=1) + + return posemb + + def forward(self, pixel_values, return_dict=False): + batch_size, num_channels, height, width = pixel_values.shape + + # possibly interpolate position encodings to handle varying image sizes + patch_size = self.config.patch_size + position_embeddings = self._resize_pos_embed( + self.position_embeddings, height // patch_size, width // patch_size + ) + + embeddings = self.patch_embeddings(pixel_values) + + batch_size, seq_len, _ = embeddings.size() + + # add the [CLS] token to the embedded patch tokens + cls_tokens = self.cls_token.expand(batch_size, -1, -1) + embeddings = torch.cat((cls_tokens, embeddings), dim=1) + + # add positional encoding to each token + embeddings = embeddings + position_embeddings + + embeddings = self.dropout(embeddings) + + if not return_dict: + return (embeddings,) + + return BaseModelOutputWithIntermediateActivations(last_hidden_states=embeddings) + + +class DPTViTPatchEmbeddings(nn.Module): + """ + Image to Patch Embedding. + + """ + + def __init__(self, config): + super().__init__() + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.hidden_size + + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + + self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) + + def forward(self, pixel_values): + batch_size, num_channels, height, width = pixel_values.shape + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2) + return embeddings + + +# Copied from transformers.models.vit.modeling_vit.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + # Take the dot product between "query" and "key" to get the raw attention scores. + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + + # Normalize the attention scores to probabilities. + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + # Mask heads if we want to + if attention_mask is not None: + attn_weights = attn_weights * attention_mask + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->DPT +class DPTSelfAttention(nn.Module): + def __init__(self, config: DPTConfig) -> None: + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size {config.hidden_size} is not a multiple of the number of attention " + f"heads {config.num_attention_heads}." + ) + + self.config = config + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.dropout_prob = config.attention_probs_dropout_prob + self.scaling = self.attention_head_size**-0.5 + self.is_causal = False + + self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(self.query(hidden_states)) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + context_layer, attention_probs = attention_interface( + self, + query_layer, + key_layer, + value_layer, + head_mask, + is_causal=self.is_causal, + scaling=self.scaling, + dropout=0.0 if not self.training else self.dropout_prob, + ) + + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.reshape(new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->DPT +class DPTViTSelfOutput(nn.Module): + """ + The residual connection is defined in DPTLayer instead of here (as is the case with other models), due to the + layernorm applied before each block. + """ + + def __init__(self, config: DPTConfig) -> None: + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + return hidden_states + + +class DPTViTAttention(nn.Module): + def __init__(self, config: DPTConfig) -> None: + super().__init__() + self.attention = DPTSelfAttention(config) + self.output = DPTViTSelfOutput(config) + self.pruned_heads = set() + + # Copied from transformers.models.vit.modeling_vit.ViTAttention.prune_heads + def prune_heads(self, heads: Set[int]) -> None: + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.attention.query = prune_linear_layer(self.attention.query, index) + self.attention.key = prune_linear_layer(self.attention.key, index) + self.attention.value = prune_linear_layer(self.attention.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads) + self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + # Copied from transformers.models.vit.modeling_vit.ViTAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + self_outputs = self.attention(hidden_states, head_mask, output_attentions) + + attention_output = self.output(self_outputs[0], hidden_states) + + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->DPT +class DPTViTIntermediate(nn.Module): + def __init__(self, config: DPTConfig) -> None: + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + + return hidden_states + + +# Copied from transformers.models.vit.modeling_vit.ViTOutput with ViT->DPT +class DPTViTOutput(nn.Module): + def __init__(self, config: DPTConfig) -> None: + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + hidden_states = hidden_states + input_tensor + + return hidden_states + + +# copied from transformers.models.vit.modeling_vit.ViTLayer with ViTConfig->DPTConfig, ViTAttention->DPTViTAttention, ViTIntermediate->DPTViTIntermediate, ViTOutput->DPTViTOutput +class DPTViTLayer(nn.Module): + """This corresponds to the Block class in the timm implementation.""" + + def __init__(self, config: DPTConfig) -> None: + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = DPTViTAttention(config) + self.intermediate = DPTViTIntermediate(config) + self.output = DPTViTOutput(config) + self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + self_attention_outputs = self.attention( + self.layernorm_before(hidden_states), # in ViT, layernorm is applied before self-attention + head_mask, + output_attentions=output_attentions, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + # first residual connection + hidden_states = attention_output + hidden_states + + # in ViT, layernorm is also applied after self-attention + layer_output = self.layernorm_after(hidden_states) + layer_output = self.intermediate(layer_output) + + # second residual connection is done here + layer_output = self.output(layer_output, hidden_states) + + outputs = (layer_output,) + outputs + + return outputs + + +# copied from transformers.models.vit.modeling_vit.ViTEncoder with ViTConfig -> DPTConfig, ViTLayer->DPTViTLayer +class DPTViTEncoder(nn.Module): + def __init__(self, config: DPTConfig) -> None: + super().__init__() + self.config = config + self.layer = nn.ModuleList([DPTViTLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ) -> Union[tuple, BaseModelOutput]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + layer_head_mask, + output_attentions, + ) + else: + layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class DPTReassembleStage(nn.Module): + """ + This class reassembles the hidden states of the backbone into image-like feature representations at various + resolutions. + + This happens in 3 stages: + 1. Map the N + 1 tokens to a set of N tokens, by taking into account the readout ([CLS]) token according to + `config.readout_type`. + 2. Project the channel dimension of the hidden states according to `config.neck_hidden_sizes`. + 3. Resizing the spatial dimensions (height, width). + + Args: + config (`[DPTConfig]`): + Model configuration class defining the model architecture. + """ + + def __init__(self, config): + super().__init__() + + self.config = config + self.layers = nn.ModuleList() + if config.is_hybrid: + self._init_reassemble_dpt_hybrid(config) + else: + self._init_reassemble_dpt(config) + + self.neck_ignore_stages = config.neck_ignore_stages + + def _init_reassemble_dpt_hybrid(self, config): + r""" " + For DPT-Hybrid the first 2 reassemble layers are set to `nn.Identity()`, please check the official + implementation: https://github.com/isl-org/DPT/blob/f43ef9e08d70a752195028a51be5e1aff227b913/dpt/vit.py#L438 + for more details. + """ + for i, factor in zip(range(len(config.neck_hidden_sizes)), config.reassemble_factors): + if i <= 1: + self.layers.append(nn.Identity()) + elif i > 1: + self.layers.append(DPTReassembleLayer(config, channels=config.neck_hidden_sizes[i], factor=factor)) + + if config.readout_type != "project": + raise ValueError(f"Readout type {config.readout_type} is not supported for DPT-Hybrid.") + + # When using DPT-Hybrid the readout type is set to "project". The sanity check is done on the config file + self.readout_projects = nn.ModuleList() + hidden_size = _get_backbone_hidden_size(config) + for i in range(len(config.neck_hidden_sizes)): + if i <= 1: + self.readout_projects.append(nn.Sequential(nn.Identity())) + elif i > 1: + self.readout_projects.append( + nn.Sequential(nn.Linear(2 * hidden_size, hidden_size), ACT2FN[config.hidden_act]) + ) + + def _init_reassemble_dpt(self, config): + for i, factor in zip(range(len(config.neck_hidden_sizes)), config.reassemble_factors): + self.layers.append(DPTReassembleLayer(config, channels=config.neck_hidden_sizes[i], factor=factor)) + + if config.readout_type == "project": + self.readout_projects = nn.ModuleList() + hidden_size = _get_backbone_hidden_size(config) + for _ in range(len(config.neck_hidden_sizes)): + self.readout_projects.append( + nn.Sequential(nn.Linear(2 * hidden_size, hidden_size), ACT2FN[config.hidden_act]) + ) + + def forward(self, hidden_states: List[torch.Tensor], patch_height=None, patch_width=None) -> List[torch.Tensor]: + """ + Args: + hidden_states (`List[torch.FloatTensor]`, each of shape `(batch_size, sequence_length + 1, hidden_size)`): + List of hidden states from the backbone. + """ + out = [] + + for i, hidden_state in enumerate(hidden_states): + if i not in self.neck_ignore_stages: + # reshape to (batch_size, num_channels, height, width) + cls_token, hidden_state = hidden_state[:, 0], hidden_state[:, 1:] + batch_size, sequence_length, num_channels = hidden_state.shape + if patch_height is not None and patch_width is not None: + hidden_state = hidden_state.reshape(batch_size, patch_height, patch_width, num_channels) + else: + size = torch_int(sequence_length**0.5) + hidden_state = hidden_state.reshape(batch_size, size, size, num_channels) + hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous() + + feature_shape = hidden_state.shape + if self.config.readout_type == "project": + # reshape to (batch_size, height*width, num_channels) + hidden_state = hidden_state.flatten(2).permute((0, 2, 1)) + readout = cls_token.unsqueeze(1).expand_as(hidden_state) + # concatenate the readout token to the hidden states and project + hidden_state = self.readout_projects[i](torch.cat((hidden_state, readout), -1)) + # reshape back to (batch_size, num_channels, height, width) + hidden_state = hidden_state.permute(0, 2, 1).reshape(feature_shape) + elif self.config.readout_type == "add": + hidden_state = hidden_state.flatten(2) + cls_token.unsqueeze(-1) + hidden_state = hidden_state.reshape(feature_shape) + hidden_state = self.layers[i](hidden_state) + out.append(hidden_state) + + return out + + +def _get_backbone_hidden_size(config): + if config.backbone_config is not None and config.is_hybrid is False: + return config.backbone_config.hidden_size + else: + return config.hidden_size + + +class DPTReassembleLayer(nn.Module): + def __init__(self, config, channels, factor): + super().__init__() + # projection + hidden_size = _get_backbone_hidden_size(config) + self.projection = nn.Conv2d(in_channels=hidden_size, out_channels=channels, kernel_size=1) + + # up/down sampling depending on factor + if factor > 1: + self.resize = nn.ConvTranspose2d(channels, channels, kernel_size=factor, stride=factor, padding=0) + elif factor == 1: + self.resize = nn.Identity() + elif factor < 1: + # so should downsample + self.resize = nn.Conv2d(channels, channels, kernel_size=3, stride=int(1 / factor), padding=1) + + def forward(self, hidden_state): + hidden_state = self.projection(hidden_state) + hidden_state = self.resize(hidden_state) + return hidden_state + + +class DPTFeatureFusionStage(nn.Module): + def __init__(self, config): + super().__init__() + self.layers = nn.ModuleList() + for _ in range(len(config.neck_hidden_sizes)): + self.layers.append(DPTFeatureFusionLayer(config)) + + def forward(self, hidden_states): + # reversing the hidden_states, we start from the last + hidden_states = hidden_states[::-1] + + fused_hidden_states = [] + fused_hidden_state = None + for hidden_state, layer in zip(hidden_states, self.layers): + if fused_hidden_state is None: + # first layer only uses the last hidden_state + fused_hidden_state = layer(hidden_state) + else: + fused_hidden_state = layer(fused_hidden_state, hidden_state) + fused_hidden_states.append(fused_hidden_state) + + return fused_hidden_states + + +class DPTPreActResidualLayer(nn.Module): + """ + ResidualConvUnit, pre-activate residual unit. + + Args: + config (`[DPTConfig]`): + Model configuration class defining the model architecture. + """ + + def __init__(self, config): + super().__init__() + + self.use_batch_norm = config.use_batch_norm_in_fusion_residual + use_bias_in_fusion_residual = ( + config.use_bias_in_fusion_residual + if config.use_bias_in_fusion_residual is not None + else not self.use_batch_norm + ) + + self.activation1 = nn.ReLU() + self.convolution1 = nn.Conv2d( + config.fusion_hidden_size, + config.fusion_hidden_size, + kernel_size=3, + stride=1, + padding=1, + bias=use_bias_in_fusion_residual, + ) + + self.activation2 = nn.ReLU() + self.convolution2 = nn.Conv2d( + config.fusion_hidden_size, + config.fusion_hidden_size, + kernel_size=3, + stride=1, + padding=1, + bias=use_bias_in_fusion_residual, + ) + + if self.use_batch_norm: + self.batch_norm1 = nn.BatchNorm2d(config.fusion_hidden_size) + self.batch_norm2 = nn.BatchNorm2d(config.fusion_hidden_size) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + residual = hidden_state + hidden_state = self.activation1(hidden_state) + + hidden_state = self.convolution1(hidden_state) + + if self.use_batch_norm: + hidden_state = self.batch_norm1(hidden_state) + + hidden_state = self.activation2(hidden_state) + hidden_state = self.convolution2(hidden_state) + + if self.use_batch_norm: + hidden_state = self.batch_norm2(hidden_state) + + return hidden_state + residual + + +class DPTFeatureFusionLayer(nn.Module): + """Feature fusion layer, merges feature maps from different stages. + + Args: + config (`[DPTConfig]`): + Model configuration class defining the model architecture. + align_corners (`bool`, *optional*, defaults to `True`): + The align_corner setting for bilinear upsample. + """ + + def __init__(self, config, align_corners=True): + super().__init__() + + self.align_corners = align_corners + + self.projection = nn.Conv2d(config.fusion_hidden_size, config.fusion_hidden_size, kernel_size=1, bias=True) + + self.residual_layer1 = DPTPreActResidualLayer(config) + self.residual_layer2 = DPTPreActResidualLayer(config) + + def forward(self, hidden_state, residual=None): + if residual is not None: + if hidden_state.shape != residual.shape: + residual = nn.functional.interpolate( + residual, size=(hidden_state.shape[2], hidden_state.shape[3]), mode="bilinear", align_corners=False + ) + hidden_state = hidden_state + self.residual_layer1(residual) + + hidden_state = self.residual_layer2(hidden_state) + hidden_state = nn.functional.interpolate( + hidden_state, scale_factor=2, mode="bilinear", align_corners=self.align_corners + ) + hidden_state = self.projection(hidden_state) + + return hidden_state + + +class DPTPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = DPTConfig + base_model_prefix = "dpt" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + _supports_sdpa = True + _supports_flash_attn_2 = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d)): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, (DPTViTEmbeddings, DPTViTHybridEmbeddings)): + module.cls_token.data.zero_() + module.position_embeddings.data.zero_() + + +DPT_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it + as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`ViTConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +DPT_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`DPTImageProcessor.__call__`] + for details. + + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare DPT Model transformer outputting raw hidden-states without any specific head on top.", + DPT_START_DOCSTRING, +) +class DPTModel(DPTPreTrainedModel): + def __init__(self, config, add_pooling_layer=True): + super().__init__(config) + self.config = config + + # vit encoder + if config.is_hybrid: + self.embeddings = DPTViTHybridEmbeddings(config) + else: + self.embeddings = DPTViTEmbeddings(config) + self.encoder = DPTViTEncoder(config) + + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.pooler = DPTViTPooler(config) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + if self.config.is_hybrid: + return self.embeddings + else: + return self.embeddings.patch_embeddings + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(DPT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPoolingAndIntermediateActivations, + config_class=_CONFIG_FOR_DOC, + modality="vision", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward( + self, + pixel_values: torch.FloatTensor, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPoolingAndIntermediateActivations]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings(pixel_values, return_dict=return_dict) + + embedding_last_hidden_states = embedding_output[0] if not return_dict else embedding_output.last_hidden_states + + encoder_outputs = self.encoder( + embedding_last_hidden_states, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,) + return head_outputs + encoder_outputs[1:] + embedding_output[1:] + + return BaseModelOutputWithPoolingAndIntermediateActivations( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + intermediate_activations=embedding_output.intermediate_activations, + ) + + +# Copied from transformers.models.vit.modeling_vit.ViTPooler with ViT->DPT +class DPTViTPooler(nn.Module): + def __init__(self, config: DPTConfig): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.pooler_output_size) + self.activation = ACT2FN[config.pooler_act] + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class DPTNeck(nn.Module): + """ + DPTNeck. A neck is a module that is normally used between the backbone and the head. It takes a list of tensors as + input and produces another list of tensors as output. For DPT, it includes 2 stages: + + * DPTReassembleStage + * DPTFeatureFusionStage. + + Args: + config (dict): config dict. + """ + + def __init__(self, config): + super().__init__() + self.config = config + + # postprocessing: only required in case of a non-hierarchical backbone (e.g. ViT, BEiT) + if config.backbone_config is not None and config.backbone_config.model_type in ["swinv2"]: + self.reassemble_stage = None + else: + self.reassemble_stage = DPTReassembleStage(config) + + self.convs = nn.ModuleList() + for channel in config.neck_hidden_sizes: + self.convs.append(nn.Conv2d(channel, config.fusion_hidden_size, kernel_size=3, padding=1, bias=False)) + + # fusion + self.fusion_stage = DPTFeatureFusionStage(config) + + def forward(self, hidden_states: List[torch.Tensor], patch_height=None, patch_width=None) -> List[torch.Tensor]: + """ + Args: + hidden_states (`List[torch.FloatTensor]`, each of shape `(batch_size, sequence_length, hidden_size)` or `(batch_size, hidden_size, height, width)`): + List of hidden states from the backbone. + """ + if not isinstance(hidden_states, (tuple, list)): + raise TypeError("hidden_states should be a tuple or list of tensors") + + if len(hidden_states) != len(self.config.neck_hidden_sizes): + raise ValueError("The number of hidden states should be equal to the number of neck hidden sizes.") + + # postprocess hidden states + if self.reassemble_stage is not None: + hidden_states = self.reassemble_stage(hidden_states, patch_height, patch_width) + + features = [self.convs[i](feature) for i, feature in enumerate(hidden_states)] + + # fusion blocks + output = self.fusion_stage(features) + + return output + + +class DPTDepthEstimationHead(nn.Module): + """ + Output head consisting of 3 convolutional layers. It progressively halves the feature dimension and upsamples + the predictions to the input resolution after the first convolutional layer (details can be found in the paper's + supplementary material). + """ + + def __init__(self, config): + super().__init__() + + self.config = config + + self.projection = None + if config.add_projection: + self.projection = nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + + features = config.fusion_hidden_size + self.head = nn.Sequential( + nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), + nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(), + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(), + ) + + def forward(self, hidden_states: List[torch.Tensor]) -> torch.Tensor: + # use last features + hidden_states = hidden_states[self.config.head_in_index] + + if self.projection is not None: + hidden_states = self.projection(hidden_states) + hidden_states = nn.ReLU()(hidden_states) + + predicted_depth = self.head(hidden_states) + + predicted_depth = predicted_depth.squeeze(dim=1) + + return predicted_depth + + +@add_start_docstrings( + """ + DPT Model with a depth estimation head on top (consisting of 3 convolutional layers) e.g. for KITTI, NYUv2. + """, + DPT_START_DOCSTRING, +) +class DPTForDepthEstimation(DPTPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.backbone = None + if config.is_hybrid is False and (config.backbone_config is not None or config.backbone is not None): + self.backbone = load_backbone(config) + else: + self.dpt = DPTModel(config, add_pooling_layer=False) + + # Neck + self.neck = DPTNeck(config) + + # Depth estimation head + self.head = DPTDepthEstimationHead(config) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(DPT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=DepthEstimatorOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: torch.FloatTensor, + head_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], DepthEstimatorOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*): + Ground truth depth estimation maps for computing the loss. + + Returns: + + Examples: + ```python + >>> from transformers import AutoImageProcessor, DPTForDepthEstimation + >>> import torch + >>> import numpy as np + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("Intel/dpt-large") + >>> model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large") + + >>> # prepare image for the model + >>> inputs = image_processor(images=image, return_tensors="pt") + + >>> with torch.no_grad(): + ... outputs = model(**inputs) + + >>> # interpolate to original size + >>> post_processed_output = image_processor.post_process_depth_estimation( + ... outputs, + ... target_sizes=[(image.height, image.width)], + ... ) + + >>> # visualize the prediction + >>> predicted_depth = post_processed_output[0]["predicted_depth"] + >>> depth = predicted_depth * 255 / predicted_depth.max() + >>> depth = depth.detach().cpu().numpy() + >>> depth = Image.fromarray(depth.astype("uint8")) + ```""" + loss = None + if labels is not None: + raise NotImplementedError("Training is not implemented yet") + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + if self.backbone is not None: + outputs = self.backbone.forward_with_filtered_kwargs( + pixel_values, output_hidden_states=output_hidden_states, output_attentions=output_attentions + ) + hidden_states = outputs.feature_maps + else: + outputs = self.dpt( + pixel_values, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=True, # we need the intermediate hidden states + return_dict=return_dict, + ) + hidden_states = outputs.hidden_states if return_dict else outputs[1] + # only keep certain features based on config.backbone_out_indices + # note that the hidden_states also include the initial embeddings + if not self.config.is_hybrid: + hidden_states = [ + feature for idx, feature in enumerate(hidden_states[1:]) if idx in self.config.backbone_out_indices + ] + else: + backbone_hidden_states = outputs.intermediate_activations if return_dict else list(outputs[-1]) + backbone_hidden_states.extend( + feature + for idx, feature in enumerate(hidden_states[1:]) + if idx in self.config.backbone_out_indices[2:] + ) + + hidden_states = backbone_hidden_states + + patch_height, patch_width = None, None + if self.config.backbone_config is not None and self.config.is_hybrid is False: + _, _, height, width = pixel_values.shape + patch_size = self.config.backbone_config.patch_size + patch_height = height // patch_size + patch_width = width // patch_size + + hidden_states = self.neck(hidden_states, patch_height, patch_width) + + predicted_depth = self.head(hidden_states) + + if not return_dict: + if output_hidden_states: + output = (predicted_depth,) + outputs[1:] + else: + output = (predicted_depth,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return DepthEstimatorOutput( + loss=loss, + predicted_depth=predicted_depth, + hidden_states=outputs.hidden_states if output_hidden_states else None, + attentions=outputs.attentions, + ) + + +class DPTSemanticSegmentationHead(nn.Module): + def __init__(self, config): + super().__init__() + + self.config = config + + features = config.fusion_hidden_size + self.head = nn.Sequential( + nn.Conv2d(features, features, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(features), + nn.ReLU(), + nn.Dropout(config.semantic_classifier_dropout), + nn.Conv2d(features, config.num_labels, kernel_size=1), + nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True), + ) + + def forward(self, hidden_states: List[torch.Tensor]) -> torch.Tensor: + # use last features + hidden_states = hidden_states[self.config.head_in_index] + + logits = self.head(hidden_states) + + return logits + + +class DPTAuxiliaryHead(nn.Module): + def __init__(self, config): + super().__init__() + + features = config.fusion_hidden_size + self.head = nn.Sequential( + nn.Conv2d(features, features, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(features), + nn.ReLU(), + nn.Dropout(0.1, False), + nn.Conv2d(features, config.num_labels, kernel_size=1), + ) + + def forward(self, hidden_states): + logits = self.head(hidden_states) + + return logits + + +@add_start_docstrings( + """ + DPT Model with a semantic segmentation head on top e.g. for ADE20k, CityScapes. + """, + DPT_START_DOCSTRING, +) +class DPTForSemanticSegmentation(DPTPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.dpt = DPTModel(config, add_pooling_layer=False) + + # Neck + self.neck = DPTNeck(config) + + # Segmentation head(s) + self.head = DPTSemanticSegmentationHead(config) + self.auxiliary_head = DPTAuxiliaryHead(config) if config.use_auxiliary_head else None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(DPT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=SemanticSegmenterOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], SemanticSegmenterOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*): + Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy). + + Returns: + + Examples: + ```python + >>> from transformers import AutoImageProcessor, DPTForSemanticSegmentation + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("Intel/dpt-large-ade") + >>> model = DPTForSemanticSegmentation.from_pretrained("Intel/dpt-large-ade") + + >>> inputs = image_processor(images=image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> logits = outputs.logits + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + if labels is not None and self.config.num_labels == 1: + raise ValueError("The number of labels should be greater than one") + + outputs = self.dpt( + pixel_values, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=True, # we need the intermediate hidden states + return_dict=return_dict, + ) + + hidden_states = outputs.hidden_states if return_dict else outputs[1] + + # only keep certain features based on config.backbone_out_indices + # note that the hidden_states also include the initial embeddings + if not self.config.is_hybrid: + hidden_states = [ + feature for idx, feature in enumerate(hidden_states[1:]) if idx in self.config.backbone_out_indices + ] + else: + backbone_hidden_states = outputs.intermediate_activations if return_dict else list(outputs[-1]) + backbone_hidden_states.extend( + feature for idx, feature in enumerate(hidden_states[1:]) if idx in self.config.backbone_out_indices[2:] + ) + + hidden_states = backbone_hidden_states + + hidden_states = self.neck(hidden_states=hidden_states) + + logits = self.head(hidden_states) + + auxiliary_logits = None + if self.auxiliary_head is not None: + auxiliary_logits = self.auxiliary_head(hidden_states[-1]) + + loss = None + if labels is not None: + # upsample logits to the images' original size + upsampled_logits = nn.functional.interpolate( + logits, size=labels.shape[-2:], mode="bilinear", align_corners=False + ) + if auxiliary_logits is not None: + upsampled_auxiliary_logits = nn.functional.interpolate( + auxiliary_logits, size=labels.shape[-2:], mode="bilinear", align_corners=False + ) + # compute weighted loss + loss_fct = CrossEntropyLoss(ignore_index=self.config.semantic_loss_ignore_index) + main_loss = loss_fct(upsampled_logits, labels) + auxiliary_loss = loss_fct(upsampled_auxiliary_logits, labels) + loss = main_loss + self.config.auxiliary_loss_weight * auxiliary_loss + + if not return_dict: + if output_hidden_states: + output = (logits,) + outputs[1:] + else: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SemanticSegmenterOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states if output_hidden_states else None, + attentions=outputs.attentions, + ) + + +__all__ = ["DPTForDepthEstimation", "DPTForSemanticSegmentation", "DPTModel", "DPTPreTrainedModel"] diff --git a/docs/transformers/build/lib/transformers/models/efficientnet/__init__.py b/docs/transformers/build/lib/transformers/models/efficientnet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..24d58e81167ec8729d04fa52ce96ebc1737a5982 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/efficientnet/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_efficientnet import * + from .image_processing_efficientnet import * + from .image_processing_efficientnet_fast import * + from .modeling_efficientnet import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/build/lib/transformers/models/efficientnet/configuration_efficientnet.py b/docs/transformers/build/lib/transformers/models/efficientnet/configuration_efficientnet.py new file mode 100644 index 0000000000000000000000000000000000000000..29df2ce0e34a8291ee484b3f93a96081f658d1db --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/efficientnet/configuration_efficientnet.py @@ -0,0 +1,169 @@ +# coding=utf-8 +# Copyright 2023 Google Research, Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""EfficientNet model configuration""" + +from collections import OrderedDict +from typing import List, Mapping + +from packaging import version + +from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class EfficientNetConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`EfficientNetModel`]. It is used to instantiate an + EfficientNet model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the EfficientNet + [google/efficientnet-b7](https://huggingface.co/google/efficientnet-b7) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + image_size (`int`, *optional*, defaults to 600): + The input image size. + width_coefficient (`float`, *optional*, defaults to 2.0): + Scaling coefficient for network width at each stage. + depth_coefficient (`float`, *optional*, defaults to 3.1): + Scaling coefficient for network depth at each stage. + depth_divisor `int`, *optional*, defaults to 8): + A unit of network width. + kernel_sizes (`List[int]`, *optional*, defaults to `[3, 3, 5, 3, 5, 5, 3]`): + List of kernel sizes to be used in each block. + in_channels (`List[int]`, *optional*, defaults to `[32, 16, 24, 40, 80, 112, 192]`): + List of input channel sizes to be used in each block for convolutional layers. + out_channels (`List[int]`, *optional*, defaults to `[16, 24, 40, 80, 112, 192, 320]`): + List of output channel sizes to be used in each block for convolutional layers. + depthwise_padding (`List[int]`, *optional*, defaults to `[]`): + List of block indices with square padding. + strides (`List[int]`, *optional*, defaults to `[1, 2, 2, 2, 1, 2, 1]`): + List of stride sizes to be used in each block for convolutional layers. + num_block_repeats (`List[int]`, *optional*, defaults to `[1, 2, 2, 3, 3, 4, 1]`): + List of the number of times each block is to repeated. + expand_ratios (`List[int]`, *optional*, defaults to `[1, 6, 6, 6, 6, 6, 6]`): + List of scaling coefficient of each block. + squeeze_expansion_ratio (`float`, *optional*, defaults to 0.25): + Squeeze expansion ratio. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in each block. If string, `"gelu"`, `"relu"`, + `"selu", `"gelu_new"`, `"silu"` and `"mish"` are supported. + hidden_dim (`int`, *optional*, defaults to 1280): + The hidden dimension of the layer before the classification head. + pooling_type (`str` or `function`, *optional*, defaults to `"mean"`): + Type of final pooling to be applied before the dense classification head. Available options are [`"mean"`, + `"max"`] + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + batch_norm_eps (`float`, *optional*, defaults to 1e-3): + The epsilon used by the batch normalization layers. + batch_norm_momentum (`float`, *optional*, defaults to 0.99): + The momentum used by the batch normalization layers. + dropout_rate (`float`, *optional*, defaults to 0.5): + The dropout rate to be applied before final classifier layer. + drop_connect_rate (`float`, *optional*, defaults to 0.2): + The drop rate for skip connections. + + Example: + ```python + >>> from transformers import EfficientNetConfig, EfficientNetModel + + >>> # Initializing a EfficientNet efficientnet-b7 style configuration + >>> configuration = EfficientNetConfig() + + >>> # Initializing a model (with random weights) from the efficientnet-b7 style configuration + >>> model = EfficientNetModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "efficientnet" + + def __init__( + self, + num_channels: int = 3, + image_size: int = 600, + width_coefficient: float = 2.0, + depth_coefficient: float = 3.1, + depth_divisor: int = 8, + kernel_sizes: List[int] = [3, 3, 5, 3, 5, 5, 3], + in_channels: List[int] = [32, 16, 24, 40, 80, 112, 192], + out_channels: List[int] = [16, 24, 40, 80, 112, 192, 320], + depthwise_padding: List[int] = [], + strides: List[int] = [1, 2, 2, 2, 1, 2, 1], + num_block_repeats: List[int] = [1, 2, 2, 3, 3, 4, 1], + expand_ratios: List[int] = [1, 6, 6, 6, 6, 6, 6], + squeeze_expansion_ratio: float = 0.25, + hidden_act: str = "swish", + hidden_dim: int = 2560, + pooling_type: str = "mean", + initializer_range: float = 0.02, + batch_norm_eps: float = 0.001, + batch_norm_momentum: float = 0.99, + dropout_rate: float = 0.5, + drop_connect_rate: float = 0.2, + **kwargs, + ): + super().__init__(**kwargs) + + self.num_channels = num_channels + self.image_size = image_size + self.width_coefficient = width_coefficient + self.depth_coefficient = depth_coefficient + self.depth_divisor = depth_divisor + self.kernel_sizes = kernel_sizes + self.in_channels = in_channels + self.out_channels = out_channels + self.depthwise_padding = depthwise_padding + self.strides = strides + self.num_block_repeats = num_block_repeats + self.expand_ratios = expand_ratios + self.squeeze_expansion_ratio = squeeze_expansion_ratio + self.hidden_act = hidden_act + self.hidden_dim = hidden_dim + self.pooling_type = pooling_type + self.initializer_range = initializer_range + self.batch_norm_eps = batch_norm_eps + self.batch_norm_momentum = batch_norm_momentum + self.dropout_rate = dropout_rate + self.drop_connect_rate = drop_connect_rate + self.num_hidden_layers = sum(num_block_repeats) * 4 + + +class EfficientNetOnnxConfig(OnnxConfig): + torch_onnx_minimum_version = version.parse("1.11") + + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict( + [ + ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}), + ] + ) + + @property + def atol_for_validation(self) -> float: + return 1e-5 + + +__all__ = ["EfficientNetConfig", "EfficientNetOnnxConfig"] diff --git a/docs/transformers/build/lib/transformers/models/efficientnet/convert_efficientnet_to_pytorch.py b/docs/transformers/build/lib/transformers/models/efficientnet/convert_efficientnet_to_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..e9988524aca04de2a1d600586ff01d9b9a3ea6c2 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/efficientnet/convert_efficientnet_to_pytorch.py @@ -0,0 +1,339 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert EfficientNet checkpoints from the original repository. + +URL: https://github.com/keras-team/keras/blob/v2.11.0/keras/applications/efficientnet.py""" + +import argparse +import json +import os + +import numpy as np +import PIL +import requests +import tensorflow.keras.applications.efficientnet as efficientnet +import torch +from huggingface_hub import hf_hub_download +from PIL import Image +from tensorflow.keras.preprocessing import image + +from transformers import ( + EfficientNetConfig, + EfficientNetForImageClassification, + EfficientNetImageProcessor, +) +from transformers.utils import logging + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + +model_classes = { + "b0": efficientnet.EfficientNetB0, + "b1": efficientnet.EfficientNetB1, + "b2": efficientnet.EfficientNetB2, + "b3": efficientnet.EfficientNetB3, + "b4": efficientnet.EfficientNetB4, + "b5": efficientnet.EfficientNetB5, + "b6": efficientnet.EfficientNetB6, + "b7": efficientnet.EfficientNetB7, +} + +CONFIG_MAP = { + "b0": { + "hidden_dim": 1280, + "width_coef": 1.0, + "depth_coef": 1.0, + "image_size": 224, + "dropout_rate": 0.2, + "dw_padding": [], + }, + "b1": { + "hidden_dim": 1280, + "width_coef": 1.0, + "depth_coef": 1.1, + "image_size": 240, + "dropout_rate": 0.2, + "dw_padding": [16], + }, + "b2": { + "hidden_dim": 1408, + "width_coef": 1.1, + "depth_coef": 1.2, + "image_size": 260, + "dropout_rate": 0.3, + "dw_padding": [5, 8, 16], + }, + "b3": { + "hidden_dim": 1536, + "width_coef": 1.2, + "depth_coef": 1.4, + "image_size": 300, + "dropout_rate": 0.3, + "dw_padding": [5, 18], + }, + "b4": { + "hidden_dim": 1792, + "width_coef": 1.4, + "depth_coef": 1.8, + "image_size": 380, + "dropout_rate": 0.4, + "dw_padding": [6], + }, + "b5": { + "hidden_dim": 2048, + "width_coef": 1.6, + "depth_coef": 2.2, + "image_size": 456, + "dropout_rate": 0.4, + "dw_padding": [13, 27], + }, + "b6": { + "hidden_dim": 2304, + "width_coef": 1.8, + "depth_coef": 2.6, + "image_size": 528, + "dropout_rate": 0.5, + "dw_padding": [31], + }, + "b7": { + "hidden_dim": 2560, + "width_coef": 2.0, + "depth_coef": 3.1, + "image_size": 600, + "dropout_rate": 0.5, + "dw_padding": [18], + }, +} + + +def get_efficientnet_config(model_name): + config = EfficientNetConfig() + config.hidden_dim = CONFIG_MAP[model_name]["hidden_dim"] + config.width_coefficient = CONFIG_MAP[model_name]["width_coef"] + config.depth_coefficient = CONFIG_MAP[model_name]["depth_coef"] + config.image_size = CONFIG_MAP[model_name]["image_size"] + config.dropout_rate = CONFIG_MAP[model_name]["dropout_rate"] + config.depthwise_padding = CONFIG_MAP[model_name]["dw_padding"] + + repo_id = "huggingface/label-files" + filename = "imagenet-1k-id2label.json" + config.num_labels = 1000 + id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r")) + id2label = {int(k): v for k, v in id2label.items()} + + config.id2label = id2label + config.label2id = {v: k for k, v in id2label.items()} + return config + + +# We will verify our results on an image of cute cats +def prepare_img(): + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + im = Image.open(requests.get(url, stream=True).raw) + return im + + +def convert_image_processor(model_name): + size = CONFIG_MAP[model_name]["image_size"] + preprocessor = EfficientNetImageProcessor( + size={"height": size, "width": size}, + image_mean=[0.485, 0.456, 0.406], + image_std=[0.47853944, 0.4732864, 0.47434163], + do_center_crop=False, + ) + return preprocessor + + +# here we list all keys to be renamed (original name on the left, our name on the right) +def rename_keys(original_param_names): + block_names = [v.split("_")[0].split("block")[1] for v in original_param_names if v.startswith("block")] + block_names = sorted(set(block_names)) + num_blocks = len(block_names) + block_name_mapping = {b: str(i) for b, i in zip(block_names, range(num_blocks))} + + rename_keys = [] + rename_keys.append(("stem_conv/kernel:0", "embeddings.convolution.weight")) + rename_keys.append(("stem_bn/gamma:0", "embeddings.batchnorm.weight")) + rename_keys.append(("stem_bn/beta:0", "embeddings.batchnorm.bias")) + rename_keys.append(("stem_bn/moving_mean:0", "embeddings.batchnorm.running_mean")) + rename_keys.append(("stem_bn/moving_variance:0", "embeddings.batchnorm.running_var")) + + for b in block_names: + hf_b = block_name_mapping[b] + rename_keys.append((f"block{b}_expand_conv/kernel:0", f"encoder.blocks.{hf_b}.expansion.expand_conv.weight")) + rename_keys.append((f"block{b}_expand_bn/gamma:0", f"encoder.blocks.{hf_b}.expansion.expand_bn.weight")) + rename_keys.append((f"block{b}_expand_bn/beta:0", f"encoder.blocks.{hf_b}.expansion.expand_bn.bias")) + rename_keys.append( + (f"block{b}_expand_bn/moving_mean:0", f"encoder.blocks.{hf_b}.expansion.expand_bn.running_mean") + ) + rename_keys.append( + (f"block{b}_expand_bn/moving_variance:0", f"encoder.blocks.{hf_b}.expansion.expand_bn.running_var") + ) + rename_keys.append( + (f"block{b}_dwconv/depthwise_kernel:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_conv.weight") + ) + rename_keys.append((f"block{b}_bn/gamma:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_norm.weight")) + rename_keys.append((f"block{b}_bn/beta:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_norm.bias")) + rename_keys.append( + (f"block{b}_bn/moving_mean:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_norm.running_mean") + ) + rename_keys.append( + (f"block{b}_bn/moving_variance:0", f"encoder.blocks.{hf_b}.depthwise_conv.depthwise_norm.running_var") + ) + + rename_keys.append((f"block{b}_se_reduce/kernel:0", f"encoder.blocks.{hf_b}.squeeze_excite.reduce.weight")) + rename_keys.append((f"block{b}_se_reduce/bias:0", f"encoder.blocks.{hf_b}.squeeze_excite.reduce.bias")) + rename_keys.append((f"block{b}_se_expand/kernel:0", f"encoder.blocks.{hf_b}.squeeze_excite.expand.weight")) + rename_keys.append((f"block{b}_se_expand/bias:0", f"encoder.blocks.{hf_b}.squeeze_excite.expand.bias")) + rename_keys.append( + (f"block{b}_project_conv/kernel:0", f"encoder.blocks.{hf_b}.projection.project_conv.weight") + ) + rename_keys.append((f"block{b}_project_bn/gamma:0", f"encoder.blocks.{hf_b}.projection.project_bn.weight")) + rename_keys.append((f"block{b}_project_bn/beta:0", f"encoder.blocks.{hf_b}.projection.project_bn.bias")) + rename_keys.append( + (f"block{b}_project_bn/moving_mean:0", f"encoder.blocks.{hf_b}.projection.project_bn.running_mean") + ) + rename_keys.append( + (f"block{b}_project_bn/moving_variance:0", f"encoder.blocks.{hf_b}.projection.project_bn.running_var") + ) + + rename_keys.append(("top_conv/kernel:0", "encoder.top_conv.weight")) + rename_keys.append(("top_bn/gamma:0", "encoder.top_bn.weight")) + rename_keys.append(("top_bn/beta:0", "encoder.top_bn.bias")) + rename_keys.append(("top_bn/moving_mean:0", "encoder.top_bn.running_mean")) + rename_keys.append(("top_bn/moving_variance:0", "encoder.top_bn.running_var")) + + key_mapping = {} + for item in rename_keys: + if item[0] in original_param_names: + key_mapping[item[0]] = "efficientnet." + item[1] + + key_mapping["predictions/kernel:0"] = "classifier.weight" + key_mapping["predictions/bias:0"] = "classifier.bias" + return key_mapping + + +def replace_params(hf_params, tf_params, key_mapping): + for key, value in tf_params.items(): + if "normalization" in key: + continue + + hf_key = key_mapping[key] + if "_conv" in key and "kernel" in key: + new_hf_value = torch.from_numpy(value).permute(3, 2, 0, 1) + elif "depthwise_kernel" in key: + new_hf_value = torch.from_numpy(value).permute(2, 3, 0, 1) + elif "kernel" in key: + new_hf_value = torch.from_numpy(np.transpose(value)) + else: + new_hf_value = torch.from_numpy(value) + + # Replace HF parameters with original TF model parameters + assert hf_params[hf_key].shape == new_hf_value.shape + hf_params[hf_key].copy_(new_hf_value) + + +@torch.no_grad() +def convert_efficientnet_checkpoint(model_name, pytorch_dump_folder_path, save_model, push_to_hub): + """ + Copy/paste/tweak model's weights to our EfficientNet structure. + """ + # Load original model + original_model = model_classes[model_name]( + include_top=True, + weights="imagenet", + input_tensor=None, + input_shape=None, + pooling=None, + classes=1000, + classifier_activation="softmax", + ) + + tf_params = original_model.trainable_variables + tf_non_train_params = original_model.non_trainable_variables + tf_params = {param.name: param.numpy() for param in tf_params} + for param in tf_non_train_params: + tf_params[param.name] = param.numpy() + tf_param_names = list(tf_params.keys()) + + # Load HuggingFace model + config = get_efficientnet_config(model_name) + hf_model = EfficientNetForImageClassification(config).eval() + hf_params = hf_model.state_dict() + + # Create src-to-dst parameter name mapping dictionary + print("Converting parameters...") + key_mapping = rename_keys(tf_param_names) + replace_params(hf_params, tf_params, key_mapping) + + # Initialize preprocessor and preprocess input image + preprocessor = convert_image_processor(model_name) + inputs = preprocessor(images=prepare_img(), return_tensors="pt") + + # HF model inference + hf_model.eval() + with torch.no_grad(): + outputs = hf_model(**inputs) + hf_logits = outputs.logits.detach().numpy() + + # Original model inference + original_model.trainable = False + image_size = CONFIG_MAP[model_name]["image_size"] + img = prepare_img().resize((image_size, image_size), resample=PIL.Image.NEAREST) + x = image.img_to_array(img) + x = np.expand_dims(x, axis=0) + original_logits = original_model.predict(x) + + # Check whether original and HF model outputs match -> np.allclose + assert np.allclose(original_logits, hf_logits, atol=1e-3), "The predicted logits are not the same." + print("Model outputs match!") + + if save_model: + # Create folder to save model + if not os.path.isdir(pytorch_dump_folder_path): + os.mkdir(pytorch_dump_folder_path) + # Save converted model and image processor + hf_model.save_pretrained(pytorch_dump_folder_path) + preprocessor.save_pretrained(pytorch_dump_folder_path) + + if push_to_hub: + # Push model and image processor to hub + print(f"Pushing converted {model_name} to the hub...") + model_name = f"efficientnet-{model_name}" + preprocessor.push_to_hub(model_name) + hf_model.push_to_hub(model_name) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--model_name", + default="b0", + type=str, + help="Version name of the EfficientNet model you want to convert, select from [b0, b1, b2, b3, b4, b5, b6, b7].", + ) + parser.add_argument( + "--pytorch_dump_folder_path", + default="hf_model", + type=str, + help="Path to the output PyTorch model directory.", + ) + parser.add_argument("--save_model", action="store_true", help="Save model to local") + parser.add_argument("--push_to_hub", action="store_true", help="Push model and image processor to the hub") + + args = parser.parse_args() + convert_efficientnet_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.save_model, args.push_to_hub) diff --git a/docs/transformers/build/lib/transformers/models/efficientnet/image_processing_efficientnet.py b/docs/transformers/build/lib/transformers/models/efficientnet/image_processing_efficientnet.py new file mode 100644 index 0000000000000000000000000000000000000000..612ede7086ead59ff28a051faf957a9120fdfa61 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/efficientnet/image_processing_efficientnet.py @@ -0,0 +1,369 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for EfficientNet.""" + +from typing import Dict, List, Optional, Union + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import rescale, resize, to_channel_dimension_format +from ...image_utils import ( + IMAGENET_STANDARD_MEAN, + IMAGENET_STANDARD_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + infer_channel_dimension_format, + is_scaled_image, + make_list_of_images, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging + + +if is_vision_available(): + import PIL + + +logger = logging.get_logger(__name__) + + +class EfficientNetImageProcessor(BaseImageProcessor): + r""" + Constructs a EfficientNet image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by + `do_resize` in `preprocess`. + size (`Dict[str, int]` *optional*, defaults to `{"height": 346, "width": 346}`): + Size of the image after `resize`. Can be overridden by `size` in `preprocess`. + resample (`PILImageResampling` filter, *optional*, defaults to 0): + Resampling filter to use if resizing the image. Can be overridden by `resample` in `preprocess`. + do_center_crop (`bool`, *optional*, defaults to `False`): + Whether to center crop the image. If the input size is smaller than `crop_size` along any edge, the image + is padded with 0's and then center cropped. Can be overridden by `do_center_crop` in `preprocess`. + crop_size (`Dict[str, int]`, *optional*, defaults to `{"height": 289, "width": 289}`): + Desired output size when applying center-cropping. Can be overridden by `crop_size` in `preprocess`. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the + `preprocess` method. + rescale_offset (`bool`, *optional*, defaults to `False`): + Whether to rescale the image between [-scale_range, scale_range] instead of [0, scale_range]. Can be + overridden by the `rescale_factor` parameter in the `preprocess` method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale` + parameter in the `preprocess` method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess` + method. + image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + include_top (`bool`, *optional*, defaults to `True`): + Whether to rescale the image again. Should be set to True if the inputs are used for image classification. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize: bool = True, + size: Dict[str, int] = None, + resample: PILImageResampling = PIL.Image.NEAREST, + do_center_crop: bool = False, + crop_size: Dict[str, int] = None, + rescale_factor: Union[int, float] = 1 / 255, + rescale_offset: bool = False, + do_rescale: bool = True, + do_normalize: bool = True, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + include_top: bool = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"height": 346, "width": 346} + size = get_size_dict(size) + crop_size = crop_size if crop_size is not None else {"height": 289, "width": 289} + crop_size = get_size_dict(crop_size, param_name="crop_size") + + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_center_crop = do_center_crop + self.crop_size = crop_size + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.rescale_offset = rescale_offset + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN + self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD + self.include_top = include_top + + # Copied from transformers.models.vit.image_processing_vit.ViTImageProcessor.resize with PILImageResampling.BILINEAR->PILImageResampling.NEAREST + def resize( + self, + image: np.ndarray, + size: Dict[str, int], + resample: PILImageResampling = PILImageResampling.NEAREST, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize an image to `(size["height"], size["width"])`. + + Args: + image (`np.ndarray`): + Image to resize. + size (`Dict[str, int]`): + Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.NEAREST`): + `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.NEAREST`. + data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the output image. If unset, the channel dimension format of the input + image is used. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + + Returns: + `np.ndarray`: The resized image. + """ + size = get_size_dict(size) + if "height" not in size or "width" not in size: + raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}") + output_size = (size["height"], size["width"]) + return resize( + image, + size=output_size, + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + def rescale( + self, + image: np.ndarray, + scale: Union[int, float], + offset: bool = True, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ): + """ + Rescale an image by a scale factor. + + If `offset` is `True`, the image has its values rescaled by `scale` and then offset by 1. If `scale` is + 1/127.5, the image is rescaled between [-1, 1]. + image = image * scale - 1 + + If `offset` is `False`, and `scale` is 1/255, the image is rescaled between [0, 1]. + image = image * scale + + Args: + image (`np.ndarray`): + Image to rescale. + scale (`int` or `float`): + Scale to apply to the image. + offset (`bool`, *optional*): + Whether to scale the image in both negative and positive directions. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + rescaled_image = rescale( + image, scale=scale, data_format=data_format, input_data_format=input_data_format, **kwargs + ) + + if offset: + rescaled_image = rescaled_image - 1 + + return rescaled_image + + @filter_out_non_signature_kwargs() + def preprocess( + self, + images: ImageInput, + do_resize: Optional[bool] = None, + size: Dict[str, int] = None, + resample=None, + do_center_crop: Optional[bool] = None, + crop_size: Dict[str, int] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + rescale_offset: Optional[bool] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + include_top: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: ChannelDimension = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> PIL.Image.Image: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image after `resize`. + resample (`PILImageResampling`, *optional*, defaults to `self.resample`): + PILImageResampling filter to use if resizing the image Only has an effect if `do_resize` is set to + `True`. + do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`): + Whether to center crop the image. + crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`): + Size of the image after center crop. If one edge the image is smaller than `crop_size`, it will be + padded with zeros and then cropped + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image values between [0 - 1]. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + rescale_offset (`bool`, *optional*, defaults to `self.rescale_offset`): + Whether to rescale the image between [-scale_range, scale_range] instead of [0, scale_range]. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation. + include_top (`bool`, *optional*, defaults to `self.include_top`): + Rescales the image again for image classification if set to True. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - `None`: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `ChannelDimension.LAST`: image in (height, width, num_channels) format. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + do_resize = do_resize if do_resize is not None else self.do_resize + resample = resample if resample is not None else self.resample + do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + rescale_offset = rescale_offset if rescale_offset is not None else self.rescale_offset + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + include_top = include_top if include_top is not None else self.include_top + + size = size if size is not None else self.size + size = get_size_dict(size) + crop_size = crop_size if crop_size is not None else self.crop_size + crop_size = get_size_dict(crop_size, param_name="crop_size") + + images = make_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_resize=do_resize, + size=size, + resample=resample, + ) + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if do_rescale and is_scaled_image(images[0]): + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + if do_resize: + images = [ + self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format) + for image in images + ] + + if do_center_crop: + images = [ + self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) for image in images + ] + + if do_rescale: + images = [ + self.rescale( + image=image, scale=rescale_factor, offset=rescale_offset, input_data_format=input_data_format + ) + for image in images + ] + + if do_normalize: + images = [ + self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) + for image in images + ] + + if include_top: + images = [ + self.normalize(image=image, mean=0, std=image_std, input_data_format=input_data_format) + for image in images + ] + + images = [ + to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images + ] + + data = {"pixel_values": images} + return BatchFeature(data=data, tensor_type=return_tensors) + + +__all__ = ["EfficientNetImageProcessor"] diff --git a/docs/transformers/build/lib/transformers/models/efficientnet/image_processing_efficientnet_fast.py b/docs/transformers/build/lib/transformers/models/efficientnet/image_processing_efficientnet_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..fb639564014fa7eaa2c3114bce6a23e7e76ecf46 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/efficientnet/image_processing_efficientnet_fast.py @@ -0,0 +1,226 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for EfficientNet.""" + +from functools import lru_cache +from typing import Optional, Union + +from ...image_processing_utils_fast import ( + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS, + BaseImageProcessorFast, + BatchFeature, + DefaultFastImageProcessorKwargs, +) +from ...image_transforms import group_images_by_shape, reorder_images +from ...image_utils import ( + IMAGENET_STANDARD_MEAN, + IMAGENET_STANDARD_STD, + ImageInput, + PILImageResampling, + SizeDict, +) +from ...processing_utils import Unpack +from ...utils import ( + TensorType, + add_start_docstrings, + is_torch_available, + is_torchvision_available, + is_torchvision_v2_available, +) + + +if is_torch_available(): + import torch + +if is_torchvision_available(): + if is_torchvision_v2_available(): + from torchvision.transforms.v2 import functional as F + else: + from torchvision.transforms import functional as F + + +class EfficientNetFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): + rescale_offset: bool + include_top: bool + + +@add_start_docstrings( + "Constructs a fast EfficientNet image processor.", + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, +) +class EfficientNetImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.NEAREST + image_mean = IMAGENET_STANDARD_MEAN + image_std = IMAGENET_STANDARD_STD + size = {"height": 346, "width": 346} + crop_size = {"height": 289, "width": 289} + do_resize = True + do_center_crop = False + do_rescale = True + rescale_factor = 1 / 255 + rescale_offset = False + do_normalize = True + include_top = True + valid_kwargs = EfficientNetFastImageProcessorKwargs + + def __init__(self, **kwargs: Unpack[EfficientNetFastImageProcessorKwargs]): + super().__init__(**kwargs) + + def rescale( + self, + image: "torch.Tensor", + scale: float, + offset: Optional[bool] = True, + **kwargs, + ) -> "torch.Tensor": + """ + Rescale an image by a scale factor. + + If `offset` is `True`, the image has its values rescaled by `scale` and then offset by 1. If `scale` is + 1/127.5, the image is rescaled between [-1, 1]. + image = image * scale - 1 + + If `offset` is `False`, and `scale` is 1/255, the image is rescaled between [0, 1]. + image = image * scale + + Args: + image (`torch.Tensor`): + Image to rescale. + scale (`float`): + The scaling factor to rescale pixel values by. + offset (`bool`, *optional*): + Whether to scale the image in both negative and positive directions. + + Returns: + `torch.Tensor`: The rescaled image. + """ + + rescaled_image = image * scale + + if offset: + rescaled_image -= 1 + + return rescaled_image + + @lru_cache(maxsize=10) + def _fuse_mean_std_and_rescale_factor( + self, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + device: Optional["torch.device"] = None, + rescale_offset: Optional[bool] = False, + ) -> tuple: + if do_rescale and do_normalize and not rescale_offset: + # Fused rescale and normalize + image_mean = torch.tensor(image_mean, device=device) * (1.0 / rescale_factor) + image_std = torch.tensor(image_std, device=device) * (1.0 / rescale_factor) + do_rescale = False + return image_mean, image_std, do_rescale + + def rescale_and_normalize( + self, + images: "torch.Tensor", + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + image_mean: Union[float, list[float]], + image_std: Union[float, list[float]], + rescale_offset: bool = False, + ) -> "torch.Tensor": + """ + Rescale and normalize images. + """ + image_mean, image_std, do_rescale = self._fuse_mean_std_and_rescale_factor( + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + device=images.device, + rescale_offset=rescale_offset, + ) + # if/elif as we use fused rescale and normalize if both are set to True + if do_rescale: + images = self.rescale(images, rescale_factor, rescale_offset) + if do_normalize: + images = self.normalize(images.to(dtype=torch.float32), image_mean, image_std) + + return images + + def _preprocess( + self, + images: list["torch.Tensor"], + do_resize: bool, + size: SizeDict, + interpolation: Optional["F.InterpolationMode"], + do_center_crop: bool, + crop_size: SizeDict, + do_rescale: bool, + rescale_factor: float, + rescale_offset: bool, + do_normalize: bool, + include_top: bool, + image_mean: Optional[Union[float, list[float]]], + image_std: Optional[Union[float, list[float]]], + return_tensors: Optional[Union[str, TensorType]], + **kwargs, + ) -> BatchFeature: + # Group images by size for batched resizing + grouped_images, grouped_images_index = group_images_by_shape(images) + resized_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_resize: + stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation) + resized_images_grouped[shape] = stacked_images + resized_images = reorder_images(resized_images_grouped, grouped_images_index) + + # Group images by size for further processing + # Needed in case do_resize is False, or resize returns images with different sizes + grouped_images, grouped_images_index = group_images_by_shape(resized_images) + processed_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_center_crop: + stacked_images = self.center_crop(stacked_images, crop_size) + # Fused rescale and normalize + stacked_images = self.rescale_and_normalize( + stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std, rescale_offset + ) + if include_top: + stacked_images = self.normalize(stacked_images, 0, image_std) + processed_images_grouped[shape] = stacked_images + + processed_images = reorder_images(processed_images_grouped, grouped_images_index) + processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images + + return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) + + @add_start_docstrings( + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS, + """ + rescale_offset (`bool`, *optional*, defaults to `self.rescale_offset`): + Whether to rescale the image between [-max_range/2, scale_range/2] instead of [0, scale_range]. + include_top (`bool`, *optional*, defaults to `self.include_top`): + Normalize the image again with the standard deviation only for image classification if set to True. + """, + ) + def preprocess(self, images: ImageInput, **kwargs: Unpack[EfficientNetFastImageProcessorKwargs]) -> BatchFeature: + return super().preprocess(images, **kwargs) + + +__all__ = ["EfficientNetImageProcessorFast"] diff --git a/docs/transformers/build/lib/transformers/models/efficientnet/modeling_efficientnet.py b/docs/transformers/build/lib/transformers/models/efficientnet/modeling_efficientnet.py new file mode 100644 index 0000000000000000000000000000000000000000..9e0b89072921b47bd42a5febf9cdd576a31e33e1 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/efficientnet/modeling_efficientnet.py @@ -0,0 +1,647 @@ +# coding=utf-8 +# Copyright 2023 Google Research, Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch EfficientNet model.""" + +import math +from typing import Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...modeling_outputs import ( + BaseModelOutputWithNoAttention, + BaseModelOutputWithPoolingAndNoAttention, + ImageClassifierOutputWithNoAttention, +) +from ...modeling_utils import PreTrainedModel +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, +) +from .configuration_efficientnet import EfficientNetConfig + + +logger = logging.get_logger(__name__) + +# General docstring +_CONFIG_FOR_DOC = "EfficientNetConfig" + +# Base docstring +_CHECKPOINT_FOR_DOC = "google/efficientnet-b7" +_EXPECTED_OUTPUT_SHAPE = [1, 768, 7, 7] + +# Image classification docstring +_IMAGE_CLASS_CHECKPOINT = "google/efficientnet-b7" +_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat" + + +EFFICIENTNET_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it + as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`EfficientNetConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +EFFICIENTNET_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See + [`AutoImageProcessor.__call__`] for details. + + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +def round_filters(config: EfficientNetConfig, num_channels: int): + r""" + Round number of filters based on depth multiplier. + """ + divisor = config.depth_divisor + num_channels *= config.width_coefficient + new_dim = max(divisor, int(num_channels + divisor / 2) // divisor * divisor) + + # Make sure that round down does not go down by more than 10%. + if new_dim < 0.9 * num_channels: + new_dim += divisor + + return int(new_dim) + + +def correct_pad(kernel_size: Union[int, Tuple], adjust: bool = True): + r""" + Utility function to get the tuple padding value for the depthwise convolution. + + Args: + kernel_size (`int` or `tuple`): + Kernel size of the convolution layers. + adjust (`bool`, *optional*, defaults to `True`): + Adjusts padding value to apply to right and bottom sides of the input. + """ + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size) + + correct = (kernel_size[0] // 2, kernel_size[1] // 2) + if adjust: + return (correct[1] - 1, correct[1], correct[0] - 1, correct[0]) + else: + return (correct[1], correct[1], correct[0], correct[0]) + + +class EfficientNetEmbeddings(nn.Module): + r""" + A module that corresponds to the stem module of the original work. + """ + + def __init__(self, config: EfficientNetConfig): + super().__init__() + + self.out_dim = round_filters(config, 32) + self.padding = nn.ZeroPad2d(padding=(0, 1, 0, 1)) + self.convolution = nn.Conv2d( + config.num_channels, self.out_dim, kernel_size=3, stride=2, padding="valid", bias=False + ) + self.batchnorm = nn.BatchNorm2d(self.out_dim, eps=config.batch_norm_eps, momentum=config.batch_norm_momentum) + self.activation = ACT2FN[config.hidden_act] + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + features = self.padding(pixel_values) + features = self.convolution(features) + features = self.batchnorm(features) + features = self.activation(features) + + return features + + +class EfficientNetDepthwiseConv2d(nn.Conv2d): + def __init__( + self, + in_channels, + depth_multiplier=1, + kernel_size=3, + stride=1, + padding=0, + dilation=1, + bias=True, + padding_mode="zeros", + ): + out_channels = in_channels * depth_multiplier + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=in_channels, + bias=bias, + padding_mode=padding_mode, + ) + + +class EfficientNetExpansionLayer(nn.Module): + r""" + This corresponds to the expansion phase of each block in the original implementation. + """ + + def __init__(self, config: EfficientNetConfig, in_dim: int, out_dim: int, stride: int): + super().__init__() + self.expand_conv = nn.Conv2d( + in_channels=in_dim, + out_channels=out_dim, + kernel_size=1, + padding="same", + bias=False, + ) + self.expand_bn = nn.BatchNorm2d(num_features=out_dim, eps=config.batch_norm_eps) + self.expand_act = ACT2FN[config.hidden_act] + + def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor: + # Expand phase + hidden_states = self.expand_conv(hidden_states) + hidden_states = self.expand_bn(hidden_states) + hidden_states = self.expand_act(hidden_states) + + return hidden_states + + +class EfficientNetDepthwiseLayer(nn.Module): + r""" + This corresponds to the depthwise convolution phase of each block in the original implementation. + """ + + def __init__( + self, + config: EfficientNetConfig, + in_dim: int, + stride: int, + kernel_size: int, + adjust_padding: bool, + ): + super().__init__() + self.stride = stride + conv_pad = "valid" if self.stride == 2 else "same" + padding = correct_pad(kernel_size, adjust=adjust_padding) + + self.depthwise_conv_pad = nn.ZeroPad2d(padding=padding) + self.depthwise_conv = EfficientNetDepthwiseConv2d( + in_dim, kernel_size=kernel_size, stride=stride, padding=conv_pad, bias=False + ) + self.depthwise_norm = nn.BatchNorm2d( + num_features=in_dim, eps=config.batch_norm_eps, momentum=config.batch_norm_momentum + ) + self.depthwise_act = ACT2FN[config.hidden_act] + + def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor: + # Depthwise convolution + if self.stride == 2: + hidden_states = self.depthwise_conv_pad(hidden_states) + + hidden_states = self.depthwise_conv(hidden_states) + hidden_states = self.depthwise_norm(hidden_states) + hidden_states = self.depthwise_act(hidden_states) + + return hidden_states + + +class EfficientNetSqueezeExciteLayer(nn.Module): + r""" + This corresponds to the Squeeze and Excitement phase of each block in the original implementation. + """ + + def __init__(self, config: EfficientNetConfig, in_dim: int, expand_dim: int, expand: bool = False): + super().__init__() + self.dim = expand_dim if expand else in_dim + self.dim_se = max(1, int(in_dim * config.squeeze_expansion_ratio)) + + self.squeeze = nn.AdaptiveAvgPool2d(output_size=1) + self.reduce = nn.Conv2d( + in_channels=self.dim, + out_channels=self.dim_se, + kernel_size=1, + padding="same", + ) + self.expand = nn.Conv2d( + in_channels=self.dim_se, + out_channels=self.dim, + kernel_size=1, + padding="same", + ) + self.act_reduce = ACT2FN[config.hidden_act] + self.act_expand = nn.Sigmoid() + + def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor: + inputs = hidden_states + hidden_states = self.squeeze(hidden_states) + hidden_states = self.reduce(hidden_states) + hidden_states = self.act_reduce(hidden_states) + + hidden_states = self.expand(hidden_states) + hidden_states = self.act_expand(hidden_states) + hidden_states = torch.mul(inputs, hidden_states) + + return hidden_states + + +class EfficientNetFinalBlockLayer(nn.Module): + r""" + This corresponds to the final phase of each block in the original implementation. + """ + + def __init__( + self, config: EfficientNetConfig, in_dim: int, out_dim: int, stride: int, drop_rate: float, id_skip: bool + ): + super().__init__() + self.apply_dropout = stride == 1 and not id_skip + self.project_conv = nn.Conv2d( + in_channels=in_dim, + out_channels=out_dim, + kernel_size=1, + padding="same", + bias=False, + ) + self.project_bn = nn.BatchNorm2d( + num_features=out_dim, eps=config.batch_norm_eps, momentum=config.batch_norm_momentum + ) + self.dropout = nn.Dropout(p=drop_rate) + + def forward(self, embeddings: torch.FloatTensor, hidden_states: torch.FloatTensor) -> torch.Tensor: + hidden_states = self.project_conv(hidden_states) + hidden_states = self.project_bn(hidden_states) + + if self.apply_dropout: + hidden_states = self.dropout(hidden_states) + hidden_states = hidden_states + embeddings + + return hidden_states + + +class EfficientNetBlock(nn.Module): + r""" + This corresponds to the expansion and depthwise convolution phase of each block in the original implementation. + + Args: + config ([`EfficientNetConfig`]): + Model configuration class. + in_dim (`int`): + Number of input channels. + out_dim (`int`): + Number of output channels. + stride (`int`): + Stride size to be used in convolution layers. + expand_ratio (`int`): + Expand ratio to set the output dimensions for the expansion and squeeze-excite layers. + kernel_size (`int`): + Kernel size for the depthwise convolution layer. + drop_rate (`float`): + Dropout rate to be used in the final phase of each block. + id_skip (`bool`): + Whether to apply dropout and sum the final hidden states with the input embeddings during the final phase + of each block. Set to `True` for the first block of each stage. + adjust_padding (`bool`): + Whether to apply padding to only right and bottom side of the input kernel before the depthwise convolution + operation, set to `True` for inputs with odd input sizes. + """ + + def __init__( + self, + config: EfficientNetConfig, + in_dim: int, + out_dim: int, + stride: int, + expand_ratio: int, + kernel_size: int, + drop_rate: float, + id_skip: bool, + adjust_padding: bool, + ): + super().__init__() + self.expand_ratio = expand_ratio + self.expand = True if self.expand_ratio != 1 else False + expand_in_dim = in_dim * expand_ratio + + if self.expand: + self.expansion = EfficientNetExpansionLayer( + config=config, in_dim=in_dim, out_dim=expand_in_dim, stride=stride + ) + + self.depthwise_conv = EfficientNetDepthwiseLayer( + config=config, + in_dim=expand_in_dim if self.expand else in_dim, + stride=stride, + kernel_size=kernel_size, + adjust_padding=adjust_padding, + ) + self.squeeze_excite = EfficientNetSqueezeExciteLayer( + config=config, in_dim=in_dim, expand_dim=expand_in_dim, expand=self.expand + ) + self.projection = EfficientNetFinalBlockLayer( + config=config, + in_dim=expand_in_dim if self.expand else in_dim, + out_dim=out_dim, + stride=stride, + drop_rate=drop_rate, + id_skip=id_skip, + ) + + def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor: + embeddings = hidden_states + # Expansion and depthwise convolution phase + if self.expand_ratio != 1: + hidden_states = self.expansion(hidden_states) + hidden_states = self.depthwise_conv(hidden_states) + + # Squeeze and excite phase + hidden_states = self.squeeze_excite(hidden_states) + hidden_states = self.projection(embeddings, hidden_states) + return hidden_states + + +class EfficientNetEncoder(nn.Module): + r""" + Forward propogates the embeddings through each EfficientNet block. + + Args: + config ([`EfficientNetConfig`]): + Model configuration class. + """ + + def __init__(self, config: EfficientNetConfig): + super().__init__() + self.config = config + self.depth_coefficient = config.depth_coefficient + + def round_repeats(repeats): + # Round number of block repeats based on depth multiplier. + return int(math.ceil(self.depth_coefficient * repeats)) + + num_base_blocks = len(config.in_channels) + num_blocks = sum(round_repeats(n) for n in config.num_block_repeats) + + curr_block_num = 0 + blocks = [] + for i in range(num_base_blocks): + in_dim = round_filters(config, config.in_channels[i]) + out_dim = round_filters(config, config.out_channels[i]) + stride = config.strides[i] + kernel_size = config.kernel_sizes[i] + expand_ratio = config.expand_ratios[i] + + for j in range(round_repeats(config.num_block_repeats[i])): + id_skip = True if j == 0 else False + stride = 1 if j > 0 else stride + in_dim = out_dim if j > 0 else in_dim + adjust_padding = False if curr_block_num in config.depthwise_padding else True + drop_rate = config.drop_connect_rate * curr_block_num / num_blocks + + block = EfficientNetBlock( + config=config, + in_dim=in_dim, + out_dim=out_dim, + stride=stride, + kernel_size=kernel_size, + expand_ratio=expand_ratio, + drop_rate=drop_rate, + id_skip=id_skip, + adjust_padding=adjust_padding, + ) + blocks.append(block) + curr_block_num += 1 + + self.blocks = nn.ModuleList(blocks) + self.top_conv = nn.Conv2d( + in_channels=out_dim, + out_channels=round_filters(config, 1280), + kernel_size=1, + padding="same", + bias=False, + ) + self.top_bn = nn.BatchNorm2d( + num_features=config.hidden_dim, eps=config.batch_norm_eps, momentum=config.batch_norm_momentum + ) + self.top_activation = ACT2FN[config.hidden_act] + + def forward( + self, + hidden_states: torch.FloatTensor, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + ) -> BaseModelOutputWithNoAttention: + all_hidden_states = (hidden_states,) if output_hidden_states else None + + for block in self.blocks: + hidden_states = block(hidden_states) + if output_hidden_states: + all_hidden_states += (hidden_states,) + + hidden_states = self.top_conv(hidden_states) + hidden_states = self.top_bn(hidden_states) + hidden_states = self.top_activation(hidden_states) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states] if v is not None) + + return BaseModelOutputWithNoAttention( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + ) + + +class EfficientNetPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = EfficientNetConfig + base_model_prefix = "efficientnet" + main_input_name = "pixel_values" + _no_split_modules = [] + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +@add_start_docstrings( + "The bare EfficientNet model outputting raw features without any specific head on top.", + EFFICIENTNET_START_DOCSTRING, +) +class EfficientNetModel(EfficientNetPreTrainedModel): + def __init__(self, config: EfficientNetConfig): + super().__init__(config) + self.config = config + self.embeddings = EfficientNetEmbeddings(config) + self.encoder = EfficientNetEncoder(config) + + # Final pooling layer + if config.pooling_type == "mean": + self.pooler = nn.AvgPool2d(config.hidden_dim, ceil_mode=True) + elif config.pooling_type == "max": + self.pooler = nn.MaxPool2d(config.hidden_dim, ceil_mode=True) + else: + raise ValueError(f"config.pooling must be one of ['mean', 'max'] got {config.pooling}") + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(EFFICIENTNET_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPoolingAndNoAttention, + config_class=_CONFIG_FOR_DOC, + modality="vision", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPoolingAndNoAttention]: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + embedding_output = self.embeddings(pixel_values) + + encoder_outputs = self.encoder( + embedding_output, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # Apply pooling + last_hidden_state = encoder_outputs[0] + pooled_output = self.pooler(last_hidden_state) + # Reshape (batch_size, 1280, 1 , 1) -> (batch_size, 1280) + pooled_output = pooled_output.reshape(pooled_output.shape[:2]) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndNoAttention( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + ) + + +@add_start_docstrings( + """ + EfficientNet Model with an image classification head on top (a linear layer on top of the pooled features), e.g. + for ImageNet. + """, + EFFICIENTNET_START_DOCSTRING, +) +class EfficientNetForImageClassification(EfficientNetPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + self.efficientnet = EfficientNetModel(config) + # Classifier head + self.dropout = nn.Dropout(p=config.dropout_rate) + self.classifier = nn.Linear(config.hidden_dim, self.num_labels) if self.num_labels > 0 else nn.Identity() + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(EFFICIENTNET_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_IMAGE_CLASS_CHECKPOINT, + output_type=ImageClassifierOutputWithNoAttention, + config_class=_CONFIG_FOR_DOC, + expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, + ) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, ImageClassifierOutputWithNoAttention]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.efficientnet(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict) + + pooled_output = outputs.pooler_output if return_dict else outputs[1] + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutputWithNoAttention( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + ) + + +__all__ = ["EfficientNetForImageClassification", "EfficientNetModel", "EfficientNetPreTrainedModel"] diff --git a/docs/transformers/build/lib/transformers/models/electra/__init__.py b/docs/transformers/build/lib/transformers/models/electra/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a78ed5c42aea51038335efabde5b03e333592ed6 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/electra/__init__.py @@ -0,0 +1,31 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_electra import * + from .modeling_electra import * + from .modeling_flax_electra import * + from .modeling_tf_electra import * + from .tokenization_electra import * + from .tokenization_electra_fast import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/build/lib/transformers/models/electra/configuration_electra.py b/docs/transformers/build/lib/transformers/models/electra/configuration_electra.py new file mode 100644 index 0000000000000000000000000000000000000000..20b242c0f8d65fd5a4a3109fdf3b58a3ff7b9181 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/electra/configuration_electra.py @@ -0,0 +1,187 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""ELECTRA model configuration""" + +from collections import OrderedDict +from typing import Mapping + +from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class ElectraConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`ElectraModel`] or a [`TFElectraModel`]. It is + used to instantiate a ELECTRA model according to the specified arguments, defining the model architecture. + Instantiating a configuration with the defaults will yield a similar configuration to that of the ELECTRA + [google/electra-small-discriminator](https://huggingface.co/google/electra-small-discriminator) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 30522): + Vocabulary size of the ELECTRA model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`ElectraModel`] or [`TFElectraModel`]. + embedding_size (`int`, *optional*, defaults to 128): + Dimensionality of the encoder layers and the pooler layer. + hidden_size (`int`, *optional*, defaults to 256): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 4): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 1024): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + type_vocab_size (`int`, *optional*, defaults to 2): + The vocabulary size of the `token_type_ids` passed when calling [`ElectraModel`] or [`TFElectraModel`]. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + summary_type (`str`, *optional*, defaults to `"first"`): + Argument used when doing sequence summary. Used in the sequence classification and multiple choice models. + + Has to be one of the following options: + + - `"last"`: Take the last token hidden state (like XLNet). + - `"first"`: Take the first token hidden state (like BERT). + - `"mean"`: Take the mean of all tokens hidden states. + - `"cls_index"`: Supply a Tensor of classification token position (like GPT/GPT-2). + - `"attn"`: Not implemented now, use multi-head attention. + summary_use_proj (`bool`, *optional*, defaults to `True`): + Argument used when doing sequence summary. Used in the sequence classification and multiple choice models. + + Whether or not to add a projection after the vector extraction. + summary_activation (`str`, *optional*): + Argument used when doing sequence summary. Used in the sequence classification and multiple choice models. + + Pass `"gelu"` for a gelu activation to the output, any other value will result in no activation. + summary_last_dropout (`float`, *optional*, defaults to 0.0): + Argument used when doing sequence summary. Used in the sequence classification and multiple choice models. + + The dropout ratio to be used after the projection and activation. + position_embedding_type (`str`, *optional*, defaults to `"absolute"`): + Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For + positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to + [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155). + For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models + with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658). + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + classifier_dropout (`float`, *optional*): + The dropout ratio for the classification head. + + Examples: + + ```python + >>> from transformers import ElectraConfig, ElectraModel + + >>> # Initializing a ELECTRA electra-base-uncased style configuration + >>> configuration = ElectraConfig() + + >>> # Initializing a model (with random weights) from the electra-base-uncased style configuration + >>> model = ElectraModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "electra" + + def __init__( + self, + vocab_size=30522, + embedding_size=128, + hidden_size=256, + num_hidden_layers=12, + num_attention_heads=4, + intermediate_size=1024, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + layer_norm_eps=1e-12, + summary_type="first", + summary_use_proj=True, + summary_activation="gelu", + summary_last_dropout=0.1, + pad_token_id=0, + position_embedding_type="absolute", + use_cache=True, + classifier_dropout=None, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, **kwargs) + + self.vocab_size = vocab_size + self.embedding_size = embedding_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + + self.summary_type = summary_type + self.summary_use_proj = summary_use_proj + self.summary_activation = summary_activation + self.summary_last_dropout = summary_last_dropout + self.position_embedding_type = position_embedding_type + self.use_cache = use_cache + self.classifier_dropout = classifier_dropout + + +class ElectraOnnxConfig(OnnxConfig): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + if self.task == "multiple-choice": + dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"} + else: + dynamic_axis = {0: "batch", 1: "sequence"} + return OrderedDict( + [ + ("input_ids", dynamic_axis), + ("attention_mask", dynamic_axis), + ("token_type_ids", dynamic_axis), + ] + ) + + +__all__ = ["ElectraConfig", "ElectraOnnxConfig"] diff --git a/docs/transformers/build/lib/transformers/models/electra/convert_electra_original_tf_checkpoint_to_pytorch.py b/docs/transformers/build/lib/transformers/models/electra/convert_electra_original_tf_checkpoint_to_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..b0abc30cd758743b243baabbf1298bcc2e1e595e --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/electra/convert_electra_original_tf_checkpoint_to_pytorch.py @@ -0,0 +1,79 @@ +# coding=utf-8 +# Copyright 2018 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert ELECTRA checkpoint.""" + +import argparse + +import torch + +from transformers import ElectraConfig, ElectraForMaskedLM, ElectraForPreTraining, load_tf_weights_in_electra +from transformers.utils import logging + + +logging.set_verbosity_info() + + +def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path, discriminator_or_generator): + # Initialise PyTorch model + config = ElectraConfig.from_json_file(config_file) + print(f"Building PyTorch model from configuration: {config}") + + if discriminator_or_generator == "discriminator": + model = ElectraForPreTraining(config) + elif discriminator_or_generator == "generator": + model = ElectraForMaskedLM(config) + else: + raise ValueError("The discriminator_or_generator argument should be either 'discriminator' or 'generator'") + + # Load weights from tf checkpoint + load_tf_weights_in_electra( + model, config, tf_checkpoint_path, discriminator_or_generator=discriminator_or_generator + ) + + # Save pytorch-model + print(f"Save PyTorch model to {pytorch_dump_path}") + torch.save(model.state_dict(), pytorch_dump_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path." + ) + parser.add_argument( + "--config_file", + default=None, + type=str, + required=True, + help="The config json file corresponding to the pre-trained model. \nThis specifies the model architecture.", + ) + parser.add_argument( + "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." + ) + parser.add_argument( + "--discriminator_or_generator", + default=None, + type=str, + required=True, + help=( + "Whether to export the generator or the discriminator. Should be a string, either 'discriminator' or " + "'generator'." + ), + ) + args = parser.parse_args() + convert_tf_checkpoint_to_pytorch( + args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path, args.discriminator_or_generator + ) diff --git a/docs/transformers/build/lib/transformers/models/electra/modeling_electra.py b/docs/transformers/build/lib/transformers/models/electra/modeling_electra.py new file mode 100644 index 0000000000000000000000000000000000000000..08cc3e530d6f4a1753dd1e178ec178993e6c951d --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/electra/modeling_electra.py @@ -0,0 +1,1777 @@ +# coding=utf-8 +# Copyright 2019 The Google AI Language Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch ELECTRA model.""" + +import math +import os +from dataclasses import dataclass +from typing import Callable, List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN, get_activation +from ...generation import GenerationMixin +from ...modeling_outputs import ( + BaseModelOutputWithCrossAttentions, + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + MaskedLMOutput, + MultipleChoiceModelOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from .configuration_electra import ElectraConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "google/electra-small-discriminator" +_CONFIG_FOR_DOC = "ElectraConfig" + + +def load_tf_weights_in_electra(model, config, tf_checkpoint_path, discriminator_or_generator="discriminator"): + """Load tf checkpoints in a pytorch model.""" + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array) + for name, array in zip(names, arrays): + original_name: str = name + + try: + if isinstance(model, ElectraForMaskedLM): + name = name.replace("electra/embeddings/", "generator/embeddings/") + + if discriminator_or_generator == "generator": + name = name.replace("electra/", "discriminator/") + name = name.replace("generator/", "electra/") + + name = name.replace("dense_1", "dense_prediction") + name = name.replace("generator_predictions/output_bias", "generator_lm_head/bias") + + name = name.split("/") + # print(original_name, name) + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any(n in ["global_step", "temperature"] for n in name): + logger.info(f"Skipping {original_name}") + continue + pointer = model + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "kernel" or scope_names[0] == "gamma": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "output_bias" or scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "output_weights": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "squad": + pointer = getattr(pointer, "classifier") + else: + pointer = getattr(pointer, scope_names[0]) + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if m_name.endswith("_embeddings"): + pointer = getattr(pointer, "weight") + elif m_name == "kernel": + array = np.transpose(array) + try: + if pointer.shape != array.shape: + raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") + except ValueError as e: + e.args += (pointer.shape, array.shape) + raise + print(f"Initialize PyTorch weight {name}", original_name) + pointer.data = torch.from_numpy(array) + except AttributeError as e: + print(f"Skipping {original_name}", name, e) + continue + return model + + +class ElectraEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.embedding_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.embedding_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + self.register_buffer( + "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False + ) + + # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values_length: int = 0, + ) -> torch.Tensor: + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] + + # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs + # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves + # issue #5664 + if token_type_ids is None: + if hasattr(self, "token_type_ids"): + buffered_token_type_ids = self.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->Electra +class ElectraSelfAttention(nn.Module): + def __init__(self, config, position_embedding_type=None): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = position_embedding_type or getattr( + config, "position_embedding_type", "absolute" + ) + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) + + self.is_decoder = config.is_decoder + + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_layer = past_key_value[0] + value_layer = past_key_value[1] + attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + use_cache = past_key_value is not None + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + query_length, key_length = query_layer.shape[2], key_layer.shape[2] + if use_cache: + position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( + -1, 1 + ) + else: + position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in ElectraModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + +# Copied from transformers.models.bert.modeling_bert.BertSelfOutput +class ElectraSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +ELECTRA_SELF_ATTENTION_CLASSES = { + "eager": ElectraSelfAttention, +} + + +# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Electra,BERT->ELECTRA +class ElectraAttention(nn.Module): + def __init__(self, config, position_embedding_type=None): + super().__init__() + self.self = ELECTRA_SELF_ATTENTION_CLASSES[config._attn_implementation]( + config, position_embedding_type=position_embedding_type + ) + self.output = ElectraSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.bert.modeling_bert.BertIntermediate +class ElectraIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertOutput +class ElectraOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Electra +class ElectraLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = ElectraAttention(config) + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + if not self.is_decoder: + raise ValueError(f"{self} should be used as a decoder model if cross attention is added") + self.crossattention = ElectraAttention(config, position_embedding_type="absolute") + self.intermediate = ElectraIntermediate(config) + self.output = ElectraOutput(config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + + # if decoder, the last output is tuple of self-attn cache + if self.is_decoder: + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + else: + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + cross_attn_present_key_value = None + if self.is_decoder and encoder_hidden_states is not None: + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" + " by setting `config.add_cross_attention=True`" + ) + + # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + cross_attn_past_key_value, + output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + + # add cross-attn cache to positions 3,4 of present_key_value tuple + cross_attn_present_key_value = cross_attention_outputs[-1] + present_key_value = present_key_value + cross_attn_present_key_value + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + ) + outputs = (layer_output,) + outputs + + # if decoder, return the attn key/values as the last output + if self.is_decoder: + outputs = outputs + (present_key_value,) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +# Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->Electra +class ElectraEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([ElectraLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + next_decoder_cache = () if use_cache else None + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class ElectraDiscriminatorPredictions(nn.Module): + """Prediction module for the discriminator, made up of two dense layers.""" + + def __init__(self, config): + super().__init__() + + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = get_activation(config.hidden_act) + self.dense_prediction = nn.Linear(config.hidden_size, 1) + self.config = config + + def forward(self, discriminator_hidden_states): + hidden_states = self.dense(discriminator_hidden_states) + hidden_states = self.activation(hidden_states) + logits = self.dense_prediction(hidden_states).squeeze(-1) + + return logits + + +class ElectraGeneratorPredictions(nn.Module): + """Prediction module for the generator, made up of two dense layers.""" + + def __init__(self, config): + super().__init__() + + self.activation = get_activation("gelu") + self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps) + self.dense = nn.Linear(config.hidden_size, config.embedding_size) + + def forward(self, generator_hidden_states): + hidden_states = self.dense(generator_hidden_states) + hidden_states = self.activation(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + + return hidden_states + + +class ElectraPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = ElectraConfig + load_tf_weights = load_tf_weights_in_electra + base_model_prefix = "electra" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +@dataclass +class ElectraForPreTrainingOutput(ModelOutput): + """ + Output type of [`ElectraForPreTraining`]. + + Args: + loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`): + Total loss of the ELECTRA objective. + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Prediction scores of the head (scores for each token before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +ELECTRA_START_DOCSTRING = r""" + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`ElectraConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +ELECTRA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + encoder_hidden_states (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare Electra Model transformer outputting raw hidden-states without any specific head on top. Identical to " + "the BERT model except that it uses an additional linear layer between the embedding layer and the encoder if the " + "hidden size and embedding size are different. " + "" + "Both the generator and discriminator checkpoints may be loaded into this model.", + ELECTRA_START_DOCSTRING, +) +class ElectraModel(ElectraPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.embeddings = ElectraEmbeddings(config) + + if config.embedding_size != config.hidden_size: + self.embeddings_project = nn.Linear(config.embedding_size, config.hidden_size) + + self.encoder = ElectraEncoder(config) + self.config = config + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if attention_mask is None: + attention_mask = torch.ones(input_shape, device=device) + if token_type_ids is None: + if hasattr(self.embeddings, "token_type_ids"): + buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + hidden_states = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + + if hasattr(self, "embeddings_project"): + hidden_states = self.embeddings_project(hidden_states) + + hidden_states = self.encoder( + hidden_states, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + return hidden_states + + +class ElectraClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.activation = get_activation("gelu") + self.dropout = nn.Dropout(classifier_dropout) + self.out_proj = nn.Linear(config.hidden_size, config.num_labels) + + def forward(self, features, **kwargs): + x = features[:, 0, :] # take token (equiv. to [CLS]) + x = self.dropout(x) + x = self.dense(x) + x = self.activation(x) # although BERT uses tanh here, it seems Electra authors used gelu here + x = self.dropout(x) + x = self.out_proj(x) + return x + + +# Copied from transformers.models.xlm.modeling_xlm.XLMSequenceSummary with XLM->Electra +class ElectraSequenceSummary(nn.Module): + r""" + Compute a single vector summary of a sequence hidden states. + + Args: + config ([`ElectraConfig`]): + The config used by the model. Relevant arguments in the config class of the model are (refer to the actual + config class of your model for the default values it uses): + + - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are: + + - `"last"` -- Take the last token hidden state (like XLNet) + - `"first"` -- Take the first token hidden state (like Bert) + - `"mean"` -- Take the mean of all tokens hidden states + - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2) + - `"attn"` -- Not implemented now, use multi-head attention + + - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction. + - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes + (otherwise to `config.hidden_size`). + - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output, + another string or `None` will add no activation. + - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation. + - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation. + """ + + def __init__(self, config: ElectraConfig): + super().__init__() + + self.summary_type = getattr(config, "summary_type", "last") + if self.summary_type == "attn": + # We should use a standard multi-head attention module with absolute positional embedding for that. + # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276 + # We can probably just use the multi-head attention module of PyTorch >=1.1.0 + raise NotImplementedError + + self.summary = nn.Identity() + if hasattr(config, "summary_use_proj") and config.summary_use_proj: + if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0: + num_classes = config.num_labels + else: + num_classes = config.hidden_size + self.summary = nn.Linear(config.hidden_size, num_classes) + + activation_string = getattr(config, "summary_activation", None) + self.activation: Callable = get_activation(activation_string) if activation_string else nn.Identity() + + self.first_dropout = nn.Identity() + if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0: + self.first_dropout = nn.Dropout(config.summary_first_dropout) + + self.last_dropout = nn.Identity() + if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0: + self.last_dropout = nn.Dropout(config.summary_last_dropout) + + def forward( + self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None + ) -> torch.FloatTensor: + """ + Compute a single vector summary of a sequence hidden states. + + Args: + hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`): + The hidden states of the last layer. + cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*): + Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token. + + Returns: + `torch.FloatTensor`: The summary of the sequence hidden states. + """ + if self.summary_type == "last": + output = hidden_states[:, -1] + elif self.summary_type == "first": + output = hidden_states[:, 0] + elif self.summary_type == "mean": + output = hidden_states.mean(dim=1) + elif self.summary_type == "cls_index": + if cls_index is None: + cls_index = torch.full_like( + hidden_states[..., :1, :], + hidden_states.shape[-2] - 1, + dtype=torch.long, + ) + else: + cls_index = cls_index.unsqueeze(-1).unsqueeze(-1) + cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),)) + # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states + output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size) + elif self.summary_type == "attn": + raise NotImplementedError + + output = self.first_dropout(output) + output = self.summary(output) + output = self.activation(output) + output = self.last_dropout(output) + + return output + + +@add_start_docstrings( + """ + ELECTRA Model transformer with a sequence classification/regression head on top (a linear layer on top of the + pooled output) e.g. for GLUE tasks. + """, + ELECTRA_START_DOCSTRING, +) +class ElectraForSequenceClassification(ElectraPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + self.electra = ElectraModel(config) + self.classifier = ElectraClassificationHead(config) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint="bhadresh-savani/electra-base-emotion", + output_type=SequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_output="'joy'", + expected_loss=0.06, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + discriminator_hidden_states = self.electra( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = discriminator_hidden_states[0] + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + discriminator_hidden_states[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=discriminator_hidden_states.hidden_states, + attentions=discriminator_hidden_states.attentions, + ) + + +@add_start_docstrings( + """ + Electra model with a binary classification head on top as used during pretraining for identifying generated tokens. + + It is recommended to load the discriminator checkpoint into that model. + """, + ELECTRA_START_DOCSTRING, +) +class ElectraForPreTraining(ElectraPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.electra = ElectraModel(config) + self.discriminator_predictions = ElectraDiscriminatorPredictions(config) + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=ElectraForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], ElectraForPreTrainingOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the ELECTRA loss. Input should be a sequence of tokens (see `input_ids` docstring) + Indices should be in `[0, 1]`: + + - 0 indicates the token is an original token, + - 1 indicates the token was replaced. + + Returns: + + Examples: + + ```python + >>> from transformers import ElectraForPreTraining, AutoTokenizer + >>> import torch + + >>> discriminator = ElectraForPreTraining.from_pretrained("google/electra-base-discriminator") + >>> tokenizer = AutoTokenizer.from_pretrained("google/electra-base-discriminator") + + >>> sentence = "The quick brown fox jumps over the lazy dog" + >>> fake_sentence = "The quick brown fox fake over the lazy dog" + + >>> fake_tokens = tokenizer.tokenize(fake_sentence, add_special_tokens=True) + >>> fake_inputs = tokenizer.encode(fake_sentence, return_tensors="pt") + >>> discriminator_outputs = discriminator(fake_inputs) + >>> predictions = torch.round((torch.sign(discriminator_outputs[0]) + 1) / 2) + + >>> fake_tokens + ['[CLS]', 'the', 'quick', 'brown', 'fox', 'fake', 'over', 'the', 'lazy', 'dog', '[SEP]'] + + >>> predictions.squeeze().tolist() + [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0] + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + discriminator_hidden_states = self.electra( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + discriminator_sequence_output = discriminator_hidden_states[0] + + logits = self.discriminator_predictions(discriminator_sequence_output) + + loss = None + if labels is not None: + loss_fct = nn.BCEWithLogitsLoss() + if attention_mask is not None: + active_loss = attention_mask.view(-1, discriminator_sequence_output.shape[1]) == 1 + active_logits = logits.view(-1, discriminator_sequence_output.shape[1])[active_loss] + active_labels = labels[active_loss] + loss = loss_fct(active_logits, active_labels.float()) + else: + loss = loss_fct(logits.view(-1, discriminator_sequence_output.shape[1]), labels.float()) + + if not return_dict: + output = (logits,) + discriminator_hidden_states[1:] + return ((loss,) + output) if loss is not None else output + + return ElectraForPreTrainingOutput( + loss=loss, + logits=logits, + hidden_states=discriminator_hidden_states.hidden_states, + attentions=discriminator_hidden_states.attentions, + ) + + +@add_start_docstrings( + """ + Electra model with a language modeling head on top. + + Even though both the discriminator and generator may be loaded into this model, the generator is the only model of + the two to have been trained for the masked language modeling task. + """, + ELECTRA_START_DOCSTRING, +) +class ElectraForMaskedLM(ElectraPreTrainedModel): + _tied_weights_keys = ["generator_lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + + self.electra = ElectraModel(config) + self.generator_predictions = ElectraGeneratorPredictions(config) + + self.generator_lm_head = nn.Linear(config.embedding_size, config.vocab_size) + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.generator_lm_head + + def set_output_embeddings(self, word_embeddings): + self.generator_lm_head = word_embeddings + + @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint="google/electra-small-generator", + output_type=MaskedLMOutput, + config_class=_CONFIG_FOR_DOC, + mask="[MASK]", + expected_output="'paris'", + expected_loss=1.22, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + generator_hidden_states = self.electra( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + generator_sequence_output = generator_hidden_states[0] + + prediction_scores = self.generator_predictions(generator_sequence_output) + prediction_scores = self.generator_lm_head(prediction_scores) + + loss = None + # Masked language modeling softmax layer + if labels is not None: + loss_fct = nn.CrossEntropyLoss() # -100 index = padding token + loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + generator_hidden_states[1:] + return ((loss,) + output) if loss is not None else output + + return MaskedLMOutput( + loss=loss, + logits=prediction_scores, + hidden_states=generator_hidden_states.hidden_states, + attentions=generator_hidden_states.attentions, + ) + + +@add_start_docstrings( + """ + Electra model with a token classification head on top. + + Both the discriminator and generator may be loaded into this model. + """, + ELECTRA_START_DOCSTRING, +) +class ElectraForTokenClassification(ElectraPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.electra = ElectraModel(config) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint="bhadresh-savani/electra-base-discriminator-finetuned-conll03-english", + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_output="['B-LOC', 'B-ORG', 'O', 'O', 'O', 'O', 'O', 'B-LOC', 'O', 'B-LOC', 'I-LOC']", + expected_loss=0.11, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + discriminator_hidden_states = self.electra( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + discriminator_sequence_output = discriminator_hidden_states[0] + + discriminator_sequence_output = self.dropout(discriminator_sequence_output) + logits = self.classifier(discriminator_sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + discriminator_hidden_states[1:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=discriminator_hidden_states.hidden_states, + attentions=discriminator_hidden_states.attentions, + ) + + +@add_start_docstrings( + """ + ELECTRA Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear + layers on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + ELECTRA_START_DOCSTRING, +) +class ElectraForQuestionAnswering(ElectraPreTrainedModel): + config_class = ElectraConfig + base_model_prefix = "electra" + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.electra = ElectraModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint="bhadresh-savani/electra-base-squad2", + output_type=QuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + qa_target_start_index=11, + qa_target_end_index=12, + expected_output="'a nice puppet'", + expected_loss=2.64, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + start_positions: Optional[torch.Tensor] = None, + end_positions: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + discriminator_hidden_states = self.electra( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + sequence_output = discriminator_hidden_states[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = ( + start_logits, + end_logits, + ) + discriminator_hidden_states[1:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=discriminator_hidden_states.hidden_states, + attentions=discriminator_hidden_states.attentions, + ) + + +@add_start_docstrings( + """ + ELECTRA Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a + softmax) e.g. for RocStories/SWAG tasks. + """, + ELECTRA_START_DOCSTRING, +) +class ElectraForMultipleChoice(ElectraPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.electra = ElectraModel(config) + self.sequence_summary = ElectraSequenceSummary(config) + self.classifier = nn.Linear(config.hidden_size, 1) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MultipleChoiceModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., + num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See + `input_ids` above) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] + + input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None + position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) + + discriminator_hidden_states = self.electra( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = discriminator_hidden_states[0] + + pooled_output = self.sequence_summary(sequence_output) + logits = self.classifier(pooled_output) + reshaped_logits = logits.view(-1, num_choices) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + + if not return_dict: + output = (reshaped_logits,) + discriminator_hidden_states[1:] + return ((loss,) + output) if loss is not None else output + + return MultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=discriminator_hidden_states.hidden_states, + attentions=discriminator_hidden_states.attentions, + ) + + +@add_start_docstrings( + """ELECTRA Model with a `language modeling` head on top for CLM fine-tuning.""", ELECTRA_START_DOCSTRING +) +class ElectraForCausalLM(ElectraPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["generator_lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + + if not config.is_decoder: + logger.warning("If you want to use `ElectraForCausalLM` as a standalone, add `is_decoder=True.`") + + self.electra = ElectraModel(config) + self.generator_predictions = ElectraGeneratorPredictions(config) + self.generator_lm_head = nn.Linear(config.embedding_size, config.vocab_size) + + self.init_weights() + + def get_output_embeddings(self): + return self.generator_lm_head + + def set_output_embeddings(self, new_embeddings): + self.generator_lm_head = new_embeddings + + @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.Tensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are + ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, ElectraForCausalLM, ElectraConfig + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("google/electra-base-generator") + >>> config = ElectraConfig.from_pretrained("google/electra-base-generator") + >>> config.is_decoder = True + >>> model = ElectraForCausalLM.from_pretrained("google/electra-base-generator", config=config) + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> prediction_logits = outputs.logits + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if labels is not None: + use_cache = False + + outputs = self.electra( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + prediction_scores = self.generator_lm_head(self.generator_predictions(sequence_output)) + + lm_loss = None + if labels is not None: + lm_loss = self.loss_function( + prediction_scores, + labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + if not return_dict: + output = (prediction_scores,) + outputs[1:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + # Copied from transformers.models.roberta.modeling_roberta.RobertaForCausalLM._reorder_cache + def _reorder_cache(self, past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + + +__all__ = [ + "ElectraForCausalLM", + "ElectraForMaskedLM", + "ElectraForMultipleChoice", + "ElectraForPreTraining", + "ElectraForQuestionAnswering", + "ElectraForSequenceClassification", + "ElectraForTokenClassification", + "ElectraModel", + "ElectraPreTrainedModel", + "load_tf_weights_in_electra", +] diff --git a/docs/transformers/build/lib/transformers/models/electra/modeling_flax_electra.py b/docs/transformers/build/lib/transformers/models/electra/modeling_flax_electra.py new file mode 100644 index 0000000000000000000000000000000000000000..77a445e6ccaa7ce3294655894be874bc83300a38 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/electra/modeling_flax_electra.py @@ -0,0 +1,1614 @@ +# coding=utf-8 +# Copyright 2021 The Google Flax Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Optional, Tuple + +import flax +import flax.linen as nn +import jax +import jax.numpy as jnp +import numpy as np +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze +from flax.linen import combine_masks, make_causal_mask +from flax.linen import partitioning as nn_partitioning +from flax.linen.attention import dot_product_attention_weights +from flax.traverse_util import flatten_dict, unflatten_dict +from jax import lax + +from ...modeling_flax_outputs import ( + FlaxBaseModelOutput, + FlaxBaseModelOutputWithPastAndCrossAttentions, + FlaxCausalLMOutputWithCrossAttentions, + FlaxMaskedLMOutput, + FlaxMultipleChoiceModelOutput, + FlaxQuestionAnsweringModelOutput, + FlaxSequenceClassifierOutput, + FlaxTokenClassifierOutput, +) +from ...modeling_flax_utils import ( + ACT2FN, + FlaxPreTrainedModel, + append_call_sample_docstring, + append_replace_return_docstrings, + overwrite_call_docstring, +) +from ...utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging +from .configuration_electra import ElectraConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "google/electra-small-discriminator" +_CONFIG_FOR_DOC = "ElectraConfig" + +remat = nn_partitioning.remat + + +@flax.struct.dataclass +class FlaxElectraForPreTrainingOutput(ModelOutput): + """ + Output type of [`ElectraForPreTraining`]. + + Args: + logits (`jnp.ndarray` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + logits: jnp.ndarray = None + hidden_states: Optional[Tuple[jnp.ndarray]] = None + attentions: Optional[Tuple[jnp.ndarray]] = None + + +ELECTRA_START_DOCSTRING = r""" + + This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading, saving and converting weights from PyTorch models) + + This model is also a Flax Linen + [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a + regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior. + + Finally, this model supports inherent JAX features such as: + + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + + Parameters: + config ([`ElectraConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +ELECTRA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`numpy.ndarray` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`numpy.ndarray` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`numpy.ndarray` of shape `({0})`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`numpy.ndarray` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + head_mask (`numpy.ndarray` of shape `({0})`, `optional): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + +""" + + +class FlaxElectraEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + config: ElectraConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.word_embeddings = nn.Embed( + self.config.vocab_size, + self.config.embedding_size, + embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), + ) + self.position_embeddings = nn.Embed( + self.config.max_position_embeddings, + self.config.embedding_size, + embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), + ) + self.token_type_embeddings = nn.Embed( + self.config.type_vocab_size, + self.config.embedding_size, + embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), + ) + self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + + # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEmbeddings.__call__ + def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True): + # Embed + inputs_embeds = self.word_embeddings(input_ids.astype("i4")) + position_embeds = self.position_embeddings(position_ids.astype("i4")) + token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4")) + + # Sum all embeddings + hidden_states = inputs_embeds + token_type_embeddings + position_embeds + + # Layer Norm + hidden_states = self.LayerNorm(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + return hidden_states + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfAttention with Bert->Electra +class FlaxElectraSelfAttention(nn.Module): + config: ElectraConfig + causal: bool = False + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.head_dim = self.config.hidden_size // self.config.num_attention_heads + if self.config.hidden_size % self.config.num_attention_heads != 0: + raise ValueError( + "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads` " + " : {self.config.num_attention_heads}" + ) + + self.query = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + ) + self.key = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + ) + self.value = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + ) + + if self.causal: + self.causal_mask = make_causal_mask( + jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool" + ) + + def _split_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.config.num_attention_heads, self.head_dim)) + + def _merge_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.config.hidden_size,)) + + @nn.compact + # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention._concatenate_to_cache + def _concatenate_to_cache(self, key, value, query, attention_mask): + """ + This function takes projected key, value states from a single input token and concatenates the states to cached + states from previous steps. This function is slightly adapted from the official Flax repository: + https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 + """ + # detect if we're initializing by absence of existing cache data. + is_initialized = self.has_variable("cache", "cached_key") + cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) + cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) + cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) + + if is_initialized: + *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape + # update key, value caches with our new 1d spatial slices + cur_index = cache_index.value + indices = (0,) * len(batch_dims) + (cur_index, 0, 0) + key = lax.dynamic_update_slice(cached_key.value, key, indices) + value = lax.dynamic_update_slice(cached_value.value, value, indices) + cached_key.value = key + cached_value.value = value + num_updated_cache_vectors = query.shape[1] + cache_index.value = cache_index.value + num_updated_cache_vectors + # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements. + pad_mask = jnp.broadcast_to( + jnp.arange(max_length) < cur_index + num_updated_cache_vectors, + tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), + ) + attention_mask = combine_masks(pad_mask, attention_mask) + return key, value, attention_mask + + def __call__( + self, + hidden_states, + attention_mask, + layer_head_mask, + key_value_states: Optional[jnp.ndarray] = None, + init_cache: bool = False, + deterministic=True, + output_attentions: bool = False, + ): + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + batch_size = hidden_states.shape[0] + + # get query proj + query_states = self.query(hidden_states) + # get key, value proj + if is_cross_attention: + # cross_attentions + key_states = self.key(key_value_states) + value_states = self.value(key_value_states) + else: + # self_attention + key_states = self.key(hidden_states) + value_states = self.value(hidden_states) + + query_states = self._split_heads(query_states) + key_states = self._split_heads(key_states) + value_states = self._split_heads(value_states) + + # handle cache prepare causal attention mask + if self.causal: + query_length, key_length = query_states.shape[1], key_states.shape[1] + if self.has_variable("cache", "cached_key"): + mask_shift = self.variables["cache"]["cache_index"] + max_decoder_length = self.variables["cache"]["cached_key"].shape[1] + causal_mask = lax.dynamic_slice( + self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length) + ) + else: + causal_mask = self.causal_mask[:, :, :query_length, :key_length] + causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) + + # combine masks if needed + if attention_mask is not None and self.causal: + attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) + attention_mask = combine_masks(attention_mask, causal_mask) + elif self.causal: + attention_mask = causal_mask + elif attention_mask is not None: + attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) + + # During fast autoregressive decoding, we feed one position at a time, + # and cache the keys and values step by step. + if self.causal and (self.has_variable("cache", "cached_key") or init_cache): + key_states, value_states, attention_mask = self._concatenate_to_cache( + key_states, value_states, query_states, attention_mask + ) + + # Convert the boolean attention mask to an attention bias. + if attention_mask is not None: + # attention mask in the form of attention bias + attention_bias = lax.select( + attention_mask > 0, + jnp.full(attention_mask.shape, 0.0).astype(self.dtype), + jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), + ) + else: + attention_bias = None + + dropout_rng = None + if not deterministic and self.config.attention_probs_dropout_prob > 0.0: + dropout_rng = self.make_rng("dropout") + + attn_weights = dot_product_attention_weights( + query_states, + key_states, + bias=attention_bias, + dropout_rng=dropout_rng, + dropout_rate=self.config.attention_probs_dropout_prob, + broadcast_dropout=True, + deterministic=deterministic, + dtype=self.dtype, + precision=None, + ) + + # Mask heads if we want to + if layer_head_mask is not None: + attn_weights = jnp.einsum("...hqk,h->...hqk", attn_weights, layer_head_mask) + + attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) + attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,)) + + outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) + return outputs + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfOutput with Bert->Electra +class FlaxElectraSelfOutput(nn.Module): + config: ElectraConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dense = nn.Dense( + self.config.hidden_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + ) + self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + + def __call__(self, hidden_states, input_tensor, deterministic: bool = True): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertAttention with Bert->Electra +class FlaxElectraAttention(nn.Module): + config: ElectraConfig + causal: bool = False + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.self = FlaxElectraSelfAttention(self.config, causal=self.causal, dtype=self.dtype) + self.output = FlaxElectraSelfOutput(self.config, dtype=self.dtype) + + def __call__( + self, + hidden_states, + attention_mask, + layer_head_mask, + key_value_states=None, + init_cache=False, + deterministic=True, + output_attentions: bool = False, + ): + # Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length) + # FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable + # with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length) + attn_outputs = self.self( + hidden_states, + attention_mask, + layer_head_mask=layer_head_mask, + key_value_states=key_value_states, + init_cache=init_cache, + deterministic=deterministic, + output_attentions=output_attentions, + ) + attn_output = attn_outputs[0] + hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_outputs[1],) + + return outputs + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertIntermediate with Bert->Electra +class FlaxElectraIntermediate(nn.Module): + config: ElectraConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dense = nn.Dense( + self.config.intermediate_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + ) + self.activation = ACT2FN[self.config.hidden_act] + + def __call__(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.activation(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertOutput with Bert->Electra +class FlaxElectraOutput(nn.Module): + config: ElectraConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dense = nn.Dense( + self.config.hidden_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + ) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + + def __call__(self, hidden_states, attention_output, deterministic: bool = True): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + hidden_states = self.LayerNorm(hidden_states + attention_output) + return hidden_states + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayer with Bert->Electra +class FlaxElectraLayer(nn.Module): + config: ElectraConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.attention = FlaxElectraAttention(self.config, causal=self.config.is_decoder, dtype=self.dtype) + self.intermediate = FlaxElectraIntermediate(self.config, dtype=self.dtype) + self.output = FlaxElectraOutput(self.config, dtype=self.dtype) + if self.config.add_cross_attention: + self.crossattention = FlaxElectraAttention(self.config, causal=False, dtype=self.dtype) + + def __call__( + self, + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + deterministic: bool = True, + output_attentions: bool = False, + ): + # Self Attention + attention_outputs = self.attention( + hidden_states, + attention_mask, + layer_head_mask=layer_head_mask, + init_cache=init_cache, + deterministic=deterministic, + output_attentions=output_attentions, + ) + attention_output = attention_outputs[0] + + # Cross-Attention Block + if encoder_hidden_states is not None: + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask=encoder_attention_mask, + layer_head_mask=layer_head_mask, + key_value_states=encoder_hidden_states, + deterministic=deterministic, + output_attentions=output_attentions, + ) + attention_output = cross_attention_outputs[0] + + hidden_states = self.intermediate(attention_output) + hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attention_outputs[1],) + if encoder_hidden_states is not None: + outputs += (cross_attention_outputs[1],) + return outputs + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayerCollection with Bert->Electra +class FlaxElectraLayerCollection(nn.Module): + config: ElectraConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False + + def setup(self): + if self.gradient_checkpointing: + FlaxElectraCheckpointLayer = remat(FlaxElectraLayer, static_argnums=(5, 6, 7)) + self.layers = [ + FlaxElectraCheckpointLayer(self.config, name=str(i), dtype=self.dtype) + for i in range(self.config.num_hidden_layers) + ] + else: + self.layers = [ + FlaxElectraLayer(self.config, name=str(i), dtype=self.dtype) + for i in range(self.config.num_hidden_layers) + ] + + def __call__( + self, + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + all_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + + # Check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + if head_mask.shape[0] != (len(self.layers)): + raise ValueError( + f"The head_mask should be specified for {len(self.layers)} layers, but it is for " + f" {head_mask.shape[0]}." + ) + + for i, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = layer( + hidden_states, + attention_mask, + head_mask[i] if head_mask is not None else None, + encoder_hidden_states, + encoder_attention_mask, + init_cache, + deterministic, + output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + outputs = (hidden_states, all_hidden_states, all_attentions, all_cross_attentions) + + if not return_dict: + return tuple(v for v in outputs if v is not None) + + return FlaxBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEncoder with Bert->Electra +class FlaxElectraEncoder(nn.Module): + config: ElectraConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False + + def setup(self): + self.layer = FlaxElectraLayerCollection( + self.config, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) + + def __call__( + self, + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + return self.layer( + hidden_states, + attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + init_cache=init_cache, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +class FlaxElectraGeneratorPredictions(nn.Module): + config: ElectraConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + self.dense = nn.Dense(self.config.embedding_size, dtype=self.dtype) + + def __call__(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = ACT2FN[self.config.hidden_act](hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class FlaxElectraDiscriminatorPredictions(nn.Module): + """Prediction module for the discriminator, made up of two dense layers.""" + + config: ElectraConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.dense = nn.Dense(self.config.hidden_size, dtype=self.dtype) + self.dense_prediction = nn.Dense(1, dtype=self.dtype) + + def __call__(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = ACT2FN[self.config.hidden_act](hidden_states) + hidden_states = self.dense_prediction(hidden_states).squeeze(-1) + return hidden_states + + +class FlaxElectraPreTrainedModel(FlaxPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = ElectraConfig + base_model_prefix = "electra" + module_class: nn.Module = None + + def __init__( + self, + config: ElectraConfig, + input_shape: Tuple = (1, 1), + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, + gradient_checkpointing: bool = False, + **kwargs, + ): + module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + + # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.enable_gradient_checkpointing + def enable_gradient_checkpointing(self): + self._module = self.module_class( + config=self.config, + dtype=self.dtype, + gradient_checkpointing=True, + ) + + # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.init_weights + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: + # init input tensors + input_ids = jnp.zeros(input_shape, dtype="i4") + token_type_ids = jnp.zeros_like(input_ids) + position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape) + attention_mask = jnp.ones_like(input_ids) + head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads)) + + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + if self.config.add_cross_attention: + encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,)) + encoder_attention_mask = attention_mask + module_init_outputs = self.module.init( + rngs, + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + return_dict=False, + ) + else: + module_init_outputs = self.module.init( + rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False + ) + + random_params = module_init_outputs["params"] + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + + # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartDecoderPreTrainedModel.init_cache + def init_cache(self, batch_size, max_length): + r""" + Args: + batch_size (`int`): + batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. + max_length (`int`): + maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized + cache. + """ + # init input variables to retrieve cache + input_ids = jnp.ones((batch_size, max_length), dtype="i4") + attention_mask = jnp.ones_like(input_ids, dtype="i4") + position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) + + init_variables = self.module.init( + jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True + ) + return unfreeze(init_variables["cache"]) + + @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + def __call__( + self, + input_ids, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + params: dict = None, + dropout_rng: jax.random.PRNGKey = None, + train: bool = False, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + past_key_values: dict = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + # init input tensors if not passed + if token_type_ids is None: + token_type_ids = jnp.ones_like(input_ids) + + if position_ids is None: + position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) + + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + + if head_mask is None: + head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads)) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + inputs = {"params": params or self.params} + + if self.config.add_cross_attention: + # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed + # down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be + # changed by FlaxElectraAttention module + if past_key_values: + inputs["cache"] = past_key_values + mutable = ["cache"] + else: + mutable = False + + outputs = self.module.apply( + inputs, + jnp.array(input_ids, dtype="i4"), + jnp.array(attention_mask, dtype="i4"), + token_type_ids=jnp.array(token_type_ids, dtype="i4"), + position_ids=jnp.array(position_ids, dtype="i4"), + head_mask=jnp.array(head_mask, dtype="i4"), + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + deterministic=not train, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + rngs=rngs, + mutable=mutable, + ) + + # add updated cache to model output + if past_key_values is not None and return_dict: + outputs, past_key_values = outputs + outputs["past_key_values"] = unfreeze(past_key_values["cache"]) + return outputs + elif past_key_values is not None and not return_dict: + outputs, past_key_values = outputs + outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:] + + else: + outputs = self.module.apply( + inputs, + jnp.array(input_ids, dtype="i4"), + jnp.array(attention_mask, dtype="i4"), + token_type_ids=jnp.array(token_type_ids, dtype="i4"), + position_ids=jnp.array(position_ids, dtype="i4"), + head_mask=jnp.array(head_mask, dtype="i4"), + deterministic=not train, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + rngs=rngs, + ) + + return outputs + + +class FlaxElectraModule(nn.Module): + config: ElectraConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False + + def setup(self): + self.embeddings = FlaxElectraEmbeddings(self.config, dtype=self.dtype) + if self.config.embedding_size != self.config.hidden_size: + self.embeddings_project = nn.Dense(self.config.hidden_size, dtype=self.dtype) + self.encoder = FlaxElectraEncoder( + self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing + ) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask: Optional[np.ndarray] = None, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + embeddings = self.embeddings( + input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic + ) + if hasattr(self, "embeddings_project"): + embeddings = self.embeddings_project(embeddings) + + return self.encoder( + embeddings, + attention_mask, + head_mask=head_mask, + deterministic=deterministic, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + init_cache=init_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +@add_start_docstrings( + "The bare Electra Model transformer outputting raw hidden-states without any specific head on top.", + ELECTRA_START_DOCSTRING, +) +class FlaxElectraModel(FlaxElectraPreTrainedModel): + module_class = FlaxElectraModule + + +append_call_sample_docstring(FlaxElectraModel, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutput, _CONFIG_FOR_DOC) + + +class FlaxElectraTiedDense(nn.Module): + embedding_size: int + dtype: jnp.dtype = jnp.float32 + precision = None + bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros + + def setup(self): + self.bias = self.param("bias", self.bias_init, (self.embedding_size,)) + + def __call__(self, x, kernel): + x = jnp.asarray(x, self.dtype) + kernel = jnp.asarray(kernel, self.dtype) + y = lax.dot_general( + x, + kernel, + (((x.ndim - 1,), (0,)), ((), ())), + precision=self.precision, + ) + bias = jnp.asarray(self.bias, self.dtype) + return y + bias + + +class FlaxElectraForMaskedLMModule(nn.Module): + config: ElectraConfig + dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False + + def setup(self): + self.electra = FlaxElectraModule( + config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing + ) + self.generator_predictions = FlaxElectraGeneratorPredictions(config=self.config, dtype=self.dtype) + if self.config.tie_word_embeddings: + self.generator_lm_head = FlaxElectraTiedDense(self.config.vocab_size, dtype=self.dtype) + else: + self.generator_lm_head = nn.Dense(self.config.vocab_size, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + outputs = self.electra( + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = outputs[0] + prediction_scores = self.generator_predictions(hidden_states) + + if self.config.tie_word_embeddings: + shared_embedding = self.electra.variables["params"]["embeddings"]["word_embeddings"]["embedding"] + prediction_scores = self.generator_lm_head(prediction_scores, shared_embedding.T) + else: + prediction_scores = self.generator_lm_head(prediction_scores) + + if not return_dict: + return (prediction_scores,) + outputs[1:] + + return FlaxMaskedLMOutput( + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings("""Electra Model with a `language modeling` head on top.""", ELECTRA_START_DOCSTRING) +class FlaxElectraForMaskedLM(FlaxElectraPreTrainedModel): + module_class = FlaxElectraForMaskedLMModule + + +append_call_sample_docstring(FlaxElectraForMaskedLM, _CHECKPOINT_FOR_DOC, FlaxMaskedLMOutput, _CONFIG_FOR_DOC) + + +class FlaxElectraForPreTrainingModule(nn.Module): + config: ElectraConfig + dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False + + def setup(self): + self.electra = FlaxElectraModule( + config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing + ) + self.discriminator_predictions = FlaxElectraDiscriminatorPredictions(config=self.config, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # Model + outputs = self.electra( + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = outputs[0] + + logits = self.discriminator_predictions(hidden_states) + + if not return_dict: + return (logits,) + outputs[1:] + + return FlaxElectraForPreTrainingOutput( + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Electra model with a binary classification head on top as used during pretraining for identifying generated tokens. + + It is recommended to load the discriminator checkpoint into that model. + """, + ELECTRA_START_DOCSTRING, +) +class FlaxElectraForPreTraining(FlaxElectraPreTrainedModel): + module_class = FlaxElectraForPreTrainingModule + + +FLAX_ELECTRA_FOR_PRETRAINING_DOCSTRING = """ + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, FlaxElectraForPreTraining + + >>> tokenizer = AutoTokenizer.from_pretrained("google/electra-small-discriminator") + >>> model = FlaxElectraForPreTraining.from_pretrained("google/electra-small-discriminator") + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="np") + >>> outputs = model(**inputs) + + >>> prediction_logits = outputs.logits + ``` +""" + +overwrite_call_docstring( + FlaxElectraForPreTraining, + ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length") + FLAX_ELECTRA_FOR_PRETRAINING_DOCSTRING, +) +append_replace_return_docstrings( + FlaxElectraForPreTraining, output_type=FlaxElectraForPreTrainingOutput, config_class=_CONFIG_FOR_DOC +) + + +class FlaxElectraForTokenClassificationModule(nn.Module): + config: ElectraConfig + dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False + + def setup(self): + self.electra = FlaxElectraModule( + config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing + ) + classifier_dropout = ( + self.config.classifier_dropout + if self.config.classifier_dropout is not None + else self.config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # Model + outputs = self.electra( + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = outputs[0] + + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + logits = self.classifier(hidden_states) + + if not return_dict: + return (logits,) + outputs[1:] + + return FlaxTokenClassifierOutput( + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Electra model with a token classification head on top. + + Both the discriminator and generator may be loaded into this model. + """, + ELECTRA_START_DOCSTRING, +) +class FlaxElectraForTokenClassification(FlaxElectraPreTrainedModel): + module_class = FlaxElectraForTokenClassificationModule + + +append_call_sample_docstring( + FlaxElectraForTokenClassification, + _CHECKPOINT_FOR_DOC, + FlaxTokenClassifierOutput, + _CONFIG_FOR_DOC, +) + + +def identity(x, **kwargs): + return x + + +class FlaxElectraSequenceSummary(nn.Module): + r""" + Compute a single vector summary of a sequence hidden states. + + Args: + config ([`PretrainedConfig`]): + The config used by the model. Relevant arguments in the config class of the model are (refer to the actual + config class of your model for the default values it uses): + + - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction. + - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes + (otherwise to `config.hidden_size`). + - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output, + another string or `None` will add no activation. + - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation. + - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation. + """ + + config: ElectraConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.summary = identity + if hasattr(self.config, "summary_use_proj") and self.config.summary_use_proj: + if ( + hasattr(self.config, "summary_proj_to_labels") + and self.config.summary_proj_to_labels + and self.config.num_labels > 0 + ): + num_classes = self.config.num_labels + else: + num_classes = self.config.hidden_size + self.summary = nn.Dense(num_classes, dtype=self.dtype) + + activation_string = getattr(self.config, "summary_activation", None) + self.activation = ACT2FN[activation_string] if activation_string else lambda x: x # noqa F407 + + self.first_dropout = identity + if hasattr(self.config, "summary_first_dropout") and self.config.summary_first_dropout > 0: + self.first_dropout = nn.Dropout(self.config.summary_first_dropout) + + self.last_dropout = identity + if hasattr(self.config, "summary_last_dropout") and self.config.summary_last_dropout > 0: + self.last_dropout = nn.Dropout(self.config.summary_last_dropout) + + def __call__(self, hidden_states, cls_index=None, deterministic: bool = True): + """ + Compute a single vector summary of a sequence hidden states. + + Args: + hidden_states (`jnp.ndarray` of shape `[batch_size, seq_len, hidden_size]`): + The hidden states of the last layer. + cls_index (`jnp.ndarray` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*): + Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token. + + Returns: + `jnp.ndarray`: The summary of the sequence hidden states. + """ + # NOTE: this doest "first" type summary always + output = hidden_states[:, 0] + output = self.first_dropout(output, deterministic=deterministic) + output = self.summary(output) + output = self.activation(output) + output = self.last_dropout(output, deterministic=deterministic) + return output + + +class FlaxElectraForMultipleChoiceModule(nn.Module): + config: ElectraConfig + dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False + + def setup(self): + self.electra = FlaxElectraModule( + config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing + ) + self.sequence_summary = FlaxElectraSequenceSummary(config=self.config, dtype=self.dtype) + self.classifier = nn.Dense(1, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + num_choices = input_ids.shape[1] + input_ids = input_ids.reshape(-1, input_ids.shape[-1]) if input_ids is not None else None + attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1]) if attention_mask is not None else None + token_type_ids = token_type_ids.reshape(-1, token_type_ids.shape[-1]) if token_type_ids is not None else None + position_ids = position_ids.reshape(-1, position_ids.shape[-1]) if position_ids is not None else None + + # Model + outputs = self.electra( + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = outputs[0] + pooled_output = self.sequence_summary(hidden_states, deterministic=deterministic) + logits = self.classifier(pooled_output) + + reshaped_logits = logits.reshape(-1, num_choices) + + if not return_dict: + return (reshaped_logits,) + outputs[1:] + + return FlaxMultipleChoiceModelOutput( + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + ELECTRA Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a + softmax) e.g. for RocStories/SWAG tasks. + """, + ELECTRA_START_DOCSTRING, +) +class FlaxElectraForMultipleChoice(FlaxElectraPreTrainedModel): + module_class = FlaxElectraForMultipleChoiceModule + + +# adapt docstring slightly for FlaxElectraForMultipleChoice +overwrite_call_docstring( + FlaxElectraForMultipleChoice, ELECTRA_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length") +) +append_call_sample_docstring( + FlaxElectraForMultipleChoice, + _CHECKPOINT_FOR_DOC, + FlaxMultipleChoiceModelOutput, + _CONFIG_FOR_DOC, +) + + +class FlaxElectraForQuestionAnsweringModule(nn.Module): + config: ElectraConfig + dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False + + def setup(self): + self.electra = FlaxElectraModule( + config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing + ) + self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # Model + outputs = self.electra( + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = outputs[0] + logits = self.qa_outputs(hidden_states) + start_logits, end_logits = jnp.split(logits, self.config.num_labels, axis=-1) + start_logits = start_logits.squeeze(-1) + end_logits = end_logits.squeeze(-1) + + if not return_dict: + return (start_logits, end_logits) + outputs[1:] + + return FlaxQuestionAnsweringModelOutput( + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + ELECTRA Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear + layers on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + ELECTRA_START_DOCSTRING, +) +class FlaxElectraForQuestionAnswering(FlaxElectraPreTrainedModel): + module_class = FlaxElectraForQuestionAnsweringModule + + +append_call_sample_docstring( + FlaxElectraForQuestionAnswering, + _CHECKPOINT_FOR_DOC, + FlaxQuestionAnsweringModelOutput, + _CONFIG_FOR_DOC, +) + + +class FlaxElectraClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + config: ElectraConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.dense = nn.Dense(self.config.hidden_size, dtype=self.dtype) + classifier_dropout = ( + self.config.classifier_dropout + if self.config.classifier_dropout is not None + else self.config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) + self.out_proj = nn.Dense(self.config.num_labels, dtype=self.dtype) + + def __call__(self, hidden_states, deterministic: bool = True): + x = hidden_states[:, 0, :] # take token (equiv. to [CLS]) + x = self.dropout(x, deterministic=deterministic) + x = self.dense(x) + x = ACT2FN["gelu"](x) # although BERT uses tanh here, it seems Electra authors used gelu + x = self.dropout(x, deterministic=deterministic) + x = self.out_proj(x) + return x + + +class FlaxElectraForSequenceClassificationModule(nn.Module): + config: ElectraConfig + dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False + + def setup(self): + self.electra = FlaxElectraModule( + config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing + ) + self.classifier = FlaxElectraClassificationHead(config=self.config, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # Model + outputs = self.electra( + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = outputs[0] + logits = self.classifier(hidden_states, deterministic=deterministic) + + if not return_dict: + return (logits,) + outputs[1:] + + return FlaxSequenceClassifierOutput( + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Electra Model transformer with a sequence classification/regression head on top (a linear layer on top of the + pooled output) e.g. for GLUE tasks. + """, + ELECTRA_START_DOCSTRING, +) +class FlaxElectraForSequenceClassification(FlaxElectraPreTrainedModel): + module_class = FlaxElectraForSequenceClassificationModule + + +append_call_sample_docstring( + FlaxElectraForSequenceClassification, + _CHECKPOINT_FOR_DOC, + FlaxSequenceClassifierOutput, + _CONFIG_FOR_DOC, +) + + +class FlaxElectraForCausalLMModule(nn.Module): + config: ElectraConfig + dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False + + def setup(self): + self.electra = FlaxElectraModule( + config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing + ) + self.generator_predictions = FlaxElectraGeneratorPredictions(config=self.config, dtype=self.dtype) + if self.config.tie_word_embeddings: + self.generator_lm_head = FlaxElectraTiedDense(self.config.vocab_size, dtype=self.dtype) + else: + self.generator_lm_head = nn.Dense(self.config.vocab_size, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask: Optional[jnp.ndarray] = None, + token_type_ids: Optional[jnp.ndarray] = None, + position_ids: Optional[jnp.ndarray] = None, + head_mask: Optional[jnp.ndarray] = None, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + outputs = self.electra( + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + init_cache=init_cache, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = outputs[0] + prediction_scores = self.generator_predictions(hidden_states) + + if self.config.tie_word_embeddings: + shared_embedding = self.electra.variables["params"]["embeddings"]["word_embeddings"]["embedding"] + prediction_scores = self.generator_lm_head(prediction_scores, shared_embedding.T) + else: + prediction_scores = self.generator_lm_head(prediction_scores) + + if not return_dict: + return (prediction_scores,) + outputs[1:] + + return FlaxCausalLMOutputWithCrossAttentions( + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +@add_start_docstrings( + """ + Electra Model with a language modeling head on top (a linear layer on top of the hidden-states output) e.g for + autoregressive tasks. + """, + ELECTRA_START_DOCSTRING, +) +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForCausalLM with Bert->Electra +class FlaxElectraForCausalLM(FlaxElectraPreTrainedModel): + module_class = FlaxElectraForCausalLMModule + + def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None): + # initializing the cache + batch_size, seq_length = input_ids.shape + + past_key_values = self.init_cache(batch_size, max_length) + # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. + # But since the decoder uses a causal mask, those positions are masked anyway. + # Thus, we can create a single static attention_mask here, which is more efficient for compilation + extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") + if attention_mask is not None: + position_ids = attention_mask.cumsum(axis=-1) - 1 + extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0)) + else: + position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) + + return { + "past_key_values": past_key_values, + "attention_mask": extended_attention_mask, + "position_ids": position_ids, + } + + def update_inputs_for_generation(self, model_outputs, model_kwargs): + model_kwargs["past_key_values"] = model_outputs.past_key_values + model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1 + return model_kwargs + + +append_call_sample_docstring( + FlaxElectraForCausalLM, + _CHECKPOINT_FOR_DOC, + FlaxCausalLMOutputWithCrossAttentions, + _CONFIG_FOR_DOC, +) + + +__all__ = [ + "FlaxElectraForCausalLM", + "FlaxElectraForMaskedLM", + "FlaxElectraForMultipleChoice", + "FlaxElectraForPreTraining", + "FlaxElectraForQuestionAnswering", + "FlaxElectraForSequenceClassification", + "FlaxElectraForTokenClassification", + "FlaxElectraModel", + "FlaxElectraPreTrainedModel", +] diff --git a/docs/transformers/build/lib/transformers/models/electra/modeling_tf_electra.py b/docs/transformers/build/lib/transformers/models/electra/modeling_tf_electra.py new file mode 100644 index 0000000000000000000000000000000000000000..6dc3ac8ebf8cb32056aac75586f9f5eb15f58e7e --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/electra/modeling_tf_electra.py @@ -0,0 +1,1776 @@ +# coding=utf-8 +# Copyright 2019 The Google AI Language Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TF Electra model.""" + +from __future__ import annotations + +import math +import warnings +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import tensorflow as tf + +from ...activations_tf import get_tf_activation +from ...modeling_tf_outputs import ( + TFBaseModelOutputWithPastAndCrossAttentions, + TFMaskedLMOutput, + TFMultipleChoiceModelOutput, + TFQuestionAnsweringModelOutput, + TFSequenceClassifierOutput, + TFTokenClassifierOutput, +) +from ...modeling_tf_utils import ( + TFMaskedLanguageModelingLoss, + TFModelInputType, + TFMultipleChoiceLoss, + TFPreTrainedModel, + TFQuestionAnsweringLoss, + TFSequenceClassificationLoss, + TFSequenceSummary, + TFTokenClassificationLoss, + get_initializer, + keras, + keras_serializable, + unpack_inputs, +) +from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax +from ...utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from .configuration_electra import ElectraConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "google/electra-small-discriminator" +_CONFIG_FOR_DOC = "ElectraConfig" + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfAttention with Bert->Electra +class TFElectraSelfAttention(keras.layers.Layer): + def __init__(self, config: ElectraConfig, **kwargs): + super().__init__(**kwargs) + + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number " + f"of attention heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.sqrt_att_head_size = math.sqrt(self.attention_head_size) + + self.query = keras.layers.Dense( + units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query" + ) + self.key = keras.layers.Dense( + units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key" + ) + self.value = keras.layers.Dense( + units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value" + ) + self.dropout = keras.layers.Dropout(rate=config.attention_probs_dropout_prob) + + self.is_decoder = config.is_decoder + self.config = config + + def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor: + # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size] + tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size)) + + # Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size] + return tf.transpose(tensor, perm=[0, 2, 1, 3]) + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: tf.Tensor, + head_mask: tf.Tensor, + encoder_hidden_states: tf.Tensor, + encoder_attention_mask: tf.Tensor, + past_key_value: Tuple[tf.Tensor], + output_attentions: bool, + training: bool = False, + ) -> Tuple[tf.Tensor]: + batch_size = shape_list(hidden_states)[0] + mixed_query_layer = self.query(inputs=hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_layer = past_key_value[0] + value_layer = past_key_value[1] + attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.transpose_for_scores(self.key(inputs=encoder_hidden_states), batch_size) + value_layer = self.transpose_for_scores(self.value(inputs=encoder_hidden_states), batch_size) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size) + value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size) + key_layer = tf.concat([past_key_value[0], key_layer], axis=2) + value_layer = tf.concat([past_key_value[1], value_layer], axis=2) + else: + key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size) + value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size) + + query_layer = self.transpose_for_scores(mixed_query_layer, batch_size) + + if self.is_decoder: + # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + # (batch size, num_heads, seq_len_q, seq_len_k) + attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) + dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype) + attention_scores = tf.divide(attention_scores, dk) + + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in TFElectraModel call() function) + attention_scores = tf.add(attention_scores, attention_mask) + + # Normalize the attention scores to probabilities. + attention_probs = stable_softmax(logits=attention_scores, axis=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(inputs=attention_probs, training=training) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = tf.multiply(attention_probs, head_mask) + + attention_output = tf.matmul(attention_probs, value_layer) + attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3]) + + # (batch_size, seq_len_q, all_head_size) + attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size)) + outputs = (attention_output, attention_probs) if output_attentions else (attention_output,) + + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "query", None) is not None: + with tf.name_scope(self.query.name): + self.query.build([None, None, self.config.hidden_size]) + if getattr(self, "key", None) is not None: + with tf.name_scope(self.key.name): + self.key.build([None, None, self.config.hidden_size]) + if getattr(self, "value", None) is not None: + with tf.name_scope(self.value.name): + self.value.build([None, None, self.config.hidden_size]) + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfOutput with Bert->Electra +class TFElectraSelfOutput(keras.layers.Layer): + def __init__(self, config: ElectraConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense( + units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") + self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) + self.config = config + + def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: + hidden_states = self.dense(inputs=hidden_states) + hidden_states = self.dropout(inputs=hidden_states, training=training) + hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor) + + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + if getattr(self, "LayerNorm", None) is not None: + with tf.name_scope(self.LayerNorm.name): + self.LayerNorm.build([None, None, self.config.hidden_size]) + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertAttention with Bert->Electra +class TFElectraAttention(keras.layers.Layer): + def __init__(self, config: ElectraConfig, **kwargs): + super().__init__(**kwargs) + + self.self_attention = TFElectraSelfAttention(config, name="self") + self.dense_output = TFElectraSelfOutput(config, name="output") + + def prune_heads(self, heads): + raise NotImplementedError + + def call( + self, + input_tensor: tf.Tensor, + attention_mask: tf.Tensor, + head_mask: tf.Tensor, + encoder_hidden_states: tf.Tensor, + encoder_attention_mask: tf.Tensor, + past_key_value: Tuple[tf.Tensor], + output_attentions: bool, + training: bool = False, + ) -> Tuple[tf.Tensor]: + self_outputs = self.self_attention( + hidden_states=input_tensor, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + training=training, + ) + attention_output = self.dense_output( + hidden_states=self_outputs[0], input_tensor=input_tensor, training=training + ) + # add attentions (possibly with past_key_value) if we output them + outputs = (attention_output,) + self_outputs[1:] + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "self_attention", None) is not None: + with tf.name_scope(self.self_attention.name): + self.self_attention.build(None) + if getattr(self, "dense_output", None) is not None: + with tf.name_scope(self.dense_output.name): + self.dense_output.build(None) + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertIntermediate with Bert->Electra +class TFElectraIntermediate(keras.layers.Layer): + def __init__(self, config: ElectraConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense( + units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = get_tf_activation(config.hidden_act) + else: + self.intermediate_act_fn = config.hidden_act + self.config = config + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + hidden_states = self.dense(inputs=hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertOutput with Bert->Electra +class TFElectraOutput(keras.layers.Layer): + def __init__(self, config: ElectraConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense( + units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") + self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) + self.config = config + + def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: + hidden_states = self.dense(inputs=hidden_states) + hidden_states = self.dropout(inputs=hidden_states, training=training) + hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor) + + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.intermediate_size]) + if getattr(self, "LayerNorm", None) is not None: + with tf.name_scope(self.LayerNorm.name): + self.LayerNorm.build([None, None, self.config.hidden_size]) + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertLayer with Bert->Electra +class TFElectraLayer(keras.layers.Layer): + def __init__(self, config: ElectraConfig, **kwargs): + super().__init__(**kwargs) + + self.attention = TFElectraAttention(config, name="attention") + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + if not self.is_decoder: + raise ValueError(f"{self} should be used as a decoder model if cross attention is added") + self.crossattention = TFElectraAttention(config, name="crossattention") + self.intermediate = TFElectraIntermediate(config, name="intermediate") + self.bert_output = TFElectraOutput(config, name="output") + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: tf.Tensor, + head_mask: tf.Tensor, + encoder_hidden_states: tf.Tensor | None, + encoder_attention_mask: tf.Tensor | None, + past_key_value: Tuple[tf.Tensor] | None, + output_attentions: bool, + training: bool = False, + ) -> Tuple[tf.Tensor]: + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attention_outputs = self.attention( + input_tensor=hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=self_attn_past_key_value, + output_attentions=output_attentions, + training=training, + ) + attention_output = self_attention_outputs[0] + + # if decoder, the last output is tuple of self-attn cache + if self.is_decoder: + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + else: + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + cross_attn_present_key_value = None + if self.is_decoder and encoder_hidden_states is not None: + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" + " by setting `config.add_cross_attention=True`" + ) + + # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + cross_attention_outputs = self.crossattention( + input_tensor=attention_output, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=cross_attn_past_key_value, + output_attentions=output_attentions, + training=training, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + + # add cross-attn cache to positions 3,4 of present_key_value tuple + cross_attn_present_key_value = cross_attention_outputs[-1] + present_key_value = present_key_value + cross_attn_present_key_value + + intermediate_output = self.intermediate(hidden_states=attention_output) + layer_output = self.bert_output( + hidden_states=intermediate_output, input_tensor=attention_output, training=training + ) + outputs = (layer_output,) + outputs # add attentions if we output them + + # if decoder, return the attn key/values as the last output + if self.is_decoder: + outputs = outputs + (present_key_value,) + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "attention", None) is not None: + with tf.name_scope(self.attention.name): + self.attention.build(None) + if getattr(self, "intermediate", None) is not None: + with tf.name_scope(self.intermediate.name): + self.intermediate.build(None) + if getattr(self, "bert_output", None) is not None: + with tf.name_scope(self.bert_output.name): + self.bert_output.build(None) + if getattr(self, "crossattention", None) is not None: + with tf.name_scope(self.crossattention.name): + self.crossattention.build(None) + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertEncoder with Bert->Electra +class TFElectraEncoder(keras.layers.Layer): + def __init__(self, config: ElectraConfig, **kwargs): + super().__init__(**kwargs) + self.config = config + self.layer = [TFElectraLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)] + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: tf.Tensor, + head_mask: tf.Tensor, + encoder_hidden_states: tf.Tensor | None, + encoder_attention_mask: tf.Tensor | None, + past_key_values: Tuple[Tuple[tf.Tensor]] | None, + use_cache: Optional[bool], + output_attentions: bool, + output_hidden_states: bool, + return_dict: bool, + training: bool = False, + ) -> Union[TFBaseModelOutputWithPastAndCrossAttentions, Tuple[tf.Tensor]]: + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + next_decoder_cache = () if use_cache else None + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + past_key_value = past_key_values[i] if past_key_values is not None else None + + layer_outputs = layer_module( + hidden_states=hidden_states, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + training=training, + ) + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + if self.config.add_cross_attention and encoder_hidden_states is not None: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v for v in [hidden_states, all_hidden_states, all_attentions, all_cross_attentions] if v is not None + ) + + return TFBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "layer", None) is not None: + for layer in self.layer: + with tf.name_scope(layer.name): + layer.build(None) + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertPooler with Bert->Electra +class TFElectraPooler(keras.layers.Layer): + def __init__(self, config: ElectraConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense( + units=config.hidden_size, + kernel_initializer=get_initializer(config.initializer_range), + activation="tanh", + name="dense", + ) + self.config = config + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(inputs=first_token_tensor) + + return pooled_output + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + + +# Copied from transformers.models.albert.modeling_tf_albert.TFAlbertEmbeddings with Albert->Electra +class TFElectraEmbeddings(keras.layers.Layer): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config: ElectraConfig, **kwargs): + super().__init__(**kwargs) + + self.config = config + self.embedding_size = config.embedding_size + self.max_position_embeddings = config.max_position_embeddings + self.initializer_range = config.initializer_range + self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") + self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) + + def build(self, input_shape=None): + with tf.name_scope("word_embeddings"): + self.weight = self.add_weight( + name="weight", + shape=[self.config.vocab_size, self.embedding_size], + initializer=get_initializer(self.initializer_range), + ) + + with tf.name_scope("token_type_embeddings"): + self.token_type_embeddings = self.add_weight( + name="embeddings", + shape=[self.config.type_vocab_size, self.embedding_size], + initializer=get_initializer(self.initializer_range), + ) + + with tf.name_scope("position_embeddings"): + self.position_embeddings = self.add_weight( + name="embeddings", + shape=[self.max_position_embeddings, self.embedding_size], + initializer=get_initializer(self.initializer_range), + ) + + if self.built: + return + self.built = True + if getattr(self, "LayerNorm", None) is not None: + with tf.name_scope(self.LayerNorm.name): + self.LayerNorm.build([None, None, self.config.embedding_size]) + + # Copied from transformers.models.bert.modeling_tf_bert.TFBertEmbeddings.call + def call( + self, + input_ids: Optional[tf.Tensor] = None, + position_ids: Optional[tf.Tensor] = None, + token_type_ids: Optional[tf.Tensor] = None, + inputs_embeds: Optional[tf.Tensor] = None, + past_key_values_length=0, + training: bool = False, + ) -> tf.Tensor: + """ + Applies embedding based on inputs tensor. + + Returns: + final_embeddings (`tf.Tensor`): output embedding tensor. + """ + if input_ids is None and inputs_embeds is None: + raise ValueError("Need to provide either `input_ids` or `input_embeds`.") + + if input_ids is not None: + check_embeddings_within_bounds(input_ids, self.config.vocab_size) + inputs_embeds = tf.gather(params=self.weight, indices=input_ids) + + input_shape = shape_list(inputs_embeds)[:-1] + + if token_type_ids is None: + token_type_ids = tf.fill(dims=input_shape, value=0) + + if position_ids is None: + position_ids = tf.expand_dims( + tf.range(start=past_key_values_length, limit=input_shape[1] + past_key_values_length), axis=0 + ) + + position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids) + token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids) + final_embeddings = inputs_embeds + position_embeds + token_type_embeds + final_embeddings = self.LayerNorm(inputs=final_embeddings) + final_embeddings = self.dropout(inputs=final_embeddings, training=training) + + return final_embeddings + + +class TFElectraDiscriminatorPredictions(keras.layers.Layer): + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense(config.hidden_size, name="dense") + self.dense_prediction = keras.layers.Dense(1, name="dense_prediction") + self.config = config + + def call(self, discriminator_hidden_states, training=False): + hidden_states = self.dense(discriminator_hidden_states) + hidden_states = get_tf_activation(self.config.hidden_act)(hidden_states) + logits = tf.squeeze(self.dense_prediction(hidden_states), -1) + + return logits + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + if getattr(self, "dense_prediction", None) is not None: + with tf.name_scope(self.dense_prediction.name): + self.dense_prediction.build([None, None, self.config.hidden_size]) + + +class TFElectraGeneratorPredictions(keras.layers.Layer): + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + + self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") + self.dense = keras.layers.Dense(config.embedding_size, name="dense") + self.config = config + + def call(self, generator_hidden_states, training=False): + hidden_states = self.dense(generator_hidden_states) + hidden_states = get_tf_activation("gelu")(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "LayerNorm", None) is not None: + with tf.name_scope(self.LayerNorm.name): + self.LayerNorm.build([None, None, self.config.embedding_size]) + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + + +class TFElectraPreTrainedModel(TFPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = ElectraConfig + base_model_prefix = "electra" + # When the model is loaded from a PT model + _keys_to_ignore_on_load_unexpected = [r"generator_lm_head.weight"] + _keys_to_ignore_on_load_missing = [r"dropout"] + + +@keras_serializable +class TFElectraMainLayer(keras.layers.Layer): + config_class = ElectraConfig + + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + + self.config = config + self.is_decoder = config.is_decoder + + self.embeddings = TFElectraEmbeddings(config, name="embeddings") + + if config.embedding_size != config.hidden_size: + self.embeddings_project = keras.layers.Dense(config.hidden_size, name="embeddings_project") + + self.encoder = TFElectraEncoder(config, name="encoder") + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, value): + self.embeddings.weight = value + self.embeddings.vocab_size = shape_list(value)[0] + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + raise NotImplementedError + + def get_extended_attention_mask(self, attention_mask, input_shape, dtype, past_key_values_length=0): + batch_size, seq_length = input_shape + + if attention_mask is None: + attention_mask = tf.fill(dims=(batch_size, seq_length + past_key_values_length), value=1) + + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask_shape = shape_list(attention_mask) + + mask_seq_length = seq_length + past_key_values_length + # Copied from `modeling_tf_t5.py` + # Provided a padding mask of dimensions [batch_size, mask_seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] + if self.is_decoder: + seq_ids = tf.range(mask_seq_length) + causal_mask = tf.less_equal( + tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)), + seq_ids[None, :, None], + ) + causal_mask = tf.cast(causal_mask, dtype=attention_mask.dtype) + extended_attention_mask = causal_mask * attention_mask[:, None, :] + attention_mask_shape = shape_list(extended_attention_mask) + extended_attention_mask = tf.reshape( + extended_attention_mask, (attention_mask_shape[0], 1, attention_mask_shape[1], attention_mask_shape[2]) + ) + if past_key_values_length > 0: + extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :] + else: + extended_attention_mask = tf.reshape( + attention_mask, (attention_mask_shape[0], 1, 1, attention_mask_shape[1]) + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = tf.cast(extended_attention_mask, dtype=dtype) + one_cst = tf.constant(1.0, dtype=dtype) + ten_thousand_cst = tf.constant(-10000.0, dtype=dtype) + extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst) + + return extended_attention_mask + + def get_head_mask(self, head_mask): + if head_mask is not None: + raise NotImplementedError + else: + head_mask = [None] * self.config.num_hidden_layers + + return head_mask + + @unpack_inputs + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + encoder_hidden_states: np.ndarray | tf.Tensor | None = None, + encoder_attention_mask: np.ndarray | tf.Tensor | None = None, + past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: Optional[bool] = False, + ) -> Union[TFBaseModelOutputWithPastAndCrossAttentions, Tuple[tf.Tensor]]: + if not self.config.is_decoder: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = shape_list(input_ids) + elif inputs_embeds is not None: + input_shape = shape_list(inputs_embeds)[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + + if past_key_values is None: + past_key_values_length = 0 + past_key_values = [None] * len(self.encoder.layer) + else: + past_key_values_length = shape_list(past_key_values[0][0])[-2] + + if attention_mask is None: + attention_mask = tf.fill(dims=(batch_size, seq_length + past_key_values_length), value=1) + + if token_type_ids is None: + token_type_ids = tf.fill(dims=input_shape, value=0) + + hidden_states = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + training=training, + ) + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, input_shape, hidden_states.dtype, past_key_values_length + ) + + # Copied from `modeling_tf_t5.py` with -1e9 -> -10000 + if self.is_decoder and encoder_attention_mask is not None: + # If a 2D ou 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + encoder_attention_mask = tf.cast(encoder_attention_mask, dtype=extended_attention_mask.dtype) + num_dims_encoder_attention_mask = len(shape_list(encoder_attention_mask)) + if num_dims_encoder_attention_mask == 3: + encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] + if num_dims_encoder_attention_mask == 2: + encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] + + # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition + # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270 + # encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask, + # tf.transpose(encoder_extended_attention_mask, perm=(-1, -2))) + + encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0 + else: + encoder_extended_attention_mask = None + + head_mask = self.get_head_mask(head_mask) + + if hasattr(self, "embeddings_project"): + hidden_states = self.embeddings_project(hidden_states, training=training) + + hidden_states = self.encoder( + hidden_states=hidden_states, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "embeddings", None) is not None: + with tf.name_scope(self.embeddings.name): + self.embeddings.build(None) + if getattr(self, "encoder", None) is not None: + with tf.name_scope(self.encoder.name): + self.encoder.build(None) + if getattr(self, "embeddings_project", None) is not None: + with tf.name_scope(self.embeddings_project.name): + self.embeddings_project.build([None, None, self.config.embedding_size]) + + +@dataclass +class TFElectraForPreTrainingOutput(ModelOutput): + """ + Output type of [`TFElectraForPreTraining`]. + + Args: + loss (*optional*, returned when `labels` is provided, `tf.Tensor` of shape `(1,)`): + Total loss of the ELECTRA objective. + logits (`tf.Tensor` of shape `(batch_size, sequence_length)`): + Prediction scores of the head (scores for each token before SoftMax). + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + logits: Optional[tf.Tensor] = None + hidden_states: Tuple[tf.Tensor] | None = None + attentions: Tuple[tf.Tensor] | None = None + + +ELECTRA_START_DOCSTRING = r""" + + This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it + as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and + behavior. + + + + TensorFlow models and layers in `transformers` accept two formats as input: + + - having all inputs as keyword arguments (like PyTorch models), or + - having all inputs as a list, tuple or dict in the first positional argument. + + The reason the second format is supported is that Keras methods prefer this format when passing inputs to models + and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just + pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second + format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with + the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first + positional argument: + + - a single Tensor with `input_ids` only and nothing else: `model(input_ids)` + - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: + `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])` + - a dictionary with one or several input Tensors associated to the input names given in the docstring: + `model({"input_ids": input_ids, "token_type_ids": token_type_ids})` + + Note that when creating models and layers with + [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry + about any of this, as you can just pass inputs like you would to any other Python function! + + + + Parameters: + config ([`ElectraConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +ELECTRA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`Numpy array` or `tf.Tensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and + [`PreTrainedTokenizer.encode`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`Numpy array` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`tf.Tensor` of shape `({0}, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the + config will be used instead. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. This argument can be used only in eager mode, in graph mode the value in the config will be + used instead. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in + eager mode, in graph mode the value will always be set to True. + training (`bool`, *optional*, defaults to `False`): + Whether or not to use the model in training mode (some modules like dropout modules have different + behaviors between training and evaluation). +""" + + +@add_start_docstrings( + "The bare Electra Model transformer outputting raw hidden-states without any specific head on top. Identical to " + "the BERT model except that it uses an additional linear layer between the embedding layer and the encoder if the " + "hidden size and embedding size are different. " + "" + "Both the generator and discriminator checkpoints may be loaded into this model.", + ELECTRA_START_DOCSTRING, +) +class TFElectraModel(TFElectraPreTrainedModel): + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.electra = TFElectraMainLayer(config, name="electra") + + @unpack_inputs + @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFBaseModelOutputWithPastAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + encoder_hidden_states: np.ndarray | tf.Tensor | None = None, + encoder_attention_mask: np.ndarray | tf.Tensor | None = None, + past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: Optional[bool] = False, + ) -> Union[TFBaseModelOutputWithPastAndCrossAttentions, Tuple[tf.Tensor]]: + r""" + encoder_hidden_states (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + past_key_values (`Tuple[Tuple[tf.Tensor]]` of length `config.n_layers`) + contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*, defaults to `True`): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). Set to `False` during training, `True` during generation + """ + outputs = self.electra( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "electra", None) is not None: + with tf.name_scope(self.electra.name): + self.electra.build(None) + + +@add_start_docstrings( + """ + Electra model with a binary classification head on top as used during pretraining for identifying generated tokens. + + Even though both the discriminator and generator may be loaded into this model, the discriminator is the only model + of the two to have the correct classification head to be used for this model. + """, + ELECTRA_START_DOCSTRING, +) +class TFElectraForPreTraining(TFElectraPreTrainedModel): + def __init__(self, config, **kwargs): + super().__init__(config, **kwargs) + + self.electra = TFElectraMainLayer(config, name="electra") + self.discriminator_predictions = TFElectraDiscriminatorPredictions(config, name="discriminator_predictions") + + @unpack_inputs + @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=TFElectraForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: Optional[bool] = False, + ) -> Union[TFElectraForPreTrainingOutput, Tuple[tf.Tensor]]: + r""" + Returns: + + Examples: + + ```python + >>> import tensorflow as tf + >>> from transformers import AutoTokenizer, TFElectraForPreTraining + + >>> tokenizer = AutoTokenizer.from_pretrained("google/electra-small-discriminator") + >>> model = TFElectraForPreTraining.from_pretrained("google/electra-small-discriminator") + >>> input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1 + >>> outputs = model(input_ids) + >>> scores = outputs[0] + ```""" + discriminator_hidden_states = self.electra( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + discriminator_sequence_output = discriminator_hidden_states[0] + logits = self.discriminator_predictions(discriminator_sequence_output) + + if not return_dict: + return (logits,) + discriminator_hidden_states[1:] + + return TFElectraForPreTrainingOutput( + logits=logits, + hidden_states=discriminator_hidden_states.hidden_states, + attentions=discriminator_hidden_states.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "electra", None) is not None: + with tf.name_scope(self.electra.name): + self.electra.build(None) + if getattr(self, "discriminator_predictions", None) is not None: + with tf.name_scope(self.discriminator_predictions.name): + self.discriminator_predictions.build(None) + + +class TFElectraMaskedLMHead(keras.layers.Layer): + def __init__(self, config, input_embeddings, **kwargs): + super().__init__(**kwargs) + + self.config = config + self.embedding_size = config.embedding_size + self.input_embeddings = input_embeddings + + def build(self, input_shape): + self.bias = self.add_weight(shape=(self.config.vocab_size,), initializer="zeros", trainable=True, name="bias") + + super().build(input_shape) + + def get_output_embeddings(self): + return self.input_embeddings + + def set_output_embeddings(self, value): + self.input_embeddings.weight = value + self.input_embeddings.vocab_size = shape_list(value)[0] + + def get_bias(self): + return {"bias": self.bias} + + def set_bias(self, value): + self.bias = value["bias"] + self.config.vocab_size = shape_list(value["bias"])[0] + + def call(self, hidden_states): + seq_length = shape_list(tensor=hidden_states)[1] + hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.embedding_size]) + hidden_states = tf.matmul(a=hidden_states, b=self.input_embeddings.weight, transpose_b=True) + hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.config.vocab_size]) + hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.bias) + + return hidden_states + + +@add_start_docstrings( + """ + Electra model with a language modeling head on top. + + Even though both the discriminator and generator may be loaded into this model, the generator is the only model of + the two to have been trained for the masked language modeling task. + """, + ELECTRA_START_DOCSTRING, +) +class TFElectraForMaskedLM(TFElectraPreTrainedModel, TFMaskedLanguageModelingLoss): + def __init__(self, config, **kwargs): + super().__init__(config, **kwargs) + + self.config = config + self.electra = TFElectraMainLayer(config, name="electra") + self.generator_predictions = TFElectraGeneratorPredictions(config, name="generator_predictions") + + if isinstance(config.hidden_act, str): + self.activation = get_tf_activation(config.hidden_act) + else: + self.activation = config.hidden_act + + self.generator_lm_head = TFElectraMaskedLMHead(config, self.electra.embeddings, name="generator_lm_head") + + def get_lm_head(self): + return self.generator_lm_head + + def get_prefix_bias_name(self): + warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) + return self.name + "/" + self.generator_lm_head.name + + @unpack_inputs + @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint="google/electra-small-generator", + output_type=TFMaskedLMOutput, + config_class=_CONFIG_FOR_DOC, + mask="[MASK]", + expected_output="'paris'", + expected_loss=1.22, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: np.ndarray | tf.Tensor | None = None, + training: Optional[bool] = False, + ) -> Union[TFMaskedLMOutput, Tuple[tf.Tensor]]: + r""" + labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + """ + generator_hidden_states = self.electra( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + generator_sequence_output = generator_hidden_states[0] + prediction_scores = self.generator_predictions(generator_sequence_output, training=training) + prediction_scores = self.generator_lm_head(prediction_scores, training=training) + loss = None if labels is None else self.hf_compute_loss(labels, prediction_scores) + + if not return_dict: + output = (prediction_scores,) + generator_hidden_states[1:] + + return ((loss,) + output) if loss is not None else output + + return TFMaskedLMOutput( + loss=loss, + logits=prediction_scores, + hidden_states=generator_hidden_states.hidden_states, + attentions=generator_hidden_states.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "electra", None) is not None: + with tf.name_scope(self.electra.name): + self.electra.build(None) + if getattr(self, "generator_predictions", None) is not None: + with tf.name_scope(self.generator_predictions.name): + self.generator_predictions.build(None) + if getattr(self, "generator_lm_head", None) is not None: + with tf.name_scope(self.generator_lm_head.name): + self.generator_lm_head.build(None) + + +class TFElectraClassificationHead(keras.layers.Layer): + """Head for sentence-level classification tasks.""" + + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense( + config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + classifier_dropout = ( + config.classifhidden_dropout_probier_dropout + if config.classifier_dropout is not None + else config.hidden_dropout_prob + ) + self.dropout = keras.layers.Dropout(classifier_dropout) + self.out_proj = keras.layers.Dense( + config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="out_proj" + ) + self.config = config + + def call(self, inputs, **kwargs): + x = inputs[:, 0, :] # take token (equiv. to [CLS]) + x = self.dropout(x) + x = self.dense(x) + x = get_tf_activation("gelu")(x) # although BERT uses tanh here, it seems Electra authors used gelu here + x = self.dropout(x) + x = self.out_proj(x) + + return x + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + if getattr(self, "out_proj", None) is not None: + with tf.name_scope(self.out_proj.name): + self.out_proj.build([None, None, self.config.hidden_size]) + + +@add_start_docstrings( + """ + ELECTRA Model transformer with a sequence classification/regression head on top (a linear layer on top of the + pooled output) e.g. for GLUE tasks. + """, + ELECTRA_START_DOCSTRING, +) +class TFElectraForSequenceClassification(TFElectraPreTrainedModel, TFSequenceClassificationLoss): + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.num_labels = config.num_labels + self.electra = TFElectraMainLayer(config, name="electra") + self.classifier = TFElectraClassificationHead(config, name="classifier") + + @unpack_inputs + @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint="bhadresh-savani/electra-base-emotion", + output_type=TFSequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_output="'joy'", + expected_loss=0.06, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: np.ndarray | tf.Tensor | None = None, + training: Optional[bool] = False, + ) -> Union[TFSequenceClassifierOutput, Tuple[tf.Tensor]]: + r""" + labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + outputs = self.electra( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + logits = self.classifier(outputs[0]) + loss = None if labels is None else self.hf_compute_loss(labels, logits) + + if not return_dict: + output = (logits,) + outputs[1:] + + return ((loss,) + output) if loss is not None else output + + return TFSequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "electra", None) is not None: + with tf.name_scope(self.electra.name): + self.electra.build(None) + if getattr(self, "classifier", None) is not None: + with tf.name_scope(self.classifier.name): + self.classifier.build(None) + + +@add_start_docstrings( + """ + ELECTRA Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a + softmax) e.g. for RocStories/SWAG tasks. + """, + ELECTRA_START_DOCSTRING, +) +class TFElectraForMultipleChoice(TFElectraPreTrainedModel, TFMultipleChoiceLoss): + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.electra = TFElectraMainLayer(config, name="electra") + self.sequence_summary = TFSequenceSummary( + config, initializer_range=config.initializer_range, name="sequence_summary" + ) + self.classifier = keras.layers.Dense( + 1, kernel_initializer=get_initializer(config.initializer_range), name="classifier" + ) + self.config = config + + @unpack_inputs + @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFMultipleChoiceModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: np.ndarray | tf.Tensor | None = None, + training: Optional[bool] = False, + ) -> Union[TFMultipleChoiceModelOutput, Tuple[tf.Tensor]]: + r""" + labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` + where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above) + """ + + if input_ids is not None: + num_choices = shape_list(input_ids)[1] + seq_length = shape_list(input_ids)[2] + else: + num_choices = shape_list(inputs_embeds)[1] + seq_length = shape_list(inputs_embeds)[2] + + flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None + flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None + flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None + flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None + flat_inputs_embeds = ( + tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3])) + if inputs_embeds is not None + else None + ) + outputs = self.electra( + input_ids=flat_input_ids, + attention_mask=flat_attention_mask, + token_type_ids=flat_token_type_ids, + position_ids=flat_position_ids, + head_mask=head_mask, + inputs_embeds=flat_inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + logits = self.sequence_summary(outputs[0]) + logits = self.classifier(logits) + reshaped_logits = tf.reshape(logits, (-1, num_choices)) + loss = None if labels is None else self.hf_compute_loss(labels, reshaped_logits) + + if not return_dict: + output = (reshaped_logits,) + outputs[1:] + + return ((loss,) + output) if loss is not None else output + + return TFMultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "electra", None) is not None: + with tf.name_scope(self.electra.name): + self.electra.build(None) + if getattr(self, "sequence_summary", None) is not None: + with tf.name_scope(self.sequence_summary.name): + self.sequence_summary.build(None) + if getattr(self, "classifier", None) is not None: + with tf.name_scope(self.classifier.name): + self.classifier.build([None, None, self.config.hidden_size]) + + +@add_start_docstrings( + """ + Electra model with a token classification head on top. + + Both the discriminator and generator may be loaded into this model. + """, + ELECTRA_START_DOCSTRING, +) +class TFElectraForTokenClassification(TFElectraPreTrainedModel, TFTokenClassificationLoss): + def __init__(self, config, **kwargs): + super().__init__(config, **kwargs) + + self.electra = TFElectraMainLayer(config, name="electra") + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = keras.layers.Dropout(classifier_dropout) + self.classifier = keras.layers.Dense( + config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" + ) + self.config = config + + @unpack_inputs + @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint="bhadresh-savani/electra-base-discriminator-finetuned-conll03-english", + output_type=TFTokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_output="['B-LOC', 'B-ORG', 'O', 'O', 'O', 'O', 'O', 'B-LOC', 'O', 'B-LOC', 'I-LOC']", + expected_loss=0.11, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: np.ndarray | tf.Tensor | None = None, + training: Optional[bool] = False, + ) -> Union[TFTokenClassifierOutput, Tuple[tf.Tensor]]: + r""" + labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + discriminator_hidden_states = self.electra( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + discriminator_sequence_output = discriminator_hidden_states[0] + discriminator_sequence_output = self.dropout(discriminator_sequence_output) + logits = self.classifier(discriminator_sequence_output) + loss = None if labels is None else self.hf_compute_loss(labels, logits) + + if not return_dict: + output = (logits,) + discriminator_hidden_states[1:] + + return ((loss,) + output) if loss is not None else output + + return TFTokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=discriminator_hidden_states.hidden_states, + attentions=discriminator_hidden_states.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "electra", None) is not None: + with tf.name_scope(self.electra.name): + self.electra.build(None) + if getattr(self, "classifier", None) is not None: + with tf.name_scope(self.classifier.name): + self.classifier.build([None, None, self.config.hidden_size]) + + +@add_start_docstrings( + """ + Electra Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear + layers on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + ELECTRA_START_DOCSTRING, +) +class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnsweringLoss): + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.num_labels = config.num_labels + self.electra = TFElectraMainLayer(config, name="electra") + self.qa_outputs = keras.layers.Dense( + config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs" + ) + self.config = config + + @unpack_inputs + @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint="bhadresh-savani/electra-base-squad2", + output_type=TFQuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + qa_target_start_index=11, + qa_target_end_index=12, + expected_output="'a nice puppet'", + expected_loss=2.64, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + token_type_ids: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + start_positions: np.ndarray | tf.Tensor | None = None, + end_positions: np.ndarray | tf.Tensor | None = None, + training: Optional[bool] = False, + ) -> Union[TFQuestionAnsweringModelOutput, Tuple[tf.Tensor]]: + r""" + start_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + discriminator_hidden_states = self.electra( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + discriminator_sequence_output = discriminator_hidden_states[0] + logits = self.qa_outputs(discriminator_sequence_output) + start_logits, end_logits = tf.split(logits, 2, axis=-1) + start_logits = tf.squeeze(start_logits, axis=-1) + end_logits = tf.squeeze(end_logits, axis=-1) + loss = None + + if start_positions is not None and end_positions is not None: + labels = {"start_position": start_positions} + labels["end_position"] = end_positions + loss = self.hf_compute_loss(labels, (start_logits, end_logits)) + + if not return_dict: + output = ( + start_logits, + end_logits, + ) + discriminator_hidden_states[1:] + + return ((loss,) + output) if loss is not None else output + + return TFQuestionAnsweringModelOutput( + loss=loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=discriminator_hidden_states.hidden_states, + attentions=discriminator_hidden_states.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "electra", None) is not None: + with tf.name_scope(self.electra.name): + self.electra.build(None) + if getattr(self, "qa_outputs", None) is not None: + with tf.name_scope(self.qa_outputs.name): + self.qa_outputs.build([None, None, self.config.hidden_size]) + + +__all__ = [ + "TFElectraForMaskedLM", + "TFElectraForMultipleChoice", + "TFElectraForPreTraining", + "TFElectraForQuestionAnswering", + "TFElectraForSequenceClassification", + "TFElectraForTokenClassification", + "TFElectraModel", + "TFElectraPreTrainedModel", +] diff --git a/docs/transformers/build/lib/transformers/models/electra/tokenization_electra.py b/docs/transformers/build/lib/transformers/models/electra/tokenization_electra.py new file mode 100644 index 0000000000000000000000000000000000000000..3b21527e6cdae25e1bdc40a81a01f3b2f014ffb5 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/electra/tokenization_electra.py @@ -0,0 +1,511 @@ +# coding=utf-8 +# Copyright 2020 The Google AI Team, Stanford University and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +import os +import unicodedata +from typing import List, Optional, Tuple + +from ...tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace +from ...utils import logging + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} + + +# Copied from transformers.models.bert.tokenization_bert.load_vocab +def load_vocab(vocab_file): + """Loads a vocabulary file into a dictionary.""" + vocab = collections.OrderedDict() + with open(vocab_file, "r", encoding="utf-8") as reader: + tokens = reader.readlines() + for index, token in enumerate(tokens): + token = token.rstrip("\n") + vocab[token] = index + return vocab + + +# Copied from transformers.models.bert.tokenization_bert.whitespace_tokenize +def whitespace_tokenize(text): + """Runs basic whitespace cleaning and splitting on a piece of text.""" + text = text.strip() + if not text: + return [] + tokens = text.split() + return tokens + + +# Copied from transformers.models.bert.tokenization_bert.BertTokenizer with Bert->Electra,BERT->Electra +class ElectraTokenizer(PreTrainedTokenizer): + r""" + Construct a Electra tokenizer. Based on WordPiece. + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + File containing the vocabulary. + do_lower_case (`bool`, *optional*, defaults to `True`): + Whether or not to lowercase the input when tokenizing. + do_basic_tokenize (`bool`, *optional*, defaults to `True`): + Whether or not to do basic tokenization before WordPiece. + never_split (`Iterable`, *optional*): + Collection of tokens which will never be split during tokenization. Only has an effect when + `do_basic_tokenize=True` + unk_token (`str`, *optional*, defaults to `"[UNK]"`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + sep_token (`str`, *optional*, defaults to `"[SEP]"`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + pad_token (`str`, *optional*, defaults to `"[PAD]"`): + The token used for padding, for example when batching sequences of different lengths. + cls_token (`str`, *optional*, defaults to `"[CLS]"`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + mask_token (`str`, *optional*, defaults to `"[MASK]"`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + tokenize_chinese_chars (`bool`, *optional*, defaults to `True`): + Whether or not to tokenize Chinese characters. + + This should likely be deactivated for Japanese (see this + [issue](https://github.com/huggingface/transformers/issues/328)). + strip_accents (`bool`, *optional*): + Whether or not to strip all accents. If this option is not specified, then it will be determined by the + value for `lowercase` (as in the original Electra). + clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`): + Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like + extra spaces. + """ + + vocab_files_names = VOCAB_FILES_NAMES + + def __init__( + self, + vocab_file, + do_lower_case=True, + do_basic_tokenize=True, + never_split=None, + unk_token="[UNK]", + sep_token="[SEP]", + pad_token="[PAD]", + cls_token="[CLS]", + mask_token="[MASK]", + tokenize_chinese_chars=True, + strip_accents=None, + clean_up_tokenization_spaces=True, + **kwargs, + ): + if not os.path.isfile(vocab_file): + raise ValueError( + f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained" + " model use `tokenizer = ElectraTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`" + ) + self.vocab = load_vocab(vocab_file) + self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()]) + self.do_basic_tokenize = do_basic_tokenize + if do_basic_tokenize: + self.basic_tokenizer = BasicTokenizer( + do_lower_case=do_lower_case, + never_split=never_split, + tokenize_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, + ) + + self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=str(unk_token)) + + super().__init__( + do_lower_case=do_lower_case, + do_basic_tokenize=do_basic_tokenize, + never_split=never_split, + unk_token=unk_token, + sep_token=sep_token, + pad_token=pad_token, + cls_token=cls_token, + mask_token=mask_token, + tokenize_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + **kwargs, + ) + + @property + def do_lower_case(self): + return self.basic_tokenizer.do_lower_case + + @property + def vocab_size(self): + return len(self.vocab) + + def get_vocab(self): + return dict(self.vocab, **self.added_tokens_encoder) + + def _tokenize(self, text, split_special_tokens=False): + split_tokens = [] + if self.do_basic_tokenize: + for token in self.basic_tokenizer.tokenize( + text, never_split=self.all_special_tokens if not split_special_tokens else None + ): + # If the token is part of the never_split set + if token in self.basic_tokenizer.never_split: + split_tokens.append(token) + else: + split_tokens += self.wordpiece_tokenizer.tokenize(token) + else: + split_tokens = self.wordpiece_tokenizer.tokenize(text) + return split_tokens + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.vocab.get(token, self.vocab.get(self.unk_token)) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.ids_to_tokens.get(index, self.unk_token) + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + out_string = " ".join(tokens).replace(" ##", "").strip() + return out_string + + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A Electra sequence has the following format: + + - single sequence: `[CLS] X [SEP]` + - pair of sequences: `[CLS] A [SEP] B [SEP]` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + cls = [self.cls_token_id] + sep = [self.sep_token_id] + return cls + token_ids_0 + sep + token_ids_1 + sep + + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + if token_ids_1 is not None: + return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1] + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. A Electra sequence + pair mask has the following format: + + ``` + 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 + | first sequence | second sequence | + ``` + + If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s). + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s). + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + index = 0 + if os.path.isdir(save_directory): + vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + else: + vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory + with open(vocab_file, "w", encoding="utf-8") as writer: + for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning( + f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive." + " Please check that the vocabulary is not corrupted!" + ) + index = token_index + writer.write(token + "\n") + index += 1 + return (vocab_file,) + + +# Copied from transformers.models.bert.tokenization_bert.BasicTokenizer +class BasicTokenizer: + """ + Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.). + + Args: + do_lower_case (`bool`, *optional*, defaults to `True`): + Whether or not to lowercase the input when tokenizing. + never_split (`Iterable`, *optional*): + Collection of tokens which will never be split during tokenization. Only has an effect when + `do_basic_tokenize=True` + tokenize_chinese_chars (`bool`, *optional*, defaults to `True`): + Whether or not to tokenize Chinese characters. + + This should likely be deactivated for Japanese (see this + [issue](https://github.com/huggingface/transformers/issues/328)). + strip_accents (`bool`, *optional*): + Whether or not to strip all accents. If this option is not specified, then it will be determined by the + value for `lowercase` (as in the original BERT). + do_split_on_punc (`bool`, *optional*, defaults to `True`): + In some instances we want to skip the basic punctuation splitting so that later tokenization can capture + the full context of the words, such as contractions. + """ + + def __init__( + self, + do_lower_case=True, + never_split=None, + tokenize_chinese_chars=True, + strip_accents=None, + do_split_on_punc=True, + ): + if never_split is None: + never_split = [] + self.do_lower_case = do_lower_case + self.never_split = set(never_split) + self.tokenize_chinese_chars = tokenize_chinese_chars + self.strip_accents = strip_accents + self.do_split_on_punc = do_split_on_punc + + def tokenize(self, text, never_split=None): + """ + Basic Tokenization of a piece of text. For sub-word tokenization, see WordPieceTokenizer. + + Args: + never_split (`List[str]`, *optional*) + Kept for backward compatibility purposes. Now implemented directly at the base class level (see + [`PreTrainedTokenizer.tokenize`]) List of token not to split. + """ + # union() returns a new set by concatenating the two sets. + never_split = self.never_split.union(set(never_split)) if never_split else self.never_split + text = self._clean_text(text) + + # This was added on November 1st, 2018 for the multilingual and Chinese + # models. This is also applied to the English models now, but it doesn't + # matter since the English models were not trained on any Chinese data + # and generally don't have any Chinese data in them (there are Chinese + # characters in the vocabulary because Wikipedia does have some Chinese + # words in the English Wikipedia.). + if self.tokenize_chinese_chars: + text = self._tokenize_chinese_chars(text) + # prevents treating the same character with different unicode codepoints as different characters + unicode_normalized_text = unicodedata.normalize("NFC", text) + orig_tokens = whitespace_tokenize(unicode_normalized_text) + split_tokens = [] + for token in orig_tokens: + if token not in never_split: + if self.do_lower_case: + token = token.lower() + if self.strip_accents is not False: + token = self._run_strip_accents(token) + elif self.strip_accents: + token = self._run_strip_accents(token) + split_tokens.extend(self._run_split_on_punc(token, never_split)) + + output_tokens = whitespace_tokenize(" ".join(split_tokens)) + return output_tokens + + def _run_strip_accents(self, text): + """Strips accents from a piece of text.""" + text = unicodedata.normalize("NFD", text) + output = [] + for char in text: + cat = unicodedata.category(char) + if cat == "Mn": + continue + output.append(char) + return "".join(output) + + def _run_split_on_punc(self, text, never_split=None): + """Splits punctuation on a piece of text.""" + if not self.do_split_on_punc or (never_split is not None and text in never_split): + return [text] + chars = list(text) + i = 0 + start_new_word = True + output = [] + while i < len(chars): + char = chars[i] + if _is_punctuation(char): + output.append([char]) + start_new_word = True + else: + if start_new_word: + output.append([]) + start_new_word = False + output[-1].append(char) + i += 1 + + return ["".join(x) for x in output] + + def _tokenize_chinese_chars(self, text): + """Adds whitespace around any CJK character.""" + output = [] + for char in text: + cp = ord(char) + if self._is_chinese_char(cp): + output.append(" ") + output.append(char) + output.append(" ") + else: + output.append(char) + return "".join(output) + + def _is_chinese_char(self, cp): + """Checks whether CP is the codepoint of a CJK character.""" + # This defines a "chinese character" as anything in the CJK Unicode block: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + # + # Note that the CJK Unicode block is NOT all Japanese and Korean characters, + # despite its name. The modern Korean Hangul alphabet is a different block, + # as is Japanese Hiragana and Katakana. Those alphabets are used to write + # space-separated words, so they are not treated specially and handled + # like the all of the other languages. + if ( + (cp >= 0x4E00 and cp <= 0x9FFF) + or (cp >= 0x3400 and cp <= 0x4DBF) # + or (cp >= 0x20000 and cp <= 0x2A6DF) # + or (cp >= 0x2A700 and cp <= 0x2B73F) # + or (cp >= 0x2B740 and cp <= 0x2B81F) # + or (cp >= 0x2B820 and cp <= 0x2CEAF) # + or (cp >= 0xF900 and cp <= 0xFAFF) + or (cp >= 0x2F800 and cp <= 0x2FA1F) # + ): # + return True + + return False + + def _clean_text(self, text): + """Performs invalid character removal and whitespace cleanup on text.""" + output = [] + for char in text: + cp = ord(char) + if cp == 0 or cp == 0xFFFD or _is_control(char): + continue + if _is_whitespace(char): + output.append(" ") + else: + output.append(char) + return "".join(output) + + +# Copied from transformers.models.bert.tokenization_bert.WordpieceTokenizer +class WordpieceTokenizer: + """Runs WordPiece tokenization.""" + + def __init__(self, vocab, unk_token, max_input_chars_per_word=100): + self.vocab = vocab + self.unk_token = unk_token + self.max_input_chars_per_word = max_input_chars_per_word + + def tokenize(self, text): + """ + Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform + tokenization using the given vocabulary. + + For example, `input = "unaffable"` wil return as output `["un", "##aff", "##able"]`. + + Args: + text: A single token or whitespace separated tokens. This should have + already been passed through *BasicTokenizer*. + + Returns: + A list of wordpiece tokens. + """ + + output_tokens = [] + for token in whitespace_tokenize(text): + chars = list(token) + if len(chars) > self.max_input_chars_per_word: + output_tokens.append(self.unk_token) + continue + + is_bad = False + start = 0 + sub_tokens = [] + while start < len(chars): + end = len(chars) + cur_substr = None + while start < end: + substr = "".join(chars[start:end]) + if start > 0: + substr = "##" + substr + if substr in self.vocab: + cur_substr = substr + break + end -= 1 + if cur_substr is None: + is_bad = True + break + sub_tokens.append(cur_substr) + start = end + + if is_bad: + output_tokens.append(self.unk_token) + else: + output_tokens.extend(sub_tokens) + return output_tokens + + +__all__ = ["ElectraTokenizer"] diff --git a/docs/transformers/build/lib/transformers/models/electra/tokenization_electra_fast.py b/docs/transformers/build/lib/transformers/models/electra/tokenization_electra_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..34ea4339b9382b28c6f9a5842f88af8807f9f928 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/electra/tokenization_electra_fast.py @@ -0,0 +1,172 @@ +# coding=utf-8 +# Copyright 2020 The Google AI Team, Stanford University and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from typing import List, Optional, Tuple + +from tokenizers import normalizers + +from ...tokenization_utils_fast import PreTrainedTokenizerFast +from .tokenization_electra import ElectraTokenizer + + +VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.json"} + + +# Copied from transformers.models.bert.tokenization_bert_fast.BertTokenizerFast with Bert->Electra , BERT->ELECTRA +class ElectraTokenizerFast(PreTrainedTokenizerFast): + r""" + Construct a "fast" ELECTRA tokenizer (backed by HuggingFace's *tokenizers* library). Based on WordPiece. + + This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should + refer to this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + File containing the vocabulary. + do_lower_case (`bool`, *optional*, defaults to `True`): + Whether or not to lowercase the input when tokenizing. + unk_token (`str`, *optional*, defaults to `"[UNK]"`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + sep_token (`str`, *optional*, defaults to `"[SEP]"`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + pad_token (`str`, *optional*, defaults to `"[PAD]"`): + The token used for padding, for example when batching sequences of different lengths. + cls_token (`str`, *optional*, defaults to `"[CLS]"`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + mask_token (`str`, *optional*, defaults to `"[MASK]"`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + clean_text (`bool`, *optional*, defaults to `True`): + Whether or not to clean the text before tokenization by removing any control characters and replacing all + whitespaces by the classic one. + tokenize_chinese_chars (`bool`, *optional*, defaults to `True`): + Whether or not to tokenize Chinese characters. This should likely be deactivated for Japanese (see [this + issue](https://github.com/huggingface/transformers/issues/328)). + strip_accents (`bool`, *optional*): + Whether or not to strip all accents. If this option is not specified, then it will be determined by the + value for `lowercase` (as in the original ELECTRA). + wordpieces_prefix (`str`, *optional*, defaults to `"##"`): + The prefix for subwords. + """ + + vocab_files_names = VOCAB_FILES_NAMES + slow_tokenizer_class = ElectraTokenizer + + def __init__( + self, + vocab_file=None, + tokenizer_file=None, + do_lower_case=True, + unk_token="[UNK]", + sep_token="[SEP]", + pad_token="[PAD]", + cls_token="[CLS]", + mask_token="[MASK]", + tokenize_chinese_chars=True, + strip_accents=None, + **kwargs, + ): + super().__init__( + vocab_file, + tokenizer_file=tokenizer_file, + do_lower_case=do_lower_case, + unk_token=unk_token, + sep_token=sep_token, + pad_token=pad_token, + cls_token=cls_token, + mask_token=mask_token, + tokenize_chinese_chars=tokenize_chinese_chars, + strip_accents=strip_accents, + **kwargs, + ) + + normalizer_state = json.loads(self.backend_tokenizer.normalizer.__getstate__()) + if ( + normalizer_state.get("lowercase", do_lower_case) != do_lower_case + or normalizer_state.get("strip_accents", strip_accents) != strip_accents + or normalizer_state.get("handle_chinese_chars", tokenize_chinese_chars) != tokenize_chinese_chars + ): + normalizer_class = getattr(normalizers, normalizer_state.pop("type")) + normalizer_state["lowercase"] = do_lower_case + normalizer_state["strip_accents"] = strip_accents + normalizer_state["handle_chinese_chars"] = tokenize_chinese_chars + self.backend_tokenizer.normalizer = normalizer_class(**normalizer_state) + + self.do_lower_case = do_lower_case + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A ELECTRA sequence has the following format: + + - single sequence: `[CLS] X [SEP]` + - pair of sequences: `[CLS] A [SEP] B [SEP]` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + + if token_ids_1 is not None: + output += token_ids_1 + [self.sep_token_id] + + return output + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. A ELECTRA sequence + pair mask has the following format: + + ``` + 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 + | first sequence | second sequence | + ``` + + If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s). + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s). + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + files = self._tokenizer.model.save(save_directory, name=filename_prefix) + return tuple(files) + + +__all__ = ["ElectraTokenizerFast"] diff --git a/docs/transformers/build/lib/transformers/models/emu3/__init__.py b/docs/transformers/build/lib/transformers/models/emu3/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d8555f58d1866451c38abb5559ef5bef9545f0b0 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/emu3/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_emu3 import * + from .image_processing_emu3 import * + from .modeling_emu3 import * + from .processing_emu3 import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/build/lib/transformers/models/emu3/configuration_emu3.py b/docs/transformers/build/lib/transformers/models/emu3/configuration_emu3.py new file mode 100644 index 0000000000000000000000000000000000000000..5b5abedf4016d5959c8eeea9a3d955470c8b1f13 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/emu3/configuration_emu3.py @@ -0,0 +1,327 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, Optional, Union + +from ...configuration_utils import PretrainedConfig +from ...modeling_rope_utils import rope_config_validation + + +class Emu3VQVAEConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Emu3VQVAE`]. It is used to instantiate an VQ-VAE + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a configuration to the VQ model presented in Emu3 paper. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + codebook_size (`int`, *optional*, defaults to 32768): + Codebook size of the VQ model. + embed_dim (`int`, *optional*, defaults to 4): + Dimension of the quantized vector in codebook. + latent_channels (`int`, *optional*, defaults to 4): + Dimension of the output channel of encoder and the input channel of decoder + double_latent (`bool`, *optional*, defaults to `False`): + Whether double the output dim of the encoder. + in_channels (`int`, *optional*, defaults to 3): + Input channel of encoder. + out_channels (`int`, *optional*, defaults to 3): + Output channel of decoder. + temporal_downsample_factor (`int`, *optional*, defaults to 4): + Temporal downsample factor. + base_channels (`int`, *optional*, defaults to 256): + Basic channel number of the intermediate blocks. + channel_multiplier (`List[int]`, *optional*, defaults to `[1, 2, 2, 4]`): + Channel scaling factor of the intermediate blocks. + num_res_blocks (`int`, *optional*, defaults to 2): + Residual block number in each stage. + attn_resolutions (`List[int]`, *optional*, defaults to `[3]`): + Stage indices to apply attention. + hidden_size (`int`, *optional*, defaults to 1024): + Dimension of the hidden representations in the attention layer. + num_attention_heads (`int`, *optional*, defaults to 1): + Number of attention heads for each attention layer. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + + ```python + >>> from transformers import Emu3VQVAE, Emu3VQVAEConfig + + >>> # Initializing a video VQ model of Emu3 configuration + >>> configuration = Emu3VQVAEConfig() + + >>> # Initializing a model from the Emu3 VQ model style configuration + >>> model = Emu3VQVAE(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "emu3_vqgan" + base_config_key = "vq_config" + + def __init__( + self, + codebook_size: int = 32768, + embed_dim: int = 4, + latent_channels: int = 4, + double_latent: bool = False, + in_channels: int = 3, + out_channels: int = 3, + temporal_downsample_factor: int = 4, + base_channels: int = 256, + channel_multiplier: List[int] = [1, 2, 2, 4], + num_res_blocks: int = 2, + attn_resolutions: List[int] = [3], + hidden_size: int = 1024, + num_attention_heads: int = 1, + attention_dropout: float = 0.0, + **kwargs, + ): + super().__init__(**kwargs) + + self.codebook_size = codebook_size + self.embed_dim = embed_dim + self.latent_channels = latent_channels + self.double_latent = double_latent + self.in_channels = in_channels + self.out_channels = out_channels + self.temporal_downsample_factor = temporal_downsample_factor + self.base_channels = base_channels + self.channel_multiplier = channel_multiplier + self.num_res_blocks = num_res_blocks + self.attn_resolutions = attn_resolutions + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.attention_dropout = attention_dropout + + +class Emu3TextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Emu3TextModel`]. It is used to instantiate a + emu3 model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the + [Emu3-community/Emu3-Chat-hf](https://huggingface.co/Emu3-community/Emu3-Chat-hf). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 184622): + Vocabulary size of the Emu3 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Emu3Model`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 14336): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 9216): + The maximum sequence length that this model might ever be used with. Emu supports up to 9216 tokens, + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*, defaults to 151643): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 151849): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 151850): + End of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 1000000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + mlp_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers. + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + + + ```python + >>> from transformers import Emu3Model, Emu3Config + + >>> # Initializing a Emu3-community/Emu3-Chat-hf style configuration + >>> configuration = Emu3Config() + + >>> # Initializing a model from the Emu3-community/Emu3-Chat-hf style configuration + >>> model = Emu3Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "emu3_text_model" + base_config_key = "text_config" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size: int = 184622, + hidden_size: int = 4096, + intermediate_size: int = 14336, + num_hidden_layers: int = 32, + num_attention_heads: int = 32, + num_key_value_heads: Optional[int] = 8, + hidden_act: str = "silu", + max_position_embeddings: int = 9216, + rms_norm_eps: float = 1e-5, + use_cache: bool = True, + pad_token_id: int = 151643, + bos_token_id: int = 151849, + eos_token_id: int = 151850, + tie_word_embeddings: bool = False, + rope_theta: float = 1000000.0, + rope_scaling: Optional = None, + mlp_bias=False, + attention_bias=False, + attention_dropout: float = 0.1, + initializer_range: float = 0.02, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.mlp_bias = mlp_bias + self.attention_bias = attention_bias + self.initializer_range = initializer_range + rope_config_validation(self) + + self.attention_dropout = attention_dropout + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +class Emu3Config(PretrainedConfig): + """ + This is the configuration class to store the configuration of a [`Emu3Model`]. It is used to instantiate a + emu3 model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the + [Emu3-community/Emu3-Chat-hf](https://huggingface.co/Emu3-community/Emu3-Chat-hf). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vq_config (`Union[Dict, Emu3VQVAEConfig]`, *optional*): + Emu3VQVAEConfig instance containing the configuration for the VQ-VAE model. + text_config (`Union[Dict, Emu3TextConfig]``, *optional*): + Emu3TextConfig instance containing the configuration for the language model. + vocabulary_map (`dict`, *optional*): + A dictionary containing the vocabulary map from the tokenizer. Used to obtain tokens from the image inputs. + """ + + model_type = "emu3" + keys_to_ignore_at_inference = ["past_key_values"] + sub_configs = {"text_config": Emu3TextConfig, "vq_config": Emu3VQVAEConfig} + + def __init__( + self, + vq_config: Union[Dict, Emu3VQVAEConfig] = None, + text_config: Union[Dict, Emu3TextConfig] = None, + vocabulary_map: Dict[int, int] = None, + **kwargs, + ): + if vq_config is None: + vq_config = Emu3VQVAEConfig() + elif isinstance(vq_config, dict): + vq_config = Emu3VQVAEConfig(**vq_config) + + if text_config is None: + text_config = Emu3TextConfig() + elif isinstance(text_config, dict): + text_config = Emu3TextConfig(**text_config) + + self.vq_config = vq_config + self.text_config = text_config + self.vocabulary_map = vocabulary_map + + super().__init__(**kwargs) + + +__all__ = ["Emu3Config", "Emu3TextConfig", "Emu3VQVAEConfig"] diff --git a/docs/transformers/build/lib/transformers/models/emu3/convert_emu3_weights_to_hf.py b/docs/transformers/build/lib/transformers/models/emu3/convert_emu3_weights_to_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..8ac8db7e429031ee5157532bb8f4fb044844d281 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/emu3/convert_emu3_weights_to_hf.py @@ -0,0 +1,448 @@ +# Copyright 2024 The Emu team, BAAI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import json +import os +import re +from typing import Dict, Optional + +import requests +import torch +from accelerate import init_empty_weights +from PIL import Image + +from transformers import ( + AutoModel, + AutoModelForCausalLM, + AutoTokenizer, + Emu3Config, + Emu3ForConditionalGeneration, + Emu3ImageProcessor, + Emu3Processor, + Emu3TextConfig, + GenerationConfig, +) +from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode + + +""" +Sample usage: + +``` +python src/transformers/models/emu3/convert_emu3_weights_to_hf.py \ + --vq_model_id BAAI/Emu3-VisionTokenizer --llm_model_id BAAI/Emu3-Chat --output_dir /output/path +``` + +Thereafter, models can be loaded via: + +```py +from transformers import Emu3ForConditionalGeneration, Emu3Processor + +model = Emu3ForConditionalGeneration.from_pretrained("/output/path") +processor = Emu3Processor.from_pretrained("/output/path") +``` + +""" + + +byte_encoder = bytes_to_unicode() +CHAT_TEMPLATE = "{% for message in messages %}{% if message['role'] != 'system' %}{{ message['role'].upper() + ': '}}{% endif %}{# Render all images first #}{% for content in message['content'] | selectattr('type', 'equalto', 'image') %}{{ '' }}{% endfor %}{# Render all text next #}{% if message['role'] != 'assistant' %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{{ content['text'] + ' '}}{% endfor %}{% else %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{% generation %}{{ content['text'] + ' '}}{% endgeneration %}{% endfor %}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'ASSISTANT:' }}{% endif %}" + + +# Tiktoken to HF conversion, thanks for Xenova +def token_bytes_to_string(b): + return "".join([byte_encoder[ord(char)] for char in b.decode("latin-1")]) + + +# Adapted from https://github.com/openai/tiktoken/issues/60#issuecomment-1499977960 +def bpe(mergeable_ranks: Dict[bytes, int], token: bytes, max_rank: Optional[int] = None): + parts = [bytes([b]) for b in token] + while True: + min_idx = None + min_rank = None + for i, pair in enumerate(zip(parts[:-1], parts[1:])): + rank = mergeable_ranks.get(pair[0] + pair[1]) + if rank is not None and (min_rank is None or rank < min_rank): + min_idx = i + min_rank = rank + if min_rank is None or (max_rank is not None and min_rank >= max_rank): + break + assert min_idx is not None + parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2 :] + return parts + + +def generate_vocab_and_merges(encoder): + mergeable_ranks = encoder._mergeable_ranks + + merges = [] + vocab = {} + for token, rank in mergeable_ranks.items(): + vocab[token_bytes_to_string(token)] = rank + + if len(token) == 1: + continue + merged = tuple(bpe(mergeable_ranks, token, max_rank=rank)) + assert len(merged) == 2 + merges.append(" ".join(map(token_bytes_to_string, merged))) + + # Also add special tokens + vocab.update(encoder._special_tokens) + return vocab, merges + + +def convert_tiktoken(tokenizer, output_dir): + encoder = tokenizer.tokenizer + vocab, merges = generate_vocab_and_merges(encoder) + added_tokens = [ + { + "id": id, + "content": content, + "single_word": False, + "lstrip": False, + "rstrip": False, + "normalized": False, + "special": True, + } + for content, id in encoder._special_tokens.items() + if content != "<|extra_0|>" + ] + + # https://huggingface.co/Xenova/gpt2/raw/main/tokenizer_config.json + tokenizer_config_template = { + "add_prefix_space": False, + "bos_token": "<|extra_203|>", + "clean_up_tokenization_spaces": False, + "eos_token": "<|extra_204|>", + "pad_token": "<|endoftext|>", + } + tokenizer_config_template.update({"tokenizer_class": "GPT2Tokenizer"}) + tokenizer_config_template = dict(sorted(tokenizer_config_template.items(), key=lambda x: x[0])) + + # add placeholder image token by taking one of the reserved tokens + reserved_token_id = vocab["<|extra_0|>"] + vocab[""] = reserved_token_id + del vocab["<|extra_0|>"] + added_tokens.append( + { + "id": reserved_token_id, + "content": "", + "single_word": False, + "lstrip": False, + "rstrip": False, + "normalized": False, + "special": True, + } + ) + + os.makedirs(output_dir, exist_ok=True) + + pre_tokenizer = { + "type": "ByteLevel", + "add_prefix_space": False, + "trim_offsets": True, + "use_regex": True, + } + + # https://huggingface.co/Xenova/gpt2/raw/main/tokenizer.json + tokenizer_template = { + "version": "1.0", + "truncation": None, + "padding": None, + "added_tokens": added_tokens, + "normalizer": None, + "pre_tokenizer": pre_tokenizer, + "post_processor": None, + "decoder": { + "type": "ByteLevel", + "add_prefix_space": True, + "trim_offsets": True, + "use_regex": True, + }, + "model": { + "type": "BPE", + "dropout": None, + "unk_token": None, + "continuing_subword_prefix": "", + "end_of_word_suffix": "", + "fuse_unk": False, + "byte_fallback": False, + "vocab": vocab, + "merges": merges, + }, + } + + # Save to files + with open(os.path.join(output_dir, "vocab.json"), "w", encoding="utf-8") as fp: + json.dump(vocab, fp, indent=2, ensure_ascii=False) + + with open(os.path.join(output_dir, "tokenizer.json"), "w", encoding="utf-8") as fp: + json.dump(tokenizer_template, fp, indent=2, ensure_ascii=False) + + with open(os.path.join(output_dir, "tokenizer_config.json"), "w", encoding="utf-8") as fp: + json.dump(tokenizer_config_template, fp, indent=2, ensure_ascii=False) + + with open(os.path.join(output_dir, "special_tokens_map.json"), "w", encoding="utf-8") as fp: + json.dump( + { + "bos_token": "<|extra_203|>", + "eos_token": "<|extra_204|>", + "pad_token": "<|endoftext|>", + }, + fp, + indent=2, + ensure_ascii=False, + ) + + with open(os.path.join(output_dir, "merges.txt"), "w", encoding="utf-8") as fp: + fp.write("#version: 0.2\n") + fp.write("\n".join(merges)) + + +KEYS_TO_MODIFY_MAPPING = { + "^encoder": "model.vqmodel.encoder", + "^decoder": "model.vqmodel.decoder", + "^post_quant_conv": "model.vqmodel.post_quant_conv", + "^quant_conv": "model.vqmodel.quant_conv", + "^quantize": "model.vqmodel.quantize", + "^model": "text_model.model", + r"lm_head\.weight": "text_model.lm_head.weight", + r"^text_model\.model\.vqmodel": "vqmodel", + # rename QKV proj for the VQ-VAE model because we use SiglipAttention + r"\.q\.": ".q_proj.", + r"\.k\.": ".k_proj.", + r"\.v\.": ".v_proj.", + r"\.proj_out\.": ".out_proj.", + # move the attention norms outside of attention modules + r"mid\.attn_1\.norm\.": "mid.attn_norm.", + r"attn\.0\.norm\.": "attn_norms.0.", + r"attn\.1\.norm\.": "attn_norms.1.", + r"attn\.2\.norm\.": "attn_norms.2.", + r"attn\.3\.norm\.": "attn_norms.3.", + # isolate down/mid/up into separate classes for readability + r"\.down\.": ".down_block.down.", + r"\.up\.": ".up_block.up.", + r"\.mid\.": ".middle_block.", +} + + +def convert_state_dict_to_hf(old_state_dict, new_state_dict): + for key, value in old_state_dict.items(): + # convert conv layers in attn to linear + if ( + any(key.endswith(name) for name in ["q.weight", "k.weight", "v.weight", "proj_out.weight"]) + and value.ndim == 4 + ): + value = value.squeeze() + + for old_pattern, new_pattern in KEYS_TO_MODIFY_MAPPING.items(): + key = re.sub(old_pattern, new_pattern, key) + + new_state_dict[key] = value + return new_state_dict + + +def convert_model(vq_model_id, llm_model_id, output_dir, hub_model_id=None, test_inference=False): + os.makedirs(output_dir, exist_ok=True) + + # Convert and save processor + tokenizer_tiktoken = AutoTokenizer.from_pretrained(llm_model_id, trust_remote_code=True) + convert_tiktoken(tokenizer_tiktoken, output_dir) + extra_special_tokens = extra_special_tokens = { + "image_token": "", + "boi_token": "<|image start|>", + "eoi_token": "<|image end|>", + "image_wrapper_token": "<|image token|>", + "eof_token": "<|extra_201|>", + } + tokenizer_converted = AutoTokenizer.from_pretrained(output_dir, extra_special_tokens=extra_special_tokens) + tokenizer_converted.padding_side = "left" + + image_processor = Emu3ImageProcessor.from_pretrained(vq_model_id) + processor = Emu3Processor(image_processor, tokenizer_converted, chat_template=CHAT_TEMPLATE) + processor.save_pretrained(output_dir) + + # load models + model_llm = AutoModelForCausalLM.from_pretrained( + llm_model_id, + trust_remote_code=True, + ) + model_vqgan = AutoModel.from_pretrained(vq_model_id, trust_remote_code=True) + with open(f"{output_dir}/tokenizer.json", "r") as file: + tokenizer_config = json.load(file) + vocabulary_map = tokenizer_config["model"]["vocab"] + + text_config = Emu3TextConfig( + max_position_embeddings=model_llm.config.max_position_embeddings, + rope_scaling={"rope_type": "default"}, + ) + config = Emu3Config(text_config=text_config, vocabulary_map=vocabulary_map) + + with init_empty_weights(): + model = Emu3ForConditionalGeneration(config=config) + model.generation_config = GenerationConfig( + do_sample=True, + top_k=2048, + max_new_tokens=50_000, + pad_token_id=processor.tokenizer.pad_token_id, + eos_token_id=processor.tokenizer.eos_token_id, + ) + + state_dict = {} + state_dict = convert_state_dict_to_hf(model_llm.state_dict(), state_dict) + state_dict = convert_state_dict_to_hf(model_vqgan.state_dict(), state_dict) + + model.load_state_dict(state_dict, assign=True, strict=True) + model.save_pretrained(output_dir, safe_serialization=True) + + if hub_model_id is not None: + model.push_to_hub(hub_model_id) + processor.push_to_hub(hub_model_id) + + if test_inference and llm_model_id.endswith("Chat"): + # Short inference on a few examples to check if generation makes sense + print("Loading the checkpoint in a Emu3 model...") + print("*" * 100) + model = Emu3ForConditionalGeneration.from_pretrained(output_dir, torch_dtype=torch.bfloat16, device_map="auto") + processor = Emu3Processor.from_pretrained(output_dir) + + conversation = [ + { + "role": "system", + "content": [ + {"type": "text", "text": "You are a helpful assistant."}, + ], + }, + { + "role": "user", + "content": [ + {"type": "text", "text": "Please tell me about this art work and its artist."}, + {"type": "image"}, + ], + }, + ] + prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) + + image = Image.open( + requests.get( + "https://uploads4.wikiart.org/images/paul-klee/death-for-the-idea-1915.jpg!Large.jpg", stream=True + ).raw + ) + inputs = processor(images=image, text=prompt, return_tensors="pt").to(model.device, torch.bfloat16) + length = inputs.input_ids.shape[1] + + out = model.generate(**inputs, max_new_tokens=40, do_sample=False) + generated_text = processor.batch_decode(out[:, length:], skip_special_tokens=True)[0] + + print(f"Generation for single-image: {generated_text}") + print("*" * 100) + elif test_inference and llm_model_id.endswith("Gen"): + processor = Emu3Processor.from_pretrained(output_dir) + model = Emu3ForConditionalGeneration.from_pretrained(output_dir, torch_dtype=torch.bfloat16, device_map="auto") + + inputs = processor( + text=[ + "a portrait of young girl. masterpiece, film grained, best quality.", + "a dog running under the rain", + ], + padding=True, + return_tensors="pt", + return_for_image_generation=True, + ) + inputs = inputs.to(device="cuda:0", dtype=torch.bfloat16) + + neg_prompt = "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry." + neg_inputs = processor(text=[neg_prompt] * 2, return_tensors="pt").to(device="cuda:0") + + image_sizes = inputs.pop("image_sizes") + HEIGHT, WIDTH = image_sizes[0] + VISUAL_TOKENS = model.vocabulary_mapping.image_tokens + + def prefix_allowed_tokens_fn(batch_id, input_ids): + height, width = HEIGHT, WIDTH + visual_tokens = VISUAL_TOKENS + image_token_id = processor.tokenizer.encode("<|image token|>", return_tensors="pt")[0].to(model.device) + eoi_token_id = processor.tokenizer.encode("<|image end|>", return_tensors="pt")[0] + eos_token_id = processor.tokenizer.encode("<|extra_204|>", return_tensors="pt")[0] + pad_token_id = processor.tokenizer.encode("<|endoftext|>", return_tensors="pt")[0] + eol_token_id = processor.tokenizer.encode("<|extra_200|>", return_tensors="pt")[0] + eof_token_id = processor.tokenizer.encode("<|extra_201|>", return_tensors="pt")[0] + + position = torch.nonzero(input_ids == image_token_id, as_tuple=True)[0][0] + offset = input_ids.shape[0] - position + if offset % (width + 1) == 0: + return (eol_token_id,) + elif offset == (width + 1) * height + 1: + return (eof_token_id,) + elif offset == (width + 1) * height + 2: + return (eoi_token_id,) + elif offset == (width + 1) * height + 3: + return (eos_token_id,) + elif offset > (width + 1) * height + 3: + return (pad_token_id,) + else: + return visual_tokens + + out = model.generate( + **inputs, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + negative_prompt_ids=neg_inputs.input_ids, + negative_prompt_attention_mask=neg_inputs.attention_mask, + ) + + image = model.decode_image_tokens(out[:, inputs.input_ids.shape[1] :], height=HEIGHT, width=WIDTH) + images = processor.postprocess( + list(image.float()), return_tensors="PIL.Image.Image" + ) # internally we convert to np but it's not supported in bf16 precision + for i, image in enumerate(images["pixel_values"]): + image.save(f"result_{i}.png") + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--vq_model_id", + help="Model ID of Emu3 VQ-VAE on the hub", + default="BAAI/Emu3-VisionTokenizer", + ) + parser.add_argument( + "--llm_model_id", + help="Model ID of Emu3 bacbone LLM on the hub", + default="BAAI/Emu3-Chat", + ) + parser.add_argument( + "--output_dir", + help="Location to write HF model", + ) + parser.add_argument( + "--hub_model_id", + help="Model ID in the hub where to push the model.", + ) + parser.add_argument( + "--test_inference", + action="store_true", + help="Whether to load the model for generation to test it's converted correctly.", + ) + args = parser.parse_args() + convert_model( + vq_model_id=args.vq_model_id, + llm_model_id=args.llm_model_id, + output_dir=args.output_dir, + hub_model_id=args.hub_model_id, + test_inference=args.test_inference, + ) + + +if __name__ == "__main__": + main() diff --git a/docs/transformers/build/lib/transformers/models/emu3/image_processing_emu3.py b/docs/transformers/build/lib/transformers/models/emu3/image_processing_emu3.py new file mode 100644 index 0000000000000000000000000000000000000000..a63269c99ef12e0303fe2c7a0bcf93f4918eb459 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/emu3/image_processing_emu3.py @@ -0,0 +1,552 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Dict, Iterable, List, Optional, Union + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature +from ...image_transforms import convert_to_rgb, pad, resize, to_channel_dimension_format +from ...image_utils import ( + OPENAI_CLIP_MEAN, + OPENAI_CLIP_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + VideoInput, + get_image_size, + infer_channel_dimension_format, + is_scaled_image, + is_valid_image, + make_list_of_images, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...utils import TensorType, is_vision_available, logging + + +if is_vision_available(): + from PIL import Image + +logger = logging.get_logger(__name__) + + +def make_batched_images(images) -> List[List[ImageInput]]: + """ + Accepts images in list or nested list format, and makes a list of images for preprocessing. + + Args: + images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`): + The input image. + + Returns: + list: A list of images. + """ + if isinstance(images, (list, tuple)) and isinstance(images[0], (list, tuple)) and is_valid_image(images[0][0]): + return [img for img_list in images for img in img_list] + + elif isinstance(images, (list, tuple)) and is_valid_image(images[0]): + return images + + elif is_valid_image(images): + return [images] + + raise ValueError(f"Could not make batched images from {images}") + + +def smart_resize( + height: int, width: int, factor: int = 28, min_pixels: int = 56 * 56, max_pixels: int = 14 * 14 * 4 * 1280 +): + """Rescales the image so that the following conditions are met: + + 1. Both dimensions (height and width) are divisible by 'factor'. + + 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. + + 3. The aspect ratio of the image is maintained as closely as possible. + + """ + if height < factor or width < factor: + raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}") + elif max(height, width) / min(height, width) > 200: + raise ValueError( + f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}" + ) + h_bar = round(height / factor) * factor + w_bar = round(width / factor) * factor + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = math.floor(height / beta / factor) * factor + w_bar = math.floor(width / beta / factor) * factor + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = math.ceil(height * beta / factor) * factor + w_bar = math.ceil(width * beta / factor) * factor + return h_bar, w_bar + + +class Emu3ImageProcessor(BaseImageProcessor): + r""" + Constructs a Emu3 image processor that dynamically resizes images based on the original images. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): + Resampling filter to use when resizing the image. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`): + Mean to use if normalizing the image. This is a float or list of floats for each channel in the image. + image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`): + Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. + do_pad (`bool`, *optional*, defaults to `True`): + Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest + number of patches in the batch. Padding will be applied to the bottom and right with zeros. + min_pixels (`int`, *optional*, defaults to `512 * 512`): + The min pixels of the image to resize the image. + max_pixels (`int`, *optional*, defaults to `1024 * 1024`): + The max pixels of the image to resize the image. + spatial_factor (`int`, *optional*, defaults to 8): + The spatial downsample factor the image will be downsampled in feature extracting phase + """ + + model_input_names = ["pixel_values", "image_sizes"] + + def __init__( + self, + do_resize: bool = True, + resample: PILImageResampling = PILImageResampling.BICUBIC, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_convert_rgb: bool = True, + do_pad: bool = True, + min_pixels: int = 512 * 512, + max_pixels: int = 1024 * 1024, + spatial_factor: int = 8, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.do_resize = do_resize + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN + self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD + self.min_pixels = min_pixels + self.max_pixels = max_pixels + self.spatial_factor = spatial_factor + self.size = {"min_pixels": min_pixels, "max_pixels": max_pixels} + self.do_convert_rgb = do_convert_rgb + + def _preprocess( + self, + images: Union[ImageInput, VideoInput], + do_resize: Optional[bool] = None, + resample: PILImageResampling = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_convert_rgb: Optional[bool] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image or batch of images to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`. + vision_info (`List[Dict]`, *optional*): + Optional list of dictionaries containing additional information about vision inputs. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + resample (`PILImageResampling`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the `PILImageResampling` enums. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Scale factor to use if rescaling the image. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + images = make_list_of_images(images) + + if do_convert_rgb: + images = [convert_to_rgb(image) for image in images] + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if is_scaled_image(images[0]) and do_rescale: + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + height, width = get_image_size(images[0], channel_dim=input_data_format) + resized_height, resized_width = height, width + processed_images = [] + for image in images: + if do_resize: + resized_height, resized_width = smart_resize( + height, + width, + factor=self.spatial_factor, + min_pixels=self.min_pixels, + max_pixels=self.max_pixels, + ) + image = resize( + image, size=(resized_height, resized_width), resample=resample, input_data_format=input_data_format + ) + + if do_rescale: + image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format) + + if do_normalize: + image = self.normalize( + image=image, mean=image_mean, std=image_std, input_data_format=input_data_format + ) + + image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + processed_images.append(image) + + images = np.array(processed_images) + return images + + def _pad_for_batching( + self, + pixel_values: List[np.ndarray], + image_sizes: List[List[int]], + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + """ + Pads images on the `num_of_patches` dimension with zeros to form a batch of same number of patches. + + Args: + pixel_values (`List[np.ndarray]`): + An array of pixel values of each images of shape (`batch_size`, `num_patches`, `image_in_3D`) + image_sizes (`List[List[int]]`): + A list of sizes for each image in `pixel_values` in (height, width) format. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + If unset, will use same as the input image. + input_data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format for the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + If unset, will use the inferred format of the input image. + + Returns: + List[`np.ndarray`]: The padded images. + """ + + max_shape = ( + max([size[0] for size in image_sizes]), + max([size[1] for size in image_sizes]), + ) + pixel_values = [ + pad( + image, + padding=((0, max_shape[0] - size[0]), (0, max_shape[1] - size[1])), + data_format=data_format, + input_data_format=input_data_format, + ) + for image, size in zip(pixel_values, image_sizes) + ] + return pixel_values + + def preprocess( + self, + images: ImageInput, + do_resize: Optional[bool] = None, + size: Dict[str, int] = None, + resample: PILImageResampling = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_convert_rgb: Optional[bool] = None, + do_pad: bool = True, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + """ + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with + the longest edge resized to keep the input aspect ratio. + resample (`int`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only + has an effect if `do_resize` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to + `True`. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + do_pad (`bool`, *optional*, defaults to `True`): + Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest + number of patches in the batch. Padding will be applied to the bottom and right with zeros. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + + """ + do_resize = do_resize if do_resize is not None else self.do_resize + size = size if size is not None else self.size + resample = resample if resample is not None else self.resample + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + do_pad = do_pad if do_pad is not None else self.do_pad + + if images is not None: + images = make_batched_images(images) + + if images is not None and not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + validate_preprocess_arguments( + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_resize=do_resize, + size=size, + resample=resample, + ) + + pixel_values = [] + for image in images: + image = self._preprocess( + image, + do_resize=do_resize, + resample=resample, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + data_format=data_format, + do_convert_rgb=do_convert_rgb, + input_data_format=input_data_format, + ) + pixel_values.extend(image) + + image_sizes = [image.shape[-2:] for image in pixel_values] + if do_pad: + pixel_values = self._pad_for_batching(pixel_values, image_sizes) + pixel_values = np.array(pixel_values) + + return BatchFeature( + data={"pixel_values": pixel_values, "image_sizes": image_sizes}, tensor_type=return_tensors + ) + + def postprocess( + self, + images: ImageInput, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + return_tensors: Union[str, TensorType] = "PIL.Image.Image", + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + """ + Postprocess an image or batch of images tensor. Postprocess is the reverse process of preprocess. + The parameters should be same as in preprocess. + Args: + images (`ImageInput`): + Image to postprocess. Expects a single or batch of images with pixel values ranging from -1 to 1. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to `True`. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = 1.0 / self.rescale_factor if rescale_factor is None else rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + + images = make_list_of_images(images) + if isinstance(images[0], Image.Image): + return images if len(images) > 1 else images[0] + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + pixel_values = [] + for image in images: + image = to_numpy_array(image) + if do_normalize: + image = self.unnormalize( + image=image, image_mean=image_mean, image_std=image_std, input_data_format=input_data_format + ) + + if do_rescale: + image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format) + image = image.clip(0, 255).astype(np.uint8) + + if do_normalize and do_rescale and return_tensors == "PIL.Image.Image": + image = to_channel_dimension_format(image, ChannelDimension.LAST, input_channel_dim=input_data_format) + pixel_values.append(Image.fromarray(image)) + else: + pixel_values.extend(image) + + data = {"pixel_values": pixel_values} + return_tensors = return_tensors if return_tensors != "PIL.Image.Image" else None + + return BatchFeature(data=data, tensor_type=return_tensors) + + def unnormalize( + self, + image: np.array, + image_mean: Union[float, Iterable[float]], + image_std: Union[float, Iterable[float]], + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> np.array: + """ + Unnormalizes `image` using the mean and standard deviation specified by `mean` and `std`. + image = (image * image_std) + image_mean + Args: + image (`torch.Tensor` of shape `(batch_size, num_channels, image_size, image_size)` or `(num_channels, image_size, image_size)`): + Batch of pixel values to postprocess. + image_mean (`float` or `Iterable[float]`): + The mean to use for unnormalization. + image_std (`float` or `Iterable[float]`): + The standard deviation to use for unnormalization. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + num_channels = 3 + + if isinstance(image_mean, Iterable): + if len(image_mean) != num_channels: + raise ValueError(f"mean must have {num_channels} elements if it is an iterable, got {len(image_mean)}") + else: + image_mean = [image_mean] * num_channels + + if isinstance(image_std, Iterable): + if len(image_std) != num_channels: + raise ValueError(f"std must have {num_channels} elements if it is an iterable, got {len(image_std)}") + else: + image_std = [image_std] * num_channels + + rev_image_mean = tuple(-mean / std for mean, std in zip(image_mean, image_std)) + rev_image_std = tuple(1 / std for std in image_std) + image = self.normalize( + image=image, mean=rev_image_mean, std=rev_image_std, input_data_format=input_data_format + ) + return image + + +__all__ = ["Emu3ImageProcessor"] diff --git a/docs/transformers/build/lib/transformers/models/emu3/modeling_emu3.py b/docs/transformers/build/lib/transformers/models/emu3/modeling_emu3.py new file mode 100644 index 0000000000000000000000000000000000000000..375b1beb230ce89ef388f0f7b521bb1d28acec00 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/emu3/modeling_emu3.py @@ -0,0 +1,1976 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/emu3/modular_emu3.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_emu3.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from functools import cached_property +from typing import Callable, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ( + LossKwargs, + add_start_docstrings, + add_start_docstrings_to_model_forward, + can_return_tuple, + is_torch_flex_attn_available, + logging, + replace_return_docstrings, +) +from .configuration_emu3 import Emu3Config, Emu3TextConfig, Emu3VQVAEConfig + + +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + +logger = logging.get_logger(__name__) + + +_CONFIG_FOR_DOC = "Emu3Config" + + +@use_kernel_forward_from_hub("RMSNorm") +class Emu3RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Emu3RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class Emu3MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class Emu3Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Emu3Config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Emu3DecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Emu3Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = Emu3Attention(config=config, layer_idx=layer_idx) + + self.mlp = Emu3MLP(config) + self.input_layernorm = Emu3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Emu3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.dropout = nn.Dropout(config.attention_dropout) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + self.dropout(hidden_states) + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + self.dropout(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +class Emu3VQVAEVectorQuantizer(nn.Module): + """ + A module for vector quantization using learned embedding vectors. + + This module implements the quantization process similar to te one described in + the VQ-VAE (Vector Quantized Variational AutoEncoder) paper. It quantizes continuous + input vectors into discrete codebook vectors, which are learned during training. + Current implementation improves over previous ones by avoiding costly matrix multiplications + and allowing for post-hoc remapping of indices. + """ + + def __init__(self, config: Emu3VQVAEConfig): + super().__init__() + self.embedding = nn.Embedding(config.codebook_size, config.embed_dim) + self.embedding.weight.data.uniform_(-1.0 / config.codebook_size, 1.0 / config.codebook_size) + + def forward(self, hidden_state: torch.Tensor): + batch_size, temporal, channels, height, width = hidden_state.shape + hidden_state = hidden_state.permute(0, 1, 3, 4, 2).contiguous() + hidden_state_flattened = hidden_state.view(-1, channels) + + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + hidden_state_sum = torch.sum(hidden_state_flattened**2, dim=1, keepdim=True) + embedding_sum = torch.sum(self.embedding.weight**2, dim=1) + + # "bd,dn->bn", + distances = 2 * torch.matmul(hidden_state_flattened, self.embedding.weight.transpose(0, 1)) + distances = hidden_state_sum + embedding_sum - distances + + min_encoding_indices = torch.argmin(distances, dim=1) + min_encoding_indices = min_encoding_indices.view(batch_size, temporal, height, width) + return min_encoding_indices + + +class Emu3VQVAEEncoderConvDownsample(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, hidden_states): + # no asymmetric padding in torch conv, must do it ourselves + hidden_states = F.pad(hidden_states, pad=(0, 1, 0, 1), mode="constant", value=0) + hidden_states = self.conv(hidden_states) + return hidden_states + + +class Emu3VQVAEEncoderConvUpsample(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, hidden_states): + hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") + hidden_states = self.conv(hidden_states) + return hidden_states + + +class Emu3VQVAEConv3d(nn.Module): + def __init__( + self, + in_channel: int, + out_channel: int, + kernel_size: Tuple[int], + stride: Tuple[int], + ): + super().__init__() + + padding_sizes = [one_kernel - one_stride for one_kernel, one_stride in zip(kernel_size[1:], stride[1:])] + self.padding = () + for pad_size in padding_sizes[::-1]: + self.padding += (pad_size // 2 + pad_size % 2, pad_size // 2) + self.padding += (2, 0) + + self.conv = nn.Conv3d( + in_channel, + out_channel, + kernel_size, + stride=stride, + ) + + def forward(self, hidden_states: torch.Tensor): + hidden_states = F.pad(hidden_states, self.padding) + hidden_states = self.conv(hidden_states) + return hidden_states + + +class Emu3VQVAESpatialNorm(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + ): + super().__init__() + self.norm_layer = nn.GroupNorm( + num_channels=out_channels, + num_groups=32, + eps=1e-6, + affine=True, + ) + + self.conv_y = nn.Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + ) + self.conv_b = nn.Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + ) + + def forward(self, hidden_states: torch.Tensor, quant_states: torch.Tensor): + quant_states = F.interpolate(quant_states, size=hidden_states.shape[-2:], mode="nearest") + hidden_states = self.norm_layer(hidden_states) + hidden_states = hidden_states * self.conv_y(quant_states) + self.conv_b(quant_states) + return hidden_states + + +class Emu3VQVAETemporalUpsample(nn.Module): + def __init__( + self, + in_channel: int, + out_channel: int, + ): + super().__init__() + self.conv = Emu3VQVAEConv3d( + in_channel, + out_channel, + kernel_size=(3, 3, 3), + stride=(1, 1, 1), + ) + + def forward(self, hidden_states: torch.Tensor): + batch_size, channels, temporal, height, width = hidden_states.shape + hidden_states = hidden_states.permute(0, 1, 3, 4, 2).contiguous().view(batch_size, -1, temporal) + hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") + hidden_states = hidden_states.view(batch_size, channels, height, width, -1).permute(0, 1, 4, 2, 3).contiguous() + hidden_states = self.conv(hidden_states) + return hidden_states + + +class Emu3VQVAETemporalDownsample(nn.Module): + def __init__( + self, + in_channel: int, + out_channel: int, + ): + super().__init__() + self.conv = Emu3VQVAEConv3d( + in_channel, + out_channel, + kernel_size=(4, 3, 3), + stride=(2, 1, 1), + ) + + def forward(self, hidden_states: torch.Tensor): + hidden_states = self.conv(hidden_states) + return hidden_states + + +class Emu3VQVAETemporalResnetBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels=None, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + + self.norm1 = nn.BatchNorm3d(in_channels) + self.conv1 = Emu3VQVAEConv3d( + in_channels, + out_channels, + kernel_size=(3, 3, 3), + stride=(1, 1, 1), + ) + self.norm2 = nn.BatchNorm3d(out_channels) + self.conv2 = Emu3VQVAEConv3d( + out_channels, + out_channels, + kernel_size=(3, 3, 3), + stride=(1, 1, 1), + ) + if self.in_channels != self.out_channels: + self.nin_shortcut = nn.Conv3d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + ) + + def forward(self, hidden_states): + residual = hidden_states + hidden_states = self.norm1(hidden_states) + hidden_states *= torch.sigmoid(hidden_states) + hidden_states = self.conv1(hidden_states) + + hidden_states = self.norm2(hidden_states) + hidden_states *= torch.sigmoid(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.in_channels != self.out_channels: + residual = self.nin_shortcut(residual) + + return residual + hidden_states + + +class Emu3VQVAEResnetBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + quant_channels: Optional[int] = None, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.quant_channels = quant_channels + + if quant_channels is None: + self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=32, eps=1e-6, affine=True) + self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=32, eps=1e-6, affine=True) + else: + self.norm1 = Emu3VQVAESpatialNorm(quant_channels, in_channels) + self.norm2 = Emu3VQVAESpatialNorm(quant_channels, out_channels) + + self.conv1 = nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + self.conv2 = nn.Conv2d( + out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + if self.in_channels != self.out_channels: + self.nin_shortcut = nn.Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + ) + + def forward(self, hidden_states: torch.Tensor, quant_channels: Optional[torch.Tensor] = None): + norm_args = () if self.quant_channels is None else (quant_channels,) + + residual = hidden_states + hidden_states = self.norm1(hidden_states, *norm_args) + hidden_states *= torch.sigmoid(hidden_states) + hidden_states = self.conv1(hidden_states) + + hidden_states = self.norm2(hidden_states, *norm_args) + hidden_states *= torch.sigmoid(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.in_channels != self.out_channels: + residual = self.nin_shortcut(residual) + + return residual + hidden_states + + +class Emu3VQVAEAttentionBlock(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Emu3VQVAEConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + self.is_causal = False + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + + # for compatibility with the attention interface + self.num_key_value_groups = 1 + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Input shape: Batch x Time x Channel""" + + batch_size, seq_length, embed_dim = hidden_states.shape + + queries = self.q_proj(hidden_states) + keys = self.k_proj(hidden_states) + values = self.v_proj(hidden_states) + + queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + queries, + keys, + values, + attention_mask, + is_causal=self.is_causal, + scaling=self.scale, + dropout=0.0 if not self.training else self.dropout, + ) + + attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous() + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights + + +class Emu3VQVAEGroupNorm(nn.GroupNorm): + """ + Same as the torch GroupNorm with the only difference that this ones accepts + an optional kwarg `quant_states` which is not used. This class makes it easier to + use SpatialNorm or GroupNorm without conditionals + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def forward(self, input, quant_states=None): + return F.group_norm(input, self.num_groups, self.weight, self.bias, self.eps) + + +class Emu3VQVAEMiddleBlock(nn.Module): + def __init__(self, config, in_channels, quant_channels=None): + super().__init__() + + self.block_1 = Emu3VQVAEResnetBlock( + in_channels=in_channels, + out_channels=in_channels, + quant_channels=quant_channels, + ) + self.attn_1 = Emu3VQVAEAttentionBlock(config) + if quant_channels is None: + self.attn_norm = Emu3VQVAEGroupNorm(num_channels=in_channels, num_groups=32, eps=1e-6, affine=True) + else: + self.attn_norm = Emu3VQVAESpatialNorm(quant_channels, in_channels) + + self.block_2 = Emu3VQVAEResnetBlock( + in_channels=in_channels, + out_channels=in_channels, + quant_channels=quant_channels, + ) + + def forward(self, hidden_states: torch.FloatTensor, quant_states: Optional[torch.FloatTensor] = None): + hidden_states = self.block_1(hidden_states, quant_states) + residual = hidden_states + hidden_states = self.attn_norm(hidden_states, quant_states) + batch_size, channels, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channels, height * width).transpose(1, 2) + hidden_states = self.attn_1(hidden_states)[0] + hidden_states = hidden_states.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2) + hidden_states = residual + hidden_states + hidden_states = self.block_2(hidden_states, quant_states) + return hidden_states + + +class Emu3VQVAEDownBlock(nn.Module): + def __init__(self, config): + super().__init__() + + self.num_resolutions = len(config.channel_multiplier) + self.num_res_blocks = config.num_res_blocks + base_channels = config.base_channels + channel_multiplier = config.channel_multiplier + + in_channel_multiplier = (1,) + tuple(channel_multiplier) + self.in_channel_multiplier = in_channel_multiplier + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + attn_norms = nn.ModuleList() + block_in = base_channels * in_channel_multiplier[i_level] + block_out = base_channels * channel_multiplier[i_level] + for i_block in range(self.num_res_blocks): + block.append( + Emu3VQVAEResnetBlock( + in_channels=block_in, + out_channels=block_out, + ) + ) + block_in = block_out + if config.attn_resolutions is not None and i_level in config.attn_resolutions: + attn.append(Emu3VQVAEAttentionBlock(config)) + attn_norms.append(nn.GroupNorm(num_channels=block_in, num_groups=32, eps=1e-6, affine=True)) + + down = nn.Module() + down.block = block + down.attn = attn + down.attn_norms = attn_norms + if i_level != self.num_resolutions - 1: + down.downsample = Emu3VQVAEEncoderConvDownsample(block_in) + self.down.append(down) + + def forward(self, hidden_states: torch.FloatTensor): + for i_level, blocks in enumerate(self.down): + for i_block in range(self.num_res_blocks): + hidden_states = blocks.block[i_block](hidden_states) + if len(blocks.attn) > 0: + residual = hidden_states + hidden_states = blocks.attn_norms[i_block](hidden_states) + + batch_size, channels, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channels, height * width).transpose(1, 2) + hidden_states = blocks.attn[i_block](hidden_states)[0] + + hidden_states = hidden_states.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2) + hidden_states = residual + hidden_states + + if i_level != self.num_resolutions - 1: + hidden_states = blocks.downsample(hidden_states) + + return hidden_states + + +class Emu3VQVAEUpBlock(nn.Module): + def __init__(self, config): + super().__init__() + + self.num_resolutions = len(config.channel_multiplier) + self.num_res_blocks = config.num_res_blocks + + quant_channels = config.embed_dim + block_in = config.base_channels * config.channel_multiplier[-1] + + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + attn_norms = nn.ModuleList() + block_out = config.base_channels * config.channel_multiplier[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + Emu3VQVAEResnetBlock( + in_channels=block_in, + out_channels=block_out, + quant_channels=quant_channels, + ) + ) + block_in = block_out + if i_level in config.attn_resolutions: + attn.append(Emu3VQVAEAttentionBlock(config)) + attn_norms.append(Emu3VQVAESpatialNorm(quant_channels, block_in)) + + up = nn.Module() + up.block = block + up.attn = attn + up.attn_norms = attn_norms + if i_level != 0: + up.upsample = Emu3VQVAEEncoderConvUpsample(block_in) + + self.up.insert(0, up) + + def forward(self, hidden_states: torch.FloatTensor, quant_states: torch.FloatTensor): + for i_level, blocks in enumerate(self.up[::-1]): + for i_block in range(self.num_res_blocks + 1): + hidden_states = blocks.block[i_block](hidden_states, quant_states) + if len(blocks.attn) > 0: + residual = hidden_states + hidden_states = blocks.attn_norms[i_block](hidden_states, quant_states) + + batch_size, channels, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channels, height * width).transpose(1, 2) + hidden_states = blocks.attn[i_block](hidden_states)[0] + + hidden_states = hidden_states.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2) + hidden_states = residual + hidden_states + if i_level != len(self.up) - 1: + hidden_states = blocks.upsample(hidden_states) + + return hidden_states + + +class Emu3VQVAEEncoder(nn.Module): + def __init__(self, config): + super().__init__() + + base_channels = config.base_channels + in_channels = config.in_channels + double_latent = config.double_latent + latent_channels = config.latent_channels + channel_multiplier = config.channel_multiplier + out_channels = 2 * latent_channels if double_latent else latent_channels + block_in = base_channels * channel_multiplier[-1] + + self.conv_in = torch.nn.Conv2d(in_channels, base_channels, kernel_size=3, stride=1, padding=1) + self.down_block = Emu3VQVAEDownBlock(config) + self.middle_block = Emu3VQVAEMiddleBlock(config, block_in) + + self.norm_out = torch.nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) + self.conv_out = torch.nn.Conv2d( + block_in, + out_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + temporal_down_blocks = int(math.log2(config.temporal_downsample_factor)) + self.time_conv = nn.ModuleList() + self.time_res_stack = nn.ModuleList() + + for i in range(temporal_down_blocks): + conv = Emu3VQVAETemporalDownsample(out_channels, out_channels) + self.time_conv.append(conv) + + for _ in range(config.num_res_blocks): + time_res_conv = Emu3VQVAETemporalResnetBlock( + in_channels=out_channels, + out_channels=out_channels, + ) + self.time_res_stack.append(time_res_conv) + + def forward(self, pixel_values: torch.LongTensor): + temporal_dim = pixel_values.shape[1] + pixel_values = pixel_values.reshape(-1, *pixel_values.shape[2:]) + + # downsampling & middle + hidden_states = self.conv_in(pixel_values) + hidden_states = self.down_block(hidden_states) + hidden_states = self.middle_block(hidden_states) + + # end + hidden_states = self.norm_out(hidden_states) + hidden_states *= torch.sigmoid(hidden_states) + hidden_states = self.conv_out(hidden_states) + + hidden_states = hidden_states.reshape(-1, temporal_dim, *hidden_states.shape[1:]) + hidden_states = hidden_states.permute(0, 2, 1, 3, 4) + + # temporal convs + for conv in self.time_conv: + hidden_states = conv(hidden_states) + hidden_states *= torch.sigmoid(hidden_states) + + for layer in self.time_res_stack: + hidden_states = layer(hidden_states) + + hidden_states = hidden_states.permute(0, 2, 1, 3, 4) + + return hidden_states + + +class Emu3VQVAEDecoder(nn.Module): + def __init__(self, config: Emu3VQVAEConfig): + super().__init__() + + quant_channels = config.embed_dim + block_in = config.base_channels * config.channel_multiplier[-1] + self.time_res_stack = nn.ModuleList() + for _ in range(config.num_res_blocks): + time_res_conv = Emu3VQVAETemporalResnetBlock( + in_channels=config.latent_channels, out_channels=config.latent_channels + ) + self.time_res_stack.append(time_res_conv) + + temp_upsample_block_num = int(math.log2(config.temporal_downsample_factor)) + self.time_conv = nn.ModuleList() + for i in range(temp_upsample_block_num): + conv = Emu3VQVAETemporalUpsample(config.latent_channels, config.latent_channels) + self.time_conv.append(conv) + + self.conv_in = nn.Conv2d( + config.latent_channels, + block_in, + kernel_size=3, + stride=1, + padding=1, + ) + + self.middle_block = Emu3VQVAEMiddleBlock(config, block_in, quant_channels=quant_channels) + self.up_block = Emu3VQVAEUpBlock(config) + + block_in = config.base_channels * config.channel_multiplier[0] + self.norm_out = Emu3VQVAESpatialNorm(quant_channels, block_in) + self.conv_out = nn.Conv2d( + block_in, + config.out_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + def forward(self, hidden_states: torch.Tensor, quant_states: torch.Tensor): + hidden_quant_states = torch.cat((hidden_states, quant_states), dim=0) + hidden_quant_states = hidden_quant_states.permute(0, 2, 1, 3, 4) + + # temporal convs + for layer in self.time_res_stack: + hidden_quant_states = layer(hidden_quant_states) + + for layer in self.time_conv: + hidden_quant_states = layer(hidden_quant_states) + hidden_quant_states *= torch.sigmoid(hidden_quant_states) + + hidden_quant_states = hidden_quant_states.permute(0, 2, 1, 3, 4) + hidden_states, quant_states = torch.chunk(hidden_quant_states, 2, dim=0) + hidden_states = hidden_states.reshape(-1, *hidden_states.shape[2:]) + quant_states = quant_states.reshape(-1, *quant_states.shape[2:]) + + hidden_states = self.conv_in(hidden_states) + + # middle & upsampling + hidden_states = self.middle_block(hidden_states, quant_states) + hidden_states = self.up_block(hidden_states, quant_states) + + hidden_states = self.norm_out(hidden_states, quant_states) + hidden_states *= torch.sigmoid(hidden_states) + hidden_states = self.conv_out(hidden_states) + + return hidden_states + + +EMU3_VQ_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Emu3VQVAEConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + """The VQ-VAE model used in Emu3 for encoding/decoding images into discrete tokens. + This model follows the "Make-a-scene: Scene-based text-to-image generation with human priors" paper from + [ Oran Gafni, Adam Polyak, Oron Ashual, Shelly Sheynin, Devi Parikh, and Yaniv Taigman](https://arxiv.org/abs/2203.13131). + """, + EMU3_VQ_START_DOCSTRING, +) +class Emu3VQVAE(PreTrainedModel): + config_class = Emu3VQVAEConfig + base_model_prefix = "emuvideovq" + main_input_name = "pixel_values" + _supports_sdpa = True + _supports_flash_attn_2 = True + _supports_flex_attn = True + _no_split_modules = [ + "Emu3VQVAETemporalResnetBlock", + "Emu3VQVAEAttentionBlock", + "Emu3VQVAEResnetBlock", + "Emu3VQVAEVectorQuantizer", + ] + + def _init_weights(self, module): + if isinstance(module, (nn.Conv2d, nn.Conv3d)): + nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") + if module.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight) + bound = 1 / math.sqrt(fan_in) + nn.init.uniform_(module.bias, -bound, bound) + elif isinstance(module, nn.Linear): + nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5)) + if module.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + nn.init.uniform_(module.bias, -bound, bound) + elif isinstance(module, (nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm)): + nn.init.constant_(module.weight, 1.0) + nn.init.constant_(module.bias, 0.0) + elif isinstance(module, nn.Embedding): + module.weight.data.normal_() + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def __init__(self, config: Emu3VQVAEConfig): + super().__init__(config) + + self.config = config + + self.encoder = Emu3VQVAEEncoder(config) + self.decoder = Emu3VQVAEDecoder(config) + self.quantize = Emu3VQVAEVectorQuantizer(config) + self.vision_spatial_factor = 2 ** (len(config.channel_multiplier) - 1) + + self.quant_conv = Emu3VQVAEConv3d( + config.latent_channels, config.embed_dim, kernel_size=(3, 1, 1), stride=(1, 1, 1) + ) + self.post_quant_conv = Emu3VQVAEConv3d( + config.embed_dim, config.latent_channels, kernel_size=(3, 1, 1), stride=(1, 1, 1) + ) + self.spatial_scale_factor = 2 ** (len(config.channel_multiplier) - 1) + self.eval() # Emu3's VQ model is frozen + + self.post_init() + + def encode(self, pixel_values: torch.Tensor, image_sizes: torch.Tensor): + is_image = pixel_values.ndim == 4 + if is_image: + temporal = self.config.temporal_downsample_factor + batch_size, channels, height, width = pixel_values.shape + pixel_values = pixel_values.unsqueeze(1).repeat(1, temporal, 1, 1, 1) + else: + batch_size, temporal, channels, height, width = pixel_values.shape + + hidden_states = self.encoder(pixel_values) + + # b t c h w -> b c t h w + hidden_states = hidden_states.permute(0, 2, 1, 3, 4) + hidden_states = self.quant_conv(hidden_states) + + # b c t h w -> b t c h w + hidden_states = hidden_states.permute(0, 2, 1, 3, 4) + codes = self.quantize(hidden_states) + + image_tokens = codes.squeeze(1) if is_image else codes + + image_tokens = [ + single_image[: int(size[0] / self.vision_spatial_factor), : int(size[1] / self.vision_spatial_factor)] + for single_image, size in zip(image_tokens, image_sizes) + ] + + return image_tokens + + def decode(self, hidden_states: torch.Tensor): + is_image = hidden_states.ndim == 3 + if is_image: + hidden_states = hidden_states.unsqueeze(1) + + batch_size, temporal, height, width = hidden_states.shape + quant = self.quantize.embedding(hidden_states.flatten()) + + channels = quant.shape[-1] + quant = quant.view(batch_size, temporal, height, width, channels).permute(0, 4, 1, 2, 3).contiguous() + post_quant = self.post_quant_conv(quant) + + quant = quant.permute(0, 2, 1, 3, 4) + post_quant = post_quant.permute(0, 2, 1, 3, 4) + + video = self.decoder(post_quant, quant) + video = video.reshape( + batch_size, + temporal * self.config.temporal_downsample_factor, + self.config.out_channels, + height * self.spatial_scale_factor, + width * self.spatial_scale_factor, + ) + return video[:, 0] if is_image else video + + +class Emu3ImageVocabularyMapping: + """ + A class for mapping discrete image tokens from VQGAN to BPE tokens. + """ + + def __init__(self, vocab_map): + self.vocab_map = vocab_map + self.eol_token_id = vocab_map.get("<|extra_200|>") + self.image_token_id = vocab_map.get("") + + @cached_property + def image_tokens(self): + return sorted([val for name, val in self.vocab_map.items() if name.startswith("<|visual token")]) + + @cached_property + def image_tokens_str(self): + return sorted([name for name, val in self.vocab_map.items() if name.startswith("<|visual token")]) + + @cached_property + def img2bpe(self): + return {int(token[-8:-2]): self.vocab_map[token] for token in self.image_tokens_str} + + @cached_property + def bpe2img(self): + return {v: k for k, v in self.img2bpe.items()} + + @cached_property + def bpe2img_mapping_tensor(self): + mapping = torch.zeros(max(self.bpe2img.keys()) + 1, dtype=torch.int) + for k, v in self.bpe2img.items(): + mapping[k] = v + return mapping + + @cached_property + def img2bpe_mapping_tensor(self): + mapping = torch.zeros(max(self.img2bpe.keys()) + 1, dtype=torch.int) + for k, v in self.img2bpe.items(): + mapping[k] = v + return mapping + + def convert_img2bpe(self, img_batch: List[torch.Tensor]) -> torch.Tensor: + device = img_batch.device + eol_row = torch.ones((img_batch.shape[0], 1), dtype=torch.int) * self.eol_token_id + img_tokens = self.img2bpe_mapping_tensor[img_batch.to("cpu")] + img_tokens = torch.cat([img_tokens, eol_row], dim=-1) + return img_tokens.to(device) + + def convert_bpe2img(self, img_batch: torch.Tensor) -> torch.Tensor: + device = img_batch.device + img_batch = img_batch[..., :-1] # remove last row of EOL tokens + img_tokens = self.bpe2img_mapping_tensor[img_batch.to("cpu")] + return img_tokens.to(device) + + +EMU3_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Emu3Config`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare emu3 Model outputting raw hidden-states without any specific head on top.", + EMU3_START_DOCSTRING, +) +class Emu3PreTrainedModel(PreTrainedModel): + config_class = Emu3Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = [ + "Emu3DecoderLayer", + ] + _skip_keys_device_placement = ["past_key_values", "causal_mask"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_quantized_cache = True + _supports_cache_class = True + _supports_static_cache = True + _supports_param_buffer_assignment = False + _supports_flex_attn = True + + def _init_weights(self, module): + std = self.config.get_text_config().initializer_range + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, Emu3RMSNorm): # noqa: F821 + module.weight.data.fill_(1.0) + + +class Emu3RotaryEmbedding(nn.Module): + def __init__(self, config: Emu3Config, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +EMU3_TEXT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Has to be an instance of [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + The model will output the same cache type that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare Emu3Text Model outputting raw hidden-states without any specific head on top.", + EMU3_START_DOCSTRING, +) +class Emu3TextModel(Emu3PreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Emu3TextDecoderLayer`] + + Args: + config: Emu3TextConfig + """ + + def __init__(self, config: Emu3Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Emu3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Emu3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Emu3RotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @can_return_tuple + @add_start_docstrings_to_model_forward(EMU3_TEXT_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> BaseModelOutputWithPast: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache + if not isinstance(past_key_values, (type(None), Cache)): + raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, "BlockMask"], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to place the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + +class Emu3ForCausalLM(Emu3PreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + config_class = Emu3TextConfig + + def __init__(self, config): + super().__init__(config) + self.model = Emu3TextModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @can_return_tuple + @add_start_docstrings_to_model_forward(EMU3_TEXT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="Emu3TextConfig") + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[KwargsForCausalLM], + ) -> CausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from transformers import Emu3Processor, Emu3ForConditionalGeneration + >>> import torch + >>> import requests + >>> from PIL import Image + + >>> model = Emu3ForCausalLM.from_pretrained("BAAI/Emu3-Chat-hf", torch_dtype=torch.bfloat16) + >>> processor = Emu3Processor.from_pretrained("BAAI/Emu3-Chat-hf") + + >>> inputs = processor(text=["Can you write me a poem about winter."], return_tensors="pt").to(model.device) + + >>> generated_ids = model.generate(**inputs, max_new_tokens=100, do_sample=False) + >>> processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +EMU3_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + pixel_values (`torch.FloatTensor` of shape `(batch_size, max_num_images, max_num_tiles, channels, image_size, image_size)): + The tensors corresponding to the input images. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`Emu3ImageProcessor.__call__`] for details ([]`Emu3Processor`] uses + [`Emu3ImageProcessor`] for processing images). + image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`): + The sizes of the images in the batch, being (height, width) for each image. Image sizes can be obtained using + [`AutoImageProcessor`]. See [`Emu3ImageProcessor.__call__`] for details ([]`Emu3Processor`] uses + [`Emu3ImageProcessor`] for processing images). + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Has to be an instance of [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin): + _tied_weights_keys = ["text_model.lm_head.weight"] + _supports_static_cache = False # `get_image_tokens()`, called when `pixel_values` is passed, is not compileable + + def __init__(self, config): + super().__init__(config) + self.text_model = Emu3ForCausalLM._from_config(config.text_config) + self.vqmodel = Emu3VQVAE(config.vq_config) + self.vocabulary_mapping = Emu3ImageVocabularyMapping(config.vocabulary_map) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.text_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.text_model.set_input_embeddings(value) + + def get_image_tokens(self, pixel_values: torch.FloatTensor, image_sizes: torch.LongTensor): + """ + Tokenizes images into discrete tokens with VQGAN module. Converts + obtained image tokens into BPE tokens and wraps with "boi" and "eoi" + special tokens. + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input images. + image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`): + The sizes of the images in the batch, being (height, width) for each image. + """ + image_tokens_list = self.vqmodel.encode(pixel_values, image_sizes) + bpe_tokens_list = [self.vocabulary_mapping.convert_img2bpe(tokens).flatten() for tokens in image_tokens_list] + bpe_tokens = torch.cat(bpe_tokens_list) + return bpe_tokens + + @torch.no_grad + def decode_image_tokens(self, image_tokens: torch.LongTensor, height: int, width: int): + """ + Decodes generated image tokens from language model to continuous pixel values + with VQGAN module via upsampling. + + Args: + image_tokens (`torch.LongTensor` of shape `(batch_size, num_of_tokens)`): + The tensors corresponding to the input images. + height (`int`): + Height of the generated image before upsampling. + width (`int`): + Width of the generated image before upsampling. + """ + sequences = image_tokens[:, :-3].view(-1, height, width + 1) + image_tokens = self.vocabulary_mapping.convert_bpe2img(sequences) + image = self.vqmodel.decode(image_tokens) + return image + + @can_return_tuple + @add_start_docstrings_to_model_forward(EMU3_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + ) -> CausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from transformers import Emu3Processor, Emu3ForConditionalGeneration + >>> import torch + >>> import requests + >>> from PIL import Image + + >>> model = Emu3ForConditionalGeneration.from_pretrained("BAAI/Emu3-Chat-hf", torch_dtype=torch.bfloat16) + >>> processor = Emu3Processor.from_pretrained("BAAI/Emu3-Chat-hf") + + >>> conversation = [ + ... { + ... "role": "system", + ... "content": [ + ... {"type": "text", "text": "You are a helpful assistant."}, + ... ], + ... }, + ... { + ... "role": "user", + ... "content": [ + ... {"type": "image"}, + ... {"type": "text", "text": "Please describe the image."}, + ... ], + ... }, + ... ] + + >>> prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) + >>> image = Image.open(requests.get("https://www.ilankelman.org/stopsigns/australia.jpg", stream=True).raw) + + >>> inputs = processor(images=[image], text=[prompt], return_tensors="pt").to(model.device, torch.bfloat16) + + >>> generated_ids = model.generate(**inputs, max_new_tokens=100, do_sample=False) + >>> processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if pixel_values is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" + ) + + if pixel_values is not None: + image_tokens = self.get_image_tokens(pixel_values, image_sizes) + special_image_mask = input_ids == self.vocabulary_mapping.image_token_id + image_tokens = image_tokens.to(input_ids.device, input_ids.dtype) + input_ids = input_ids.masked_scatter(special_image_mask, image_tokens) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + return self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + logits_to_keep=logits_to_keep, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + **kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + position_ids=position_ids, + pixel_values=pixel_values, + use_cache=use_cache, + **kwargs, + ) + + if cache_position[0] != 0: + model_inputs["pixel_values"] = None + + return model_inputs + + +__all__ = ["Emu3ForConditionalGeneration", "Emu3ForCausalLM", "Emu3TextModel", "Emu3PreTrainedModel", "Emu3VQVAE"] diff --git a/docs/transformers/build/lib/transformers/models/emu3/modular_emu3.py b/docs/transformers/build/lib/transformers/models/emu3/modular_emu3.py new file mode 100644 index 0000000000000000000000000000000000000000..c4e35e71d21baacd46310d029b9caf7d5925457a --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/emu3/modular_emu3.py @@ -0,0 +1,1321 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from functools import cached_property +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint + +from ...cache_utils import Cache +from ...generation import GenerationMixin +from ...modeling_outputs import ( + CausalLMOutputWithPast, +) +from ...modeling_utils import PreTrainedModel +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + can_return_tuple, + logging, + replace_return_docstrings, +) +from ..chameleon.modeling_chameleon import ( + ChameleonPreTrainedModel, + ChameleonVQVAEEncoderConvDownsample, +) +from ..llama.modeling_llama import ( + LlamaDecoderLayer, + LlamaForCausalLM, + LlamaModel, +) +from ..siglip.modeling_siglip import SiglipAttention +from .configuration_emu3 import Emu3Config, Emu3TextConfig, Emu3VQVAEConfig + + +_CONFIG_FOR_DOC = "Emu3Config" +_CHECKPOINT_FOR_DOC = "BAAI/Emu3-Chat-hf" + +logger = logging.get_logger(__name__) + + +# Has extra dropout which no other model in the library has +class Emu3DecoderLayer(LlamaDecoderLayer): + def __init__(self, config: Emu3Config, layer_idx: int): + super().__init__(config, layer_idx) + self.dropout = nn.Dropout(config.attention_dropout) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + self.dropout(hidden_states) + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + self.dropout(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +class Emu3VQVAEVectorQuantizer(nn.Module): + """ + A module for vector quantization using learned embedding vectors. + + This module implements the quantization process similar to te one described in + the VQ-VAE (Vector Quantized Variational AutoEncoder) paper. It quantizes continuous + input vectors into discrete codebook vectors, which are learned during training. + Current implementation improves over previous ones by avoiding costly matrix multiplications + and allowing for post-hoc remapping of indices. + """ + + def __init__(self, config: Emu3VQVAEConfig): + super().__init__() + self.embedding = nn.Embedding(config.codebook_size, config.embed_dim) + self.embedding.weight.data.uniform_(-1.0 / config.codebook_size, 1.0 / config.codebook_size) + + def forward(self, hidden_state: torch.Tensor): + batch_size, temporal, channels, height, width = hidden_state.shape + hidden_state = hidden_state.permute(0, 1, 3, 4, 2).contiguous() + hidden_state_flattened = hidden_state.view(-1, channels) + + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + hidden_state_sum = torch.sum(hidden_state_flattened**2, dim=1, keepdim=True) + embedding_sum = torch.sum(self.embedding.weight**2, dim=1) + + # "bd,dn->bn", + distances = 2 * torch.matmul(hidden_state_flattened, self.embedding.weight.transpose(0, 1)) + distances = hidden_state_sum + embedding_sum - distances + + min_encoding_indices = torch.argmin(distances, dim=1) + min_encoding_indices = min_encoding_indices.view(batch_size, temporal, height, width) + return min_encoding_indices + + +class Emu3VQVAEEncoderConvDownsample(ChameleonVQVAEEncoderConvDownsample): + pass + + +class Emu3VQVAEEncoderConvUpsample(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, hidden_states): + hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") + hidden_states = self.conv(hidden_states) + return hidden_states + + +class Emu3VQVAEConv3d(nn.Module): + def __init__( + self, + in_channel: int, + out_channel: int, + kernel_size: Tuple[int], + stride: Tuple[int], + ): + super().__init__() + + padding_sizes = [one_kernel - one_stride for one_kernel, one_stride in zip(kernel_size[1:], stride[1:])] + self.padding = () + for pad_size in padding_sizes[::-1]: + self.padding += (pad_size // 2 + pad_size % 2, pad_size // 2) + self.padding += (2, 0) + + self.conv = nn.Conv3d( + in_channel, + out_channel, + kernel_size, + stride=stride, + ) + + def forward(self, hidden_states: torch.Tensor): + hidden_states = F.pad(hidden_states, self.padding) + hidden_states = self.conv(hidden_states) + return hidden_states + + +class Emu3VQVAESpatialNorm(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + ): + super().__init__() + self.norm_layer = nn.GroupNorm( + num_channels=out_channels, + num_groups=32, + eps=1e-6, + affine=True, + ) + + self.conv_y = nn.Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + ) + self.conv_b = nn.Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + ) + + def forward(self, hidden_states: torch.Tensor, quant_states: torch.Tensor): + quant_states = F.interpolate(quant_states, size=hidden_states.shape[-2:], mode="nearest") + hidden_states = self.norm_layer(hidden_states) + hidden_states = hidden_states * self.conv_y(quant_states) + self.conv_b(quant_states) + return hidden_states + + +class Emu3VQVAETemporalUpsample(nn.Module): + def __init__( + self, + in_channel: int, + out_channel: int, + ): + super().__init__() + self.conv = Emu3VQVAEConv3d( + in_channel, + out_channel, + kernel_size=(3, 3, 3), + stride=(1, 1, 1), + ) + + def forward(self, hidden_states: torch.Tensor): + batch_size, channels, temporal, height, width = hidden_states.shape + hidden_states = hidden_states.permute(0, 1, 3, 4, 2).contiguous().view(batch_size, -1, temporal) + hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") + hidden_states = hidden_states.view(batch_size, channels, height, width, -1).permute(0, 1, 4, 2, 3).contiguous() + hidden_states = self.conv(hidden_states) + return hidden_states + + +class Emu3VQVAETemporalDownsample(nn.Module): + def __init__( + self, + in_channel: int, + out_channel: int, + ): + super().__init__() + self.conv = Emu3VQVAEConv3d( + in_channel, + out_channel, + kernel_size=(4, 3, 3), + stride=(2, 1, 1), + ) + + def forward(self, hidden_states: torch.Tensor): + hidden_states = self.conv(hidden_states) + return hidden_states + + +class Emu3VQVAETemporalResnetBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels=None, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + + self.norm1 = nn.BatchNorm3d(in_channels) + self.conv1 = Emu3VQVAEConv3d( + in_channels, + out_channels, + kernel_size=(3, 3, 3), + stride=(1, 1, 1), + ) + self.norm2 = nn.BatchNorm3d(out_channels) + self.conv2 = Emu3VQVAEConv3d( + out_channels, + out_channels, + kernel_size=(3, 3, 3), + stride=(1, 1, 1), + ) + if self.in_channels != self.out_channels: + self.nin_shortcut = nn.Conv3d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + ) + + def forward(self, hidden_states): + residual = hidden_states + hidden_states = self.norm1(hidden_states) + hidden_states *= torch.sigmoid(hidden_states) + hidden_states = self.conv1(hidden_states) + + hidden_states = self.norm2(hidden_states) + hidden_states *= torch.sigmoid(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.in_channels != self.out_channels: + residual = self.nin_shortcut(residual) + + return residual + hidden_states + + +class Emu3VQVAEResnetBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + quant_channels: Optional[int] = None, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.quant_channels = quant_channels + + if quant_channels is None: + self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=32, eps=1e-6, affine=True) + self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=32, eps=1e-6, affine=True) + else: + self.norm1 = Emu3VQVAESpatialNorm(quant_channels, in_channels) + self.norm2 = Emu3VQVAESpatialNorm(quant_channels, out_channels) + + self.conv1 = nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + self.conv2 = nn.Conv2d( + out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + if self.in_channels != self.out_channels: + self.nin_shortcut = nn.Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + ) + + def forward(self, hidden_states: torch.Tensor, quant_channels: Optional[torch.Tensor] = None): + norm_args = () if self.quant_channels is None else (quant_channels,) + + residual = hidden_states + hidden_states = self.norm1(hidden_states, *norm_args) + hidden_states *= torch.sigmoid(hidden_states) + hidden_states = self.conv1(hidden_states) + + hidden_states = self.norm2(hidden_states, *norm_args) + hidden_states *= torch.sigmoid(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.in_channels != self.out_channels: + residual = self.nin_shortcut(residual) + + return residual + hidden_states + + +class Emu3VQVAEAttentionBlock(SiglipAttention): + def __init__(self, config: Emu3VQVAEConfig): + super().__init__(config) + + # for compatibility with the attention interface + self.num_key_value_groups = 1 + + +class Emu3VQVAEGroupNorm(nn.GroupNorm): + """ + Same as the torch GroupNorm with the only difference that this ones accepts + an optional kwarg `quant_states` which is not used. This class makes it easier to + use SpatialNorm or GroupNorm without conditionals + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def forward(self, input, quant_states=None): + return F.group_norm(input, self.num_groups, self.weight, self.bias, self.eps) + + +class Emu3VQVAEMiddleBlock(nn.Module): + def __init__(self, config, in_channels, quant_channels=None): + super().__init__() + + self.block_1 = Emu3VQVAEResnetBlock( + in_channels=in_channels, + out_channels=in_channels, + quant_channels=quant_channels, + ) + self.attn_1 = Emu3VQVAEAttentionBlock(config) + if quant_channels is None: + self.attn_norm = Emu3VQVAEGroupNorm(num_channels=in_channels, num_groups=32, eps=1e-6, affine=True) + else: + self.attn_norm = Emu3VQVAESpatialNorm(quant_channels, in_channels) + + self.block_2 = Emu3VQVAEResnetBlock( + in_channels=in_channels, + out_channels=in_channels, + quant_channels=quant_channels, + ) + + def forward(self, hidden_states: torch.FloatTensor, quant_states: Optional[torch.FloatTensor] = None): + hidden_states = self.block_1(hidden_states, quant_states) + residual = hidden_states + hidden_states = self.attn_norm(hidden_states, quant_states) + batch_size, channels, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channels, height * width).transpose(1, 2) + hidden_states = self.attn_1(hidden_states)[0] + hidden_states = hidden_states.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2) + hidden_states = residual + hidden_states + hidden_states = self.block_2(hidden_states, quant_states) + return hidden_states + + +class Emu3VQVAEDownBlock(nn.Module): + def __init__(self, config): + super().__init__() + + self.num_resolutions = len(config.channel_multiplier) + self.num_res_blocks = config.num_res_blocks + base_channels = config.base_channels + channel_multiplier = config.channel_multiplier + + in_channel_multiplier = (1,) + tuple(channel_multiplier) + self.in_channel_multiplier = in_channel_multiplier + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + attn_norms = nn.ModuleList() + block_in = base_channels * in_channel_multiplier[i_level] + block_out = base_channels * channel_multiplier[i_level] + for i_block in range(self.num_res_blocks): + block.append( + Emu3VQVAEResnetBlock( + in_channels=block_in, + out_channels=block_out, + ) + ) + block_in = block_out + if config.attn_resolutions is not None and i_level in config.attn_resolutions: + attn.append(Emu3VQVAEAttentionBlock(config)) + attn_norms.append(nn.GroupNorm(num_channels=block_in, num_groups=32, eps=1e-6, affine=True)) + + down = nn.Module() + down.block = block + down.attn = attn + down.attn_norms = attn_norms + if i_level != self.num_resolutions - 1: + down.downsample = Emu3VQVAEEncoderConvDownsample(block_in) + self.down.append(down) + + def forward(self, hidden_states: torch.FloatTensor): + for i_level, blocks in enumerate(self.down): + for i_block in range(self.num_res_blocks): + hidden_states = blocks.block[i_block](hidden_states) + if len(blocks.attn) > 0: + residual = hidden_states + hidden_states = blocks.attn_norms[i_block](hidden_states) + + batch_size, channels, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channels, height * width).transpose(1, 2) + hidden_states = blocks.attn[i_block](hidden_states)[0] + + hidden_states = hidden_states.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2) + hidden_states = residual + hidden_states + + if i_level != self.num_resolutions - 1: + hidden_states = blocks.downsample(hidden_states) + + return hidden_states + + +class Emu3VQVAEUpBlock(nn.Module): + def __init__(self, config): + super().__init__() + + self.num_resolutions = len(config.channel_multiplier) + self.num_res_blocks = config.num_res_blocks + + quant_channels = config.embed_dim + block_in = config.base_channels * config.channel_multiplier[-1] + + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + attn_norms = nn.ModuleList() + block_out = config.base_channels * config.channel_multiplier[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + Emu3VQVAEResnetBlock( + in_channels=block_in, + out_channels=block_out, + quant_channels=quant_channels, + ) + ) + block_in = block_out + if i_level in config.attn_resolutions: + attn.append(Emu3VQVAEAttentionBlock(config)) + attn_norms.append(Emu3VQVAESpatialNorm(quant_channels, block_in)) + + up = nn.Module() + up.block = block + up.attn = attn + up.attn_norms = attn_norms + if i_level != 0: + up.upsample = Emu3VQVAEEncoderConvUpsample(block_in) + + self.up.insert(0, up) + + def forward(self, hidden_states: torch.FloatTensor, quant_states: torch.FloatTensor): + for i_level, blocks in enumerate(self.up[::-1]): + for i_block in range(self.num_res_blocks + 1): + hidden_states = blocks.block[i_block](hidden_states, quant_states) + if len(blocks.attn) > 0: + residual = hidden_states + hidden_states = blocks.attn_norms[i_block](hidden_states, quant_states) + + batch_size, channels, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channels, height * width).transpose(1, 2) + hidden_states = blocks.attn[i_block](hidden_states)[0] + + hidden_states = hidden_states.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2) + hidden_states = residual + hidden_states + if i_level != len(self.up) - 1: + hidden_states = blocks.upsample(hidden_states) + + return hidden_states + + +class Emu3VQVAEEncoder(nn.Module): + def __init__(self, config): + super().__init__() + + base_channels = config.base_channels + in_channels = config.in_channels + double_latent = config.double_latent + latent_channels = config.latent_channels + channel_multiplier = config.channel_multiplier + out_channels = 2 * latent_channels if double_latent else latent_channels + block_in = base_channels * channel_multiplier[-1] + + self.conv_in = torch.nn.Conv2d(in_channels, base_channels, kernel_size=3, stride=1, padding=1) + self.down_block = Emu3VQVAEDownBlock(config) + self.middle_block = Emu3VQVAEMiddleBlock(config, block_in) + + self.norm_out = torch.nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) + self.conv_out = torch.nn.Conv2d( + block_in, + out_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + temporal_down_blocks = int(math.log2(config.temporal_downsample_factor)) + self.time_conv = nn.ModuleList() + self.time_res_stack = nn.ModuleList() + + for i in range(temporal_down_blocks): + conv = Emu3VQVAETemporalDownsample(out_channels, out_channels) + self.time_conv.append(conv) + + for _ in range(config.num_res_blocks): + time_res_conv = Emu3VQVAETemporalResnetBlock( + in_channels=out_channels, + out_channels=out_channels, + ) + self.time_res_stack.append(time_res_conv) + + def forward(self, pixel_values: torch.LongTensor): + temporal_dim = pixel_values.shape[1] + pixel_values = pixel_values.reshape(-1, *pixel_values.shape[2:]) + + # downsampling & middle + hidden_states = self.conv_in(pixel_values) + hidden_states = self.down_block(hidden_states) + hidden_states = self.middle_block(hidden_states) + + # end + hidden_states = self.norm_out(hidden_states) + hidden_states *= torch.sigmoid(hidden_states) + hidden_states = self.conv_out(hidden_states) + + hidden_states = hidden_states.reshape(-1, temporal_dim, *hidden_states.shape[1:]) + hidden_states = hidden_states.permute(0, 2, 1, 3, 4) + + # temporal convs + for conv in self.time_conv: + hidden_states = conv(hidden_states) + hidden_states *= torch.sigmoid(hidden_states) + + for layer in self.time_res_stack: + hidden_states = layer(hidden_states) + + hidden_states = hidden_states.permute(0, 2, 1, 3, 4) + + return hidden_states + + +class Emu3VQVAEDecoder(nn.Module): + def __init__(self, config: Emu3VQVAEConfig): + super().__init__() + + quant_channels = config.embed_dim + block_in = config.base_channels * config.channel_multiplier[-1] + self.time_res_stack = nn.ModuleList() + for _ in range(config.num_res_blocks): + time_res_conv = Emu3VQVAETemporalResnetBlock( + in_channels=config.latent_channels, out_channels=config.latent_channels + ) + self.time_res_stack.append(time_res_conv) + + temp_upsample_block_num = int(math.log2(config.temporal_downsample_factor)) + self.time_conv = nn.ModuleList() + for i in range(temp_upsample_block_num): + conv = Emu3VQVAETemporalUpsample(config.latent_channels, config.latent_channels) + self.time_conv.append(conv) + + self.conv_in = nn.Conv2d( + config.latent_channels, + block_in, + kernel_size=3, + stride=1, + padding=1, + ) + + self.middle_block = Emu3VQVAEMiddleBlock(config, block_in, quant_channels=quant_channels) + self.up_block = Emu3VQVAEUpBlock(config) + + block_in = config.base_channels * config.channel_multiplier[0] + self.norm_out = Emu3VQVAESpatialNorm(quant_channels, block_in) + self.conv_out = nn.Conv2d( + block_in, + config.out_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + def forward(self, hidden_states: torch.Tensor, quant_states: torch.Tensor): + hidden_quant_states = torch.cat((hidden_states, quant_states), dim=0) + hidden_quant_states = hidden_quant_states.permute(0, 2, 1, 3, 4) + + # temporal convs + for layer in self.time_res_stack: + hidden_quant_states = layer(hidden_quant_states) + + for layer in self.time_conv: + hidden_quant_states = layer(hidden_quant_states) + hidden_quant_states *= torch.sigmoid(hidden_quant_states) + + hidden_quant_states = hidden_quant_states.permute(0, 2, 1, 3, 4) + hidden_states, quant_states = torch.chunk(hidden_quant_states, 2, dim=0) + hidden_states = hidden_states.reshape(-1, *hidden_states.shape[2:]) + quant_states = quant_states.reshape(-1, *quant_states.shape[2:]) + + hidden_states = self.conv_in(hidden_states) + + # middle & upsampling + hidden_states = self.middle_block(hidden_states, quant_states) + hidden_states = self.up_block(hidden_states, quant_states) + + hidden_states = self.norm_out(hidden_states, quant_states) + hidden_states *= torch.sigmoid(hidden_states) + hidden_states = self.conv_out(hidden_states) + + return hidden_states + + +EMU3_VQ_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Emu3VQVAEConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + """The VQ-VAE model used in Emu3 for encoding/decoding images into discrete tokens. + This model follows the "Make-a-scene: Scene-based text-to-image generation with human priors" paper from + [ Oran Gafni, Adam Polyak, Oron Ashual, Shelly Sheynin, Devi Parikh, and Yaniv Taigman](https://arxiv.org/abs/2203.13131). + """, + EMU3_VQ_START_DOCSTRING, +) +class Emu3VQVAE(PreTrainedModel): + config_class = Emu3VQVAEConfig + base_model_prefix = "emuvideovq" + main_input_name = "pixel_values" + _supports_sdpa = True + _supports_flash_attn_2 = True + _supports_flex_attn = True + _no_split_modules = [ + "Emu3VQVAETemporalResnetBlock", + "Emu3VQVAEAttentionBlock", + "Emu3VQVAEResnetBlock", + "Emu3VQVAEVectorQuantizer", + ] + + def _init_weights(self, module): + if isinstance(module, (nn.Conv2d, nn.Conv3d)): + nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") + if module.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight) + bound = 1 / math.sqrt(fan_in) + nn.init.uniform_(module.bias, -bound, bound) + elif isinstance(module, nn.Linear): + nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5)) + if module.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + nn.init.uniform_(module.bias, -bound, bound) + elif isinstance(module, (nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm)): + nn.init.constant_(module.weight, 1.0) + nn.init.constant_(module.bias, 0.0) + elif isinstance(module, nn.Embedding): + module.weight.data.normal_() + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def __init__(self, config: Emu3VQVAEConfig): + super().__init__(config) + + self.config = config + + self.encoder = Emu3VQVAEEncoder(config) + self.decoder = Emu3VQVAEDecoder(config) + self.quantize = Emu3VQVAEVectorQuantizer(config) + self.vision_spatial_factor = 2 ** (len(config.channel_multiplier) - 1) + + self.quant_conv = Emu3VQVAEConv3d( + config.latent_channels, config.embed_dim, kernel_size=(3, 1, 1), stride=(1, 1, 1) + ) + self.post_quant_conv = Emu3VQVAEConv3d( + config.embed_dim, config.latent_channels, kernel_size=(3, 1, 1), stride=(1, 1, 1) + ) + self.spatial_scale_factor = 2 ** (len(config.channel_multiplier) - 1) + self.eval() # Emu3's VQ model is frozen + + self.post_init() + + def encode(self, pixel_values: torch.Tensor, image_sizes: torch.Tensor): + is_image = pixel_values.ndim == 4 + if is_image: + temporal = self.config.temporal_downsample_factor + batch_size, channels, height, width = pixel_values.shape + pixel_values = pixel_values.unsqueeze(1).repeat(1, temporal, 1, 1, 1) + else: + batch_size, temporal, channels, height, width = pixel_values.shape + + hidden_states = self.encoder(pixel_values) + + # b t c h w -> b c t h w + hidden_states = hidden_states.permute(0, 2, 1, 3, 4) + hidden_states = self.quant_conv(hidden_states) + + # b c t h w -> b t c h w + hidden_states = hidden_states.permute(0, 2, 1, 3, 4) + codes = self.quantize(hidden_states) + + image_tokens = codes.squeeze(1) if is_image else codes + + image_tokens = [ + single_image[: int(size[0] / self.vision_spatial_factor), : int(size[1] / self.vision_spatial_factor)] + for single_image, size in zip(image_tokens, image_sizes) + ] + + return image_tokens + + def decode(self, hidden_states: torch.Tensor): + is_image = hidden_states.ndim == 3 + if is_image: + hidden_states = hidden_states.unsqueeze(1) + + batch_size, temporal, height, width = hidden_states.shape + quant = self.quantize.embedding(hidden_states.flatten()) + + channels = quant.shape[-1] + quant = quant.view(batch_size, temporal, height, width, channels).permute(0, 4, 1, 2, 3).contiguous() + post_quant = self.post_quant_conv(quant) + + quant = quant.permute(0, 2, 1, 3, 4) + post_quant = post_quant.permute(0, 2, 1, 3, 4) + + video = self.decoder(post_quant, quant) + video = video.reshape( + batch_size, + temporal * self.config.temporal_downsample_factor, + self.config.out_channels, + height * self.spatial_scale_factor, + width * self.spatial_scale_factor, + ) + return video[:, 0] if is_image else video + + +class Emu3ImageVocabularyMapping: + """ + A class for mapping discrete image tokens from VQGAN to BPE tokens. + """ + + def __init__(self, vocab_map): + self.vocab_map = vocab_map + self.eol_token_id = vocab_map.get("<|extra_200|>") + self.image_token_id = vocab_map.get("") + + @cached_property + def image_tokens(self): + return sorted([val for name, val in self.vocab_map.items() if name.startswith("<|visual token")]) + + @cached_property + def image_tokens_str(self): + return sorted([name for name, val in self.vocab_map.items() if name.startswith("<|visual token")]) + + @cached_property + def img2bpe(self): + return {int(token[-8:-2]): self.vocab_map[token] for token in self.image_tokens_str} + + @cached_property + def bpe2img(self): + return {v: k for k, v in self.img2bpe.items()} + + @cached_property + def bpe2img_mapping_tensor(self): + mapping = torch.zeros(max(self.bpe2img.keys()) + 1, dtype=torch.int) + for k, v in self.bpe2img.items(): + mapping[k] = v + return mapping + + @cached_property + def img2bpe_mapping_tensor(self): + mapping = torch.zeros(max(self.img2bpe.keys()) + 1, dtype=torch.int) + for k, v in self.img2bpe.items(): + mapping[k] = v + return mapping + + def convert_img2bpe(self, img_batch: List[torch.Tensor]) -> torch.Tensor: + device = img_batch.device + eol_row = torch.ones((img_batch.shape[0], 1), dtype=torch.int) * self.eol_token_id + img_tokens = self.img2bpe_mapping_tensor[img_batch.to("cpu")] + img_tokens = torch.cat([img_tokens, eol_row], dim=-1) + return img_tokens.to(device) + + def convert_bpe2img(self, img_batch: torch.Tensor) -> torch.Tensor: + device = img_batch.device + img_batch = img_batch[..., :-1] # remove last row of EOL tokens + img_tokens = self.bpe2img_mapping_tensor[img_batch.to("cpu")] + return img_tokens.to(device) + + +class Emu3PreTrainedModel(ChameleonPreTrainedModel, Emu3VQVAE): + _no_split_modules = [ + "Emu3DecoderLayer", + ] + _supports_flex_attn = True + + def _init_weights(self, module): + std = self.config.get_text_config().initializer_range + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, Emu3RMSNorm): # noqa: F821 + module.weight.data.fill_(1.0) + + +EMU3_TEXT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Has to be an instance of [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + The model will output the same cache type that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +EMU3_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + pixel_values (`torch.FloatTensor` of shape `(batch_size, max_num_images, max_num_tiles, channels, image_size, image_size)): + The tensors corresponding to the input images. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`Emu3ImageProcessor.__call__`] for details ([]`Emu3Processor`] uses + [`Emu3ImageProcessor`] for processing images). + image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`): + The sizes of the images in the batch, being (height, width) for each image. Image sizes can be obtained using + [`AutoImageProcessor`]. See [`Emu3ImageProcessor.__call__`] for details ([]`Emu3Processor`] uses + [`Emu3ImageProcessor`] for processing images). + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Has to be an instance of [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +class Emu3TextModel(LlamaModel, Emu3PreTrainedModel): + def __init__(self, config: Emu3Config): + super().__init__(config) + self.layers = nn.ModuleList( + [Emu3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + + @can_return_tuple + @add_start_docstrings_to_model_forward(EMU3_TEXT_INPUTS_DOCSTRING) + def forward(self, **super_kwargs): + super().forward(**super_kwargs) + + +class Emu3ForCausalLM(LlamaForCausalLM, Emu3PreTrainedModel, GenerationMixin): + config_class = Emu3TextConfig + + def __init__(self, config): + super().__init__(config) + self.model = Emu3TextModel(config) + + @can_return_tuple + @add_start_docstrings_to_model_forward(EMU3_TEXT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class="Emu3TextConfig") + def forward(**super_kwargs): + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from transformers import Emu3Processor, Emu3ForConditionalGeneration + >>> import torch + >>> import requests + >>> from PIL import Image + + >>> model = Emu3ForCausalLM.from_pretrained("BAAI/Emu3-Chat-hf", torch_dtype=torch.bfloat16) + >>> processor = Emu3Processor.from_pretrained("BAAI/Emu3-Chat-hf") + + >>> inputs = processor(text=["Can you write me a poem about winter."], return_tensors="pt").to(model.device) + + >>> generated_ids = model.generate(**inputs, max_new_tokens=100, do_sample=False) + >>> processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + ```""" + super().forward() + + +class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin): + _tied_weights_keys = ["text_model.lm_head.weight"] + _supports_static_cache = False # `get_image_tokens()`, called when `pixel_values` is passed, is not compileable + + def __init__(self, config): + super().__init__(config) + self.text_model = Emu3ForCausalLM._from_config(config.text_config) + self.vqmodel = Emu3VQVAE(config.vq_config) + self.vocabulary_mapping = Emu3ImageVocabularyMapping(config.vocabulary_map) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.text_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.text_model.set_input_embeddings(value) + + def get_image_tokens(self, pixel_values: torch.FloatTensor, image_sizes: torch.LongTensor): + """ + Tokenizes images into discrete tokens with VQGAN module. Converts + obtained image tokens into BPE tokens and wraps with "boi" and "eoi" + special tokens. + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input images. + image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`): + The sizes of the images in the batch, being (height, width) for each image. + """ + image_tokens_list = self.vqmodel.encode(pixel_values, image_sizes) + bpe_tokens_list = [self.vocabulary_mapping.convert_img2bpe(tokens).flatten() for tokens in image_tokens_list] + bpe_tokens = torch.cat(bpe_tokens_list) + return bpe_tokens + + @torch.no_grad + def decode_image_tokens(self, image_tokens: torch.LongTensor, height: int, width: int): + """ + Decodes generated image tokens from language model to continuous pixel values + with VQGAN module via upsampling. + + Args: + image_tokens (`torch.LongTensor` of shape `(batch_size, num_of_tokens)`): + The tensors corresponding to the input images. + height (`int`): + Height of the generated image before upsampling. + width (`int`): + Width of the generated image before upsampling. + """ + sequences = image_tokens[:, :-3].view(-1, height, width + 1) + image_tokens = self.vocabulary_mapping.convert_bpe2img(sequences) + image = self.vqmodel.decode(image_tokens) + return image + + @can_return_tuple + @add_start_docstrings_to_model_forward(EMU3_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + ) -> CausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from transformers import Emu3Processor, Emu3ForConditionalGeneration + >>> import torch + >>> import requests + >>> from PIL import Image + + >>> model = Emu3ForConditionalGeneration.from_pretrained("BAAI/Emu3-Chat-hf", torch_dtype=torch.bfloat16) + >>> processor = Emu3Processor.from_pretrained("BAAI/Emu3-Chat-hf") + + >>> conversation = [ + ... { + ... "role": "system", + ... "content": [ + ... {"type": "text", "text": "You are a helpful assistant."}, + ... ], + ... }, + ... { + ... "role": "user", + ... "content": [ + ... {"type": "image"}, + ... {"type": "text", "text": "Please describe the image."}, + ... ], + ... }, + ... ] + + >>> prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) + >>> image = Image.open(requests.get("https://www.ilankelman.org/stopsigns/australia.jpg", stream=True).raw) + + >>> inputs = processor(images=[image], text=[prompt], return_tensors="pt").to(model.device, torch.bfloat16) + + >>> generated_ids = model.generate(**inputs, max_new_tokens=100, do_sample=False) + >>> processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if pixel_values is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" + ) + + if pixel_values is not None: + image_tokens = self.get_image_tokens(pixel_values, image_sizes) + special_image_mask = input_ids == self.vocabulary_mapping.image_token_id + image_tokens = image_tokens.to(input_ids.device, input_ids.dtype) + input_ids = input_ids.masked_scatter(special_image_mask, image_tokens) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + return self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + logits_to_keep=logits_to_keep, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + **kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + position_ids=position_ids, + pixel_values=pixel_values, + use_cache=use_cache, + **kwargs, + ) + + if cache_position[0] != 0: + model_inputs["pixel_values"] = None + + return model_inputs + + +__all__ = [ + "Emu3ForConditionalGeneration", + "Emu3ForCausalLM", + "Emu3TextModel", + "Emu3PreTrainedModel", + "Emu3VQVAE", +] diff --git a/docs/transformers/build/lib/transformers/models/emu3/processing_emu3.py b/docs/transformers/build/lib/transformers/models/emu3/processing_emu3.py new file mode 100644 index 0000000000000000000000000000000000000000..a94dc08cd97da2cd0669d263bd423446479d69cf --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/emu3/processing_emu3.py @@ -0,0 +1,222 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Union + +from ...image_processing_utils import BatchFeature +from ...image_utils import ImageInput +from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, TextKwargs, Unpack +from ...tokenization_utils_base import PreTokenizedInput, TextInput + + +class Emu3TextKwargs(TextKwargs, total=False): + return_for_image_generation: bool + + +class Emu3ImagesKwargs(ImagesKwargs, total=False): + ratio: str + image_area: int + + +class Emu3ProcessorKwargs(ProcessingKwargs, total=False): + text_kwargs: Emu3TextKwargs + images_kwargs: Emu3ImagesKwargs + _defaults = { + "text_kwargs": { + "return_for_image_generation": False, + }, + "images_kwargs": { + "ratio": "1:1", + "image_area": 518400, + }, + } + + +class Emu3Processor(ProcessorMixin): + r""" + Constructs a Emu3 processor which wraps a Emu3 image processor and a GPT2 tokenizer into a single + processor. + + [`Emu3Processor`] offers all the functionalities of [`Emu3ImageProcessor`] and [`GPT2TokenizerFast`]. + See the [`~Emu3Processor.__call__`] and [`~Emu3Processor.decode`] for more information. + + Args: + image_processor ([`Emu3ImageProcessor`]): + The image processor is a required input. + tokenizer ([`Emu3TokenizerFast`]): + The tokenizer is a required input. + chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages + in a chat into a tokenizable string. + """ + + attributes = ["image_processor", "tokenizer"] + valid_kwargs = ["chat_template"] + tokenizer_class = ("GPT2Tokenizer", "GPT2TokenizerFast") + image_processor_class = "Emu3ImageProcessor" + + def __init__( + self, + image_processor, + tokenizer, + chat_template=None, + **kwargs, + ): + self.image_token = tokenizer.image_token # image_token as placeholder to be replaced by vq-vae tokens + self.image_token_id = tokenizer.image_token_id + self.image_start_token = tokenizer.boi_token # "<|image start|>" fixed tokens for start and end of image + self.image_end_token = tokenizer.eoi_token # "<|image end|>" + self.fake_token_around_image = tokenizer.image_wrapper_token # "<|image token|>" every image starts with it + self.eof_token = tokenizer.eof_token # "<|extra_201|>" + self.bos_token = tokenizer.bos_token + self.downsample_ratio = 8 + super().__init__(image_processor, tokenizer, chat_template=chat_template) + + def __call__( + self, + images: Optional[ImageInput] = None, + text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, + audio=None, + videos=None, + **kwargs: Unpack[Emu3ProcessorKwargs], + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` + and `kwargs` arguments to Emu3TokenizerFast's [`~Emu3TokenizerFast.__call__`] if `text` is not `None` to encode + the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to + CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the docstring + of the above two methods for more information. + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + """ + # check if images and text inputs are reversed for BC + + if isinstance(text, str): + text = [text] + elif not isinstance(text, list) and not isinstance(text[0], str): + raise TypeError("Invalid input text. Please provide a string, or a list of strings") + + output_kwargs = self._merge_kwargs( + Emu3ProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + return_for_image_generation = output_kwargs["text_kwargs"].pop("return_for_image_generation", False) + ratio = output_kwargs["images_kwargs"].pop("ratio", None) + image_area = output_kwargs["images_kwargs"].pop("image_area", None) + + if return_for_image_generation and images is not None: + raise ValueError("You should not provide `images` when `return_for_image_generation=True`") + + if not return_for_image_generation and text is None and images is None: + raise ValueError("You must provide either text or images when `return_for_image_generation=False`") + + image_features = {} + image_start_tokens = f"{self.image_start_token}" + image_end_tokens = f"{self.eof_token}{self.image_end_token}" + + # generate text from image + text input, so we add placeholders for image tokens + if not return_for_image_generation and images is not None: + image_features = self.image_processor(images, **output_kwargs["images_kwargs"]) + image_sizes = iter(image_features.image_sizes) + + prompt_strings = [] + for sample in text: + while self.image_token in sample: + image_size = next(image_sizes) + height, width = image_size + height = height // self.downsample_ratio + width = width // self.downsample_ratio + image_seq_length = height * (width + 1) # +1 for extra row when converting to BPE in modeling code + + image_placeholder = f"{image_start_tokens}{height}*{width}{self.fake_token_around_image}{'' * image_seq_length}{image_end_tokens}" + sample = sample.replace(self.image_token, image_placeholder, 1) + sample = f"{self.bos_token}{sample}" # add BOS because PT tokenizer doesn't add it + prompt_strings.append(sample) + text = [sample.replace("", self.image_token) for sample in prompt_strings] + + # generate image from text input, so we add begin-of-image tokens from where image generation starts + elif return_for_image_generation: + height, width = self.calculate_generate_size(ratio, image_area, self.downsample_ratio) + image_prompt = f"{image_start_tokens}{height}*{width}{self.fake_token_around_image}" + text = [f"{self.bos_token}{sample}{image_prompt}" for sample in text] + image_features["image_sizes"] = [[height, width]] * len(text) + + # else just generate from text-only input, and we do no special treatment for text + return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) + data = self.tokenizer(text, **output_kwargs["text_kwargs"]) + self._check_special_mm_tokens(text, data, modalities=["image"]) + + data.update(**image_features) + + return BatchFeature(data=data, tensor_type=return_tensors) + + def calculate_generate_size(self, ratio, image_area, spatial_factor): + width, height = map(int, ratio.split(":")) + current_area = width * height + target_ratio = (image_area / current_area) ** 0.5 + + token_height = int(round(height * target_ratio / spatial_factor)) + token_width = int(round(width * target_ratio / spatial_factor)) + return token_height, token_width + + def postprocess(self, images: ImageInput, **kwargs): + return self.image_processor.postprocess(images, **kwargs) + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to Emu3TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to Emu3TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) + + +__all__ = ["Emu3Processor"] diff --git a/docs/transformers/build/lib/transformers/models/encodec/__init__.py b/docs/transformers/build/lib/transformers/models/encodec/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3adeea056604d1d31f946a5cd0bf53ea590ea3aa --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/encodec/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_encodec import * + from .feature_extraction_encodec import * + from .modeling_encodec import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/build/lib/transformers/models/encodec/configuration_encodec.py b/docs/transformers/build/lib/transformers/models/encodec/configuration_encodec.py new file mode 100644 index 0000000000000000000000000000000000000000..77fd67727dc390e0b6a402f4b3617e6aae9e9791 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/encodec/configuration_encodec.py @@ -0,0 +1,192 @@ +# coding=utf-8 +# Copyright 2023 Meta Platforms, Inc. and affiliates, and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""EnCodec model configuration""" + +import math +from typing import Optional + +import numpy as np + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class EncodecConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of an [`EncodecModel`]. It is used to instantiate a + Encodec model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the + [facebook/encodec_24khz](https://huggingface.co/facebook/encodec_24khz) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + target_bandwidths (`List[float]`, *optional*, defaults to `[1.5, 3.0, 6.0, 12.0, 24.0]`): + The range of diffent bandwiths the model can encode audio with. + sampling_rate (`int`, *optional*, defaults to 24000): + The sampling rate at which the audio waveform should be digitalized expressed in hertz (Hz). + audio_channels (`int`, *optional*, defaults to 1): + Number of channels in the audio data. Either 1 for mono or 2 for stereo. + normalize (`bool`, *optional*, defaults to `False`): + Whether the audio shall be normalized when passed. + chunk_length_s (`float`, *optional*): + If defined the audio is pre-processed into chunks of lengths `chunk_length_s` and then encoded. + overlap (`float`, *optional*): + Defines the overlap between each chunk. It is used to compute the `chunk_stride` using the following + formulae : `int((1.0 - self.overlap) * self.chunk_length)`. + hidden_size (`int`, *optional*, defaults to 128): + Intermediate representation dimension. + num_filters (`int`, *optional*, defaults to 32): + Number of convolution kernels of first `EncodecConv1d` down sampling layer. + num_residual_layers (`int`, *optional*, defaults to 1): + Number of residual layers. + upsampling_ratios (`Sequence[int]` , *optional*, defaults to `[8, 5, 4, 2]`): + Kernel size and stride ratios. The encoder uses downsampling ratios instead of upsampling ratios, hence it + will use the ratios in the reverse order to the ones specified here that must match the decoder order. + norm_type (`str`, *optional*, defaults to `"weight_norm"`): + Normalization method. Should be in `["weight_norm", "time_group_norm"]` + kernel_size (`int`, *optional*, defaults to 7): + Kernel size for the initial convolution. + last_kernel_size (`int`, *optional*, defaults to 7): + Kernel size for the last convolution layer. + residual_kernel_size (`int`, *optional*, defaults to 3): + Kernel size for the residual layers. + dilation_growth_rate (`int`, *optional*, defaults to 2): + How much to increase the dilation with each layer. + use_causal_conv (`bool`, *optional*, defaults to `True`): + Whether to use fully causal convolution. + pad_mode (`str`, *optional*, defaults to `"reflect"`): + Padding mode for the convolutions. + compress (`int`, *optional*, defaults to 2): + Reduced dimensionality in residual branches (from Demucs v3). + num_lstm_layers (`int`, *optional*, defaults to 2): + Number of LSTM layers at the end of the encoder. + trim_right_ratio (`float`, *optional*, defaults to 1.0): + Ratio for trimming at the right of the transposed convolution under the `use_causal_conv = True` setup. If + equal to 1.0, it means that all the trimming is done at the right. + codebook_size (`int`, *optional*, defaults to 1024): + Number of discret codes that make up VQVAE. + codebook_dim (`int`, *optional*): + Dimension of the codebook vectors. If not defined, uses `hidden_size`. + use_conv_shortcut (`bool`, *optional*, defaults to `True`): + Whether to use a convolutional layer as the 'skip' connection in the `EncodecResnetBlock` block. If False, + an identity function will be used, giving a generic residual connection. + + Example: + + ```python + >>> from transformers import EncodecModel, EncodecConfig + + >>> # Initializing a "facebook/encodec_24khz" style configuration + >>> configuration = EncodecConfig() + + >>> # Initializing a model (with random weights) from the "facebook/encodec_24khz" style configuration + >>> model = EncodecModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "encodec" + + def __init__( + self, + target_bandwidths=[1.5, 3.0, 6.0, 12.0, 24.0], + sampling_rate=24_000, + audio_channels=1, + normalize=False, + chunk_length_s=None, + overlap=None, + hidden_size=128, + num_filters=32, + num_residual_layers=1, + upsampling_ratios=[8, 5, 4, 2], + norm_type="weight_norm", + kernel_size=7, + last_kernel_size=7, + residual_kernel_size=3, + dilation_growth_rate=2, + use_causal_conv=True, + pad_mode="reflect", + compress=2, + num_lstm_layers=2, + trim_right_ratio=1.0, + codebook_size=1024, + codebook_dim=None, + use_conv_shortcut=True, + **kwargs, + ): + self.target_bandwidths = target_bandwidths + self.sampling_rate = sampling_rate + self.audio_channels = audio_channels + self.normalize = normalize + self.chunk_length_s = chunk_length_s + self.overlap = overlap + self.hidden_size = hidden_size + self.num_filters = num_filters + self.num_residual_layers = num_residual_layers + self.upsampling_ratios = upsampling_ratios + self.norm_type = norm_type + self.kernel_size = kernel_size + self.last_kernel_size = last_kernel_size + self.residual_kernel_size = residual_kernel_size + self.dilation_growth_rate = dilation_growth_rate + self.use_causal_conv = use_causal_conv + self.pad_mode = pad_mode + self.compress = compress + self.num_lstm_layers = num_lstm_layers + self.trim_right_ratio = trim_right_ratio + self.codebook_size = codebook_size + self.codebook_dim = codebook_dim if codebook_dim is not None else hidden_size + self.use_conv_shortcut = use_conv_shortcut + + if self.norm_type not in ["weight_norm", "time_group_norm"]: + raise ValueError( + f'self.norm_type must be one of `"weight_norm"`, `"time_group_norm"`), got {self.norm_type}' + ) + + super().__init__(**kwargs) + + # This is a property because you might want to change the chunk_length_s on the fly + @property + def chunk_length(self) -> Optional[int]: + if self.chunk_length_s is None: + return None + else: + return int(self.chunk_length_s * self.sampling_rate) + + # This is a property because you might want to change the chunk_length_s on the fly + @property + def chunk_stride(self) -> Optional[int]: + if self.chunk_length_s is None or self.overlap is None: + return None + else: + return max(1, int((1.0 - self.overlap) * self.chunk_length)) + + @property + def frame_rate(self) -> int: + hop_length = np.prod(self.upsampling_ratios) + return math.ceil(self.sampling_rate / hop_length) + + @property + def num_quantizers(self) -> int: + return int(1000 * self.target_bandwidths[-1] // (self.frame_rate * 10)) + + +__all__ = ["EncodecConfig"] diff --git a/docs/transformers/build/lib/transformers/models/encodec/convert_encodec_checkpoint_to_pytorch.py b/docs/transformers/build/lib/transformers/models/encodec/convert_encodec_checkpoint_to_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..f1fb0168705f42b943cb5d9bd40aa904c294cfb6 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/encodec/convert_encodec_checkpoint_to_pytorch.py @@ -0,0 +1,365 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert EnCodec checkpoints.""" + +import argparse + +import torch + +from transformers import ( + EncodecConfig, + EncodecFeatureExtractor, + EncodecModel, + logging, +) + + +# checkpoints downloaded from: +# https://dl.fbaipublicfiles.com/encodec/v0/encodec_24khz-d7cc33bc.th +# https://huggingface.co/facebook/musicgen-small/resolve/main/compression_state_dict.bin +# https://dl.fbaipublicfiles.com/encodec/v0/encodec_48khz-7e698e3e.th + + +logging.set_verbosity_info() +logger = logging.get_logger("transformers.models.encodec") + +MAPPING_QUANTIZER = { + "quantizer.vq.layers.*._codebook.inited": "quantizer.layers.*.codebook.inited", + "quantizer.vq.layers.*._codebook.cluster_size": "quantizer.layers.*.codebook.cluster_size", + "quantizer.vq.layers.*._codebook.embed": "quantizer.layers.*.codebook.embed", + "quantizer.vq.layers.*._codebook.embed_avg": "quantizer.layers.*.codebook.embed_avg", +} +MAPPING_ENCODER = { + "encoder.model.0.conv.conv": "encoder.layers.0.conv", + "encoder.model.1.block.1.conv.conv": "encoder.layers.1.block.1.conv", + "encoder.model.1.block.3.conv.conv": "encoder.layers.1.block.3.conv", + "encoder.model.1.shortcut.conv.conv": "encoder.layers.1.shortcut.conv", + "encoder.model.3.conv.conv": "encoder.layers.3.conv", + "encoder.model.4.block.1.conv.conv": "encoder.layers.4.block.1.conv", + "encoder.model.4.block.3.conv.conv": "encoder.layers.4.block.3.conv", + "encoder.model.4.shortcut.conv.conv": "encoder.layers.4.shortcut.conv", + "encoder.model.6.conv.conv": "encoder.layers.6.conv", + "encoder.model.7.block.1.conv.conv": "encoder.layers.7.block.1.conv", + "encoder.model.7.block.3.conv.conv": "encoder.layers.7.block.3.conv", + "encoder.model.7.shortcut.conv.conv": "encoder.layers.7.shortcut.conv", + "encoder.model.9.conv.conv": "encoder.layers.9.conv", + "encoder.model.10.block.1.conv.conv": "encoder.layers.10.block.1.conv", + "encoder.model.10.block.3.conv.conv": "encoder.layers.10.block.3.conv", + "encoder.model.10.shortcut.conv.conv": "encoder.layers.10.shortcut.conv", + "encoder.model.12.conv.conv": "encoder.layers.12.conv", + "encoder.model.13.lstm": "encoder.layers.13.lstm", + "encoder.model.15.conv.conv": "encoder.layers.15.conv", +} +MAPPING_ENCODER_48K = { + "encoder.model.0.conv.norm": "encoder.layers.0.norm", + "encoder.model.1.block.1.conv.norm": "encoder.layers.1.block.1.norm", + "encoder.model.1.block.3.conv.norm": "encoder.layers.1.block.3.norm", + "encoder.model.1.shortcut.conv.norm": "encoder.layers.1.shortcut.norm", + "encoder.model.3.conv.norm": "encoder.layers.3.norm", + "encoder.model.4.block.1.conv.norm": "encoder.layers.4.block.1.norm", + "encoder.model.4.block.3.conv.norm": "encoder.layers.4.block.3.norm", + "encoder.model.4.shortcut.conv.norm": "encoder.layers.4.shortcut.norm", + "encoder.model.6.conv.norm": "encoder.layers.6.norm", + "encoder.model.7.block.1.conv.norm": "encoder.layers.7.block.1.norm", + "encoder.model.7.block.3.conv.norm": "encoder.layers.7.block.3.norm", + "encoder.model.7.shortcut.conv.norm": "encoder.layers.7.shortcut.norm", + "encoder.model.9.conv.norm": "encoder.layers.9.norm", + "encoder.model.10.block.1.conv.norm": "encoder.layers.10.block.1.norm", + "encoder.model.10.block.3.conv.norm": "encoder.layers.10.block.3.norm", + "encoder.model.10.shortcut.conv.norm": "encoder.layers.10.shortcut.norm", + "encoder.model.12.conv.norm": "encoder.layers.12.norm", + "encoder.model.15.conv.norm": "encoder.layers.15.norm", +} +MAPPING_DECODER = { + "decoder.model.0.conv.conv": "decoder.layers.0.conv", + "decoder.model.1.lstm": "decoder.layers.1.lstm", + "decoder.model.3.convtr.convtr": "decoder.layers.3.conv", + "decoder.model.4.block.1.conv.conv": "decoder.layers.4.block.1.conv", + "decoder.model.4.block.3.conv.conv": "decoder.layers.4.block.3.conv", + "decoder.model.4.shortcut.conv.conv": "decoder.layers.4.shortcut.conv", + "decoder.model.6.convtr.convtr": "decoder.layers.6.conv", + "decoder.model.7.block.1.conv.conv": "decoder.layers.7.block.1.conv", + "decoder.model.7.block.3.conv.conv": "decoder.layers.7.block.3.conv", + "decoder.model.7.shortcut.conv.conv": "decoder.layers.7.shortcut.conv", + "decoder.model.9.convtr.convtr": "decoder.layers.9.conv", + "decoder.model.10.block.1.conv.conv": "decoder.layers.10.block.1.conv", + "decoder.model.10.block.3.conv.conv": "decoder.layers.10.block.3.conv", + "decoder.model.10.shortcut.conv.conv": "decoder.layers.10.shortcut.conv", + "decoder.model.12.convtr.convtr": "decoder.layers.12.conv", + "decoder.model.13.block.1.conv.conv": "decoder.layers.13.block.1.conv", + "decoder.model.13.block.3.conv.conv": "decoder.layers.13.block.3.conv", + "decoder.model.13.shortcut.conv.conv": "decoder.layers.13.shortcut.conv", + "decoder.model.15.conv.conv": "decoder.layers.15.conv", +} +MAPPING_DECODER_48K = { + "decoder.model.0.conv.norm": "decoder.layers.0.norm", + "decoder.model.3.convtr.norm": "decoder.layers.3.norm", + "decoder.model.4.block.1.conv.norm": "decoder.layers.4.block.1.norm", + "decoder.model.4.block.3.conv.norm": "decoder.layers.4.block.3.norm", + "decoder.model.4.shortcut.conv.norm": "decoder.layers.4.shortcut.norm", + "decoder.model.6.convtr.norm": "decoder.layers.6.norm", + "decoder.model.7.block.1.conv.norm": "decoder.layers.7.block.1.norm", + "decoder.model.7.block.3.conv.norm": "decoder.layers.7.block.3.norm", + "decoder.model.7.shortcut.conv.norm": "decoder.layers.7.shortcut.norm", + "decoder.model.9.convtr.norm": "decoder.layers.9.norm", + "decoder.model.10.block.1.conv.norm": "decoder.layers.10.block.1.norm", + "decoder.model.10.block.3.conv.norm": "decoder.layers.10.block.3.norm", + "decoder.model.10.shortcut.conv.norm": "decoder.layers.10.shortcut.norm", + "decoder.model.12.convtr.norm": "decoder.layers.12.norm", + "decoder.model.13.block.1.conv.norm": "decoder.layers.13.block.1.norm", + "decoder.model.13.block.3.conv.norm": "decoder.layers.13.block.3.norm", + "decoder.model.13.shortcut.conv.norm": "decoder.layers.13.shortcut.norm", + "decoder.model.15.conv.norm": "decoder.layers.15.norm", +} +MAPPING_24K = { + **MAPPING_QUANTIZER, + **MAPPING_ENCODER, + **MAPPING_DECODER, +} +MAPPING_48K = { + **MAPPING_QUANTIZER, + **MAPPING_ENCODER, + **MAPPING_ENCODER_48K, + **MAPPING_DECODER, + **MAPPING_DECODER_48K, +} +TOP_LEVEL_KEYS = [] +IGNORE_KEYS = [] + + +def set_recursively(hf_pointer, key, value, full_name, weight_type): + for attribute in key.split("."): + hf_pointer = getattr(hf_pointer, attribute) + + if weight_type is not None: + hf_shape = getattr(hf_pointer, weight_type).shape + else: + hf_shape = hf_pointer.shape + + if hf_shape != value.shape: + raise ValueError( + f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be" + f" {value.shape} for {full_name}" + ) + + if weight_type == "weight": + hf_pointer.weight.data = value + elif weight_type == "weight_g": + hf_pointer.weight_g.data = value + elif weight_type == "weight_v": + hf_pointer.weight_v.data = value + elif weight_type == "bias": + hf_pointer.bias.data = value + elif weight_type == "running_mean": + hf_pointer.running_mean.data = value + elif weight_type == "running_var": + hf_pointer.running_var.data = value + elif weight_type == "num_batches_tracked": + hf_pointer.num_batches_tracked.data = value + elif weight_type == "weight_ih_l0": + hf_pointer.weight_ih_l0.data = value + elif weight_type == "weight_hh_l0": + hf_pointer.weight_hh_l0.data = value + elif weight_type == "bias_ih_l0": + hf_pointer.bias_ih_l0.data = value + elif weight_type == "bias_hh_l0": + hf_pointer.bias_hh_l0.data = value + elif weight_type == "weight_ih_l1": + hf_pointer.weight_ih_l1.data = value + elif weight_type == "weight_hh_l1": + hf_pointer.weight_hh_l1.data = value + elif weight_type == "bias_ih_l1": + hf_pointer.bias_ih_l1.data = value + elif weight_type == "bias_hh_l1": + hf_pointer.bias_hh_l1.data = value + else: + hf_pointer.data = value + + logger.info(f"{key + ('.' + weight_type if weight_type is not None else '')} was initialized from {full_name}.") + + +def should_ignore(name, ignore_keys): + for key in ignore_keys: + if key.endswith(".*"): + if name.startswith(key[:-1]): + return True + elif ".*." in key: + prefix, suffix = key.split(".*.") + if prefix in name and suffix in name: + return True + elif key in name: + return True + return False + + +def recursively_load_weights(orig_dict, hf_model, model_name): + unused_weights = [] + + if model_name in ["encodec_24khz", "encodec_32khz"]: + MAPPING = MAPPING_24K + elif model_name == "encodec_48khz": + MAPPING = MAPPING_48K + else: + raise ValueError(f"Unsupported model: {model_name}") + + for name, value in orig_dict.items(): + if should_ignore(name, IGNORE_KEYS): + logger.info(f"{name} was ignored") + continue + + is_used = False + for key, mapped_key in MAPPING.items(): + if "*" in key: + prefix, suffix = key.split(".*.") + if prefix in name and suffix in name: + key = suffix + + if key in name: + # HACK otherwise .embed gets initialized with .embed_avg too + if key.endswith("embed") and name.endswith("embed_avg"): + continue + + is_used = True + if "*" in mapped_key: + layer_index = name.split(key)[0].split(".")[-2] + mapped_key = mapped_key.replace("*", layer_index) + if "weight_g" in name: + weight_type = "weight_g" + elif "weight_v" in name: + weight_type = "weight_v" + elif "weight_ih_l0" in name: + weight_type = "weight_ih_l0" + elif "weight_hh_l0" in name: + weight_type = "weight_hh_l0" + elif "bias_ih_l0" in name: + weight_type = "bias_ih_l0" + elif "bias_hh_l0" in name: + weight_type = "bias_hh_l0" + elif "weight_ih_l1" in name: + weight_type = "weight_ih_l1" + elif "weight_hh_l1" in name: + weight_type = "weight_hh_l1" + elif "bias_ih_l1" in name: + weight_type = "bias_ih_l1" + elif "bias_hh_l1" in name: + weight_type = "bias_hh_l1" + elif "bias" in name: + weight_type = "bias" + elif "weight" in name: + weight_type = "weight" + elif "running_mean" in name: + weight_type = "running_mean" + elif "running_var" in name: + weight_type = "running_var" + elif "num_batches_tracked" in name: + weight_type = "num_batches_tracked" + else: + weight_type = None + set_recursively(hf_model, mapped_key, value, name, weight_type) + continue + if not is_used: + unused_weights.append(name) + + logger.warning(f"Unused weights: {unused_weights}") + + +@torch.no_grad() +def convert_checkpoint( + model_name, + checkpoint_path, + pytorch_dump_folder_path, + config_path=None, + repo_id=None, +): + """ + Copy/paste/tweak model's weights to transformers design. + """ + if config_path is not None: + config = EncodecConfig.from_pretrained(config_path) + else: + config = EncodecConfig() + + if model_name == "encodec_24khz": + pass # config is already correct + elif model_name == "encodec_32khz": + config.upsampling_ratios = [8, 5, 4, 4] + config.target_bandwidths = [2.2] + config.num_filters = 64 + config.sampling_rate = 32_000 + config.codebook_size = 2048 + config.use_causal_conv = False + config.normalize = False + config.use_conv_shortcut = False + elif model_name == "encodec_48khz": + config.upsampling_ratios = [8, 5, 4, 2] + config.target_bandwidths = [3.0, 6.0, 12.0, 24.0] + config.sampling_rate = 48_000 + config.audio_channels = 2 + config.use_causal_conv = False + config.norm_type = "time_group_norm" + config.normalize = True + config.chunk_length_s = 1.0 + config.overlap = 0.01 + else: + raise ValueError(f"Unknown model name: {model_name}") + + model = EncodecModel(config) + + feature_extractor = EncodecFeatureExtractor( + feature_size=config.audio_channels, + sampling_rate=config.sampling_rate, + chunk_length_s=config.chunk_length_s, + overlap=config.overlap, + ) + feature_extractor.save_pretrained(pytorch_dump_folder_path) + + original_checkpoint = torch.load(checkpoint_path, weights_only=True) + if "best_state" in original_checkpoint: + # we might have a training state saved, in which case discard the yaml results and just retain the weights + original_checkpoint = original_checkpoint["best_state"] + recursively_load_weights(original_checkpoint, model, model_name) + model.save_pretrained(pytorch_dump_folder_path) + + if repo_id: + print("Pushing to the hub...") + feature_extractor.push_to_hub(repo_id) + model.push_to_hub(repo_id) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--model", + default="encodec_24khz", + type=str, + help="The model to convert. Should be one of 'encodec_24khz', 'encodec_32khz', 'encodec_48khz'.", + ) + parser.add_argument("--checkpoint_path", required=True, default=None, type=str, help="Path to original checkpoint") + parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert") + parser.add_argument( + "--pytorch_dump_folder_path", required=True, default=None, type=str, help="Path to the output PyTorch model." + ) + parser.add_argument( + "--push_to_hub", default=None, type=str, help="Where to upload the converted model on the 🤗 hub." + ) + + args = parser.parse_args() + convert_checkpoint( + args.model, + args.checkpoint_path, + args.pytorch_dump_folder_path, + args.config_path, + args.push_to_hub, + ) diff --git a/docs/transformers/build/lib/transformers/models/encodec/feature_extraction_encodec.py b/docs/transformers/build/lib/transformers/models/encodec/feature_extraction_encodec.py new file mode 100644 index 0000000000000000000000000000000000000000..f33191862e4891375059b882f8c50f5953dabc1b --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/encodec/feature_extraction_encodec.py @@ -0,0 +1,209 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Feature extractor class for EnCodec.""" + +from typing import List, Optional, Union + +import numpy as np + +from ...feature_extraction_sequence_utils import SequenceFeatureExtractor +from ...feature_extraction_utils import BatchFeature +from ...utils import PaddingStrategy, TensorType, logging + + +logger = logging.get_logger(__name__) + + +class EncodecFeatureExtractor(SequenceFeatureExtractor): + r""" + Constructs an EnCodec feature extractor. + + This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains + most of the main methods. Users should refer to this superclass for more information regarding those methods. + + Instantiating a feature extractor with the defaults will yield a similar configuration to that of the + [facebook/encodec_24khz](https://huggingface.co/facebook/encodec_24khz) architecture. + + Args: + feature_size (`int`, *optional*, defaults to 1): + The feature dimension of the extracted features. Use 1 for mono, 2 for stereo. + sampling_rate (`int`, *optional*, defaults to 24000): + The sampling rate at which the audio waveform should be digitalized expressed in hertz (Hz). + padding_value (`float`, *optional*, defaults to 0.0): + The value that is used to fill the padding values. + chunk_length_s (`float`, *optional*): + If defined the audio is pre-processed into chunks of lengths `chunk_length_s` and then encoded. + overlap (`float`, *optional*): + Defines the overlap between each chunk. It is used to compute the `chunk_stride` using the following + formulae : `int((1.0 - self.overlap) * self.chunk_length)`. + """ + + model_input_names = ["input_values", "padding_mask"] + + def __init__( + self, + feature_size: int = 1, + sampling_rate: int = 24000, + padding_value: float = 0.0, + chunk_length_s: Optional[float] = None, + overlap: Optional[float] = None, + **kwargs, + ): + super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs) + self.chunk_length_s = chunk_length_s + self.overlap = overlap + + # This is a property because you might want to change the chunk_length_s on the fly + @property + def chunk_length(self) -> Optional[int]: + if self.chunk_length_s is None: + return None + else: + return int(self.chunk_length_s * self.sampling_rate) + + # This is a property because you might want to change the chunk_length_s on the fly + @property + def chunk_stride(self) -> Optional[int]: + if self.chunk_length_s is None or self.overlap is None: + return None + else: + return max(1, int((1.0 - self.overlap) * self.chunk_length)) + + def __call__( + self, + raw_audio: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]], + padding: Optional[Union[bool, str, PaddingStrategy]] = None, + truncation: Optional[bool] = False, + max_length: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + sampling_rate: Optional[int] = None, + ) -> BatchFeature: + """ + Main method to featurize and prepare for the model one or several sequence(s). + + Args: + raw_audio (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`): + The sequence or batch of sequences to be processed. Each sequence can be a numpy array, a list of float + values, a list of numpy arrays or a list of list of float values. The numpy array must be of shape + `(num_samples,)` for mono audio (`feature_size = 1`), or `(2, num_samples)` for stereo audio + (`feature_size = 2`). + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding + index) among: + + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + truncation (`bool`, *optional*, defaults to `False`): + Activates truncation to cut input sequences longer than `max_length` to `max_length`. + max_length (`int`, *optional*): + Maximum length of the returned list and optionally padding length (see above). + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors instead of list of python integers. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return Numpy `np.ndarray` objects. + sampling_rate (`int`, *optional*): + The sampling rate at which the `audio` input was sampled. It is strongly recommended to pass + `sampling_rate` at the forward call to prevent silent errors. + """ + if sampling_rate is not None: + if sampling_rate != self.sampling_rate: + raise ValueError( + f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of" + f" {self.sampling_rate}. Please make sure that the provided audio input was sampled with" + f" {self.sampling_rate} and not {sampling_rate}." + ) + else: + logger.warning( + f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. " + "Failing to do so can result in silent errors that might be hard to debug." + ) + + if padding and truncation: + raise ValueError("Both padding and truncation were set. Make sure you only set one.") + elif padding is None: + # by default let's pad the inputs + padding = True + + is_batched = bool( + isinstance(raw_audio, (list, tuple)) and (isinstance(raw_audio[0], (np.ndarray, tuple, list))) + ) + + if is_batched: + raw_audio = [np.asarray(audio, dtype=np.float32).T for audio in raw_audio] + elif not is_batched and not isinstance(raw_audio, np.ndarray): + raw_audio = np.asarray(raw_audio, dtype=np.float32) + elif isinstance(raw_audio, np.ndarray) and raw_audio.dtype is np.dtype(np.float64): + raw_audio = raw_audio.astype(np.float32) + + # always return batch + if not is_batched: + raw_audio = [np.asarray(raw_audio).T] + + # verify inputs are valid + for idx, example in enumerate(raw_audio): + if example.ndim > 2: + raise ValueError(f"Expected input shape (channels, length) but got shape {example.shape}") + if self.feature_size == 1 and example.ndim != 1: + raise ValueError(f"Expected mono audio but example has {example.shape[-1]} channels") + if self.feature_size == 2 and example.shape[-1] != 2: + raise ValueError(f"Expected stereo audio but example has {example.shape[-1]} channels") + + padded_inputs = None + input_values = BatchFeature({"input_values": raw_audio}) + if self.chunk_stride is not None and self.chunk_length is not None and max_length is None: + if truncation: + max_length = min(array.shape[0] for array in raw_audio) + nb_step = int(np.floor(max_length / self.chunk_stride)) + max_length = (nb_step - 1) * self.chunk_stride + self.chunk_length + elif padding: + max_length = max(array.shape[0] for array in raw_audio) + nb_step = int(np.ceil(max_length / self.chunk_stride)) + max_length = (nb_step - 1) * self.chunk_stride + self.chunk_length + padding = "max_length" + else: + padded_inputs = input_values + + # normal padding on batch + if padded_inputs is None: + padded_inputs = self.pad( + input_values, + max_length=max_length, + truncation=truncation, + padding=padding, + return_attention_mask=padding, + ) + if padding: + padded_inputs["padding_mask"] = padded_inputs.pop("attention_mask") + + input_values = [] + for example in padded_inputs.pop("input_values"): + if self.feature_size == 1: + example = example[..., None] + input_values.append(example.T) + + padded_inputs["input_values"] = input_values + if return_tensors is not None: + padded_inputs = padded_inputs.convert_to_tensors(return_tensors) + + return padded_inputs + + +__all__ = ["EncodecFeatureExtractor"] diff --git a/docs/transformers/build/lib/transformers/models/encodec/modeling_encodec.py b/docs/transformers/build/lib/transformers/models/encodec/modeling_encodec.py new file mode 100644 index 0000000000000000000000000000000000000000..670ac99e03e03a88c07e0478e5190622403eb77a --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/encodec/modeling_encodec.py @@ -0,0 +1,818 @@ +# coding=utf-8 +# Copyright 2023 Meta Platforms, Inc. and affiliates, and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch EnCodec model.""" + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn + +from ...modeling_utils import PreTrainedModel +from ...utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from .configuration_encodec import EncodecConfig + + +logger = logging.get_logger(__name__) + + +# General docstring +_CONFIG_FOR_DOC = "EncodecConfig" + + +@dataclass +class EncodecOutput(ModelOutput): + """ + Args: + audio_codes (`torch.LongTensor` of shape `(batch_size, nb_chunks, chunk_length)`, *optional*): + Discret code embeddings computed using `model.encode`. + audio_values (`torch.FlaotTensor` of shape `(batch_size, sequence_length)`, *optional*) + Decoded audio values, obtained using the decoder part of Encodec. + """ + + audio_codes: Optional[torch.LongTensor] = None + audio_values: Optional[torch.FloatTensor] = None + + +@dataclass +class EncodecEncoderOutput(ModelOutput): + """ + Args: + audio_codes (`torch.LongTensor` of shape `(batch_size, nb_chunks, chunk_length)`, *optional*): + Discret code embeddings computed using `model.encode`. + audio_scales (`torch.Tensor` of shape `(batch_size, nb_chunks)`, *optional*): + Scaling factor for each `audio_codes` input. This is used to unscale each chunk of audio when decoding. + """ + + audio_codes: Optional[torch.LongTensor] = None + audio_scales: Optional[torch.FloatTensor] = None + + +@dataclass +class EncodecDecoderOutput(ModelOutput): + """ + Args: + audio_values (`torch.FloatTensor` of shape `(batch_size, segment_length)`, *optional*): + Decoded audio values, obtained using the decoder part of Encodec. + """ + + audio_values: Optional[torch.FloatTensor] = None + + +class EncodecConv1d(nn.Module): + """Conv1d with asymmetric or causal padding and normalization.""" + + def __init__( + self, config, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, dilation: int = 1 + ): + super().__init__() + self.causal = config.use_causal_conv + self.pad_mode = config.pad_mode + self.norm_type = config.norm_type + + if self.norm_type not in ["weight_norm", "time_group_norm"]: + raise ValueError( + f'self.norm_type must be one of `"weight_norm"`, `"time_group_norm"`), got {self.norm_type}' + ) + + # warn user on unusual setup between dilation and stride + if stride > 1 and dilation > 1: + logger.warning( + "EncodecConv1d has been initialized with stride > 1 and dilation > 1" + f" (kernel_size={kernel_size} stride={stride}, dilation={dilation})." + ) + + self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, dilation=dilation) + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + + if self.norm_type == "weight_norm": + self.conv = weight_norm(self.conv) + elif self.norm_type == "time_group_norm": + self.norm = nn.GroupNorm(1, out_channels) + + kernel_size = self.conv.kernel_size[0] + stride = torch.tensor(self.conv.stride[0], dtype=torch.int64) + dilation = self.conv.dilation[0] + + # Effective kernel size with dilations. + kernel_size = torch.tensor((kernel_size - 1) * dilation + 1, dtype=torch.int64) + + self.register_buffer("stride", stride, persistent=False) + self.register_buffer("kernel_size", kernel_size, persistent=False) + self.register_buffer("padding_total", kernel_size - stride, persistent=False) + + def _get_extra_padding_for_conv1d( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + """See `pad_for_conv1d`.""" + length = hidden_states.shape[-1] + n_frames = (length - self.kernel_size + self.padding_total) / self.stride + 1 + n_frames = torch.ceil(n_frames).to(torch.int64) - 1 + ideal_length = n_frames * self.stride + self.kernel_size - self.padding_total + + return ideal_length - length + + @staticmethod + def _pad1d(hidden_states: torch.Tensor, paddings: Tuple[int, int], mode: str = "zero", value: float = 0.0): + """Tiny wrapper around torch.nn.functional.pad, just to allow for reflect padding on small input. + If this is the case, we insert extra 0 padding to the right before the reflection happens. + """ + length = hidden_states.shape[-1] + padding_left, padding_right = paddings + if not mode == "reflect": + return nn.functional.pad(hidden_states, paddings, mode, value) + + max_pad = max(padding_left, padding_right) + extra_pad = 0 + if length <= max_pad: + extra_pad = max_pad - length + 1 + hidden_states = nn.functional.pad(hidden_states, (0, extra_pad)) + padded = nn.functional.pad(hidden_states, paddings, mode, value) + end = padded.shape[-1] - extra_pad + return padded[..., :end] + + def forward(self, hidden_states): + extra_padding = self._get_extra_padding_for_conv1d(hidden_states) + + if self.causal: + # Left padding for causal + hidden_states = self._pad1d(hidden_states, (self.padding_total, extra_padding), mode=self.pad_mode) + else: + # Asymmetric padding required for odd strides + padding_right = self.padding_total // 2 + padding_left = self.padding_total - padding_right + hidden_states = self._pad1d( + hidden_states, (padding_left, padding_right + extra_padding), mode=self.pad_mode + ) + + hidden_states = self.conv(hidden_states) + + if self.norm_type == "time_group_norm": + hidden_states = self.norm(hidden_states) + + return hidden_states + + +class EncodecConvTranspose1d(nn.Module): + """ConvTranspose1d with asymmetric or causal padding and normalization.""" + + def __init__(self, config, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1): + super().__init__() + self.causal = config.use_causal_conv + self.trim_right_ratio = config.trim_right_ratio + self.norm_type = config.norm_type + if self.norm_type not in ["weight_norm", "time_group_norm"]: + raise ValueError( + f'self.norm_type must be one of `"weight_norm"`, `"time_group_norm"`), got {self.norm_type}' + ) + + self.conv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride) + + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + + if config.norm_type == "weight_norm": + self.conv = weight_norm(self.conv) + elif config.norm_type == "time_group_norm": + self.norm = nn.GroupNorm(1, out_channels) + + if not (self.causal or self.trim_right_ratio == 1.0): + raise ValueError("`trim_right_ratio` != 1.0 only makes sense for causal convolutions") + + def forward(self, hidden_states): + kernel_size = self.conv.kernel_size[0] + stride = self.conv.stride[0] + padding_total = kernel_size - stride + + hidden_states = self.conv(hidden_states) + + if self.norm_type == "time_group_norm": + hidden_states = self.norm(hidden_states) + + # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be + # removed at the very end, when keeping only the right length for the output, + # as removing it here would require also passing the length at the matching layer + # in the encoder. + if self.causal: + # Trim the padding on the right according to the specified ratio + # if trim_right_ratio = 1.0, trim everything from right + padding_right = math.ceil(padding_total * self.trim_right_ratio) + else: + # Asymmetric padding required for odd strides + padding_right = padding_total // 2 + + padding_left = padding_total - padding_right + + # unpad + end = hidden_states.shape[-1] - padding_right + hidden_states = hidden_states[..., padding_left:end] + return hidden_states + + +class EncodecLSTM(nn.Module): + """ + LSTM without worrying about the hidden state, nor the layout of the data. Expects input as convolutional layout. + """ + + def __init__(self, config, dimension): + super().__init__() + self.lstm = nn.LSTM(dimension, dimension, config.num_lstm_layers) + + def forward(self, hidden_states): + hidden_states = hidden_states.permute(2, 0, 1) + hidden_states = self.lstm(hidden_states)[0] + hidden_states + hidden_states = hidden_states.permute(1, 2, 0) + return hidden_states + + +class EncodecResnetBlock(nn.Module): + """ + Residual block from SEANet model as used by EnCodec. + """ + + def __init__(self, config: EncodecConfig, dim: int, dilations: List[int]): + super().__init__() + kernel_sizes = (config.residual_kernel_size, 1) + if len(kernel_sizes) != len(dilations): + raise ValueError("Number of kernel sizes should match number of dilations") + + hidden = dim // config.compress + block = [] + for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)): + in_chs = dim if i == 0 else hidden + out_chs = dim if i == len(kernel_sizes) - 1 else hidden + block += [nn.ELU()] + block += [EncodecConv1d(config, in_chs, out_chs, kernel_size, dilation=dilation)] + self.block = nn.ModuleList(block) + + if config.use_conv_shortcut: + self.shortcut = EncodecConv1d(config, dim, dim, kernel_size=1) + else: + self.shortcut = nn.Identity() + + def forward(self, hidden_states): + residual = hidden_states + for layer in self.block: + hidden_states = layer(hidden_states) + + return self.shortcut(residual) + hidden_states + + +class EncodecEncoder(nn.Module): + """SEANet encoder as used by EnCodec.""" + + def __init__(self, config: EncodecConfig): + super().__init__() + model = [EncodecConv1d(config, config.audio_channels, config.num_filters, config.kernel_size)] + scaling = 1 + + # Downsample to raw audio scale + for ratio in reversed(config.upsampling_ratios): + current_scale = scaling * config.num_filters + # Add residual layers + for j in range(config.num_residual_layers): + model += [EncodecResnetBlock(config, current_scale, [config.dilation_growth_rate**j, 1])] + # Add downsampling layers + model += [nn.ELU()] + model += [EncodecConv1d(config, current_scale, current_scale * 2, kernel_size=ratio * 2, stride=ratio)] + scaling *= 2 + + model += [EncodecLSTM(config, scaling * config.num_filters)] + model += [nn.ELU()] + model += [EncodecConv1d(config, scaling * config.num_filters, config.hidden_size, config.last_kernel_size)] + + self.layers = nn.ModuleList(model) + + def forward(self, hidden_states): + for layer in self.layers: + hidden_states = layer(hidden_states) + return hidden_states + + +class EncodecDecoder(nn.Module): + """SEANet decoder as used by EnCodec.""" + + def __init__(self, config: EncodecConfig): + super().__init__() + scaling = int(2 ** len(config.upsampling_ratios)) + model = [EncodecConv1d(config, config.hidden_size, scaling * config.num_filters, config.kernel_size)] + + model += [EncodecLSTM(config, scaling * config.num_filters)] + + # Upsample to raw audio scale + for ratio in config.upsampling_ratios: + current_scale = scaling * config.num_filters + # Add upsampling layers + model += [nn.ELU()] + model += [ + EncodecConvTranspose1d(config, current_scale, current_scale // 2, kernel_size=ratio * 2, stride=ratio) + ] + # Add residual layers + for j in range(config.num_residual_layers): + model += [EncodecResnetBlock(config, current_scale // 2, (config.dilation_growth_rate**j, 1))] + scaling //= 2 + + # Add final layers + model += [nn.ELU()] + model += [EncodecConv1d(config, config.num_filters, config.audio_channels, config.last_kernel_size)] + self.layers = nn.ModuleList(model) + + def forward(self, hidden_states): + for layer in self.layers: + hidden_states = layer(hidden_states) + return hidden_states + + +class EncodecEuclideanCodebook(nn.Module): + """Codebook with Euclidean distance.""" + + def __init__(self, config: EncodecConfig): + super().__init__() + embed = torch.zeros(config.codebook_size, config.codebook_dim) + + self.codebook_size = config.codebook_size + + self.register_buffer("inited", torch.Tensor([True])) + self.register_buffer("cluster_size", torch.zeros(config.codebook_size)) + self.register_buffer("embed", embed) + self.register_buffer("embed_avg", embed.clone()) + + def quantize(self, hidden_states): + embed = self.embed.t() + scaled_states = hidden_states.pow(2).sum(1, keepdim=True) + dist = -(scaled_states - 2 * hidden_states @ embed + embed.pow(2).sum(0, keepdim=True)) + embed_ind = dist.max(dim=-1).indices + return embed_ind + + def encode(self, hidden_states): + shape = hidden_states.shape + # pre-process + hidden_states = hidden_states.reshape((-1, shape[-1])) + # quantize + embed_ind = self.quantize(hidden_states) + # post-process + embed_ind = embed_ind.view(*shape[:-1]) + return embed_ind + + def decode(self, embed_ind): + quantize = nn.functional.embedding(embed_ind, self.embed) + return quantize + + +class EncodecVectorQuantization(nn.Module): + """ + Vector quantization implementation. Currently supports only euclidean distance. + """ + + def __init__(self, config: EncodecConfig): + super().__init__() + self.codebook = EncodecEuclideanCodebook(config) + + def encode(self, hidden_states): + hidden_states = hidden_states.permute(0, 2, 1) + embed_in = self.codebook.encode(hidden_states) + return embed_in + + def decode(self, embed_ind): + quantize = self.codebook.decode(embed_ind) + quantize = quantize.permute(0, 2, 1) + return quantize + + +class EncodecResidualVectorQuantizer(nn.Module): + """Residual Vector Quantizer.""" + + def __init__(self, config: EncodecConfig): + super().__init__() + self.codebook_size = config.codebook_size + self.frame_rate = config.frame_rate + self.num_quantizers = config.num_quantizers + self.layers = nn.ModuleList([EncodecVectorQuantization(config) for _ in range(config.num_quantizers)]) + + def get_num_quantizers_for_bandwidth(self, bandwidth: Optional[float] = None) -> int: + """Return num_quantizers based on specified target bandwidth.""" + bw_per_q = math.log2(self.codebook_size) * self.frame_rate + num_quantizers = self.num_quantizers + if bandwidth is not None and bandwidth > 0.0: + num_quantizers = int(max(1, math.floor(bandwidth * 1000 / bw_per_q))) + return num_quantizers + + def encode(self, embeddings: torch.Tensor, bandwidth: Optional[float] = None) -> torch.Tensor: + """ + Encode a given input tensor with the specified frame rate at the given bandwidth. The RVQ encode method sets + the appropriate number of quantizers to use and returns indices for each quantizer. + """ + num_quantizers = self.get_num_quantizers_for_bandwidth(bandwidth) + residual = embeddings + all_indices = [] + for layer in self.layers[:num_quantizers]: + indices = layer.encode(residual) + quantized = layer.decode(indices) + residual = residual - quantized + all_indices.append(indices) + out_indices = torch.stack(all_indices) + return out_indices + + def decode(self, codes: torch.Tensor) -> torch.Tensor: + """Decode the given codes to the quantized representation.""" + quantized_out = torch.tensor(0.0, device=codes.device) + for i, indices in enumerate(codes): + layer = self.layers[i] + quantized = layer.decode(indices) + quantized_out = quantized_out + quantized + return quantized_out + + +class EncodecPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = EncodecConfig + base_model_prefix = "encodec" + main_input_name = "input_values" + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, nn.Conv1d): + nn.init.kaiming_normal_(module.weight) + if module.bias is not None: + k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) + nn.init.uniform_(module.bias, a=-k, b=k) + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LSTM): + for name, param in module.named_parameters(): + if "weight" in name: + nn.init.xavier_uniform_(param) + elif "bias" in name: + nn.init.constant_(param, 0.0) + + +ENCODEC_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`EncodecConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +ENCODEC_INPUTS_DOCSTRING = r""" + Args: + input_values (`torch.FloatTensor` of shape `(batch_size, channels, sequence_length)`, *optional*): + Raw audio input converted to Float and padded to the approriate length in order to be encoded using chunks + of length self.chunk_length and a stride of `config.chunk_stride`. + padding_mask (`torch.BoolTensor` of shape `(batch_size, channels, sequence_length)`, *optional*): + Mask to avoid computing scaling factors on padding token indices (can we avoid computing conv on these+). + Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + + + `padding_mask` should always be passed, unless the input was truncated or not padded. This is because in + order to process tensors effectively, the input audio should be padded so that `input_length % stride = + step` with `step = chunk_length-stride`. This ensures that all chunks are of the same shape + + + + bandwidth (`float`, *optional*): + The target bandwidth. Must be one of `config.target_bandwidths`. If `None`, uses the smallest possible + bandwidth. bandwidth is represented as a thousandth of what it is, e.g. 6kbps bandwidth is represented as + `bandwidth == 6.0` + audio_codes (`torch.LongTensor` of shape `(batch_size, nb_chunks, chunk_length)`, *optional*): + Discret code embeddings computed using `model.encode`. + audio_scales (`torch.Tensor` of shape `(batch_size, nb_chunks)`, *optional*): + Scaling factor for each `audio_codes` input. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The EnCodec neural audio codec model.", + ENCODEC_START_DOCSTRING, +) +class EncodecModel(EncodecPreTrainedModel): + def __init__(self, config: EncodecConfig): + super().__init__(config) + self.config = config + + self.encoder = EncodecEncoder(config) + self.decoder = EncodecDecoder(config) + + self.quantizer = EncodecResidualVectorQuantizer(config) + + self.bits_per_codebook = int(math.log2(self.config.codebook_size)) + if 2**self.bits_per_codebook != self.config.codebook_size: + raise ValueError("The codebook_size must be a power of 2.") + + # Initialize weights and apply final processing + self.post_init() + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + def _encode_frame( + self, input_values: torch.Tensor, bandwidth: float, padding_mask: int + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Encodes the given input using the underlying VQVAE. If `config.normalize` is set to `True` the input is first + normalized. The padding mask is required to compute the correct scale. + """ + length = input_values.shape[-1] + duration = length / self.config.sampling_rate + + if self.config.chunk_length_s is not None and duration > 1e-5 + self.config.chunk_length_s: + raise RuntimeError(f"Duration of frame ({duration}) is longer than chunk {self.config.chunk_length_s}") + + scale = None + if self.config.normalize: + # if the padding is non zero + input_values = input_values * padding_mask.unsqueeze(1) + mono = torch.sum(input_values, 1, keepdim=True) / input_values.shape[1] + scale = mono.pow(2).mean(dim=-1, keepdim=True).sqrt() + 1e-8 + input_values = input_values / scale + + embeddings = self.encoder(input_values) + codes = self.quantizer.encode(embeddings, bandwidth) + codes = codes.transpose(0, 1) + return codes, scale + + def encode( + self, + input_values: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, + bandwidth: Optional[float] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], EncodecEncoderOutput]: + """ + Encodes the input audio waveform into discrete codes. + + Args: + input_values (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`): + Float values of the input audio waveform. + padding_mask (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`): + Padding mask used to pad the `input_values`. + bandwidth (`float`, *optional*): + The target bandwidth. Must be one of `config.target_bandwidths`. If `None`, uses the smallest possible + bandwidth. bandwidth is represented as a thousandth of what it is, e.g. 6kbps bandwidth is represented + as bandwidth == 6.0 + + Returns: + A list of frames containing the discrete encoded codes for the input audio waveform, along with rescaling + factors for each chunk when `normalize` is True. Each frames is a tuple `(codebook, scale)`, with + `codebook` of shape `[batch_size, num_codebooks, frames]`. + """ + return_dict = return_dict if return_dict is not None else self.config.return_dict + + if bandwidth is None: + bandwidth = self.config.target_bandwidths[0] + if bandwidth not in self.config.target_bandwidths: + raise ValueError( + f"This model doesn't support the bandwidth {bandwidth}. Select one of {self.config.target_bandwidths}." + ) + + _, channels, input_length = input_values.shape + + if channels < 1 or channels > 2: + raise ValueError(f"Number of audio channels must be 1 or 2, but got {channels}") + + chunk_length = self.config.chunk_length + if chunk_length is None: + chunk_length = input_length + stride = input_length + else: + stride = self.config.chunk_stride + + if padding_mask is None: + padding_mask = torch.ones_like(input_values).bool() + + encoded_frames = [] + scales = [] + + step = chunk_length - stride + if (input_length % stride) - step != 0: + raise ValueError( + "The input length is not properly padded for batched chunked decoding. Make sure to pad the input correctly." + ) + + for offset in range(0, input_length - step, stride): + mask = padding_mask[..., offset : offset + chunk_length].bool() + frame = input_values[:, :, offset : offset + chunk_length] + encoded_frame, scale = self._encode_frame(frame, bandwidth, mask) + encoded_frames.append(encoded_frame) + scales.append(scale) + + encoded_frames = torch.stack(encoded_frames) + + if not return_dict: + return (encoded_frames, scales) + + return EncodecEncoderOutput(encoded_frames, scales) + + @staticmethod + def _linear_overlap_add(frames: List[torch.Tensor], stride: int): + # Generic overlap add, with linear fade-in/fade-out, supporting complex scenario + # e.g., more than 2 frames per position. + # The core idea is to use a weight function that is a triangle, + # with a maximum value at the middle of the chunk. + # We use this weighting when summing the frames, and divide by the sum of weights + # for each positions at the end. Thus: + # - if a frame is the only one to cover a position, the weighting is a no-op. + # - if 2 frames cover a position: + # ... ... + # / \/ \ + # / /\ \ + # S T , i.e. S offset of second frame starts, T end of first frame. + # Then the weight function for each one is: (t - S), (T - t), with `t` a given offset. + # After the final normalization, the weight of the second frame at position `t` is + # (t - S) / (t - S + (T - t)) = (t - S) / (T - S), which is exactly what we want. + # + # - if more than 2 frames overlap at a given point, we hope that by induction + # something sensible happens. + if len(frames) == 0: + raise ValueError("`frames` cannot be an empty list.") + + device = frames[0].device + dtype = frames[0].dtype + shape = frames[0].shape[:-1] + total_size = stride * (len(frames) - 1) + frames[-1].shape[-1] + + frame_length = frames[0].shape[-1] + time_vec = torch.linspace(0, 1, frame_length + 2, device=device, dtype=dtype)[1:-1] + weight = 0.5 - (time_vec - 0.5).abs() + + sum_weight = torch.zeros(total_size, device=device, dtype=dtype) + out = torch.zeros(*shape, total_size, device=device, dtype=dtype) + offset: int = 0 + + for frame in frames: + frame_length = frame.shape[-1] + out[..., offset : offset + frame_length] += weight[:frame_length] * frame + sum_weight[offset : offset + frame_length] += weight[:frame_length] + offset += stride + + if sum_weight.min() == 0: + raise ValueError(f"`sum_weight` minimum element must be bigger than zero: {sum_weight}`") + + return out / sum_weight + + def _decode_frame(self, codes: torch.Tensor, scale: Optional[torch.Tensor] = None) -> torch.Tensor: + codes = codes.transpose(0, 1) + embeddings = self.quantizer.decode(codes) + outputs = self.decoder(embeddings) + if scale is not None: + outputs = outputs * scale.view(-1, 1, 1) + return outputs + + def decode( + self, + audio_codes: torch.Tensor, + audio_scales: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], EncodecDecoderOutput]: + """ + Decodes the given frames into an output audio waveform. + + Note that the output might be a bit bigger than the input. In that case, any extra steps at the end can be + trimmed. + + Args: + audio_codes (`torch.LongTensor` of shape `(batch_size, nb_chunks, chunk_length)`, *optional*): + Discret code embeddings computed using `model.encode`. + audio_scales (`torch.Tensor` of shape `(batch_size, nb_chunks)`, *optional*): + Scaling factor for each `audio_codes` input. + padding_mask (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`): + Padding mask used to pad the `input_values`. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + + """ + return_dict = return_dict if return_dict is not None else self.config.return_dict + + chunk_length = self.config.chunk_length + if chunk_length is None: + if len(audio_codes) != 1: + raise ValueError(f"Expected one frame, got {len(audio_codes)}") + audio_values = self._decode_frame(audio_codes[0], audio_scales[0]) + else: + decoded_frames = [] + + for frame, scale in zip(audio_codes, audio_scales): + frames = self._decode_frame(frame, scale) + decoded_frames.append(frames) + + audio_values = self._linear_overlap_add(decoded_frames, self.config.chunk_stride or 1) + + # truncate based on padding mask + if padding_mask is not None and padding_mask.shape[-1] < audio_values.shape[-1]: + audio_values = audio_values[..., : padding_mask.shape[-1]] + + if not return_dict: + return (audio_values,) + return EncodecDecoderOutput(audio_values) + + @add_start_docstrings_to_model_forward(ENCODEC_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=EncodecOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_values: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, + bandwidth: Optional[float] = None, + audio_codes: Optional[torch.Tensor] = None, + audio_scales: Optional[torch.Tensor] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], EncodecOutput]: + r""" + Returns: + + Examples: + + ```python + >>> from datasets import load_dataset + >>> from transformers import AutoProcessor, EncodecModel + + >>> dataset = load_dataset("hf-internal-testing/ashraq-esc50-1-dog-example") + >>> audio_sample = dataset["train"]["audio"][0]["array"] + + >>> model_id = "facebook/encodec_24khz" + >>> model = EncodecModel.from_pretrained(model_id) + >>> processor = AutoProcessor.from_pretrained(model_id) + + >>> inputs = processor(raw_audio=audio_sample, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> audio_codes = outputs.audio_codes + >>> audio_values = outputs.audio_values + ```""" + return_dict = return_dict if return_dict is not None else self.config.return_dict + + if padding_mask is None: + padding_mask = torch.ones_like(input_values).bool() + + if audio_codes is not None and audio_scales is None: + raise ValueError("You specified `audio_codes` but did not specify the `audio_scales`") + + if audio_scales is not None and audio_codes is None: + raise ValueError("You specified `audio_scales` but did not specify the `audio_codes`") + + if audio_scales is None and audio_codes is None: + audio_codes, audio_scales = self.encode(input_values, padding_mask, bandwidth, False) + + audio_values = self.decode(audio_codes, audio_scales, padding_mask, return_dict=return_dict)[0] + if not return_dict: + return (audio_codes, audio_values) + + return EncodecOutput(audio_codes=audio_codes, audio_values=audio_values) + + +__all__ = ["EncodecModel", "EncodecPreTrainedModel"] diff --git a/docs/transformers/build/lib/transformers/models/encoder_decoder/__init__.py b/docs/transformers/build/lib/transformers/models/encoder_decoder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c786feb9213fdd31640c0fdeaead5164026ad37a --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/encoder_decoder/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_encoder_decoder import * + from .modeling_encoder_decoder import * + from .modeling_flax_encoder_decoder import * + from .modeling_tf_encoder_decoder import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/build/lib/transformers/models/encoder_decoder/configuration_encoder_decoder.py b/docs/transformers/build/lib/transformers/models/encoder_decoder/configuration_encoder_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..a5eff83e55824786d70c821ecfa210e07c27da2e --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/encoder_decoder/configuration_encoder_decoder.py @@ -0,0 +1,111 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ..auto import AutoConfig + + +logger = logging.get_logger(__name__) + + +class EncoderDecoderConfig(PretrainedConfig): + r""" + [`EncoderDecoderConfig`] is the configuration class to store the configuration of a [`EncoderDecoderModel`]. It is + used to instantiate an Encoder Decoder model according to the specified arguments, defining the encoder and decoder + configs. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + kwargs (*optional*): + Dictionary of keyword arguments. Notably: + + - **encoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that defines + the encoder config. + - **decoder** ([`PretrainedConfig`], *optional*) -- An instance of a configuration object that defines + the decoder config. + + Examples: + + ```python + >>> from transformers import BertConfig, EncoderDecoderConfig, EncoderDecoderModel + + >>> # Initializing a BERT google-bert/bert-base-uncased style configuration + >>> config_encoder = BertConfig() + >>> config_decoder = BertConfig() + + >>> config = EncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder) + + >>> # Initializing a Bert2Bert model (with random weights) from the google-bert/bert-base-uncased style configurations + >>> model = EncoderDecoderModel(config=config) + + >>> # Accessing the model configuration + >>> config_encoder = model.config.encoder + >>> config_decoder = model.config.decoder + >>> # set decoder config to causal lm + >>> config_decoder.is_decoder = True + >>> config_decoder.add_cross_attention = True + + >>> # Saving the model, including its configuration + >>> model.save_pretrained("my-model") + + >>> # loading model and config from pretrained folder + >>> encoder_decoder_config = EncoderDecoderConfig.from_pretrained("my-model") + >>> model = EncoderDecoderModel.from_pretrained("my-model", config=encoder_decoder_config) + ```""" + + model_type = "encoder-decoder" + sub_configs = {"encoder": AutoConfig, "decoder": AutoConfig} + has_no_defaults_at_init = True + + def __init__(self, **kwargs): + super().__init__(**kwargs) + if "encoder" not in kwargs or "decoder" not in kwargs: + raise ValueError( + f"A configuraton of type {self.model_type} cannot be instantiated because " + f"both `encoder` and `decoder` sub-configurations were not passed, only {kwargs}" + ) + encoder_config = kwargs.pop("encoder") + encoder_model_type = encoder_config.pop("model_type") + decoder_config = kwargs.pop("decoder") + decoder_model_type = decoder_config.pop("model_type") + + self.encoder = AutoConfig.for_model(encoder_model_type, **encoder_config) + self.decoder = AutoConfig.for_model(decoder_model_type, **decoder_config) + self.is_encoder_decoder = True + + @classmethod + def from_encoder_decoder_configs( + cls, encoder_config: PretrainedConfig, decoder_config: PretrainedConfig, **kwargs + ) -> PretrainedConfig: + r""" + Instantiate a [`EncoderDecoderConfig`] (or a derived class) from a pre-trained encoder model configuration and + decoder model configuration. + + Returns: + [`EncoderDecoderConfig`]: An instance of a configuration object + """ + logger.info("Set `config.is_decoder=True` and `config.add_cross_attention=True` for decoder_config") + decoder_config.is_decoder = True + decoder_config.add_cross_attention = True + + return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict(), **kwargs) + + +__all__ = ["EncoderDecoderConfig"] diff --git a/docs/transformers/build/lib/transformers/models/encoder_decoder/modeling_encoder_decoder.py b/docs/transformers/build/lib/transformers/models/encoder_decoder/modeling_encoder_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..415fd058e45d8fe6e44ca30ddfe178ffc717cf19 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/encoder_decoder/modeling_encoder_decoder.py @@ -0,0 +1,689 @@ +# coding=utf-8 +# Copyright 2018 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Classes to support Encoder-Decoder architectures""" + +import gc +import inspect +import os +import tempfile +import warnings +from typing import Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss + +from ...configuration_utils import PretrainedConfig +from ...generation import GenerationMixin +from ...modeling_outputs import BaseModelOutput, Seq2SeqLMOutput +from ...modeling_utils import PreTrainedModel +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from ..auto.configuration_auto import AutoConfig +from ..auto.modeling_auto import AutoModel, AutoModelForCausalLM +from .configuration_encoder_decoder import EncoderDecoderConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "EncoderDecoderConfig" + +DEPRECATION_WARNING = ( + "Version v4.12.0 introduces a better way to train encoder-decoder models by computing the loss inside the" + " encoder-decoder framework rather than in the decoder itself. You may observe training discrepancies if" + " fine-tuning a model trained with versions anterior to 4.12.0. The decoder_input_ids are now created based on the" + " labels, no need to pass them yourself anymore." +) + +ENCODER_DECODER_START_DOCSTRING = r""" + This class can be used to initialize a sequence-to-sequence model with any pretrained autoencoding model as the + encoder and any pretrained autoregressive model as the decoder. The encoder is loaded via + [`~AutoModel.from_pretrained`] function and the decoder is loaded via [`~AutoModelForCausalLM.from_pretrained`] + function. Cross-attention layers are automatically added to the decoder and should be fine-tuned on a downstream + generative task, like summarization. + + The effectiveness of initializing sequence-to-sequence models with pretrained checkpoints for sequence generation + tasks was shown in [Leveraging Pre-trained Checkpoints for Sequence Generation + Tasks](https://arxiv.org/abs/1907.12461) by Sascha Rothe, Shashi Narayan, Aliaksei Severyn. Michael Matena, Yanqi + Zhou, Wei Li, Peter J. Liu. + + After such an Encoder Decoder model has been trained/fine-tuned, it can be saved/loaded just like any other models + (see the examples for more information). + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`EncoderDecoderConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +ENCODER_DECODER_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + For training, `decoder_input_ids` are automatically created by the model by shifting the `labels` to the + right, replacing -100 by the `pad_token_id` and prepending them with the `decoder_start_token_id`. + decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + encoder_outputs (`tuple(torch.FloatTensor)`, *optional*): + This tuple must consist of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) + `last_hidden_state` (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`) is a tensor + of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the + decoder. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded + representation. This is useful if you want more control over how to convert `decoder_input_ids` indices + into associated vectors than the model's internal embedding lookup matrix. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss for the decoder. Indices should be in `[-100, 0, + ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + If set to `True`, the model will return a [`~utils.Seq2SeqLMOutput`] instead of a plain tuple. + kwargs (*optional*): Remaining dictionary of keyword arguments. Keyword arguments come in two flavors: + + - Without a prefix which will be input as `**encoder_kwargs` for the encoder forward function. + - With a *decoder_* prefix which will be input as `**decoder_kwargs` for the decoder forward function. +""" + + +def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): + """ + Shift input ids one token to the right. + """ + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() + if decoder_start_token_id is None: + raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.") + shifted_input_ids[:, 0] = decoder_start_token_id + + if pad_token_id is None: + raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.") + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +@add_start_docstrings(ENCODER_DECODER_START_DOCSTRING) +class EncoderDecoderModel(PreTrainedModel, GenerationMixin): + r""" + [`EncoderDecoderModel`] is a generic model class that will be instantiated as a transformer architecture with one + of the base model classes of the library as encoder and another one as decoder when created with the + :meth*~transformers.AutoModel.from_pretrained* class method for the encoder and + :meth*~transformers.AutoModelForCausalLM.from_pretrained* class method for the decoder. + """ + + config_class = EncoderDecoderConfig + base_model_prefix = "encoder_decoder" + main_input_name = "input_ids" + supports_gradient_checkpointing = True + _supports_param_buffer_assignment = False + _supports_flash_attn_2 = True + _supports_sdpa = True + + def __init__( + self, + config: Optional[PretrainedConfig] = None, + encoder: Optional[PreTrainedModel] = None, + decoder: Optional[PreTrainedModel] = None, + ): + if config is None and (encoder is None or decoder is None): + raise ValueError("Either a configuration or an encoder and a decoder has to be provided.") + if config is None: + config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config) + else: + if not isinstance(config, self.config_class): + raise ValueError(f"Config: {config} has to be of type {self.config_class}") + + if config.decoder.cross_attention_hidden_size is not None: + if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size: + raise ValueError( + "If `cross_attention_hidden_size` is specified in the decoder's configuration, it has to be equal" + f" to the encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for" + f" `config.decoder.cross_attention_hidden_size` and {config.encoder.hidden_size} for" + " `config.encoder.hidden_size`." + ) + + # initialize with config + super().__init__(config) + + if encoder is None: + from ..auto.modeling_auto import AutoModel + + encoder = AutoModel.from_config(config.encoder) + + if decoder is None: + from ..auto.modeling_auto import AutoModelForCausalLM + + decoder = AutoModelForCausalLM.from_config(config.decoder) + + self.encoder = encoder + self.decoder = decoder + + if self.encoder.config.to_dict() != self.config.encoder.to_dict(): + logger.warning( + f"Config of the encoder: {self.encoder.__class__} is overwritten by shared encoder config:" + f" {self.config.encoder}" + ) + if self.decoder.config.to_dict() != self.config.decoder.to_dict(): + logger.warning( + f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config:" + f" {self.config.decoder}" + ) + + # make sure that the individual model's config refers to the shared config + # so that the updates to the config will be synced + # update `_attn_implementation` because the attn is set in a deepcopied config within PreTrainedModel + self.config.encoder._attn_implementation = self.encoder.config._attn_implementation + self.config.decoder._attn_implementation = self.decoder.config._attn_implementation + self.encoder.config = self.config.encoder + self.decoder.config = self.config.decoder + + # encoder outputs might need to be projected to different dimension for decoder + if ( + self.encoder.config.hidden_size != self.decoder.config.hidden_size + and self.decoder.config.cross_attention_hidden_size is None + ): + self.enc_to_dec_proj = nn.Linear(self.encoder.config.hidden_size, self.decoder.config.hidden_size) + + if self.encoder.get_output_embeddings() is not None: + raise ValueError( + f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head" + ) + + decoder_signature = set(inspect.signature(self.decoder.forward).parameters.keys()) + if "encoder_hidden_states" not in decoder_signature: + raise ValueError( + "The selected decoder is not prepared for the encoder hidden states to be passed. Please see the " + "following discussion on GitHub: https://github.com/huggingface/transformers/issues/23350" + ) + + # tie encoder, decoder weights if config set accordingly + self.tie_weights() + + def tie_weights(self): + self.encoder.tie_weights() + self.decoder.tie_weights() + # tie encoder & decoder if needed + if self.config.tie_encoder_decoder: + # tie encoder and decoder base model + decoder_base_model_prefix = self.decoder.base_model_prefix + tied_weights = self._tie_encoder_decoder_weights( + self.encoder, + self.decoder._modules[decoder_base_model_prefix], + self.decoder.base_model_prefix, + "encoder", + ) + # Setting a dynamic variable instead of `_tied_weights_keys` because it's a class + # attributed not an instance member, therefore modifying it will modify the entire class + # Leading to issues on subsequent calls by different tests or subsequent calls. + self._dynamic_tied_weights_keys = tied_weights + + def _init_weights(self, module): + if module in self.encoder.modules(): + self.encoder._init_weights(module) + elif module in self.decoder.modules(): + self.decoder._init_weights(module) + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + def get_input_embeddings(self): + return self.encoder.get_input_embeddings() + + def get_output_embeddings(self): + return self.decoder.get_output_embeddings() + + def set_output_embeddings(self, new_embeddings): + return self.decoder.set_output_embeddings(new_embeddings) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + r""" + Example: + + ```python + >>> from transformers import EncoderDecoderModel + + >>> model = EncoderDecoderModel.from_pretrained("patrickvonplaten/bert2bert-cnn_dailymail-fp16") + ```""" + + from_tf = kwargs.pop("from_tf", False) + if from_tf: + from transformers import TFEncoderDecoderModel + + # a workaround to load from tensorflow checkpoint + # Using `_tf_model` won't work, because the weight names in the encoder/decoder of `_tf_model` get + # extended before saving those components. For example, The name of `_tf_model.encoder.vit` is + # `[top model name]/encoder/vit`, but the name of `tf_model.encoder.vit` is `[top model name]/vit`. The + # [top model name] is handled (stripped) by the conversion method, and the former case gets extra `encoder`, + # which should not occur when we want to save the components alone. + # There was a (very) ugly potential fix, which wasn't integrated to `transformers`: see + # https://github.com/huggingface/transformers/pull/13222/commits/dbb3c9de76eee235791d2064094654637c99f36d#r697304245 + # (the change in `src/transformers/modeling_tf_utils.py`) + _tf_model = TFEncoderDecoderModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) + config = _tf_model.config + + # Using `tf_model` instead + encoder = _tf_model.encoder.__class__(_tf_model.config.encoder) + decoder = _tf_model.decoder.__class__(_tf_model.config.decoder) + # Make sure models are built + encoder(encoder.dummy_inputs) + decoder(decoder.dummy_inputs) + + # Get the variable correspondence between `_tf_model` and `encoder` and `decoder` + encoder_variables = {} + for v in encoder.trainable_variables + encoder.non_trainable_variables: + encoder_variables["/".join(v.name.split("/")[1:])] = v + decoder_variables = {} + for v in decoder.trainable_variables + decoder.non_trainable_variables: + decoder_variables["/".join(v.name.split("/")[1:])] = v + + _encoder_variables = {} + for v in _tf_model.encoder.trainable_variables + _tf_model.encoder.non_trainable_variables: + _encoder_variables["/".join(v.name.split("/")[2:])] = v + _decoder_variables = {} + for v in _tf_model.decoder.trainable_variables + _tf_model.decoder.non_trainable_variables: + _decoder_variables["/".join(v.name.split("/")[2:])] = v + + # assign weight values to `encoder` and `decoder` from `_tf_model` + for name, v in encoder_variables.items(): + v.assign(_encoder_variables[name]) + for name, v in decoder_variables.items(): + v.assign(_decoder_variables[name]) + + tf_model = TFEncoderDecoderModel(encoder=encoder, decoder=decoder) + + # Deal with `enc_to_dec_proj` + if hasattr(_tf_model, "enc_to_dec_proj"): + tf_model(tf_model.dummy_inputs) + tf_model.enc_to_dec_proj.kernel.assign(_tf_model.enc_to_dec_proj.kernel) + tf_model.enc_to_dec_proj.bias.assign(_tf_model.enc_to_dec_proj.bias) + + with tempfile.TemporaryDirectory() as tmpdirname: + encoder_dir = os.path.join(tmpdirname, "encoder") + decoder_dir = os.path.join(tmpdirname, "decoder") + tf_model.encoder.save_pretrained(encoder_dir) + tf_model.decoder.save_pretrained(decoder_dir) + + if hasattr(tf_model, "enc_to_dec_proj"): + enc_to_dec_proj_weight = torch.transpose( + torch.from_numpy(tf_model.enc_to_dec_proj.kernel.numpy()), 1, 0 + ) + enc_to_dec_proj_bias = torch.from_numpy(tf_model.enc_to_dec_proj.bias.numpy()) + + del _tf_model + del tf_model + gc.collect() + + model = EncoderDecoderModel.from_encoder_decoder_pretrained( + encoder_dir, decoder_dir, encoder_from_tf=True, decoder_from_tf=True + ) + # This is only for copying some specific attributes of this particular model. + model.config = config + + if hasattr(model, "enc_to_dec_proj"): + model.enc_to_dec_proj.weight.data = enc_to_dec_proj_weight.contiguous() + model.enc_to_dec_proj.bias.data = enc_to_dec_proj_bias.contiguous() + + return model + + return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) + + @classmethod + def from_encoder_decoder_pretrained( + cls, + encoder_pretrained_model_name_or_path: Optional[str] = None, + decoder_pretrained_model_name_or_path: Optional[str] = None, + *model_args, + **kwargs, + ) -> PreTrainedModel: + r""" + Instantiate an encoder and a decoder from one or two base classes of the library from pretrained model + checkpoints. + + + The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train + the model, you need to first set it back in training mode with `model.train()`. + + Params: + encoder_pretrained_model_name_or_path (`str`, *optional*): + Information necessary to initiate the encoder. Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + - A path to a *directory* containing model weights saved using + [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. + - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In + this case, `from_tf` should be set to `True` and a configuration object should be provided as + `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a + PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. + + decoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `None`): + Information necessary to initiate the decoder. Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + - A path to a *directory* containing model weights saved using + [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. + - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In + this case, `from_tf` should be set to `True` and a configuration object should be provided as + `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a + PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. + + model_args (remaining positional arguments, *optional*): + All remaining positional arguments will be passed to the underlying model's `__init__` method. + + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., + `output_attentions=True`). + + - To update the encoder configuration, use the prefix *encoder_* for each configuration parameter. + - To update the decoder configuration, use the prefix *decoder_* for each configuration parameter. + - To update the parent model configuration, do not use a prefix for each configuration parameter. + + Behaves differently depending on whether a `config` is provided or automatically loaded. + + Example: + + ```python + >>> from transformers import EncoderDecoderModel + + >>> # initialize a bert2bert from two pretrained BERT models. Note that the cross-attention layers will be randomly initialized + >>> model = EncoderDecoderModel.from_encoder_decoder_pretrained("google-bert/bert-base-uncased", "google-bert/bert-base-uncased") + >>> # saving model after fine-tuning + >>> model.save_pretrained("./bert2bert") + >>> # load fine-tuned model + >>> model = EncoderDecoderModel.from_pretrained("./bert2bert") + ```""" + + kwargs_encoder = { + argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_") + } + + kwargs_decoder = { + argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") + } + + # remove encoder, decoder kwargs from kwargs + for key in kwargs_encoder.keys(): + del kwargs["encoder_" + key] + for key in kwargs_decoder.keys(): + del kwargs["decoder_" + key] + + # Load and initialize the encoder and decoder + # The distinction between encoder and decoder at the model level is made + # by the value of the flag `is_decoder` that we need to set correctly. + encoder = kwargs_encoder.pop("model", None) + if encoder is None: + if encoder_pretrained_model_name_or_path is None: + raise ValueError( + "If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has " + "to be defined." + ) + + if "config" not in kwargs_encoder: + encoder_config, kwargs_encoder = AutoConfig.from_pretrained( + encoder_pretrained_model_name_or_path, **kwargs_encoder, return_unused_kwargs=True + ) + + if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True: + logger.info( + f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model " + "from a decoder model. Cross-attention and casual mask are disabled." + ) + encoder_config.is_decoder = False + encoder_config.add_cross_attention = False + + kwargs_encoder["config"] = encoder_config + + encoder = AutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder) + + decoder = kwargs_decoder.pop("model", None) + if decoder is None: + if decoder_pretrained_model_name_or_path is None: + raise ValueError( + "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has " + "to be defined." + ) + + if "config" not in kwargs_decoder: + decoder_config, kwargs_decoder = AutoConfig.from_pretrained( + decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True + ) + + if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False: + logger.info( + f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention" + f" layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if" + f" {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers." + ) + decoder_config.is_decoder = True + decoder_config.add_cross_attention = True + + kwargs_decoder["config"] = decoder_config + + if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False: + logger.warning( + f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. " + f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, " + "make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` " + "passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a " + "`decoder_config` to `.from_encoder_decoder_pretrained(...)`" + ) + + decoder = AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder) + + # instantiate config with corresponding kwargs + config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs) + return cls(encoder=encoder, decoder=decoder, config=config) + + @add_start_docstrings_to_model_forward(ENCODER_DECODER_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None, + past_key_values: Tuple[Tuple[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple, Seq2SeqLMOutput]: + r""" + Returns: + + Examples: + + ```python + >>> from transformers import EncoderDecoderModel, BertTokenizer + >>> import torch + + >>> tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased") + >>> model = EncoderDecoderModel.from_encoder_decoder_pretrained( + ... "google-bert/bert-base-uncased", "google-bert/bert-base-uncased" + ... ) # initialize Bert2Bert from pre-trained checkpoints + + >>> # training + >>> model.config.decoder_start_token_id = tokenizer.cls_token_id + >>> model.config.pad_token_id = tokenizer.pad_token_id + >>> model.config.vocab_size = model.config.decoder.vocab_size + + >>> input_ids = tokenizer("This is a really long text", return_tensors="pt").input_ids + >>> labels = tokenizer("This is the corresponding summary", return_tensors="pt").input_ids + >>> outputs = model(input_ids=input_ids, labels=labels) + >>> loss, logits = outputs.loss, outputs.logits + + >>> # save and load from pretrained + >>> model.save_pretrained("bert2bert") + >>> model = EncoderDecoderModel.from_pretrained("bert2bert") + + >>> # generation + >>> generated = model.generate(input_ids) + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")} + + kwargs_decoder = { + argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") + } + if "num_items_in_batch" in kwargs_encoder: + kwargs_decoder["num_items_in_batch"] = kwargs_encoder.pop("num_items_in_batch", None) + + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **kwargs_encoder, + ) + elif isinstance(encoder_outputs, tuple): + encoder_outputs = BaseModelOutput(*encoder_outputs) + + encoder_hidden_states = encoder_outputs[0] + + # optionally project encoder_hidden_states + if ( + self.encoder.config.hidden_size != self.decoder.config.hidden_size + and self.decoder.config.cross_attention_hidden_size is None + ): + encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states) + + if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None): + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + if decoder_attention_mask is None: + decoder_attention_mask = decoder_input_ids.new_tensor(decoder_input_ids != self.config.pad_token_id) + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=attention_mask, + inputs_embeds=decoder_inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + use_cache=use_cache, + past_key_values=past_key_values, + return_dict=return_dict, + **kwargs_decoder, + ) + + # Compute loss independent from decoder (as some shift the logits inside them) + loss = None + if labels is not None: + warnings.warn(DEPRECATION_WARNING, FutureWarning) + logits = decoder_outputs.logits if return_dict else decoder_outputs[0] + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.view(-1)) + + if not return_dict: + if loss is not None: + return (loss,) + decoder_outputs + encoder_outputs + else: + return decoder_outputs + encoder_outputs + + return Seq2SeqLMOutput( + loss=loss, + logits=decoder_outputs.logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) + + def resize_token_embeddings(self, *args, **kwargs): + raise NotImplementedError( + "Resizing the embedding layers via the EncoderDecoderModel directly is not supported. Please use the" + " respective methods of the wrapped objects (model.encoder.resize_token_embeddings(...) or" + " model.decoder.resize_token_embeddings(...))" + ) + + def _reorder_cache(self, past_key_values, beam_idx): + # apply decoder cache reordering here + return self.decoder._reorder_cache(past_key_values, beam_idx) + + +__all__ = ["EncoderDecoderModel"] diff --git a/docs/transformers/build/lib/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py b/docs/transformers/build/lib/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..bdc589484cda1083f42ce96e67b12771f84b6af9 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py @@ -0,0 +1,901 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Classes to support Flax Encoder-Decoder architectures""" + +import os +from typing import Optional, Tuple, Union + +import flax.linen as nn +import jax +import jax.numpy as jnp +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze +from flax.traverse_util import flatten_dict, unflatten_dict +from jax import lax +from jax.random import PRNGKey + +from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutputWithCrossAttentions, FlaxSeq2SeqLMOutput +from ...modeling_flax_utils import FlaxPreTrainedModel +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from ..auto.configuration_auto import AutoConfig +from ..auto.modeling_flax_auto import FlaxAutoModel, FlaxAutoModelForCausalLM +from .configuration_encoder_decoder import EncoderDecoderConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "EncoderDecoderConfig" + +ENCODER_DECODER_START_DOCSTRING = r""" + This class can be used to initialize a sequence-to-sequence model with any pretrained autoencoding model as the + encoder and any pretrained autoregressive model as the decoder. The encoder is loaded via + [`~AutoModel.from_pretrained`] function and the decoder is loaded via [`~AutoModelForCausalLM.from_pretrained`] + function. Cross-attention layers are automatically added to the decoder and should be fine-tuned on a downstream + generative task, like summarization. + + The effectiveness of initializing sequence-to-sequence models with pretrained checkpoints for sequence generation + tasks was shown in [Leveraging Pre-trained Checkpoints for Sequence Generation + Tasks](https://arxiv.org/abs/1907.12461) by Sascha Rothe, Shashi Narayan, Aliaksei Severyn. Michael Matena, Yanqi + Zhou, Wei Li, Peter J. Liu. + + After such an Encoder Decoder model has been trained/fine-tuned, it can be saved/loaded just like any other models + (see the examples for more information). + + This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a Flax Linen + [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a + regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior. + + Parameters: + config ([`EncoderDecoderConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. + dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): + The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and + `jax.numpy.bfloat16` (on TPUs). + + This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If + specified all the computation will be performed with the given `dtype`. + + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** + + If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and + [`~FlaxPreTrainedModel.to_bf16`]. +""" + +ENCODER_DECODER_INPUTS_DOCSTRING = r""" + Args: + input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + For sequence to sequence training, `decoder_input_ids` should be provided. `decoder_input_ids` should be + created outside of the model by shifting the `labels` to the right, replacing -100 by the `pad_token_id` + and prepending them with the `decoder_start_token_id`. + decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.encoder.max_position_embeddings - 1]`. + decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the + range `[0, config.decoder.max_position_embeddings - 1]`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + If set to `True`, the model will return a [`~utils.FlaxSeq2SeqLMOutput`] instead of a plain tuple. +""" + +ENCODER_DECODER_ENCODE_INPUTS_DOCSTRING = r""" + Args: + input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.encoder.max_position_embeddings - 1]`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + If set to `True`, the model will return a [`~utils.FlaxBaseModelOutput`] instead of a plain tuple. +""" + +ENCODER_DECODER_DECODE_INPUTS_DOCSTRING = r""" + Args: + decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + For sequence to sequence training, `decoder_input_ids` should be provided. `decoder_input_ids` should be + created outside of the model by shifting the `labels` to the right, replacing -100 by the `pad_token_id` + and prepending them with the `decoder_start_token_id`. + encoder_outputs (`tuple(tuple(jnp.ndarray)`): + Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of + hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + encoder_attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the + range `[0, config.decoder.max_position_embeddings - 1]`. + past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`): + Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast + auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + If set to `True`, the model will return a [`~utils.FlaxCausalLMOutputWithCrossAttentions`] instead of a + plain tuple. +""" + + +class FlaxEncoderDecoderModule(nn.Module): + config: EncoderDecoderConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + encoder_config = self.config.encoder + decoder_config = self.config.decoder + + # Copied from `modeling_hybrid_clip.py` with modifications. + from ...models.auto.modeling_flax_auto import FLAX_MODEL_FOR_CAUSAL_LM_MAPPING, FLAX_MODEL_MAPPING + + encoder_module = FLAX_MODEL_MAPPING[encoder_config.__class__].module_class + decoder_module = FLAX_MODEL_FOR_CAUSAL_LM_MAPPING[decoder_config.__class__].module_class + + self.encoder = encoder_module(encoder_config, dtype=self.dtype) + self.decoder = decoder_module(decoder_config, dtype=self.dtype) + + # encoder outputs might need to be projected to different dimension for decoder + if ( + self.encoder.config.hidden_size != self.decoder.config.hidden_size + and self.decoder.config.cross_attention_hidden_size is None + ): + self.enc_to_dec_proj = nn.Dense( + self.decoder.config.hidden_size, + kernel_init=jax.nn.initializers.normal(self.decoder.config.initializer_range), + dtype=self.dtype, + ) + else: + self.enc_to_dec_proj = None + + def _get_encoder_module(self): + return self.encoder + + def _get_projection_module(self): + return self.enc_to_dec_proj + + def _get_decoder_module(self): + return self.decoder + + def __call__( + self, + input_ids, + attention_mask, + decoder_input_ids, + decoder_attention_mask, + position_ids, + decoder_position_ids, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + ): + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + encoder_hidden_states = encoder_outputs[0] + + # optionally project encoder_hidden_states + if self.enc_to_dec_proj is not None: + encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states) + + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + position_ids=decoder_position_ids, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return FlaxSeq2SeqLMOutput( + logits=decoder_outputs.logits, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings(ENCODER_DECODER_START_DOCSTRING) +class FlaxEncoderDecoderModel(FlaxPreTrainedModel): + r""" + [`FlaxEncoderDecoderModel`] is a generic model class that will be instantiated as a transformer architecture with + the module (flax.nn.Module) of one of the base model classes of the library as encoder module and another one as + decoder module when created with the :meth*~transformers.FlaxAutoModel.from_pretrained* class method for the + encoder and :meth*~transformers.FlaxAutoModelForCausalLM.from_pretrained* class method for the decoder. + """ + + config_class = EncoderDecoderConfig + base_model_prefix = "encoder_decoder" + module_class = FlaxEncoderDecoderModule + + def __init__( + self, + config: EncoderDecoderConfig, + input_shape: Optional[Tuple] = None, + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, + **kwargs, + ): + if input_shape is None: + input_shape = ((1, 1), (1, 1)) + + if not _do_init: + raise ValueError( + "`FlaxEncoderDecoderModel` cannot be created without initializing, `_do_init` must be `True`." + ) + + if config.decoder.cross_attention_hidden_size is not None: + if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size: + raise ValueError( + "If `cross_attention_hidden_size` is specified in the decoder's configuration, it has to be equal" + f" to the encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for" + f" `config.decoder.cross_attention_hidden_size` and {config.encoder.hidden_size} for" + " `config.encoder.hidden_size`." + ) + + module = self.module_class(config=config, dtype=dtype, **kwargs) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: + encoder_input_shape, decoder_input_shape = input_shape + + # init input tensors + input_ids = jnp.zeros(encoder_input_shape, dtype="i4") + attention_mask = jnp.ones_like(input_ids) + decoder_input_ids = jnp.zeros(decoder_input_shape, dtype="i4") + decoder_attention_mask = jnp.ones_like(decoder_input_ids) + + batch_size, sequence_length = input_ids.shape + position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) + + decoder_batch_size, decoder_sequence_length = decoder_input_ids.shape + if not decoder_batch_size == batch_size: + raise ValueError( + f"The inputs of encoder and decoder should have the same batch size, but got {batch_size} for encoder" + f" and {decoder_batch_size} for decoder." + ) + decoder_position_ids = jnp.broadcast_to( + jnp.arange(decoder_sequence_length)[None, :], (decoder_batch_size, decoder_sequence_length) + ) + + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + random_params = self.module.init( + rngs, + input_ids, + attention_mask, + decoder_input_ids, + decoder_attention_mask, + position_ids, + decoder_position_ids, + )["params"] + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + + def init_cache(self, batch_size, max_length, encoder_outputs): + r""" + Args: + batch_size (`int`): + batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. + max_length (`int`): + maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized + cache. + encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`): + `encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: + `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) + is a sequence of hidden-states at the output of the last layer of the encoder. Used in the + cross-attention of the decoder. + """ + # init input variables to retrieve cache + decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4") + decoder_attention_mask = jnp.ones_like(decoder_input_ids) + decoder_position_ids = jnp.broadcast_to( + jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape + ) + + def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): + decoder_module = module._get_decoder_module() + return decoder_module( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + position_ids=decoder_position_ids, + **kwargs, + ) + + init_variables = self.module.init( + jax.random.PRNGKey(0), + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, + encoder_hidden_states=encoder_outputs[0], + init_cache=True, + method=_decoder_forward, # we only need to call the decoder to init the cache + ) + return unfreeze(init_variables["cache"]) + + @add_start_docstrings(ENCODER_DECODER_ENCODE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=_CONFIG_FOR_DOC) + def encode( + self, + input_ids: jnp.ndarray, + attention_mask: Optional[jnp.ndarray] = None, + position_ids: Optional[jnp.ndarray] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: dict = None, + dropout_rng: PRNGKey = None, + ): + r""" + Returns: + + Example: + + ```python + >>> from transformers import FlaxEncoderDecoderModel, BertTokenizer + + >>> # initialize a bert2gpt2 from pretrained BERT and GPT2 models. Note that the cross-attention layers will be randomly initialized + >>> model = FlaxEncoderDecoderModel.from_encoder_decoder_pretrained("google-bert/bert-base-cased", "openai-community/gpt2") + + >>> tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-cased") + + >>> text = "My friends are cool but they eat too many carbs." + >>> input_ids = tokenizer.encode(text, return_tensors="np") + >>> encoder_outputs = model.encode(input_ids) + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + if position_ids is None: + batch_size, sequence_length = input_ids.shape + position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + def _encoder_forward(module, input_ids, attention_mask, position_ids, **kwargs): + encode_module = module._get_encoder_module() + return encode_module(input_ids, attention_mask, position_ids, **kwargs) + + outputs = self.module.apply( + {"params": params or self.params}, + input_ids=jnp.array(input_ids, dtype="i4"), + attention_mask=jnp.array(attention_mask, dtype="i4"), + position_ids=jnp.array(position_ids, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + method=_encoder_forward, + ) + + if return_dict: + outputs = FlaxBaseModelOutput( + last_hidden_state=outputs.last_hidden_state, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + return outputs + + @add_start_docstrings(ENCODER_DECODER_DECODE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) + def decode( + self, + decoder_input_ids, + encoder_outputs, + encoder_attention_mask: Optional[jnp.ndarray] = None, + decoder_attention_mask: Optional[jnp.ndarray] = None, + decoder_position_ids: Optional[jnp.ndarray] = None, + past_key_values: dict = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: dict = None, + dropout_rng: PRNGKey = None, + ): + r""" + Returns: + + Example: + + ```python + >>> from transformers import FlaxEncoderDecoderModel, BertTokenizer + >>> import jax.numpy as jnp + + >>> # initialize a bert2gpt2 from pretrained BERT and GPT2 models. Note that the cross-attention layers will be randomly initialized + >>> model = FlaxEncoderDecoderModel.from_encoder_decoder_pretrained("google-bert/bert-base-cased", "openai-community/gpt2") + + >>> tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-cased") + + >>> text = "My friends are cool but they eat too many carbs." + >>> input_ids = tokenizer.encode(text, max_length=1024, return_tensors="np") + >>> encoder_outputs = model.encode(input_ids) + + >>> decoder_start_token_id = model.config.decoder.bos_token_id + >>> decoder_input_ids = jnp.ones((input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id + + >>> outputs = model.decode(decoder_input_ids, encoder_outputs) + >>> logits = outputs.logits + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + encoder_hidden_states = encoder_outputs[0] + if encoder_attention_mask is None: + batch_size, sequence_length = encoder_hidden_states.shape[:2] + encoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + batch_size, sequence_length = decoder_input_ids.shape + if decoder_attention_mask is None: + decoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + if decoder_position_ids is None: + if past_key_values is not None: + raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.") + + decoder_position_ids = jnp.broadcast_to( + jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) + ) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + inputs = {"params": params or self.params} + + # if past_key_values are passed then cache is already initialized a private flag init_cache has to be + # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that + # it can be changed by FlaxBartAttention module + if past_key_values: + inputs["cache"] = past_key_values + mutable = ["cache"] + else: + mutable = False + + def _decoder_forward( + module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, encoder_hidden_states, **kwargs + ): + projection_module = module._get_projection_module() + decoder_module = module._get_decoder_module() + + # optionally project encoder_hidden_states + if projection_module is not None: + encoder_hidden_states = projection_module(encoder_hidden_states) + + return decoder_module( + decoder_input_ids, + decoder_attention_mask, + decoder_position_ids, + encoder_hidden_states=encoder_hidden_states, + **kwargs, + ) + + outputs = self.module.apply( + inputs, + decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), + decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), + decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + mutable=mutable, + method=_decoder_forward, + ) + + # add updated cache to model output + if past_key_values is not None and return_dict: + outputs, past = outputs + outputs["past_key_values"] = unfreeze(past["cache"]) + return outputs + elif past_key_values is not None and not return_dict: + outputs, past = outputs + outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] + + return outputs + + @add_start_docstrings_to_model_forward(ENCODER_DECODER_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + def __call__( + self, + input_ids: jnp.ndarray, + attention_mask: Optional[jnp.ndarray] = None, + decoder_input_ids: Optional[jnp.ndarray] = None, + decoder_attention_mask: Optional[jnp.ndarray] = None, + position_ids: Optional[jnp.ndarray] = None, + decoder_position_ids: Optional[jnp.ndarray] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: dict = None, + dropout_rng: PRNGKey = None, + ): + r""" + Returns: + + Examples: + + ```python + >>> from transformers import FlaxEncoderDecoderModel, BertTokenizer, GPT2Tokenizer + + >>> # load a fine-tuned bert2gpt2 model + >>> model = FlaxEncoderDecoderModel.from_pretrained("patrickvonplaten/bert2gpt2-cnn_dailymail-fp16") + >>> # load input & output tokenizer + >>> tokenizer_input = BertTokenizer.from_pretrained("google-bert/bert-base-cased") + >>> tokenizer_output = GPT2Tokenizer.from_pretrained("openai-community/gpt2") + + >>> article = '''Sigma Alpha Epsilon is under fire for a video showing party-bound fraternity members + >>> singing a racist chant. SAE's national chapter suspended the students, + >>> but University of Oklahoma President David Boren took it a step further, + >>> saying the university's affiliation with the fraternity is permanently done.''' + + >>> input_ids = tokenizer_input(article, add_special_tokens=True, return_tensors="np").input_ids + + >>> # use GPT2's eos_token as the pad as well as eos token + >>> model.config.eos_token_id = model.config.decoder.eos_token_id + >>> model.config.pad_token_id = model.config.eos_token_id + + >>> sequences = model.generate(input_ids, num_beams=4, max_length=12).sequences + + >>> summary = tokenizer_output.batch_decode(sequences, skip_special_tokens=True)[0] + >>> assert summary == "SAS Alpha Epsilon suspended Sigma Alpha Epsilon members" + ``` + """ + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + # prepare encoder inputs + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + if position_ids is None: + batch_size, sequence_length = input_ids.shape + position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) + + # prepare decoder inputs + if decoder_input_ids is None: + raise ValueError( + "`decoder_input_ids` cannot be `None`. For sequence to sequence training, `decoder_position_ids` must" + " be specified as an input argument." + ) + if decoder_attention_mask is None: + decoder_attention_mask = jnp.ones_like(decoder_input_ids) + if decoder_position_ids is None: + batch_size, sequence_length = decoder_input_ids.shape + decoder_position_ids = jnp.broadcast_to( + jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) + ) + + # Handle any PRNG if needed + rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} + + return self.module.apply( + {"params": params or self.params}, + input_ids=jnp.array(input_ids, dtype="i4"), + attention_mask=jnp.array(attention_mask, dtype="i4"), + decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), + decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), + position_ids=jnp.array(position_ids, dtype="i4"), + decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + ) + + def prepare_inputs_for_generation( + self, + decoder_input_ids, + max_length, + attention_mask: Optional[jax.Array] = None, + decoder_attention_mask: Optional[jax.Array] = None, + encoder_outputs=None, + **kwargs, + ): + # initializing the cache + batch_size, seq_length = decoder_input_ids.shape + + past_key_values = self.init_cache(batch_size, max_length, encoder_outputs) + # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. + # But since the decoder uses a causal mask, those positions are masked anyways. + # Thus we can create a single static attention_mask here, which is more efficient for compilation + extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") + if decoder_attention_mask is not None: + decoder_position_ids = decoder_attention_mask.cumsum(axis=-1) - 1 + extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0)) + else: + decoder_position_ids = jnp.broadcast_to( + jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length) + ) + + return { + "past_key_values": past_key_values, + "encoder_outputs": encoder_outputs, + "encoder_attention_mask": attention_mask, + "decoder_attention_mask": extended_attention_mask, + "decoder_position_ids": decoder_position_ids, + } + + def update_inputs_for_generation(self, model_outputs, model_kwargs): + model_kwargs["past_key_values"] = model_outputs.past_key_values + model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1 + return model_kwargs + + @classmethod + def from_encoder_decoder_pretrained( + cls, + encoder_pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, + decoder_pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, + *model_args, + **kwargs, + ) -> FlaxPreTrainedModel: + r""" + Instantiate an encoder and a decoder from one or two base classes of the library from pretrained model + checkpoints. + + Params: + encoder_pretrained_model_name_or_path (`Union[str, os.PathLike]`, *optional*): + Information necessary to initiate the encoder. Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + - A path to a *directory* containing model weights saved using + [`~FlaxPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. + + decoder_pretrained_model_name_or_path (`Union[str, os.PathLike]`, *optional*, defaults to `None`): + Information necessary to initiate the decoder. Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + - A path to a *directory* containing model weights saved using + [`~FlaxPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. + + model_args (remaining positional arguments, *optional*): + All remaning positional arguments will be passed to the underlying model's `__init__` method. + + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., + `output_attentions=True`). + + - To update the encoder configuration, use the prefix *encoder_* for each configuration parameter. + - To update the decoder configuration, use the prefix *decoder_* for each configuration parameter. + - To update the parent model configuration, do not use a prefix for each configuration parameter. + + Behaves differently depending on whether a `config` is provided or automatically loaded. + + Example: + + ```python + >>> from transformers import FlaxEncoderDecoderModel + + >>> # initialize a bert2gpt2 from pretrained BERT and GPT2 models. Note that the cross-attention layers will be randomly initialized + >>> model = FlaxEncoderDecoderModel.from_encoder_decoder_pretrained("google-bert/bert-base-cased", "openai-community/gpt2") + >>> # saving model after fine-tuning + >>> model.save_pretrained("./bert2gpt2") + >>> # load fine-tuned model + >>> model = FlaxEncoderDecoderModel.from_pretrained("./bert2gpt2") + ```""" + + kwargs_encoder = { + argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_") + } + + kwargs_decoder = { + argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") + } + + # remove encoder, decoder kwargs from kwargs + for key in kwargs_encoder.keys(): + del kwargs["encoder_" + key] + for key in kwargs_decoder.keys(): + del kwargs["decoder_" + key] + + # Load and initialize the encoder and decoder + # The distinction between encoder and decoder at the model level is made + # by the value of the flag `is_decoder` that we need to set correctly. + encoder = kwargs_encoder.pop("model", None) + if encoder is None: + if encoder_pretrained_model_name_or_path is None: + raise ValueError( + "If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has " + "to be defined." + ) + + if "config" not in kwargs_encoder: + encoder_config, kwargs_encoder = AutoConfig.from_pretrained( + encoder_pretrained_model_name_or_path, **kwargs_encoder, return_unused_kwargs=True + ) + if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True: + logger.info( + f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model " + "from a decoder model. Cross-attention and casual mask are disabled." + ) + encoder_config.is_decoder = False + encoder_config.add_cross_attention = False + + kwargs_encoder["config"] = encoder_config + + encoder = FlaxAutoModel.from_pretrained( + encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder + ) + + decoder = kwargs_decoder.pop("model", None) + if decoder is None: + if decoder_pretrained_model_name_or_path is None: + raise ValueError( + "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has " + "to be defined." + ) + + if "config" not in kwargs_decoder: + decoder_config, kwargs_decoder = AutoConfig.from_pretrained( + decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True + ) + if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False: + logger.info( + f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention" + f" layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if" + f" {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers." + ) + decoder_config.is_decoder = True + decoder_config.add_cross_attention = True + + kwargs_decoder["config"] = decoder_config + + if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False: + logger.warning( + f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. " + f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, " + "make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` " + "passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a " + "`decoder_config` to `.from_encoder_decoder_pretrained(...)`" + ) + + decoder = FlaxAutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder) + + # instantiate config with corresponding kwargs + dtype = kwargs.pop("dtype", jnp.float32) + config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs) + + # init model + model = cls(config, dtype=dtype) + model.params["encoder"] = encoder.params + model.params["decoder"] = decoder.params + + return model + + +__all__ = ["FlaxEncoderDecoderModel"] diff --git a/docs/transformers/build/lib/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py b/docs/transformers/build/lib/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..a5abafc361b6b4e71e365d7b7f3caa3368f82d92 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py @@ -0,0 +1,665 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Classes to support TF Encoder-Decoder architectures""" + +from __future__ import annotations + +import inspect +import re +import warnings +from typing import Optional, Tuple, Union + +import numpy as np +import tensorflow as tf + +from ...configuration_utils import PretrainedConfig +from ...modeling_tf_outputs import TFBaseModelOutput, TFSeq2SeqLMOutput +from ...modeling_tf_utils import ( + TFCausalLanguageModelingLoss, + TFModelInputType, + TFPreTrainedModel, + get_initializer, + keras, + unpack_inputs, +) +from ...tf_utils import shape_list +from ...utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from ..auto.configuration_auto import AutoConfig +from ..auto.modeling_tf_auto import TFAutoModel, TFAutoModelForCausalLM +from .configuration_encoder_decoder import EncoderDecoderConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "EncoderDecoderConfig" + +DEPRECATION_WARNING = ( + "Version v4.17.0 introduces a better way to train encoder-decoder models by computing the loss inside the" + " encoder-decoder framework rather than in the decoder itself. You may observe training discrepancies if" + " fine-tuning a model trained with versions anterior to 4.17.0. The decoder_input_ids are now created based on the" + " labels, no need to pass them yourself anymore." +) + +ENCODER_DECODER_START_DOCSTRING = r""" + This class can be used to initialize a sequence-to-sequence model with any pretrained autoencoding model as the + encoder and any pretrained autoregressive model as the decoder. The encoder is loaded via + [`~TFAutoModel.from_pretrained`] function and the decoder is loaded via [`~TFAutoModelForCausalLM.from_pretrained`] + function. Cross-attention layers are automatically added to the decoder and should be fine-tuned on a downstream + generative task, like summarization. + + The effectiveness of initializing sequence-to-sequence models with pretrained checkpoints for sequence generation + tasks was shown in [Leveraging Pre-trained Checkpoints for Sequence Generation + Tasks](https://arxiv.org/abs/1907.12461) by Sascha Rothe, Shashi Narayan, Aliaksei Severyn. Michael Matena, Yanqi + Zhou, Wei Li, Peter J. Liu. + + After such an Encoder Decoder model has been trained/fine-tuned, it can be saved/loaded just like any other models + (see the examples for more information). + + This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it + as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`EncoderDecoderConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. +""" + +ENCODER_DECODER_INPUTS_DOCSTRING = r""" + Args: + input_ids (`np.ndarray`, `tf.Tensor`, `List[tf.Tensor]` ``Dict[str, tf.Tensor]` or `Dict[str, np.ndarray]` and each example must have the shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_input_ids (`np.ndarray` or `tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + Provide for sequence to sequence training to the decoder. Indices can be obtained using + [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for + details. + decoder_attention_mask (`np.ndarray` or `tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + encoder_outputs (`tuple(tuple(tf.Tensor)`, *optional*): + This tuple must consist of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) + `last_hidden_state` (`tf.Tensor` of shape `({0}, hidden_size)`) is a tensor of hidden-states at the output + of the last layer of the encoder. Used in the cross-attention of the decoder. + past_key_values (`tuple(tuple(tf.Tensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `({0})`. + inputs_embeds (`np.ndarray` or `tf.Tensor` of shape `({0}, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + decoder_inputs_embeds (`np.ndarray` or `tf.Tensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded + representation. This is useful if you want more control over how to convert `decoder_input_ids` indices + into associated vectors than the model's internal embedding lookup matrix. + labels (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*): + Labels for computing the masked language modeling loss for the decoder. Indices should be in `[-100, 0, + ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + If set to `True`, the model will return a [`~utils.Seq2SeqLMOutput`] instead of a plain tuple. + training (`bool`, *optional*, defaults to `False`): + Whether or not to use the model in training mode (some modules like dropout modules have different + behaviors between training and evaluation). + kwargs (*optional*): Remaining dictionary of keyword arguments. Keyword arguments come in two flavors: + + - Without a prefix which will be input as `**encoder_kwargs` for the encoder forward function. + - With a *decoder_* prefix which will be input as `**decoder_kwargs`` for the decoder forward function. +""" + + +def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int): + if pad_token_id is None: + raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.") + pad_token_id = tf.cast(pad_token_id, input_ids.dtype) + + if decoder_start_token_id is None: + raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.") + decoder_start_token_id = tf.cast(decoder_start_token_id, input_ids.dtype) + + start_tokens = tf.fill((shape_list(input_ids)[0], 1), decoder_start_token_id) + shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1) + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids = tf.where( + shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids + ) + + # "Verify that `labels` has only positive values and -100" + assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype)) + + # Make sure the assertion op is called by wrapping the result in an identity no-op + with tf.control_dependencies([assert_gte0]): + shifted_input_ids = tf.identity(shifted_input_ids) + + return shifted_input_ids + + +@add_start_docstrings(ENCODER_DECODER_START_DOCSTRING) +class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss): + r""" + [`TFEncoderDecoderModel`] is a generic model class that will be instantiated as a transformer architecture with one + of the base model classes of the library as encoder and another one as decoder when created with the + [`~TFAutoModel.from_pretrained`] class method for the encoder and [`~TFAutoModelForCausalLM.from_pretrained`] class + method for the decoder. + """ + + config_class = EncoderDecoderConfig + base_model_prefix = "encoder_decoder" + load_weight_prefix = "tf_encoder_decoder_model" + + def __init__( + self, + config: Optional[PretrainedConfig] = None, + encoder: Optional[TFPreTrainedModel] = None, + decoder: Optional[TFPreTrainedModel] = None, + ): + if config is None and (encoder is None or decoder is None): + raise ValueError("Either a configuration or an encoder and a decoder has to be provided.") + if config is None: + config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config) + else: + if not isinstance(config, self.config_class): + raise ValueError(f"config: {config} has to be of type {self.config_class}") + + if config.decoder.cross_attention_hidden_size is not None: + if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size: + raise ValueError( + "If `cross_attention_hidden_size` is specified in the decoder's configuration, it has to be equal" + f" to the encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for" + f" `config.decoder.cross_attention_hidden_size` and {config.encoder.hidden_size} for" + " `config.encoder.hidden_size`." + ) + + # initialize with config + super().__init__(config) + + if encoder is None: + encoder = TFAutoModel.from_config(config.encoder, name="encoder") + + if decoder is None: + decoder = TFAutoModelForCausalLM.from_config(config.decoder, name="decoder") + + self.encoder = encoder + self.decoder = decoder + + if self.encoder.config.to_dict() != self.config.encoder.to_dict(): + logger.warning( + f"Config of the encoder: {self.encoder.__class__} is overwritten by shared encoder config:" + f" {self.config.encoder}" + ) + if self.decoder.config.to_dict() != self.config.decoder.to_dict(): + logger.warning( + f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config:" + f" {self.config.decoder}" + ) + + # make sure that the individual model's config refers to the shared config + # so that the updates to the config will be synced + self.encoder.config = self.config.encoder + self.decoder.config = self.config.decoder + + # encoder outputs might need to be projected to different dimension for decoder + if ( + self.encoder.config.hidden_size != self.decoder.config.hidden_size + and self.decoder.config.cross_attention_hidden_size is None + ): + self.enc_to_dec_proj = keras.layers.Dense( + units=self.decoder.config.hidden_size, + kernel_initializer=get_initializer(config.encoder.initializer_range), + name="enc_to_dec_proj", + ) + + if self.encoder.get_output_embeddings() is not None: + raise ValueError( + f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head" + ) + + decoder_signature = set(inspect.signature(self.decoder.call).parameters.keys()) + if "encoder_hidden_states" not in decoder_signature: + raise ValueError( + "The selected decoder is not prepared for the encoder hidden states to be passed. Please see the " + "following discussion on GitHub: https://github.com/huggingface/transformers/issues/23350" + ) + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + def get_input_embeddings(self): + return self.encoder.get_input_embeddings() + + def get_output_embeddings(self): + return self.decoder.get_output_embeddings() + + def set_output_embeddings(self, new_embeddings): + return self.decoder.set_output_embeddings(new_embeddings) + + def tf_to_pt_weight_rename(self, tf_weight): + # Matt: The TF and PT weights don't align because our TF base classes have an extra layer compared to PT models + # (the main model stem is in the MainLayer class). If we remove that layer, then weight names sync up as normal. + # However, the name of that extra layer is the name of the MainLayer in the base model. We make the assumption + # here that the config model_type is the same as the name of the MainLayer. I don't know of anywhere that's + # not the case, and I wasn't sure how else to go from the config to the correct MainLayer name! + + # This override is only needed in the case where we're crossloading weights from PT. However, since weights are + # often safetensors now, we don't know if we're going to be crossloading until we sniff the weights file. + # Therefore, we specify tf_to_pt_weight_rename anyway, and let the super method figure out if it needs it + # or not. + encoder_model_type = self.config.encoder.model_type + if "encoder" in tf_weight and "decoder" not in tf_weight: + return (re.sub(rf"encoder\.{encoder_model_type}\.", "encoder.", tf_weight),) + else: + return (tf_weight,) + + @classmethod + def from_encoder_decoder_pretrained( + cls, + encoder_pretrained_model_name_or_path: Optional[str] = None, + decoder_pretrained_model_name_or_path: Optional[str] = None, + *model_args, + **kwargs, + ) -> TFPreTrainedModel: + r""" + Instantiate an encoder and a decoder from one or two base classes of the library from pretrained model + checkpoints. + + + Params: + encoder_pretrained_model_name_or_path (`str`, *optional*): + Information necessary to initiate the encoder. Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + - A path to a *directory* containing model weights saved using + [`~TFPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. + - A path or url to a *pytorch index checkpoint file* (e.g, `./pt_model/`). In this case, + `encoder_from_pt` should be set to `True`. + + decoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `None`): + Information necessary to initiate the decoder. Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + - A path to a *directory* containing model weights saved using + [`~TFPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. + - A path or url to a *pytorch checkpoint file* (e.g, `./pt_model/`). In this case, + `decoder_from_pt` should be set to `True`. + + model_args (remaining positional arguments, *optional*): + All remaning positional arguments will be passed to the underlying model's `__init__` method. + + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., + `output_attentions=True`). + + - To update the encoder configuration, use the prefix *encoder_* for each configuration parameter. + - To update the decoder configuration, use the prefix *decoder_* for each configuration parameter. + - To update the parent model configuration, do not use a prefix for each configuration parameter. + + Behaves differently depending on whether a `config` is provided or automatically loaded. + + Example: + + ```python + >>> from transformers import TFEncoderDecoderModel + + >>> # initialize a bert2gpt2 from two pretrained BERT models. Note that the cross-attention layers will be randomly initialized + >>> model = TFEncoderDecoderModel.from_encoder_decoder_pretrained("google-bert/bert-base-uncased", "openai-community/gpt2") + >>> # saving model after fine-tuning + >>> model.save_pretrained("./bert2gpt2") + >>> # load fine-tuned model + >>> model = TFEncoderDecoderModel.from_pretrained("./bert2gpt2") + ```""" + + kwargs_encoder = { + argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_") + } + + kwargs_decoder = { + argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") + } + + # remove encoder, decoder kwargs from kwargs + for key in kwargs_encoder.keys(): + del kwargs["encoder_" + key] + for key in kwargs_decoder.keys(): + del kwargs["decoder_" + key] + + # Load and initialize the encoder and decoder + # The distinction between encoder and decoder at the model level is made + # by the value of the flag `is_decoder` that we need to set correctly. + encoder = kwargs_encoder.pop("model", None) + if encoder is None: + if encoder_pretrained_model_name_or_path is None: + raise ValueError( + "If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has " + "to be defined." + ) + + if "config" not in kwargs_encoder: + encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path) + if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True: + logger.info( + f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model " + "from a decoder model. Cross-attention and casual mask are disabled." + ) + encoder_config.is_decoder = False + encoder_config.add_cross_attention = False + + kwargs_encoder["config"] = encoder_config + + kwargs_encoder["name"] = "encoder" + kwargs_encoder["load_weight_prefix"] = cls.load_weight_prefix + encoder = TFAutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder) + + decoder = kwargs_decoder.pop("model", None) + if decoder is None: + if decoder_pretrained_model_name_or_path is None: + raise ValueError( + "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has " + "to be defined." + ) + + if "config" not in kwargs_decoder: + decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path) + if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False: + logger.info( + f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention" + f" layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if" + f" {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers." + ) + decoder_config.is_decoder = True + decoder_config.add_cross_attention = True + + kwargs_decoder["config"] = decoder_config + + if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False: + logger.warning( + f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. " + f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, " + "make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` " + "passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a " + "`decoder_config` to `.from_encoder_decoder_pretrained(...)`" + ) + + kwargs_decoder["name"] = "decoder" + kwargs_decoder["load_weight_prefix"] = cls.load_weight_prefix + decoder = TFAutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder) + + # Make sure these 2 `keras.Model` have fixed names so `from_pretrained` could load model weights correctly. + if encoder.name != "encoder": + raise ValueError("encoder model must be created with the name `encoder`.") + if decoder.name != "decoder": + raise ValueError("decoder model must be created with the name `decoder`.") + + # instantiate config with corresponding kwargs + config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs) + return cls(encoder=encoder, decoder=decoder, config=config) + + @unpack_inputs + @add_start_docstrings_to_model_forward(ENCODER_DECODER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + decoder_input_ids: np.ndarray | tf.Tensor | None = None, + decoder_attention_mask: np.ndarray | tf.Tensor | None = None, + encoder_outputs: np.ndarray | tf.Tensor | None = None, + past_key_values: Tuple[Tuple[tf.Tensor]] | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + decoder_inputs_embeds: np.ndarray | tf.Tensor | None = None, + labels: np.ndarray | tf.Tensor | None = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: bool = False, + **kwargs, + ) -> Union[TFSeq2SeqLMOutput, Tuple[tf.Tensor]]: + r""" + Returns: + + Examples: + + ```python + >>> from transformers import TFEncoderDecoderModel, BertTokenizer + + >>> # initialize a bert2gpt2 from a pretrained BERT and GPT2 models. Note that the cross-attention layers will be randomly initialized + >>> model = TFEncoderDecoderModel.from_encoder_decoder_pretrained("google-bert/bert-base-cased", "openai-community/gpt2") + + >>> tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-cased") + + >>> # forward + >>> input_ids = tokenizer.encode( + ... "Hello, my dog is cute", add_special_tokens=True, return_tensors="tf" + ... ) # Batch size 1 + >>> outputs = model(input_ids=input_ids, decoder_input_ids=input_ids) + + >>> # training + >>> outputs = model(input_ids=input_ids, decoder_input_ids=input_ids, labels=input_ids) + >>> loss, logits = outputs.loss, outputs.logits + + >>> # save and load from pretrained + >>> model.save_pretrained("bert2gpt2") + >>> model = TFEncoderDecoderModel.from_pretrained("bert2gpt2") + + >>> # generation + >>> generated = model.generate(input_ids, decoder_start_token_id=model.config.decoder.bos_token_id) + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")} + + kwargs_decoder = { + argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") + } + + # Let the user be responsible for the expected format. + if encoder_outputs is not None: + if return_dict and not isinstance(encoder_outputs, ModelOutput): + raise ValueError( + "If `return_dict=True` and `encoder_outputs` is provided, it should be an instance of " + f"`ModelOutput`. Got an instance {type(encoder_outputs)} for `encoder_outputs`." + ) + + if encoder_outputs is None: + encoder_inputs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "inputs_embeds": inputs_embeds, + "output_attentions": output_attentions, + "output_hidden_states": output_hidden_states, + "return_dict": return_dict, + "training": training, + } + + # Add arguments to encoder from `kwargs_encoder` + encoder_inputs.update(kwargs_encoder) + + # Handle the case where the inputs are passed as a single dict which contains `labels`. + # The `labels` shouldn't be passed to `self.encoder` below, because it is a based model without this + # parameter (otherwise, an error occurs when `input_processing` is called inside `self.encoder.call()`). + if "labels" in encoder_inputs: + labels = encoder_inputs.pop("labels") + + # handle the init case where `dummy_inputs` returns a dict containing `decoder_input_ids`. + if "decoder_input_ids" in encoder_inputs: + decoder_input_ids = encoder_inputs.pop("decoder_input_ids") + # handle the init case where `dummy_inputs` returns a dict containing `decoder_input_ids`. + if "decoder_attention_mask" in encoder_inputs: + decoder_attention_mask = encoder_inputs.pop("decoder_attention_mask") + + encoder_outputs = self.encoder(**encoder_inputs) + + encoder_hidden_states = encoder_outputs[0] + + # optionally project encoder_hidden_states + if ( + self.encoder.config.hidden_size != self.decoder.config.hidden_size + and self.decoder.config.cross_attention_hidden_size is None + ): + encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states) + + if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None): + decoder_input_ids = shift_tokens_right( + labels, self.config.pad_token_id, self.config.decoder_start_token_id + ) + + decoder_inputs = { + "input_ids": decoder_input_ids, + "attention_mask": decoder_attention_mask, + "encoder_hidden_states": encoder_hidden_states, + "encoder_attention_mask": attention_mask, + "inputs_embeds": decoder_inputs_embeds, + "output_attentions": output_attentions, + "output_hidden_states": output_hidden_states, + "use_cache": use_cache, + "past_key_values": past_key_values, + "return_dict": return_dict, + "training": training, + } + + # Add arguments to decoder from `kwargs_decoder` + decoder_inputs.update(kwargs_decoder) + + decoder_outputs = self.decoder(**decoder_inputs) + + logits = decoder_outputs[0] + + # Compute loss independent from decoder (as some shift the logits inside them) + loss = None + if labels is not None: + warnings.warn(DEPRECATION_WARNING, FutureWarning) + loss = self.hf_compute_loss(labels, logits) + + if not return_dict: + past_key_values = None + if use_cache: + past_key_values = decoder_outputs[1] + # The starting index of the remaining elements in `decoder_outputs` + start_index = sum([1 if x is not None else 0 for x in (loss, logits, past_key_values)]) + + if not isinstance(encoder_outputs, tuple): + encoder_outputs = encoder_outputs.to_tuple() + output = (loss, logits, past_key_values) + decoder_outputs[start_index:] + encoder_outputs + output = tuple([x for x in output if x is not None]) + return output + + return TFSeq2SeqLMOutput( + loss=loss, + logits=decoder_outputs.logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs + ): + decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past_key_values=past_key_values) + decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None + past_key_values = decoder_inputs.get("past_key_values") + if past_key_values is None: + past_key_values = decoder_inputs.get("past") # e.g. on TF GPT2 + input_dict = { + "input_ids": None, # needs to be passed to make Keras.layer.__call__ happy + "attention_mask": attention_mask, + "decoder_attention_mask": decoder_attention_mask, + "decoder_input_ids": decoder_inputs["input_ids"], + # TODO (joao): the `TFBaseModelOutput` wrapper should not be needed after the generate refactor is complete + "encoder_outputs": TFBaseModelOutput(last_hidden_state=encoder_outputs[0]), + "past_key_values": past_key_values, + "use_cache": use_cache, + } + return input_dict + + def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor): + return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) + + def resize_token_embeddings(self, *args, **kwargs): + raise NotImplementedError( + "Resizing the embedding layers via the TFEncoderDecoderModel directly is not supported.Please use the" + " respective methods of the wrapped objects (model.encoder.resize_token_embeddings(...) or" + " model.decoder.resize_token_embeddings(...))" + ) + + def _reorder_cache(self, past, beam_idx): + # apply decoder cache reordering here + return self.decoder._reorder_cache(past, beam_idx) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "enc_to_dec_proj", None) is not None: + with tf.name_scope(self.enc_to_dec_proj.name): + self.enc_to_dec_proj.build([None, None, self.encoder.config.hidden_size]) + if getattr(self, "encoder", None) is not None: + with tf.name_scope(self.encoder.name): + self.encoder.build(None) + if getattr(self, "decoder", None) is not None: + with tf.name_scope(self.decoder.name): + self.decoder.build(None) + + +__all__ = ["TFEncoderDecoderModel"] diff --git a/docs/transformers/build/lib/transformers/models/ernie/__init__.py b/docs/transformers/build/lib/transformers/models/ernie/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9bb8983063ddb0117e8b0d7cd6603aa6ac3056b6 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/ernie/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_ernie import * + from .modeling_ernie import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/build/lib/transformers/models/ernie/configuration_ernie.py b/docs/transformers/build/lib/transformers/models/ernie/configuration_ernie.py new file mode 100644 index 0000000000000000000000000000000000000000..655e40e163b59dac4f2cab5fe96265b2173478c1 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/ernie/configuration_ernie.py @@ -0,0 +1,163 @@ +# coding=utf-8 +# Copyright 2022 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""ERNIE model configuration""" + +from collections import OrderedDict +from typing import Mapping + +from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class ErnieConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`ErnieModel`] or a [`TFErnieModel`]. It is used to + instantiate a ERNIE model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the ERNIE + [nghuyong/ernie-3.0-base-zh](https://huggingface.co/nghuyong/ernie-3.0-base-zh) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 30522): + Vocabulary size of the ERNIE model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`ErnieModel`] or [`TFErnieModel`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + type_vocab_size (`int`, *optional*, defaults to 2): + The vocabulary size of the `token_type_ids` passed when calling [`ErnieModel`] or [`TFErnieModel`]. + task_type_vocab_size (`int`, *optional*, defaults to 3): + The vocabulary size of the `task_type_ids` for ERNIE2.0/ERNIE3.0 model + use_task_id (`bool`, *optional*, defaults to `False`): + Whether or not the model support `task_type_ids` + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + pad_token_id (`int`, *optional*, defaults to 0): + Padding token id. + position_embedding_type (`str`, *optional*, defaults to `"absolute"`): + Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For + positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to + [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155). + For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models + with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658). + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + classifier_dropout (`float`, *optional*): + The dropout ratio for the classification head. + + Examples: + + ```python + >>> from transformers import ErnieConfig, ErnieModel + + >>> # Initializing a ERNIE nghuyong/ernie-3.0-base-zh style configuration + >>> configuration = ErnieConfig() + + >>> # Initializing a model (with random weights) from the nghuyong/ernie-3.0-base-zh style configuration + >>> model = ErnieModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "ernie" + + def __init__( + self, + vocab_size=30522, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + task_type_vocab_size=3, + use_task_id=False, + initializer_range=0.02, + layer_norm_eps=1e-12, + pad_token_id=0, + position_embedding_type="absolute", + use_cache=True, + classifier_dropout=None, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.task_type_vocab_size = task_type_vocab_size + self.use_task_id = use_task_id + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.position_embedding_type = position_embedding_type + self.use_cache = use_cache + self.classifier_dropout = classifier_dropout + + +class ErnieOnnxConfig(OnnxConfig): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + if self.task == "multiple-choice": + dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"} + else: + dynamic_axis = {0: "batch", 1: "sequence"} + return OrderedDict( + [ + ("input_ids", dynamic_axis), + ("attention_mask", dynamic_axis), + ("token_type_ids", dynamic_axis), + ("task_type_ids", dynamic_axis), + ] + ) + + +__all__ = ["ErnieConfig", "ErnieOnnxConfig"] diff --git a/docs/transformers/build/lib/transformers/models/ernie/modeling_ernie.py b/docs/transformers/build/lib/transformers/models/ernie/modeling_ernie.py new file mode 100644 index 0000000000000000000000000000000000000000..221dd57485880f6210930376dc536028f3463452 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/ernie/modeling_ernie.py @@ -0,0 +1,1825 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch ERNIE model.""" + +import math +import warnings +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...generation import GenerationMixin +from ...modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + MaskedLMOutput, + MultipleChoiceModelOutput, + NextSentencePredictorOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from .configuration_ernie import ErnieConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "nghuyong/ernie-1.0-base-zh" +_CONFIG_FOR_DOC = "ErnieConfig" + + +class ErnieEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + self.use_task_id = config.use_task_id + if config.use_task_id: + self.task_type_embeddings = nn.Embedding(config.task_type_vocab_size, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + self.register_buffer( + "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False + ) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + task_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values_length: int = 0, + ) -> torch.Tensor: + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] + + # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs + # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves + # issue #5664 + if token_type_ids is None: + if hasattr(self, "token_type_ids"): + buffered_token_type_ids = self.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + + # add `task_type_id` for ERNIE model + if self.use_task_id: + if task_type_ids is None: + task_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + task_type_embeddings = self.task_type_embeddings(task_type_ids) + embeddings += task_type_embeddings + + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->Ernie +class ErnieSelfAttention(nn.Module): + def __init__(self, config, position_embedding_type=None): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = position_embedding_type or getattr( + config, "position_embedding_type", "absolute" + ) + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) + + self.is_decoder = config.is_decoder + + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_layer = past_key_value[0] + value_layer = past_key_value[1] + attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + use_cache = past_key_value is not None + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + query_length, key_length = query_layer.shape[2], key_layer.shape[2] + if use_cache: + position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( + -1, 1 + ) + else: + position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in ErnieModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + +# Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->Ernie +class ErnieSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +ERNIE_SELF_ATTENTION_CLASSES = { + "eager": ErnieSelfAttention, +} + + +# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Ernie,BERT->ERNIE +class ErnieAttention(nn.Module): + def __init__(self, config, position_embedding_type=None): + super().__init__() + self.self = ERNIE_SELF_ATTENTION_CLASSES[config._attn_implementation]( + config, position_embedding_type=position_embedding_type + ) + self.output = ErnieSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->Ernie +class ErnieIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->Ernie +class ErnieOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Ernie +class ErnieLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = ErnieAttention(config) + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + if not self.is_decoder: + raise ValueError(f"{self} should be used as a decoder model if cross attention is added") + self.crossattention = ErnieAttention(config, position_embedding_type="absolute") + self.intermediate = ErnieIntermediate(config) + self.output = ErnieOutput(config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + + # if decoder, the last output is tuple of self-attn cache + if self.is_decoder: + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + else: + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + cross_attn_present_key_value = None + if self.is_decoder and encoder_hidden_states is not None: + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" + " by setting `config.add_cross_attention=True`" + ) + + # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + cross_attn_past_key_value, + output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + + # add cross-attn cache to positions 3,4 of present_key_value tuple + cross_attn_present_key_value = cross_attention_outputs[-1] + present_key_value = present_key_value + cross_attn_present_key_value + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + ) + outputs = (layer_output,) + outputs + + # if decoder, return the attn key/values as the last output + if self.is_decoder: + outputs = outputs + (present_key_value,) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +# Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->Ernie +class ErnieEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([ErnieLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + next_decoder_cache = () if use_cache else None + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +# Copied from transformers.models.bert.modeling_bert.BertPooler with Bert->Ernie +class ErniePooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +# Copied from transformers.models.bert.modeling_bert.BertPredictionHeadTransform with Bert->Ernie +class ErniePredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead with Bert->Ernie +class ErnieLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = ErniePredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def _tie_weights(self): + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertOnlyMLMHead with Bert->Ernie +class ErnieOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = ErnieLMPredictionHead(config) + + def forward(self, sequence_output: torch.Tensor) -> torch.Tensor: + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +# Copied from transformers.models.bert.modeling_bert.BertOnlyNSPHead with Bert->Ernie +class ErnieOnlyNSPHead(nn.Module): + def __init__(self, config): + super().__init__() + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, pooled_output): + seq_relationship_score = self.seq_relationship(pooled_output) + return seq_relationship_score + + +# Copied from transformers.models.bert.modeling_bert.BertPreTrainingHeads with Bert->Ernie +class ErniePreTrainingHeads(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = ErnieLMPredictionHead(config) + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, sequence_output, pooled_output): + prediction_scores = self.predictions(sequence_output) + seq_relationship_score = self.seq_relationship(pooled_output) + return prediction_scores, seq_relationship_score + + +class ErniePreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = ErnieConfig + base_model_prefix = "ernie" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +@dataclass +# Copied from transformers.models.bert.modeling_bert.BertForPreTrainingOutput with Bert->Ernie +class ErnieForPreTrainingOutput(ModelOutput): + """ + Output type of [`ErnieForPreTraining`]. + + Args: + loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`): + Total loss as the sum of the masked language modeling loss and the next sequence prediction + (classification) loss. + prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + seq_relationship_logits (`torch.FloatTensor` of shape `(batch_size, 2)`): + Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation + before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + prediction_logits: Optional[torch.FloatTensor] = None + seq_relationship_logits: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +ERNIE_START_DOCSTRING = r""" + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`ErnieConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +ERNIE_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + task_type_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Task type embedding is a special embedding to represent the characteristic of different tasks, such as + word-aware pre-training task, structure-aware pre-training task and semantic-aware pre-training task. We + assign a `task_type_id` to each task and the `task_type_id` is in the range `[0, + config.task_type_vocab_size-1] + position_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare Ernie Model transformer outputting raw hidden-states without any specific head on top.", + ERNIE_START_DOCSTRING, +) +class ErnieModel(ErniePreTrainedModel): + """ + + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in [Attention is + all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + + To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set + to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and + `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass. + """ + + # Copied from transformers.models.clap.modeling_clap.ClapTextModel.__init__ with ClapText->Ernie + def __init__(self, config, add_pooling_layer=True): + super().__init__(config) + self.config = config + + self.embeddings = ErnieEmbeddings(config) + self.encoder = ErnieEncoder(config) + + self.pooler = ErniePooler(config) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() + + # Copied from transformers.models.bert.modeling_bert.BertModel.get_input_embeddings + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + # Copied from transformers.models.bert.modeling_bert.BertModel.set_input_embeddings + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + # Copied from transformers.models.bert.modeling_bert.BertModel._prune_heads + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(ERNIE_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPoolingAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + task_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.config.is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + + if token_type_ids is None: + if hasattr(self.embeddings, "token_type_ids"): + buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + task_type_ids=task_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +@add_start_docstrings( + """ + Ernie Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next + sentence prediction (classification)` head. + """, + ERNIE_START_DOCSTRING, +) +class ErnieForPreTraining(ErniePreTrainedModel): + _tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"] + + # Copied from transformers.models.bert.modeling_bert.BertForPreTraining.__init__ with Bert->Ernie,bert->ernie + def __init__(self, config): + super().__init__(config) + + self.ernie = ErnieModel(config) + self.cls = ErniePreTrainingHeads(config) + + # Initialize weights and apply final processing + self.post_init() + + # Copied from transformers.models.bert.modeling_bert.BertForPreTraining.get_output_embeddings + def get_output_embeddings(self): + return self.cls.predictions.decoder + + # Copied from transformers.models.bert.modeling_bert.BertForPreTraining.set_output_embeddings + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias + + @add_start_docstrings_to_model_forward(ERNIE_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=ErnieForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + task_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + next_sentence_label: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], ErnieForPreTrainingOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), + the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + next_sentence_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the next sequence prediction (classification) loss. Input should be a sequence + pair (see `input_ids` docstring) Indices should be in `[0, 1]`: + + - 0 indicates sequence B is a continuation of sequence A, + - 1 indicates sequence B is a random sequence. + kwargs (`Dict[str, any]`, *optional*, defaults to `{}`): + Used to hide legacy arguments that have been deprecated. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, ErnieForPreTraining + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("nghuyong/ernie-1.0-base-zh") + >>> model = ErnieForPreTraining.from_pretrained("nghuyong/ernie-1.0-base-zh") + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> prediction_logits = outputs.prediction_logits + >>> seq_relationship_logits = outputs.seq_relationship_logits + ``` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.ernie( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + task_type_ids=task_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output, pooled_output = outputs[:2] + prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) + + total_loss = None + if labels is not None and next_sentence_label is not None: + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) + total_loss = masked_lm_loss + next_sentence_loss + + if not return_dict: + output = (prediction_scores, seq_relationship_score) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return ErnieForPreTrainingOutput( + loss=total_loss, + prediction_logits=prediction_scores, + seq_relationship_logits=seq_relationship_score, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """Ernie Model with a `language modeling` head on top for CLM fine-tuning.""", ERNIE_START_DOCSTRING +) +class ErnieForCausalLM(ErniePreTrainedModel, GenerationMixin): + _tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"] + + # Copied from transformers.models.bert.modeling_bert.BertLMHeadModel.__init__ with BertLMHeadModel->ErnieForCausalLM,Bert->Ernie,bert->ernie + def __init__(self, config): + super().__init__(config) + + if not config.is_decoder: + logger.warning("If you want to use `ErnieForCausalLM` as a standalone, add `is_decoder=True.`") + + self.ernie = ErnieModel(config, add_pooling_layer=False) + self.cls = ErnieOnlyMLMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + # Copied from transformers.models.bert.modeling_bert.BertLMHeadModel.get_output_embeddings + def get_output_embeddings(self): + return self.cls.predictions.decoder + + # Copied from transformers.models.bert.modeling_bert.BertLMHeadModel.set_output_embeddings + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias + + @add_start_docstrings_to_model_forward(ERNIE_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=CausalLMOutputWithCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + task_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.Tensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are + ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]` + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if labels is not None: + use_cache = False + + outputs = self.ernie( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + task_type_ids=task_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + lm_loss = None + if labels is not None: + lm_loss = self.loss_function( + prediction_scores, + labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + # Copied from transformers.models.bert.modeling_bert.BertLMHeadModel._reorder_cache + def _reorder_cache(self, past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + + +@add_start_docstrings("""Ernie Model with a `language modeling` head on top.""", ERNIE_START_DOCSTRING) +class ErnieForMaskedLM(ErniePreTrainedModel): + _tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"] + + # Copied from transformers.models.bert.modeling_bert.BertForMaskedLM.__init__ with Bert->Ernie,bert->ernie + def __init__(self, config): + super().__init__(config) + + if config.is_decoder: + logger.warning( + "If you want to use `ErnieForMaskedLM` make sure `config.is_decoder=False` for " + "bi-directional self-attention." + ) + + self.ernie = ErnieModel(config, add_pooling_layer=False) + self.cls = ErnieOnlyMLMHead(config) + + # Initialize weights and apply final processing + self.post_init() + + # Copied from transformers.models.bert.modeling_bert.BertForMaskedLM.get_output_embeddings + def get_output_embeddings(self): + return self.cls.predictions.decoder + + # Copied from transformers.models.bert.modeling_bert.BertForMaskedLM.set_output_embeddings + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias + + @add_start_docstrings_to_model_forward(ERNIE_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MaskedLMOutput, + config_class=_CONFIG_FOR_DOC, + expected_output="'paris'", + expected_loss=0.88, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + task_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.ernie( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + task_type_ids=task_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + # Copied from transformers.models.bert.modeling_bert.BertForMaskedLM.prepare_inputs_for_generation + def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs): + input_shape = input_ids.shape + effective_batch_size = input_shape[0] + + # add a dummy token + if self.config.pad_token_id is None: + raise ValueError("The PAD token should be defined for generation") + + attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1) + dummy_token = torch.full( + (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device + ) + input_ids = torch.cat([input_ids, dummy_token], dim=1) + + return {"input_ids": input_ids, "attention_mask": attention_mask} + + @classmethod + def can_generate(cls) -> bool: + """ + Legacy correction: ErnieForMaskedLM can't call `generate()` from `GenerationMixin`, even though it has a + `prepare_inputs_for_generation` method. + """ + return False + + +@add_start_docstrings( + """Ernie Model with a `next sentence prediction (classification)` head on top.""", + ERNIE_START_DOCSTRING, +) +class ErnieForNextSentencePrediction(ErniePreTrainedModel): + # Copied from transformers.models.bert.modeling_bert.BertForNextSentencePrediction.__init__ with Bert->Ernie,bert->ernie + def __init__(self, config): + super().__init__(config) + + self.ernie = ErnieModel(config) + self.cls = ErnieOnlyNSPHead(config) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(ERNIE_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + task_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair + (see `input_ids` docstring). Indices should be in `[0, 1]`: + + - 0 indicates sequence B is a continuation of sequence A, + - 1 indicates sequence B is a random sequence. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, ErnieForNextSentencePrediction + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("nghuyong/ernie-1.0-base-zh") + >>> model = ErnieForNextSentencePrediction.from_pretrained("nghuyong/ernie-1.0-base-zh") + + >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." + >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light." + >>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt") + + >>> outputs = model(**encoding, labels=torch.LongTensor([1])) + >>> logits = outputs.logits + >>> assert logits[0, 0] < logits[0, 1] # next sentence was random + ``` + """ + + if "next_sentence_label" in kwargs: + warnings.warn( + "The `next_sentence_label` argument is deprecated and will be removed in a future version, use" + " `labels` instead.", + FutureWarning, + ) + labels = kwargs.pop("next_sentence_label") + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.ernie( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + task_type_ids=task_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + seq_relationship_scores = self.cls(pooled_output) + + next_sentence_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1)) + + if not return_dict: + output = (seq_relationship_scores,) + outputs[2:] + return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output + + return NextSentencePredictorOutput( + loss=next_sentence_loss, + logits=seq_relationship_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Ernie Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled + output) e.g. for GLUE tasks. + """, + ERNIE_START_DOCSTRING, +) +class ErnieForSequenceClassification(ErniePreTrainedModel): + # Copied from transformers.models.bert.modeling_bert.BertForSequenceClassification.__init__ with Bert->Ernie,bert->ernie + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + + self.ernie = ErnieModel(config) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(ERNIE_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + task_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.ernie( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + task_type_ids=task_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Ernie Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a + softmax) e.g. for RocStories/SWAG tasks. + """, + ERNIE_START_DOCSTRING, +) +class ErnieForMultipleChoice(ErniePreTrainedModel): + # Copied from transformers.models.bert.modeling_bert.BertForMultipleChoice.__init__ with Bert->Ernie,bert->ernie + def __init__(self, config): + super().__init__(config) + + self.ernie = ErnieModel(config) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, 1) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(ERNIE_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MultipleChoiceModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + task_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., + num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See + `input_ids` above) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] + + input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None + position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) + + outputs = self.ernie( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + task_type_ids=task_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + reshaped_logits = logits.view(-1, num_choices) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + + if not return_dict: + output = (reshaped_logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return MultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Ernie Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + ERNIE_START_DOCSTRING, +) +class ErnieForTokenClassification(ErniePreTrainedModel): + # Copied from transformers.models.bert.modeling_bert.BertForTokenClassification.__init__ with Bert->Ernie,bert->ernie + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.ernie = ErnieModel(config, add_pooling_layer=False) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(ERNIE_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + task_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.ernie( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + task_type_ids=task_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Ernie Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear + layers on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + ERNIE_START_DOCSTRING, +) +class ErnieForQuestionAnswering(ErniePreTrainedModel): + # Copied from transformers.models.bert.modeling_bert.BertForQuestionAnswering.__init__ with Bert->Ernie,bert->ernie + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.ernie = ErnieModel(config, add_pooling_layer=False) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(ERNIE_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + task_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + start_positions: Optional[torch.Tensor] = None, + end_positions: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.ernie( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + task_type_ids=task_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "ErnieForCausalLM", + "ErnieForMaskedLM", + "ErnieForMultipleChoice", + "ErnieForNextSentencePrediction", + "ErnieForPreTraining", + "ErnieForQuestionAnswering", + "ErnieForSequenceClassification", + "ErnieForTokenClassification", + "ErnieModel", + "ErniePreTrainedModel", +] diff --git a/docs/transformers/build/lib/transformers/models/esm/__init__.py b/docs/transformers/build/lib/transformers/models/esm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8eac54d6ddcbdae2b8ca3771ae5540522f6f29da --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/esm/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_esm import * + from .modeling_esm import * + from .modeling_esmfold import * + from .modeling_tf_esm import * + from .tokenization_esm import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/build/lib/transformers/models/esm/configuration_esm.py b/docs/transformers/build/lib/transformers/models/esm/configuration_esm.py new file mode 100644 index 0000000000000000000000000000000000000000..ac56bc8d783a5895496520fc62c7bada7fa5f61a --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/esm/configuration_esm.py @@ -0,0 +1,365 @@ +# coding=utf-8 +# Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""ESM model configuration""" + +from dataclasses import asdict, dataclass +from typing import Optional + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + +# TODO Update this + + +class EsmConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`ESMModel`]. It is used to instantiate a ESM model + according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the ESM + [facebook/esm-1b](https://huggingface.co/facebook/esm-1b) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*): + Vocabulary size of the ESM model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`ESMModel`]. + mask_token_id (`int`, *optional*): + The index of the mask token in the vocabulary. This must be included in the config because of the + "mask-dropout" scaling trick, which will scale the inputs depending on the number of masked tokens. + pad_token_id (`int`, *optional*): + The index of the padding token in the vocabulary. This must be included in the config because certain parts + of the ESM code use this instead of the attention mask. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (`int`, *optional*, defaults to 1026): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + position_embedding_type (`str`, *optional*, defaults to `"absolute"`): + Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query", "rotary"`. + For positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to + [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155). + For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models + with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658). + is_decoder (`bool`, *optional*, defaults to `False`): + Whether the model is used as a decoder or not. If `False`, the model is used as an encoder. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + emb_layer_norm_before (`bool`, *optional*): + Whether to apply layer normalization after embeddings but before the main stem of the network. + token_dropout (`bool`, defaults to `False`): + When this is enabled, masked tokens are treated as if they had been dropped out by input dropout. + + Examples: + + ```python + >>> from transformers import EsmModel, EsmConfig + + >>> # Initializing a ESM facebook/esm-1b style configuration + >>> configuration = EsmConfig(vocab_size=33) + + >>> # Initializing a model from the configuration + >>> model = EsmModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "esm" + + def __init__( + self, + vocab_size=None, + mask_token_id=None, + pad_token_id=None, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=1026, + initializer_range=0.02, + layer_norm_eps=1e-12, + position_embedding_type="absolute", + use_cache=True, + emb_layer_norm_before=None, + token_dropout=False, + is_folding_model=False, + esmfold_config=None, + vocab_list=None, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, mask_token_id=mask_token_id, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.position_embedding_type = position_embedding_type + self.use_cache = use_cache + self.emb_layer_norm_before = emb_layer_norm_before + self.token_dropout = token_dropout + self.is_folding_model = is_folding_model + if is_folding_model: + if esmfold_config is None: + logger.info("No esmfold_config supplied for folding model, using default values.") + esmfold_config = EsmFoldConfig() + elif isinstance(esmfold_config, dict): + esmfold_config = EsmFoldConfig(**esmfold_config) + self.esmfold_config = esmfold_config + if vocab_list is None: + logger.warning("No vocab_list supplied for folding model, assuming the ESM-2 vocabulary!") + self.vocab_list = get_default_vocab_list() + else: + self.vocab_list = vocab_list + else: + self.esmfold_config = None + self.vocab_list = None + if self.esmfold_config is not None and getattr(self.esmfold_config, "use_esm_attn_map", False): + raise ValueError("The HuggingFace port of ESMFold does not support use_esm_attn_map at this time!") + + def to_dict(self): + """ + Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. + + Returns: + `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, + """ + output = super().to_dict() + if isinstance(self.esmfold_config, EsmFoldConfig): + output["esmfold_config"] = self.esmfold_config.to_dict() + return output + + +@dataclass +class EsmFoldConfig: + esm_type: Optional[str] = None + fp16_esm: bool = True + use_esm_attn_map: bool = False + esm_ablate_pairwise: bool = False + esm_ablate_sequence: bool = False + esm_input_dropout: float = 0 + + embed_aa: bool = True + bypass_lm: bool = False + + lddt_head_hid_dim: int = 128 + trunk: "TrunkConfig" = None + + def __post_init__(self): + if self.trunk is None: + self.trunk = TrunkConfig() + elif isinstance(self.trunk, dict): + self.trunk = TrunkConfig(**self.trunk) + + def to_dict(self): + """ + Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. + + Returns: + `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, + """ + output = asdict(self) + output["trunk"] = self.trunk.to_dict() + return output + + +@dataclass +class TrunkConfig: + num_blocks: int = 48 + sequence_state_dim: int = 1024 + pairwise_state_dim: int = 128 + sequence_head_width: int = 32 + pairwise_head_width: int = 32 + position_bins: int = 32 + dropout: float = 0 + layer_drop: float = 0 + cpu_grad_checkpoint: bool = False + max_recycles: int = 4 + chunk_size: Optional[int] = 128 + structure_module: "StructureModuleConfig" = None + + def __post_init__(self): + if self.structure_module is None: + self.structure_module = StructureModuleConfig() + elif isinstance(self.structure_module, dict): + self.structure_module = StructureModuleConfig(**self.structure_module) + + if self.max_recycles <= 0: + raise ValueError(f"`max_recycles` should be positive, got {self.max_recycles}.") + if self.sequence_state_dim % self.sequence_state_dim != 0: + raise ValueError( + "`sequence_state_dim` should be a round multiple of `sequence_state_dim`, got" + f" {self.sequence_state_dim} and {self.sequence_state_dim}." + ) + if self.pairwise_state_dim % self.pairwise_state_dim != 0: + raise ValueError( + "`pairwise_state_dim` should be a round multiple of `pairwise_state_dim`, got" + f" {self.pairwise_state_dim} and {self.pairwise_state_dim}." + ) + + sequence_num_heads = self.sequence_state_dim // self.sequence_head_width + pairwise_num_heads = self.pairwise_state_dim // self.pairwise_head_width + + if self.sequence_state_dim != sequence_num_heads * self.sequence_head_width: + raise ValueError( + "`sequence_state_dim` should be equal to `sequence_num_heads * sequence_head_width, got" + f" {self.sequence_state_dim} != {sequence_num_heads} * {self.sequence_head_width}." + ) + if self.pairwise_state_dim != pairwise_num_heads * self.pairwise_head_width: + raise ValueError( + "`pairwise_state_dim` should be equal to `pairwise_num_heads * pairwise_head_width, got" + f" {self.pairwise_state_dim} != {pairwise_num_heads} * {self.pairwise_head_width}." + ) + if self.pairwise_state_dim % 2 != 0: + raise ValueError(f"`pairwise_state_dim` should be even, got {self.pairwise_state_dim}.") + + if self.dropout >= 0.4: + raise ValueError(f"`dropout` should not be greater than 0.4, got {self.dropout}.") + + def to_dict(self): + """ + Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. + + Returns: + `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, + """ + output = asdict(self) + output["structure_module"] = self.structure_module.to_dict() + return output + + +@dataclass +class StructureModuleConfig: + """ + Args: + sequence_dim: + Single representation channel dimension + pairwise_dim: + Pair representation channel dimension + ipa_dim: + IPA hidden channel dimension + resnet_dim: + Angle resnet (Alg. 23 lines 11-14) hidden channel dimension + num_heads_ipa: + Number of IPA heads + num_qk_points: + Number of query/key points to generate during IPA + num_v_points: + Number of value points to generate during IPA + dropout_rate: + Dropout rate used throughout the layer + num_blocks: + Number of structure module blocks + num_transition_layers: + Number of layers in the single representation transition (Alg. 23 lines 8-9) + num_resnet_blocks: + Number of blocks in the angle resnet + num_angles: + Number of angles to generate in the angle resnet + trans_scale_factor: + Scale of single representation transition hidden dimension + epsilon: + Small number used in angle resnet normalization + inf: + Large number used for attention masking + """ + + sequence_dim: int = 384 + pairwise_dim: int = 128 + ipa_dim: int = 16 + resnet_dim: int = 128 + num_heads_ipa: int = 12 + num_qk_points: int = 4 + num_v_points: int = 8 + dropout_rate: float = 0.1 + num_blocks: int = 8 + num_transition_layers: int = 1 + num_resnet_blocks: int = 2 + num_angles: int = 7 + trans_scale_factor: int = 10 + epsilon: float = 1e-8 + inf: float = 1e5 + + def to_dict(self): + return asdict(self) + + +def get_default_vocab_list(): + return ( + "", + "", + "", + "", + "L", + "A", + "G", + "V", + "S", + "E", + "R", + "T", + "I", + "D", + "P", + "K", + "Q", + "N", + "F", + "Y", + "M", + "H", + "W", + "C", + "X", + "B", + "U", + "Z", + "O", + ".", + "-", + "", + "", + ) + + +__all__ = ["EsmConfig"] diff --git a/docs/transformers/build/lib/transformers/models/esm/convert_esm.py b/docs/transformers/build/lib/transformers/models/esm/convert_esm.py new file mode 100644 index 0000000000000000000000000000000000000000..020dd4e576639230565355d82e74ad6313f875b7 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/esm/convert_esm.py @@ -0,0 +1,399 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert ESM checkpoint.""" + +import argparse +import pathlib +from pathlib import Path +from tempfile import TemporaryDirectory + +import esm as esm_module +import torch +from esm.esmfold.v1.misc import batch_encode_sequences as esmfold_encode_sequences +from esm.esmfold.v1.pretrained import esmfold_v1 + +from transformers.models.esm.configuration_esm import EsmConfig, EsmFoldConfig +from transformers.models.esm.modeling_esm import ( + EsmForMaskedLM, + EsmForSequenceClassification, + EsmIntermediate, + EsmLayer, + EsmOutput, + EsmSelfAttention, + EsmSelfOutput, +) +from transformers.models.esm.modeling_esmfold import EsmForProteinFolding +from transformers.models.esm.tokenization_esm import EsmTokenizer +from transformers.utils import logging + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + +SAMPLE_DATA = [ + ( + "protein1", + "MNGTEGPNFYVPFSNATGVVRSPFEYPQYYLAEPWQFSMLAAYMFLLIVLGFPINFLTLYVTVQHKKLRTPLNYILLNLAVADLFMVLGGFTSTLYTSLHGYFVFGPTGCNLEGFFATLGGEIALWSLVVLAIERYVVVCKPMSNFRFGENHAIMGVAFTWVMALACAAPPLAGWSRYIPEGLQCSCGIDYYTLKPEVNNESFVIYMFVVHFTIPMIIIFFCYGQLVFTVKEAAAQQQESATTQKAEKEVTRMVIIMVIAFLICWVPYASVAFYIFTHQGSNFGPIFMTIPAFFAKSAAIYNPVIYIMMNKQFRNCMLTTICCGKNPLGDDEASATVSKTETSQVAPA", + ), + ("protein2", "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLA"), + ("protein3", "MKTVRQERLKSIRILERSKEPVSGAQLAEELSSRQVIVQDIAYLRSLGYNVATPRGYVLAGG"), + ("protein4", "MKTVRQERLKSIRILERSKEPVSGAQLAEELSSRQVIVQDIAYLRSLGYNVATPRGYVLA"), +] + +MODEL_MAPPING = { + "esm1b_t33_650M_UR50S": esm_module.pretrained.esm1b_t33_650M_UR50S, + "esm1v_t33_650M_UR90S_1": esm_module.pretrained.esm1v_t33_650M_UR90S_1, + "esm1v_t33_650M_UR90S_2": esm_module.pretrained.esm1v_t33_650M_UR90S_2, + "esm1v_t33_650M_UR90S_3": esm_module.pretrained.esm1v_t33_650M_UR90S_3, + "esm1v_t33_650M_UR90S_4": esm_module.pretrained.esm1v_t33_650M_UR90S_4, + "esm1v_t33_650M_UR90S_5": esm_module.pretrained.esm1v_t33_650M_UR90S_5, + "esm2_t48_15B_UR50D": esm_module.pretrained.esm2_t48_15B_UR50D, + "esm2_t36_3B_UR50D": esm_module.pretrained.esm2_t36_3B_UR50D, + "esm2_t33_650M_UR50D": esm_module.pretrained.esm2_t33_650M_UR50D, + "esm2_t30_150M_UR50D": esm_module.pretrained.esm2_t30_150M_UR50D, + "esm2_t12_35M_UR50D": esm_module.pretrained.esm2_t12_35M_UR50D, + "esm2_t6_8M_UR50D": esm_module.pretrained.esm2_t6_8M_UR50D, + "esmfold_v1": esmfold_v1, +} + +restypes = list("ARNDCQEGHILKMFPSTWYV") + +restypes_with_x = restypes + ["X"] +restypes_with_extras = restypes_with_x + ["", "", "", "", ""] + + +def get_esmfold_tokenizer(): + with TemporaryDirectory() as tempdir: + vocab = "\n".join(restypes_with_extras) + vocab_file = Path(tempdir) / "vocab.txt" + vocab_file.write_text(vocab) + hf_tokenizer = EsmTokenizer(vocab_file=str(vocab_file)) + hf_tokenizer.pad_token_id = 0 # Overlaps with 'A' but that seems to be what they want + return hf_tokenizer + + +def transfer_and_check_weights(original_module, our_module): + status = our_module.load_state_dict(original_module.state_dict()) + if status.missing_keys: + raise ValueError(f"Missing keys: {status.missing_keys}") + if status.unexpected_keys: + raise ValueError(f"Unexpected keys: {status.unexpected_keys}") + + +def convert_esm_checkpoint_to_pytorch( + model: str, pytorch_dump_folder_path: str, classification_head: bool, push_to_repo: str, auth_token: str +): + """ + Copy/paste/tweak esm's weights to our BERT structure. + """ + if model.startswith("esmfold"): + esm = MODEL_MAPPING[model]() + else: + esm, alphabet = MODEL_MAPPING[model]() + esm.eval() # disable dropout + + if model.startswith("esmfold"): + embed_dim = esm.esm.embed_dim + num_layers = esm.esm.num_layers + num_attention_heads = esm.esm.attention_heads + intermediate_size = 4 * embed_dim + token_dropout = esm.esm.token_dropout + emb_layer_norm_before = False # This code path does not exist in ESM-2 + position_embedding_type = "rotary" + is_folding_model = True + esmfold_config = EsmFoldConfig() + for key, val in esm.cfg.items(): + if hasattr(esmfold_config, key) and key != "trunk": + setattr(esmfold_config, key, val) + for key, val in esm.cfg.trunk.items(): + if hasattr(esmfold_config.trunk, key) and key != "structure_module": + setattr(esmfold_config.trunk, key, val) + for key, val in esm.cfg.trunk.structure_module.items(): + if hasattr(esmfold_config.trunk.structure_module, key): + setattr(esmfold_config.trunk.structure_module, key, val) + elif hasattr(esm, "args"): + # Indicates an ESM-1b or ESM-1v model + embed_dim = esm.args.embed_dim + num_layers = esm.args.layers + num_attention_heads = esm.args.attention_heads + intermediate_size = esm.args.ffn_embed_dim + token_dropout = esm.args.token_dropout + emb_layer_norm_before = True if esm.emb_layer_norm_before else False + position_embedding_type = "absolute" + is_folding_model = False + esmfold_config = None + else: + # Indicates an ESM-2 model + embed_dim = esm.embed_dim + num_layers = esm.num_layers + num_attention_heads = esm.attention_heads + intermediate_size = 4 * embed_dim # This is hardcoded in ESM-2 + token_dropout = esm.token_dropout + emb_layer_norm_before = False # This code path does not exist in ESM-2 + position_embedding_type = "rotary" + is_folding_model = False + esmfold_config = None + + if is_folding_model: + alphabet = esm.esm.alphabet + vocab_list = tuple(alphabet.all_toks) + mask_token_id = alphabet.mask_idx + pad_token_id = alphabet.padding_idx + + if is_folding_model: + original_esm_model = esm.esm + else: + original_esm_model = esm + + config = EsmConfig( + vocab_size=original_esm_model.embed_tokens.num_embeddings, + mask_token_id=mask_token_id, + hidden_size=embed_dim, + num_hidden_layers=num_layers, + num_attention_heads=num_attention_heads, + intermediate_size=intermediate_size, + max_position_embeddings=1026, + layer_norm_eps=1e-5, # PyTorch default used in fairseq + attention_probs_dropout_prob=0.0, + hidden_dropout_prob=0.0, + pad_token_id=pad_token_id, + emb_layer_norm_before=emb_layer_norm_before, + token_dropout=token_dropout, + position_embedding_type=position_embedding_type, + is_folding_model=is_folding_model, + esmfold_config=esmfold_config, + vocab_list=vocab_list, + ) + if classification_head: + config.num_labels = esm.classification_heads["mnli"].out_proj.weight.shape[0] + print("Our ESM config:", config) + + if model.startswith("esmfold"): + model_class = EsmForProteinFolding + elif classification_head: + model_class = EsmForSequenceClassification + else: + model_class = EsmForMaskedLM + model = model_class(config) + model.eval() + + # Now let's copy all the weights. + # Embeddings + model.esm.embeddings.word_embeddings.weight = original_esm_model.embed_tokens.weight + if position_embedding_type == "absolute": + model.esm.embeddings.position_embeddings.weight = original_esm_model.embed_positions.weight + + if config.emb_layer_norm_before: + model.esm.embeddings.layer_norm.weight = original_esm_model.emb_layer_norm_before.weight + model.esm.embeddings.layer_norm.bias = original_esm_model.emb_layer_norm_before.bias + + model.esm.encoder.emb_layer_norm_after.weight = original_esm_model.emb_layer_norm_after.weight + model.esm.encoder.emb_layer_norm_after.bias = original_esm_model.emb_layer_norm_after.bias + + for i in range(config.num_hidden_layers): + # Encoder: start of layer + layer: EsmLayer = model.esm.encoder.layer[i] + # esm_layer: TransformerSentenceEncoderLayer = original_esm_model.layers[i] + esm_layer = original_esm_model.layers[i] + + # self attention + self_attn: EsmSelfAttention = layer.attention.self + assert ( + esm_layer.self_attn.k_proj.weight.data.shape + == esm_layer.self_attn.q_proj.weight.data.shape + == esm_layer.self_attn.v_proj.weight.data.shape + == torch.Size((config.hidden_size, config.hidden_size)) + ) + + self_attn.query.weight.data = esm_layer.self_attn.q_proj.weight + self_attn.query.bias.data = esm_layer.self_attn.q_proj.bias + self_attn.key.weight.data = esm_layer.self_attn.k_proj.weight + self_attn.key.bias.data = esm_layer.self_attn.k_proj.bias + self_attn.value.weight.data = esm_layer.self_attn.v_proj.weight + self_attn.value.bias.data = esm_layer.self_attn.v_proj.bias + + if getattr(esm_layer.self_attn, "rot_emb", None) is not None: + # Matt: Although inv_freq is not a trainable weight, it is computed at model init and cached. + # During the training of ESM-2 the model was converted to float16 precision, which also converts + # the inv_freq tensor, and the loss of precision remains even if the model is loaded later as float32. + # If we recompute inv_freq without this loss of precision then we will get subtly different rotary + # embeddings, which are enough to cause significant discrepancies in model outputs. To avoid this, + # we make sure the new model copies the data from the old inv_freq. + self_attn.rotary_embeddings.inv_freq.data = esm_layer.self_attn.rot_emb.inv_freq + + # LayerNorm changes for pre-activation + layer.attention.LayerNorm.weight = esm_layer.self_attn_layer_norm.weight + layer.attention.LayerNorm.bias = esm_layer.self_attn_layer_norm.bias + layer.LayerNorm.weight = esm_layer.final_layer_norm.weight + layer.LayerNorm.bias = esm_layer.final_layer_norm.bias + + # self-attention output + self_output: EsmSelfOutput = layer.attention.output + assert self_output.dense.weight.shape == esm_layer.self_attn.out_proj.weight.shape + self_output.dense.weight = esm_layer.self_attn.out_proj.weight + self_output.dense.bias = esm_layer.self_attn.out_proj.bias + + # intermediate + intermediate: EsmIntermediate = layer.intermediate + assert intermediate.dense.weight.shape == esm_layer.fc1.weight.shape + intermediate.dense.weight = esm_layer.fc1.weight + intermediate.dense.bias = esm_layer.fc1.bias + + # output + bert_output: EsmOutput = layer.output + assert bert_output.dense.weight.shape == esm_layer.fc2.weight.shape + bert_output.dense.weight = esm_layer.fc2.weight + bert_output.dense.bias = esm_layer.fc2.bias + # end of layer + + if is_folding_model: + model.esm_s_combine.data = esm.esm_s_combine.data + model.af2_to_esm.data = esm.af2_to_esm.data + transfer_and_check_weights(esm.embedding, model.embedding) + transfer_and_check_weights(esm.esm_s_mlp, model.esm_s_mlp) + transfer_and_check_weights(esm.trunk, model.trunk) + transfer_and_check_weights(esm.distogram_head, model.distogram_head) + transfer_and_check_weights(esm.ptm_head, model.ptm_head) + transfer_and_check_weights(esm.lm_head, model.lm_head) + transfer_and_check_weights(esm.lddt_head, model.lddt_head) + + elif classification_head: + model.classifier.dense.weight = esm.esm.classification_heads["mnli"].dense.weight + model.classifier.dense.bias = esm.classification_heads["mnli"].dense.bias + model.classifier.out_proj.weight = esm.classification_heads["mnli"].out_proj.weight + model.classifier.out_proj.bias = esm.classification_heads["mnli"].out_proj.bias + else: + # LM Head + model.lm_head.dense.weight = esm.lm_head.dense.weight + model.lm_head.dense.bias = esm.lm_head.dense.bias + model.lm_head.layer_norm.weight = esm.lm_head.layer_norm.weight + model.lm_head.layer_norm.bias = esm.lm_head.layer_norm.bias + model.lm_head.decoder.weight = esm.lm_head.weight + model.lm_head.bias = esm.lm_head.bias + + # Contact prediction head + transfer_and_check_weights(esm.contact_head, model.esm.contact_head) + + # Prepare data (first 2 sequences from ESMStructuralSplitDataset superfamily / 4) + if is_folding_model: + # Folding models aren't trained on masked inputs and don't like mask tokens. + sample_data = SAMPLE_DATA[:2] + else: + sample_data = SAMPLE_DATA + + if is_folding_model: + hf_tokenizer = get_esmfold_tokenizer() + hf_tokens = hf_tokenizer( + [row[1] for row in sample_data], return_tensors="pt", padding=True, add_special_tokens=False + ) + esmfold_aas, esmfold_mask, _, _, _ = esmfold_encode_sequences([row[1] for row in sample_data]) + success = torch.all(hf_tokens["input_ids"] == esmfold_aas) and torch.all( + hf_tokens["attention_mask"] == esmfold_mask + ) + else: + # Let's check that we get the same results. + batch_converter = alphabet.get_batch_converter() + batch_labels, batch_strs, batch_tokens = batch_converter(sample_data) + # Prepare tokenizer and make sure it matches + with TemporaryDirectory() as tempdir: + vocab = "\n".join(alphabet.all_toks) + vocab_file = Path(tempdir) / "vocab.txt" + vocab_file.write_text(vocab) + hf_tokenizer = EsmTokenizer(vocab_file=str(vocab_file)) + + hf_tokens = hf_tokenizer([row[1] for row in sample_data], return_tensors="pt", padding=True) + success = torch.all(hf_tokens["input_ids"] == batch_tokens) + + print("Do both models tokenizers output the same tokens?", "🔥" if success else "💩") + if not success: + raise Exception("Tokenization does not match!") + + with torch.no_grad(): + if is_folding_model: + # Let's test the model in parts + # ESMFold always converts the ESM stem to float16, which requires float16 ops + # that don't exist on CPU. Therefore, to test it we need to run it on GPU. However, + # ESMFold is what we in the community call a "big boy" and so we desperately avoid putting both the + # original and the converted model on the GPU at the same time. + their_output = esm.cuda().infer([row[1] for row in sample_data]) + our_output = model.cuda()( + input_ids=hf_tokens["input_ids"].cuda(), attention_mask=hf_tokens["attention_mask"].cuda() + ) + else: + our_output = model(**hf_tokens, output_hidden_states=True) + our_output = our_output["logits"] + if classification_head: + their_output = esm.model.classification_heads["mnli"](esm.extract_features(batch_tokens)) + else: + their_output = esm(hf_tokens["input_ids"], repr_layers=list(range(999))) + their_output = their_output["logits"] + + if is_folding_model: + max_absolute_diff = torch.max(torch.abs(our_output["positions"] - their_output["positions"])).item() + success = torch.allclose(our_output["positions"], their_output["positions"], atol=1e-5) + else: + max_absolute_diff = torch.max(torch.abs(our_output - their_output)).item() + success = torch.allclose(our_output, their_output, atol=1e-5) + + print(f"max_absolute_diff = {max_absolute_diff}") # ~ 1e-5 + print("Do both models output the same tensors?", "🔥" if success else "💩") + + if not success: + raise Exception("Something went wRoNg") + + if not is_folding_model: + # Let's check contact prediction too + our_output = model.predict_contacts(hf_tokens["input_ids"], hf_tokens["attention_mask"]) + their_output = esm.predict_contacts(hf_tokens["input_ids"]) + max_absolute_diff = torch.max(torch.abs(our_output - their_output)).item() + success = torch.allclose(our_output, their_output, atol=1e-5) + + print("Contact prediction testing:") + print(f"max_absolute_diff = {max_absolute_diff}") # ~ 1e-5 + print("Do both models output the same tensors?", "🔥" if success else "💩") + + if not success: + raise Exception("Something went wRoNg") + + pathlib.Path(pytorch_dump_folder_path).mkdir(parents=True, exist_ok=True) + print(f"Saving model to {pytorch_dump_folder_path}") + model.save_pretrained(pytorch_dump_folder_path) + + del esm # Free up some memory before continuing + + print(f"Saving tokenizer to {pytorch_dump_folder_path}") + hf_tokenizer.save_pretrained(pytorch_dump_folder_path) + + if push_to_repo: + model.push_to_hub(repo_id=push_to_repo, token_token=auth_token) + hf_tokenizer.push_to_hub(repo_id=push_to_repo, token_token=auth_token) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--pytorch_dump_folder_path", type=str, required=True, help="Path to the output PyTorch model." + ) + parser.add_argument( + "--classification_head", action="store_true", help="Whether to convert a final classification head." + ) + parser.add_argument("--model", default=None, type=str, required=True, help="Name of model to convert.") + parser.add_argument("--push_to_repo", type=str, help="Repo to upload to (including username!).") + parser.add_argument("--auth_token", type=str, help="HuggingFace auth token.") + args = parser.parse_args() + convert_esm_checkpoint_to_pytorch( + args.model, args.pytorch_dump_folder_path, args.classification_head, args.push_to_repo, args.auth_token + ) diff --git a/docs/transformers/build/lib/transformers/models/esm/modeling_esm.py b/docs/transformers/build/lib/transformers/models/esm/modeling_esm.py new file mode 100644 index 0000000000000000000000000000000000000000..6f90d8d052f9bc9a3f0c5cdeb74aad823b44bf21 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/esm/modeling_esm.py @@ -0,0 +1,1273 @@ +# coding=utf-8 +# Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch ESM model.""" + +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward +from ...modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + MaskedLMOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import logging +from .configuration_esm import EsmConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "facebook/esm2_t6_8M_UR50D" +_CONFIG_FOR_DOC = "EsmConfig" + + +def rotate_half(x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(x, cos, sin): + cos = cos[:, :, : x.shape[-2], :] + sin = sin[:, :, : x.shape[-2], :] + + return (x * cos) + (rotate_half(x) * sin) + + +def gelu(x): + """ + This is the gelu implementation from the original ESM repo. Using F.gelu yields subtly wrong results. + """ + return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) + + +def symmetrize(x): + "Make layer symmetric in final two dimensions, used for contact prediction." + return x + x.transpose(-1, -2) + + +def average_product_correct(x): + "Perform average product correct, used for contact prediction." + a1 = x.sum(-1, keepdims=True) + a2 = x.sum(-2, keepdims=True) + a12 = x.sum((-1, -2), keepdims=True) + + avg = a1 * a2 + avg.div_(a12) # in-place to reduce memory + normalized = x - avg + return normalized + + +class RotaryEmbedding(torch.nn.Module): + """ + Rotary position embeddings based on those in + [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation + matrices which depend on their relative positions. + """ + + def __init__(self, dim: int): + super().__init__() + # Generate and save the inverse frequency buffer (non trainable) + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim)) + inv_freq = inv_freq + self.register_buffer("inv_freq", inv_freq) + + self._seq_len_cached = None + self._cos_cached = None + self._sin_cached = None + + def _update_cos_sin_tables(self, x, seq_dimension=2): + seq_len = x.shape[seq_dimension] + + # Reset the tables if the sequence length has changed, + # or if we're on a new device (possibly due to tracing for instance) + if seq_len != self._seq_len_cached or self._cos_cached.device != x.device: + self._seq_len_cached = seq_len + t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq) + freqs = torch.outer(t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + + self._cos_cached = emb.cos()[None, None, :, :] + self._sin_cached = emb.sin()[None, None, :, :] + + return self._cos_cached, self._sin_cached + + def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2) + + return ( + apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached), + apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached), + ) + + +class EsmContactPredictionHead(nn.Module): + """Performs symmetrization, apc, and computes a logistic regression on the output features""" + + def __init__( + self, + in_features: int, + bias=True, + eos_idx: int = 2, + ): + super().__init__() + self.in_features = in_features + self.eos_idx = eos_idx + self.regression = nn.Linear(in_features, 1, bias) + self.activation = nn.Sigmoid() + + def forward(self, tokens, attentions): + # remove eos token attentions + eos_mask = tokens.ne(self.eos_idx).to(attentions) + eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2) + attentions = attentions * eos_mask[:, None, None, :, :] + attentions = attentions[..., :-1, :-1] + # remove cls token attentions + attentions = attentions[..., 1:, 1:] + batch_size, layers, heads, seqlen, _ = attentions.size() + attentions = attentions.view(batch_size, layers * heads, seqlen, seqlen) + + # features: batch x channels x tokens x tokens (symmetric) + attentions = attentions.to( + self.regression.weight.device + ) # attentions always float32, may need to convert to float16 + attentions = average_product_correct(symmetrize(attentions)) + attentions = attentions.permute(0, 2, 3, 1) + return self.activation(self.regression(attentions).squeeze(3)) + + +class EsmEmbeddings(nn.Module): + """ + Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. + """ + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + + if config.emb_layer_norm_before: + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + else: + self.layer_norm = None + self.dropout = nn.Dropout(config.hidden_dropout_prob) + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + + self.padding_idx = config.pad_token_id + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx + ) + self.token_dropout = config.token_dropout + self.mask_token_id = config.mask_token_id + + def forward( + self, input_ids=None, attention_mask=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 + ): + if position_ids is None: + if input_ids is not None: + # Create the position ids from the input token ids. Any padded tokens remain padded. + position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length) + else: + position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + # Note that if we want to support ESM-1 (not 1b!) in future then we need to support an + # embedding_scale factor here. + embeddings = inputs_embeds + + # Matt: ESM has the option to handle masking in MLM in a slightly unusual way. If the token_dropout + # flag is False then it is handled in the same was as BERT/RoBERTa. If it is set to True, however, + # masked tokens are treated as if they were selected for input dropout and zeroed out. + # This "mask-dropout" is compensated for when masked tokens are not present, by scaling embeddings by + # a factor of (fraction of unmasked tokens during training) / (fraction of unmasked tokens in sample). + # This is analogous to the way that dropout layers scale down outputs during evaluation when not + # actually dropping out values (or, equivalently, scale up their un-dropped outputs in training). + if self.token_dropout: + embeddings = embeddings.masked_fill((input_ids == self.mask_token_id).unsqueeze(-1), 0.0) + mask_ratio_train = 0.15 * 0.8 # Hardcoded as the ratio used in all ESM model training runs + src_lengths = attention_mask.sum(-1) + mask_ratio_observed = (input_ids == self.mask_token_id).sum(-1).float() / src_lengths + embeddings = (embeddings * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]).to( + embeddings.dtype + ) + + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings = embeddings + position_embeddings + + if self.layer_norm is not None: + embeddings = self.layer_norm(embeddings) + if attention_mask is not None: + embeddings = (embeddings * attention_mask.unsqueeze(-1)).to(embeddings.dtype) + # Matt: I think this line was copied incorrectly from BERT, disabling it for now. + # embeddings = self.dropout(embeddings) + return embeddings + + def create_position_ids_from_inputs_embeds(self, inputs_embeds): + """ + We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids. + + Args: + inputs_embeds: torch.Tensor + + Returns: torch.Tensor + """ + input_shape = inputs_embeds.size()[:-1] + sequence_length = input_shape[1] + + position_ids = torch.arange( + self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device + ) + return position_ids.unsqueeze(0).expand(input_shape) + + +class EsmSelfAttention(nn.Module): + def __init__(self, config, position_embedding_type=None): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = position_embedding_type or getattr( + config, "position_embedding_type", "absolute" + ) + self.rotary_embeddings = None + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) + elif self.position_embedding_type == "rotary": + self.rotary_embeddings = RotaryEmbedding(dim=self.attention_head_size) + + self.is_decoder = config.is_decoder + + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_layer = past_key_value[0] + value_layer = past_key_value[1] + attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + # Matt: Our BERT model (which this code was derived from) scales attention logits down by sqrt(head_dim). + # ESM scales the query down by the same factor instead. Modulo numerical stability these are equivalent, + # but not when rotary embeddings get involved. Therefore, we scale the query here to match the original + # ESM code and fix rotary embeddings. + query_layer = query_layer * self.attention_head_size**-0.5 + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + if self.position_embedding_type == "rotary": + query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + seq_length = hidden_states.size()[1] + position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in EsmModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs.to(value_layer.dtype), value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + +class EsmSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = hidden_states + input_tensor + return hidden_states + + +class EsmAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.self = EsmSelfAttention(config) + self.output = EsmSelfOutput(config) + self.pruned_heads = set() + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + hidden_states_ln = self.LayerNorm(hidden_states) + self_outputs = self.self( + hidden_states_ln, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class EsmIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = gelu(hidden_states) + return hidden_states + + +class EsmOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = hidden_states + input_tensor + return hidden_states + + +class EsmLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = EsmAttention(config) + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + if not self.is_decoder: + raise RuntimeError(f"{self} should be used as a decoder model if cross attention is added") + self.crossattention = EsmAttention(config) + self.intermediate = EsmIntermediate(config) + self.output = EsmOutput(config) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + + # if decoder, the last output is tuple of self-attn cache + if self.is_decoder: + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + else: + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + cross_attn_present_key_value = None + if self.is_decoder and encoder_hidden_states is not None: + if not hasattr(self, "crossattention"): + raise AttributeError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated" + " with cross-attention layers by setting `config.add_cross_attention=True`" + ) + + # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + cross_attn_past_key_value, + output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + + # add cross-attn cache to positions 3,4 of present_key_value tuple + cross_attn_present_key_value = cross_attention_outputs[-1] + present_key_value = present_key_value + cross_attn_present_key_value + + layer_output = self.feed_forward_chunk(attention_output) + + outputs = (layer_output,) + outputs + + # if decoder, return the attn key/values as the last output + if self.is_decoder: + outputs = outputs + (present_key_value,) + return outputs + + def feed_forward_chunk(self, attention_output): + attention_output_ln = self.LayerNorm(attention_output) + intermediate_output = self.intermediate(attention_output_ln) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class EsmEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([EsmLayer(config) for _ in range(config.num_hidden_layers)]) + self.emb_layer_norm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + ): + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " + "`use_cache=False`..." + ) + use_cache = False + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + next_decoder_cache = () if use_cache else None + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache = next_decoder_cache + (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if self.emb_layer_norm_after: + hidden_states = self.emb_layer_norm_after(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +# Copied from transformers.models.bert.modeling_bert.BertPooler +class EsmPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class EsmPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = EsmConfig + base_model_prefix = "esm" + supports_gradient_checkpointing = True + _no_split_modules = ["EsmLayer", "EsmFoldTriangularSelfAttentionBlock", "EsmEmbeddings"] + + # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights with BertLMPredictionHead->EsmLMHead + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, EsmLMHead): + module.bias.data.zero_() + + +ESM_START_DOCSTRING = r""" + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`EsmConfig`]): Model configuration class with all the parameters of the + model. Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +ESM_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare ESM Model transformer outputting raw hidden-states without any specific head on top.", + ESM_START_DOCSTRING, +) +class EsmModel(EsmPreTrainedModel): + """ + + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in [Attention is + all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + + To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set + to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and + `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass. + """ + + def __init__(self, config, add_pooling_layer=True): + super().__init__(config) + self.config = config + + self.embeddings = EsmEmbeddings(config) + self.encoder = EsmEncoder(config) + + self.pooler = EsmPooler(config) if add_pooling_layer else None + + self.contact_head = EsmContactPredictionHead( + in_features=config.num_hidden_layers * config.num_attention_heads, bias=True + ) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(ESM_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPoolingAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.config.is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + def predict_contacts(self, tokens, attention_mask): + attns = self(tokens, attention_mask=attention_mask, return_dict=True, output_attentions=True).attentions + attns = torch.stack(attns, dim=1) # Matches the original model layout + # In the original model, attentions for padding tokens are completely zeroed out. + # This makes no difference most of the time because the other tokens won't attend to them, + # but it does for the contact prediction task, which takes attentions as input, + # so we have to mimic that here. + attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(3) + attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(4) + return self.contact_head(tokens, attns) + + +@add_start_docstrings("""ESM Model with a `language modeling` head on top.""", ESM_START_DOCSTRING) +class EsmForMaskedLM(EsmPreTrainedModel): + _tied_weights_keys = ["lm_head.decoder.weight"] + + def __init__(self, config): + super().__init__(config) + + if config.is_decoder: + logger.warning( + "If you want to use `EsmForMaskedLM` make sure `config.is_decoder=False` for " + "bi-directional self-attention." + ) + + self.esm = EsmModel(config, add_pooling_layer=False) + self.lm_head = EsmLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.lm_head.decoder + + def set_output_embeddings(self, new_embeddings): + self.lm_head.decoder = new_embeddings + + @add_start_docstrings_to_model_forward(ESM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MaskedLMOutput, + config_class=_CONFIG_FOR_DOC, + mask="", + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, MaskedLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + kwargs (`Dict[str, any]`, *optional*, defaults to `{}`): + Used to hide legacy arguments that have been deprecated. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.esm( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + prediction_scores = self.lm_head(sequence_output) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + + labels = labels.to(prediction_scores.device) + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def predict_contacts(self, tokens, attention_mask): + return self.esm.predict_contacts(tokens, attention_mask=attention_mask) + + +class EsmLMHead(nn.Module): + """ESM Head for masked language modeling.""" + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + def forward(self, features, **kwargs): + x = self.dense(features) + x = gelu(x) + x = self.layer_norm(x) + + # project back to size of vocabulary with bias + x = self.decoder(x) + self.bias + return x + + +@add_start_docstrings( + """ + ESM Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled + output) e.g. for GLUE tasks. + """, + ESM_START_DOCSTRING, +) +class EsmForSequenceClassification(EsmPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + + self.esm = EsmModel(config, add_pooling_layer=False) + self.classifier = EsmClassificationHead(config) + + self.init_weights() + + @add_start_docstrings_to_model_forward(ESM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=SequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.esm( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + labels = labels.to(logits.device) + + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + ESM Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + ESM_START_DOCSTRING, +) +class EsmForTokenClassification(EsmPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.esm = EsmModel(config, add_pooling_layer=False) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + self.init_weights() + + @add_start_docstrings_to_model_forward(ESM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.esm( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + + labels = labels.to(logits.device) + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class EsmClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.out_proj = nn.Linear(config.hidden_size, config.num_labels) + + def forward(self, features, **kwargs): + x = features[:, 0, :] # take token (equiv. to [CLS]) + x = self.dropout(x) + x = self.dense(x) + x = torch.tanh(x) + x = self.dropout(x) + x = self.out_proj(x) + return x + + +def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0): + """ + Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols + are ignored. This is modified from fairseq's `utils.make_positions`. + + Args: + x: torch.Tensor x: + + Returns: torch.Tensor + """ + # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. + mask = input_ids.ne(padding_idx).int() + incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask + return incremental_indices.long() + padding_idx + + +__all__ = [ + "EsmForMaskedLM", + "EsmForSequenceClassification", + "EsmForTokenClassification", + "EsmModel", + "EsmPreTrainedModel", +] diff --git a/docs/transformers/build/lib/transformers/models/esm/modeling_esmfold.py b/docs/transformers/build/lib/transformers/models/esm/modeling_esmfold.py new file mode 100644 index 0000000000000000000000000000000000000000..645c9d16a5c5dcd012cf3d91ca784267ede9c654 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/esm/modeling_esmfold.py @@ -0,0 +1,2325 @@ +# coding=utf-8 +# Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +import sys +from dataclasses import dataclass +from functools import partial +from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +from torch.nn import LayerNorm + +from ...integrations.deepspeed import is_deepspeed_available +from ...modeling_outputs import ModelOutput +from ...utils import ( + ContextManagers, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_scipy_available, + logging, + replace_return_docstrings, +) +from .configuration_esm import EsmConfig +from .modeling_esm import ESM_START_DOCSTRING, EsmModel, EsmPreTrainedModel +from .openfold_utils import ( + OFProtein, + Rigid, + Rotation, + atom14_to_atom37, + chunk_layer, + compute_predicted_aligned_error, + compute_tm, + frames_and_literature_positions_to_atom14_pos, + make_atom14_masks, + residue_constants, + to_pdb, + torsion_angles_to_frames, +) + + +logger = logging.get_logger(__name__) +_CHECKPOINT_FOR_DOC = "facebook/esmfold_v1" +_CONFIG_FOR_DOC = "EsmConfig" + + +@dataclass +class EsmForProteinFoldingOutput(ModelOutput): + """ + Output type of [`EsmForProteinFoldingOutput`]. + + Args: + frames (`torch.FloatTensor`): + Output frames. + sidechain_frames (`torch.FloatTensor`): + Output sidechain frames. + unnormalized_angles (`torch.FloatTensor`): + Predicted unnormalized backbone and side chain torsion angles. + angles (`torch.FloatTensor`): + Predicted backbone and side chain torsion angles. + positions (`torch.FloatTensor`): + Predicted positions of the backbone and side chain atoms. + states (`torch.FloatTensor`): + Hidden states from the protein folding trunk. + s_s (`torch.FloatTensor`): + Per-residue embeddings derived by concatenating the hidden states of each layer of the ESM-2 LM stem. + s_z (`torch.FloatTensor`): + Pairwise residue embeddings. + distogram_logits (`torch.FloatTensor`): + Input logits to the distogram used to compute residue distances. + lm_logits (`torch.FloatTensor`): + Logits output by the ESM-2 protein language model stem. + aatype (`torch.FloatTensor`): + Input amino acids (AlphaFold2 indices). + atom14_atom_exists (`torch.FloatTensor`): + Whether each atom exists in the atom14 representation. + residx_atom14_to_atom37 (`torch.FloatTensor`): + Mapping between atoms in the atom14 and atom37 representations. + residx_atom37_to_atom14 (`torch.FloatTensor`): + Mapping between atoms in the atom37 and atom14 representations. + atom37_atom_exists (`torch.FloatTensor`): + Whether each atom exists in the atom37 representation. + residue_index (`torch.FloatTensor`): + The index of each residue in the protein chain. Unless internal padding tokens are used, this will just be + a sequence of integers from 0 to `sequence_length`. + lddt_head (`torch.FloatTensor`): + Raw outputs from the lddt head used to compute plddt. + plddt (`torch.FloatTensor`): + Per-residue confidence scores. Regions of low confidence may indicate areas where the model's prediction is + uncertain, or where the protein structure is disordered. + ptm_logits (`torch.FloatTensor`): + Raw logits used for computing ptm. + ptm (`torch.FloatTensor`): + TM-score output representing the model's high-level confidence in the overall structure. + aligned_confidence_probs (`torch.FloatTensor`): + Per-residue confidence scores for the aligned structure. + predicted_aligned_error (`torch.FloatTensor`): + Predicted error between the model's prediction and the ground truth. + max_predicted_aligned_error (`torch.FloatTensor`): + Per-sample maximum predicted error. + """ + + frames: Optional[torch.FloatTensor] = None + sidechain_frames: Optional[torch.FloatTensor] = None + unnormalized_angles: Optional[torch.FloatTensor] = None + angles: Optional[torch.FloatTensor] = None + positions: Optional[torch.FloatTensor] = None + states: Optional[torch.FloatTensor] = None + s_s: Optional[torch.FloatTensor] = None + s_z: Optional[torch.FloatTensor] = None + distogram_logits: Optional[torch.FloatTensor] = None + lm_logits: Optional[torch.FloatTensor] = None + aatype: Optional[torch.FloatTensor] = None + atom14_atom_exists: Optional[torch.FloatTensor] = None + residx_atom14_to_atom37: Optional[torch.FloatTensor] = None + residx_atom37_to_atom14: Optional[torch.FloatTensor] = None + atom37_atom_exists: Optional[torch.FloatTensor] = None + residue_index: Optional[torch.FloatTensor] = None + lddt_head: Optional[torch.FloatTensor] = None + plddt: Optional[torch.FloatTensor] = None + ptm_logits: Optional[torch.FloatTensor] = None + ptm: Optional[torch.FloatTensor] = None + aligned_confidence_probs: Optional[torch.FloatTensor] = None + predicted_aligned_error: Optional[torch.FloatTensor] = None + max_predicted_aligned_error: Optional[torch.FloatTensor] = None + + +ESMFOLD_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + masking_pattern (`torch.LongTensor` of shape `({0})`, *optional*): + Locations of tokens to mask during training as a form of regularization. Mask values selected in `[0, 1]`. + num_recycles (`int`, *optional*, defaults to `None`): + Number of times to recycle the input sequence. If `None`, defaults to `config.num_recycles`. "Recycling" + consists of passing the output of the folding trunk back in as input to the trunk. During training, the + number of recycles should vary with each batch, to ensure that the model learns to output valid predictions + after each recycle. During inference, num_recycles should be set to the highest value that the model was + trained with for maximum accuracy. Accordingly, when this value is set to `None`, config.max_recycles is + used. +""" + + +def is_fp16_enabled(): + # Autocast world + fp16_enabled = torch.get_autocast_gpu_dtype() == torch.float16 + fp16_enabled = fp16_enabled and torch.is_autocast_enabled() + + return fp16_enabled + + +def is_deepspeed_initialized(): + if is_deepspeed_available(): + return False + else: + try: + import deepspeed + + # This is not available in all DeepSpeed versions. + return deepspeed.utils.is_initialized() + except Exception: + return False + + +def collate_dense_tensors(samples: List[torch.Tensor], pad_v: float = 0) -> torch.Tensor: + """ + Takes a list of tensors with the following dimensions: + [(d_11, ..., d_1K), + (d_21, ..., d_2K), ..., (d_N1, ..., d_NK)] + and stack + pads them into a single tensor of: + (N, max_i=1,N { d_i1 }, ..., max_i=1,N {diK}) + """ + if len(samples) == 0: + return torch.Tensor() + if len({x.dim() for x in samples}) != 1: + raise RuntimeError(f"Samples has varying dimensions: {[x.dim() for x in samples]}") + (device,) = tuple({x.device for x in samples}) # assumes all on same device + max_shape = [max(lst) for lst in zip(*[x.shape for x in samples])] + result = torch.empty(len(samples), *max_shape, dtype=samples[0].dtype, device=device) + result.fill_(pad_v) + for i in range(len(samples)): + result_i = result[i] + t = samples[i] + result_i[tuple(slice(0, k) for k in t.shape)] = t + return result + + +def flatten_final_dims(t: torch.Tensor, no_dims: int): + return t.reshape(t.shape[:-no_dims] + (-1,)) + + +def permute_final_dims(tensor: torch.Tensor, inds: List[int]): + zero_index = -1 * len(inds) + first_inds = list(range(len(tensor.shape[:zero_index]))) + return tensor.permute(first_inds + [zero_index + i for i in inds]) + + +def dict_multimap(fn, dicts): + first = dicts[0] + new_dict = {} + for k, v in first.items(): + all_v = [d[k] for d in dicts] + if isinstance(v, dict): + new_dict[k] = dict_multimap(fn, all_v) + else: + new_dict[k] = fn(all_v) + + return new_dict + + +def trunc_normal_init_(weights, scale=1.0, fan="fan_in"): + shape = weights.shape + scale = scale / max(1, shape[1]) + + if not is_scipy_available(): + logger.warning( + "This init requires scipy, but scipy was not found, default to an approximation that might not be" + " equivalent." + ) + std = math.sqrt(scale) + torch.nn.init.normal_(weights, std=std).clamp(min=0.0, max=2.0 * std) + + else: + from scipy.stats import truncnorm + + std = math.sqrt(scale) / truncnorm.std(a=-2, b=2, loc=0, scale=1) + samples = truncnorm.rvs(a=-2, b=2, loc=0, scale=std, size=weights.numel()) + samples = np.reshape(samples, shape) + weights.copy_(torch.tensor(samples, device=weights.device)) + + +def ipa_point_weights_init_(weights): + with torch.no_grad(): + softplus_inverse_1 = 0.541324854612918 + weights.fill_(softplus_inverse_1) + + +class EsmFoldLinear(nn.Linear): + """ + A Linear layer with built-in nonstandard initializations. Called just like torch.nn.Linear. + + Implements the initializers in 1.11.4, plus some additional ones found in the code. + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + bias: bool = True, + init: str = "default", + init_fn: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None, + ): + """ + Args: + in_dim: + The final dimension of inputs to the layer + out_dim: + The final dimension of layer outputs + bias: + Whether to learn an additive bias. True by default + init: + The initializer to use. Choose from: + + "default": LeCun fan-in truncated normal initialization "relu": He initialization w/ truncated normal + distribution "glorot": Fan-average Glorot uniform initialization "gating": Weights=0, Bias=1 "normal": + Normal initialization with std=1/sqrt(fan_in) "final": Weights=0, Bias=0 + + Overridden by init_fn if the latter is not None. + init_fn: + A custom initializer taking weight and bias as inputs. Overrides init if not None. + """ + super().__init__(in_dim, out_dim, bias=bias) + + if bias: + with torch.no_grad(): + self.bias.fill_(0) + self.init = init + self.init_fn = init_fn + + if init not in ["default", "relu", "glorot", "gating", "normal", "final"]: + raise ValueError("Invalid init string.") + + +class EsmFoldLayerNorm(nn.Module): + def __init__(self, c_in, eps=1e-5): + super().__init__() + + self.c_in = (c_in,) + self.eps = eps + + self.weight = nn.Parameter(torch.ones(c_in)) + self.bias = nn.Parameter(torch.zeros(c_in)) + + def forward(self, x): + d = x.dtype + if d is torch.bfloat16 and not is_deepspeed_initialized(): + with torch.cuda.amp.autocast(enabled=False): + out = nn.functional.layer_norm(x, self.c_in, self.weight.to(dtype=d), self.bias.to(dtype=d), self.eps) + else: + out = nn.functional.layer_norm(x, self.c_in, self.weight, self.bias, self.eps) + + return out + + +@torch.jit.ignore +def softmax_no_cast(t: torch.Tensor, dim: int = -1) -> torch.Tensor: + """ + Softmax, but without automatic casting to fp32 when the input is of type bfloat16 + """ + d = t.dtype + if d is torch.bfloat16 and not is_deepspeed_initialized(): + with torch.cuda.amp.autocast(enabled=False): + s = torch.nn.functional.softmax(t, dim=dim) + else: + s = torch.nn.functional.softmax(t, dim=dim) + + return s + + +class EsmFoldAttention(nn.Module): + """ + Standard multi-head attention using AlphaFold's default layer initialization. Allows multiple bias vectors. + """ + + def __init__( + self, + c_q: int, + c_k: int, + c_v: int, + c_hidden: int, + no_heads: int, + gating: bool = True, + ): + """ + Args: + c_q: + Input dimension of query data + c_k: + Input dimension of key data + c_v: + Input dimension of value data + c_hidden: + Per-head hidden dimension + no_heads: + Number of attention heads + gating: + Whether the output should be gated using query data + """ + super().__init__() + + self.c_q = c_q + self.c_k = c_k + self.c_v = c_v + self.c_hidden = c_hidden + self.no_heads = no_heads + self.gating = gating + + # DISCREPANCY: c_hidden is not the per-head channel dimension, as + # stated in the supplement, but the overall channel dimension. + + self.linear_q = EsmFoldLinear(self.c_q, self.c_hidden * self.no_heads, bias=False, init="glorot") + self.linear_k = EsmFoldLinear(self.c_k, self.c_hidden * self.no_heads, bias=False, init="glorot") + self.linear_v = EsmFoldLinear(self.c_v, self.c_hidden * self.no_heads, bias=False, init="glorot") + self.linear_o = EsmFoldLinear(self.c_hidden * self.no_heads, self.c_q, init="final") + + self.linear_g = None + if self.gating: + self.linear_g = EsmFoldLinear(self.c_q, self.c_hidden * self.no_heads, init="gating") + + self.sigmoid = nn.Sigmoid() + + def _prep_qkv(self, q_x: torch.Tensor, kv_x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # [*, Q/K/V, H * C_hidden] + q = self.linear_q(q_x) + k = self.linear_k(kv_x) + v = self.linear_v(kv_x) + + # [*, Q/K, H, C_hidden] + q = q.view(q.shape[:-1] + (self.no_heads, -1)) + k = k.view(k.shape[:-1] + (self.no_heads, -1)) + v = v.view(v.shape[:-1] + (self.no_heads, -1)) + + # [*, H, Q/K, C_hidden] + q = q.transpose(-2, -3) + k = k.transpose(-2, -3) + v = v.transpose(-2, -3) + + q /= math.sqrt(self.c_hidden) + + return q, k, v + + def _wrap_up(self, o: torch.Tensor, q_x: torch.Tensor) -> torch.Tensor: + if self.linear_g is not None: + g = self.sigmoid(self.linear_g(q_x)) + + # [*, Q, H, C_hidden] + g = g.view(g.shape[:-1] + (self.no_heads, -1)) + o = o * g + + # [*, Q, H * C_hidden] + o = flatten_final_dims(o, 2) + + # [*, Q, C_q] + o = self.linear_o(o) + + return o + + def forward( + self, + q_x: torch.Tensor, + kv_x: torch.Tensor, + biases: Optional[List[torch.Tensor]] = None, + use_memory_efficient_kernel: bool = False, + use_lma: bool = False, + lma_q_chunk_size: int = 1024, + lma_kv_chunk_size: int = 4096, + use_flash: bool = False, + flash_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Args: + q_x: + [*, Q, C_q] query data + kv_x: + [*, K, C_k] key data + biases: + List of biases that broadcast to [*, H, Q, K] + use_memory_efficient_kernel: + Whether to use a custom memory-efficient attention kernel. This should be the default choice for most. + If none of the "use_<...>" flags are True, a stock PyTorch implementation is used instead + use_lma: + Whether to use low-memory attention (Staats & Rabe 2021). If none of the "use_<...>" flags are True, a + stock PyTorch implementation is used instead + lma_q_chunk_size: + Query chunk size (for LMA) + lma_kv_chunk_size: + Key/Value chunk size (for LMA) + Returns + [*, Q, C_q] attention update + """ + if use_lma and (lma_q_chunk_size is None or lma_kv_chunk_size is None): + raise ValueError("If use_lma is specified, lma_q_chunk_size and lma_kv_chunk_size must be provided") + + if use_flash and biases is not None: + raise ValueError("use_flash is incompatible with the bias option. For masking, use flash_mask instead") + + attn_options = [use_memory_efficient_kernel, use_lma, use_flash] + if sum(attn_options) > 1: + raise ValueError("Choose at most one alternative attention algorithm") + + if biases is None: + biases = [] + + # [*, H, Q/K, C_hidden] + query, key, value = self._prep_qkv(q_x, kv_x) + key = permute_final_dims(key, (1, 0)) + + # [*, H, Q, K] + output = torch.matmul(query, key) + for b in biases: + output += b + output = softmax_no_cast(output, -1) + + # [*, H, Q, C_hidden] + output = torch.matmul(output, value) + output = output.transpose(-2, -3) + output = self._wrap_up(output, q_x) + + return output + + +class EsmFoldTriangleAttention(nn.Module): + def __init__(self, c_in, c_hidden, no_heads, starting=True, inf=1e9): + """ + Args: + c_in: + Input channel dimension + c_hidden: + Overall hidden channel dimension (not per-head) + no_heads: + Number of attention heads + """ + super().__init__() + + self.c_in = c_in + self.c_hidden = c_hidden + self.no_heads = no_heads + self.starting = starting + self.inf = inf + + self.layer_norm = LayerNorm(self.c_in) + + self.linear = EsmFoldLinear(c_in, self.no_heads, bias=False, init="normal") + + self.mha = EsmFoldAttention(self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads) + + @torch.jit.ignore + def _chunk( + self, + x: torch.Tensor, + biases: List[torch.Tensor], + chunk_size: int, + use_memory_efficient_kernel: bool = False, + use_lma: bool = False, + inplace_safe: bool = False, + ) -> torch.Tensor: + "triangle! triangle!" + mha_inputs = { + "q_x": x, + "kv_x": x, + "biases": biases, + } + + return chunk_layer( + partial(self.mha, use_memory_efficient_kernel=use_memory_efficient_kernel, use_lma=use_lma), + mha_inputs, + chunk_size=chunk_size, + no_batch_dims=len(x.shape[:-2]), + _out=x if inplace_safe else None, + ) + + def forward( + self, + x: torch.Tensor, + mask: Optional[torch.Tensor] = None, + chunk_size: Optional[int] = None, + use_memory_efficient_kernel: bool = False, + use_lma: bool = False, + inplace_safe: bool = False, + ) -> torch.Tensor: + """ + Args: + x: + [*, I, J, C_in] input tensor (e.g. the pair representation) + Returns: + [*, I, J, C_in] output tensor + """ + if mask is None: + # [*, I, J] + mask = x.new_ones( + x.shape[:-1], + ) + + if not self.starting: + x = x.transpose(-2, -3) + mask = mask.transpose(-1, -2) + + # [*, I, J, C_in] + x = self.layer_norm(x) + + # [*, I, 1, 1, J] + mask_bias = (self.inf * (mask - 1))[..., :, None, None, :] + + # [*, H, I, J] + triangle_bias = permute_final_dims(self.linear(x), (2, 0, 1)) + + # [*, 1, H, I, J] + triangle_bias = triangle_bias.unsqueeze(-4) + + biases = [mask_bias, triangle_bias] + + if chunk_size is not None: + x = self._chunk( + x, + biases, + chunk_size, + use_memory_efficient_kernel=use_memory_efficient_kernel, + use_lma=use_lma, + inplace_safe=inplace_safe, + ) + else: + x = self.mha( + q_x=x, kv_x=x, biases=biases, use_memory_efficient_kernel=use_memory_efficient_kernel, use_lma=use_lma + ) + + if not self.starting: + x = x.transpose(-2, -3) + + return x + + +class EsmFoldTriangleMultiplicativeUpdate(nn.Module): + """ + Implements Algorithms 11 and 12. + """ + + def __init__(self, config, _outgoing=True): + super().__init__() + c_hidden = config.pairwise_state_dim + self._outgoing = _outgoing + + self.linear_a_p = EsmFoldLinear(c_hidden, c_hidden) + self.linear_a_g = EsmFoldLinear(c_hidden, c_hidden, init="gating") + self.linear_b_p = EsmFoldLinear(c_hidden, c_hidden) + self.linear_b_g = EsmFoldLinear(c_hidden, c_hidden, init="gating") + self.linear_g = EsmFoldLinear(c_hidden, c_hidden, init="gating") + self.linear_z = EsmFoldLinear(c_hidden, c_hidden, init="final") + + self.layer_norm_in = LayerNorm(c_hidden) + self.layer_norm_out = LayerNorm(c_hidden) + + self.sigmoid = nn.Sigmoid() + + def _combine_projections( + self, a: torch.Tensor, b: torch.Tensor, _inplace_chunk_size: Optional[int] = None + ) -> torch.Tensor: + if self._outgoing: + a = permute_final_dims(a, (2, 0, 1)) + b = permute_final_dims(b, (2, 1, 0)) + else: + a = permute_final_dims(a, (2, 1, 0)) + b = permute_final_dims(b, (2, 0, 1)) + + if _inplace_chunk_size is not None: + # To be replaced by torch vmap + for i in range(0, a.shape[-3], _inplace_chunk_size): + a_chunk = a[..., i : i + _inplace_chunk_size, :, :] + b_chunk = b[..., i : i + _inplace_chunk_size, :, :] + a[..., i : i + _inplace_chunk_size, :, :] = torch.matmul( + a_chunk, + b_chunk, + ) + + p = a + else: + p = torch.matmul(a, b) + + return permute_final_dims(p, (1, 2, 0)) + + def _inference_forward( + self, + z: torch.Tensor, + mask: Optional[torch.Tensor] = None, + inplace_chunk_size: Optional[int] = None, + with_add: bool = True, + ): + """ + Args: + z: + A [*, N, N, C_z] pair representation + mask: + A [*, N, N] pair mask + inplace_chunk_size: + Size of chunks used in the main computation. Increase to trade memory for speed. + with_add: + If True, z is overwritten with (z + update). Otherwise, it is overwritten with (update). + Returns: + A reference to the overwritten z + + More memory-efficient, inference-only version of the forward function. Uses in-place operations, fusion of the + addition that happens after this module in the Evoformer, a smidge of recomputation, and a cache of overwritten + values to lower peak memory consumption of this module from 5x the size of the input tensor z to 2.5x its size. + Useful for inference on extremely long sequences. + + It works as follows. We will make reference to variables used in the default forward implementation below. + Naively, triangle multiplication attention requires the manifestation of 5 tensors the size of z: 1) z, the + "square" input tensor, 2) a, the first projection of z, 3) b, the second projection of b, 4) g, a z-sized mask, + and 5) a z-sized tensor for intermediate computations. For large N, this is prohibitively expensive; for + N=4000, for example, z is more than 8GB alone. To avoid this problem, we compute b, g, and all intermediate + tensors in small chunks, noting that the chunks required to compute a chunk of the output depend only on the + tensor a and corresponding vertical and horizontal chunks of z. This suggests an algorithm that loops over + pairs of chunks of z: hereafter "columns" and "rows" of z, even though each "column" and "row" in fact contains + inplace_chunk_size contiguous true columns and rows of z. Writing output chunks to a new tensor would bring + total memory consumption down to 3x the size of z. However, more memory can be saved by writing output chunks + directly to z in-place. WLOG, we choose to write output chunks vertically, overwriting the ith "column" of z at + the end of the ith iteration of the main loop. Despite this overwriting, the ith column is always one column + ahead of previously overwritten columns and can be recovered directly from z. After the first iteration, + however, the ith row of z is always at least partially overwritten. For this reason, we introduce the z-cache, + a tensor one-half the size of z. The z-cache initially contains the left half (2nd and 3rd quadrants) of z. For + 0 < i < N/2, the missing left part of the ith row of z is recovered from this cache at the beginning of the ith + iteration. Once i exceeds n/2, the cache is "reoriented" to encompass the 3rd and 4th quadrants of z instead. + Though the 3rd quadrant of the original z is entirely overwritten at this point, it can be recovered from the + z-cache itself. Thereafter, the ith row of z can be recovered in its entirety from the reoriented z-cache. + After the final iteration, z has been completely overwritten and contains the triangular multiplicative update. + If with_add is True, it instead contains the sum of z and the triangular multiplicative update. In either case, + peak memory consumption is just 2.5x the size of z, disregarding memory used for chunks and other small + variables. + """ + if mask is None: + mask = z.new_ones(z.shape[:-1]) + + mask = mask.unsqueeze(-1) + + def compute_projection_helper(pair, mask, a=True): + if a: + linear_g = self.linear_a_g + linear_p = self.linear_a_p + else: + linear_g = self.linear_b_g + linear_p = self.linear_b_p + + pair = self.layer_norm_in(pair) + p = linear_g(pair) + p.sigmoid_() + p *= linear_p(pair) + p *= mask + p = permute_final_dims(p, (2, 0, 1)) + return p + + def compute_projection(pair, mask, a=True, chunked=True): + need_transpose = self._outgoing ^ a + if not chunked: + p = compute_projection_helper(pair, mask, a) + if need_transpose: + p = p.transpose(-1, -2) + else: + # This computation is chunked so as not to exceed our 2.5x + # budget with a large intermediate tensor + linear_g = self.linear_a_g if a else self.linear_b_g + c = linear_g.bias.shape[-1] + out_shape = pair.shape[:-3] + (c,) + pair.shape[-3:-1] + p = pair.new_zeros(out_shape) + for i in range(0, pair.shape[-3], inplace_chunk_size): + pair_chunk = pair[..., i : i + inplace_chunk_size, :, :] + pair_chunk = compute_projection_helper( + pair[..., i : i + inplace_chunk_size, :, :], + mask[..., i : i + inplace_chunk_size, :, :], + a, + ) + if need_transpose: + pair_chunk = pair_chunk.transpose(-1, -2) + p[..., i : i + inplace_chunk_size] = pair_chunk + else: + p[..., i : i + inplace_chunk_size, :] = pair_chunk + + del pair_chunk + + return p + + # We start by fully manifesting a. In addition to the input, this + # brings total memory consumption to 2x z (disregarding size of chunks) + # [*, N, N, c] + a = compute_projection(z, mask, True, chunked=True) + + if inplace_chunk_size is not None: + n = a.shape[-1] + half_n = n // 2 + n % 2 + row_dim = -3 + col_dim = -2 + b_chunk_dim = row_dim if self._outgoing else col_dim + + def empty_slicer(t): + return [slice(None) for _ in t.shape] + + def slice_tensor(t, start, end, dim): + # Slices start:end from the dim dimension of t + s = empty_slicer(t) + s[dim] = slice(start, end) + return t[s] + + def flip_z_cache_(z_cache, z): + # "Reorient" the z_cache (see below), filling it with quadrants + # 3---recovered from the z_cache---and 4---recovered from z--- + # of the input tensor z. + quadrant_3 = slice_tensor(z_cache, half_n, None, row_dim) + z_cache = z_cache.transpose(row_dim, col_dim) + + # If n is odd, we need to shrink the z_cache by one row + z_cache = z_cache[..., : (n // 2), :, :] + + # Move the 3rd quadrant of z into the + first_half_slicer = empty_slicer(z_cache) + first_half_slicer[col_dim] = slice(0, half_n) + z_cache[first_half_slicer] = quadrant_3 + + # Get the fourth quadrant of z + quadrant_4 = slice_tensor(z, half_n, None, row_dim) + quadrant_4 = slice_tensor(quadrant_4, half_n, None, col_dim) + + # Insert said quadrant into the rotated z-cache + quadrant_3_slicer = empty_slicer(z_cache) + quadrant_3_slicer[col_dim] = slice(half_n, None) + + z_cache[quadrant_3_slicer] = quadrant_4 + + return z_cache + + # Initialize the z cache to the left half of z. + z_cache_shape = list(z.shape) + z_cache_shape[col_dim] = half_n + z_cache = z.new_zeros(z_cache_shape) + z_cache_slicer = empty_slicer(z_cache) + z_cache_slicer[col_dim] = slice(0, half_n) + z_cache.copy_(z[z_cache_slicer]) + z_cache_rotated = False + + # We need to reorient the z-cache at the halfway point, and we + # don't want a single chunk to straddle that point. We contract one + # of the chunks in the middle to address that problem. + i_range = list(range(0, half_n, inplace_chunk_size)) + initial_offsets = [i_2 - i_1 for i_1, i_2 in zip(i_range, i_range[1:] + [half_n])] + after_half = list(range(half_n, n, inplace_chunk_size)) + after_half_offsets = [inplace_chunk_size for _ in after_half] + combined_range_with_offsets = zip(i_range + after_half, initial_offsets + after_half_offsets) + for i, offset in combined_range_with_offsets: + if not z_cache_rotated and i >= half_n: + z_cache = flip_z_cache_(z_cache, z) + z_cache_rotated = True + + z_chunk_b = slice_tensor(z, i, i + offset, b_chunk_dim) + mask_chunk = slice_tensor(mask, i, i + offset, b_chunk_dim) + + z_chunk_b = z_chunk_b.clone() + if b_chunk_dim == col_dim: + z_chunk_b = slice_tensor(z, i, i + offset, col_dim) + else: # b_chunk_dim == row_dim + # In this case, the b-dimension (b_chunk_dim) is partially + # overwritten at the end of each iteration. We need to + # restore the missing component from the z-cache. + if not z_cache_rotated: + z_chunk_slicer = empty_slicer(z_chunk_b) + z_chunk_slicer[col_dim] = slice(0, half_n) + z_chunk_b[z_chunk_slicer] = slice_tensor(z_cache, i, i + offset, row_dim) + else: + z_cache_offset = i - half_n + z_chunk_b = slice_tensor(z_cache, z_cache_offset, z_cache_offset + offset, row_dim) + + b_chunk = compute_projection(z_chunk_b, mask_chunk, a=False, chunked=False) + del z_chunk_b + + x_chunk = torch.matmul(a, b_chunk) + x_chunk = permute_final_dims(x_chunk, (1, 2, 0)) + x_chunk = self.layer_norm_out(x_chunk) + x_chunk = self.linear_z(x_chunk) + + # The g dimension (col_dim) is parallel to and ahead of the + # overwrites in z. We can extract the g chunk normally. + z_chunk_g = slice_tensor(z, i, i + offset, col_dim) + g_chunk = self.linear_g(self.layer_norm_in(z_chunk_g)) + g_chunk.sigmoid_() + del z_chunk_g + + x_chunk *= g_chunk + + # Write the columns into z in-place + z_slicer = empty_slicer(z) + z_slicer[col_dim] = slice(i, i + offset) + if with_add: + z[z_slicer] += x_chunk + else: + z[z_slicer] = x_chunk + else: + b = compute_projection(z, mask, False, False) + x = torch.matmul(a, b) + x = self.layer_norm_out(x) + x = self.linear_z(x) + g = self.linear_g(z) + g.sigmoid_() + x *= g + if with_add: + z += x + else: + z = x + + return z + + def forward( + self, + z: torch.Tensor, + mask: Optional[torch.Tensor] = None, + inplace_safe: bool = False, + _add_with_inplace: bool = False, + _inplace_chunk_size: Optional[int] = 256, + ) -> torch.Tensor: + """ + Args: + x: + [*, N_res, N_res, C_z] input tensor + mask: + [*, N_res, N_res] input mask + Returns: + [*, N_res, N_res, C_z] output tensor + """ + if inplace_safe: + x = self._inference_forward( + z, + mask, + inplace_chunk_size=_inplace_chunk_size, + with_add=_add_with_inplace, + ) + return x + + if mask is None: + mask = z.new_ones(z.shape[:-1]) + + mask = mask.unsqueeze(-1) + + z = self.layer_norm_in(z) + a = mask + a = a * self.sigmoid(self.linear_a_g(z)) + a = a * self.linear_a_p(z) + b = mask + b = b * self.sigmoid(self.linear_b_g(z)) + b = b * self.linear_b_p(z) + + if is_fp16_enabled(): + with torch.cuda.amp.autocast(enabled=False): + x = self._combine_projections(a.float(), b.float()) + else: + x = self._combine_projections(a, b) + + del a, b + x = self.layer_norm_out(x) + x = self.linear_z(x) + g = self.sigmoid(self.linear_g(z)) + x = x * g + + return x + + +class EsmFoldPreTrainedModel(EsmPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + # Subclass `EsMPreTrainedModel` to deal with special init + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, EsmFoldLinear): + with torch.no_grad(): + if module.init_fn is not None: + module.init_fn(module.weight, module.bias) + elif module.init == "default": + trunc_normal_init_(module.weight, scale=1.0) + elif module.init == "relu": + trunc_normal_init_(module.weight, scale=2.0) + elif module.init == "glorot": + nn.init.xavier_uniform_(module.weight, gain=1) + elif module.init == "gating": + module.weight.fill_(0.0) + if module.bias: + module.bias.fill_(1.0) + elif module.init == "normal": + torch.nn.init.kaiming_normal_(module.weight, nonlinearity="linear") + elif module.init == "final": + module.weight.fill_(0.0) + elif isinstance(module, EsmFoldInvariantPointAttention): + ipa_point_weights_init_(module.head_weights) + elif isinstance(module, EsmFoldTriangularSelfAttentionBlock): + torch.nn.init.zeros_(module.tri_mul_in.linear_z.weight) + torch.nn.init.zeros_(module.tri_mul_in.linear_z.bias) + torch.nn.init.zeros_(module.tri_mul_out.linear_z.weight) + torch.nn.init.zeros_(module.tri_mul_out.linear_z.bias) + torch.nn.init.zeros_(module.tri_att_start.mha.linear_o.weight) + torch.nn.init.zeros_(module.tri_att_start.mha.linear_o.bias) + torch.nn.init.zeros_(module.tri_att_end.mha.linear_o.weight) + torch.nn.init.zeros_(module.tri_att_end.mha.linear_o.bias) + + torch.nn.init.zeros_(module.sequence_to_pair.o_proj.weight) + torch.nn.init.zeros_(module.sequence_to_pair.o_proj.bias) + torch.nn.init.zeros_(module.pair_to_sequence.linear.weight) + torch.nn.init.zeros_(module.seq_attention.o_proj.weight) + torch.nn.init.zeros_(module.seq_attention.o_proj.bias) + torch.nn.init.zeros_(module.mlp_seq.mlp[-2].weight) + torch.nn.init.zeros_(module.mlp_seq.mlp[-2].bias) + torch.nn.init.zeros_(module.mlp_pair.mlp[-2].weight) + torch.nn.init.zeros_(module.mlp_pair.mlp[-2].bias) + else: + super()._init_weights(module) + + +class EsmFoldSelfAttention(nn.Module): + def __init__(self, embed_dim, num_heads, head_width, gated=False): + super().__init__() + assert embed_dim == num_heads * head_width + + self.embed_dim = embed_dim + self.num_heads = num_heads + self.head_width = head_width + + self.proj = nn.Linear(embed_dim, embed_dim * 3, bias=False) + self.o_proj = nn.Linear(embed_dim, embed_dim, bias=True) + self.gated = gated + if gated: + self.g_proj = nn.Linear(embed_dim, embed_dim) + torch.nn.init.zeros_(self.g_proj.weight) + torch.nn.init.ones_(self.g_proj.bias) + + self.rescale_factor = self.head_width**-0.5 + + torch.nn.init.zeros_(self.o_proj.bias) + + def forward(self, x, mask=None, bias=None, indices=None): + """ + Basic self attention with optional mask and external pairwise bias. To handle sequences of different lengths, + use mask. + + Inputs: + x: batch of input sequences (.. x L x C) mask: batch of boolean masks where 1=valid, 0=padding position (.. + x L_k) bias: batch of scalar pairwise attention biases (.. x Lq x Lk x num_heads) + + Outputs: + sequence projection (B x L x embed_dim), attention maps (B x L x L x num_heads) + """ + + t = self.proj(x).view(*x.shape[:2], self.num_heads, -1) + t = t.permute(0, 2, 1, 3) + q, k, v = t.chunk(3, dim=-1) + + q = self.rescale_factor * q + a = torch.einsum("...qc,...kc->...qk", q, k) + + # Add external attention bias. + if bias is not None: + a = a + bias.permute(0, 3, 1, 2) + + # Do not attend to padding tokens. + if mask is not None: + mask = mask[:, None, None] + a = a.masked_fill(mask == False, -np.inf) # noqa: E712 + + a = nn.functional.softmax(a, dim=-1) + + y = torch.einsum("...hqk,...hkc->...qhc", a, v) + y = y.reshape(*y.shape[:2], -1) + + if self.gated: + y = self.g_proj(x).sigmoid() * y + y = self.o_proj(y) + + return y, a.permute(0, 3, 1, 2) + + +class EsmFoldDropout(nn.Module): + """ + Implementation of dropout with the ability to share the dropout mask along a particular dimension. + """ + + def __init__(self, r: float, batch_dim: Union[int, List[int]]): + super().__init__() + + self.r = r + if isinstance(batch_dim, int): + batch_dim = [batch_dim] + self.batch_dim = batch_dim + self.dropout = nn.Dropout(self.r) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shape = list(x.shape) + if self.batch_dim is not None: + for bd in self.batch_dim: + shape[bd] = 1 + return x * self.dropout(x.new_ones(shape)) + + +class EsmFoldSequenceToPair(nn.Module): + def __init__(self, sequence_state_dim, inner_dim, pairwise_state_dim): + super().__init__() + + self.layernorm = nn.LayerNorm(sequence_state_dim) + self.proj = nn.Linear(sequence_state_dim, inner_dim * 2, bias=True) + self.o_proj = nn.Linear(2 * inner_dim, pairwise_state_dim, bias=True) + + torch.nn.init.zeros_(self.proj.bias) + torch.nn.init.zeros_(self.o_proj.bias) + + def forward(self, sequence_state): + """ + Inputs: + sequence_state: B x L x sequence_state_dim + + Output: + pairwise_state: B x L x L x pairwise_state_dim + + Intermediate state: + B x L x L x 2*inner_dim + """ + + assert len(sequence_state.shape) == 3 + + s = self.layernorm(sequence_state) + s = self.proj(s) + q, k = s.chunk(2, dim=-1) + + prod = q[:, None, :, :] * k[:, :, None, :] + diff = q[:, None, :, :] - k[:, :, None, :] + + x = torch.cat([prod, diff], dim=-1) + x = self.o_proj(x) + + return x + + +class EsmFoldPairToSequence(nn.Module): + def __init__(self, pairwise_state_dim, num_heads): + super().__init__() + + self.layernorm = nn.LayerNorm(pairwise_state_dim) + self.linear = nn.Linear(pairwise_state_dim, num_heads, bias=False) + + def forward(self, pairwise_state): + """ + Inputs: + pairwise_state: B x L x L x pairwise_state_dim + + Output: + pairwise_bias: B x L x L x num_heads + """ + assert len(pairwise_state.shape) == 4 + z = self.layernorm(pairwise_state) + pairwise_bias = self.linear(z) + return pairwise_bias + + +class EsmFoldResidueMLP(nn.Module): + def __init__(self, embed_dim, inner_dim, dropout=0): + super().__init__() + + self.mlp = nn.Sequential( + nn.LayerNorm(embed_dim), + nn.Linear(embed_dim, inner_dim), + nn.ReLU(), + nn.Linear(inner_dim, embed_dim), + nn.Dropout(dropout), + ) + + def forward(self, x): + return x + self.mlp(x) + + +class EsmFoldTriangularSelfAttentionBlock(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + + sequence_state_dim = config.sequence_state_dim + pairwise_state_dim = config.pairwise_state_dim + sequence_num_heads = sequence_state_dim // config.sequence_head_width + pairwise_num_heads = pairwise_state_dim // config.pairwise_head_width + + self.layernorm_1 = nn.LayerNorm(sequence_state_dim) + + self.sequence_to_pair = EsmFoldSequenceToPair(sequence_state_dim, pairwise_state_dim // 2, pairwise_state_dim) + self.pair_to_sequence = EsmFoldPairToSequence(pairwise_state_dim, sequence_num_heads) + + self.seq_attention = EsmFoldSelfAttention( + sequence_state_dim, sequence_num_heads, config.sequence_head_width, gated=True + ) + self.tri_mul_out = EsmFoldTriangleMultiplicativeUpdate(config, _outgoing=True) + self.tri_mul_in = EsmFoldTriangleMultiplicativeUpdate(config, _outgoing=False) + + self.tri_att_start = EsmFoldTriangleAttention( + pairwise_state_dim, config.pairwise_head_width, pairwise_num_heads, inf=1e9, starting=True + ) + self.tri_att_end = EsmFoldTriangleAttention( + pairwise_state_dim, config.pairwise_head_width, pairwise_num_heads, inf=1e9, starting=False + ) + + self.mlp_seq = EsmFoldResidueMLP(sequence_state_dim, 4 * sequence_state_dim, dropout=config.dropout) + self.mlp_pair = EsmFoldResidueMLP(pairwise_state_dim, 4 * pairwise_state_dim, dropout=config.dropout) + + self.drop = nn.Dropout(config.dropout) + self.row_drop = EsmFoldDropout(config.dropout * 2, 2) + self.col_drop = EsmFoldDropout(config.dropout * 2, 1) + + def forward(self, sequence_state, pairwise_state, mask=None, chunk_size=None, **__kwargs): + """ + Inputs: + sequence_state: B x L x sequence_state_dim pairwise_state: B x L x L x pairwise_state_dim mask: B x L boolean + tensor of valid positions + + Output: + sequence_state: B x L x sequence_state_dim pairwise_state: B x L x L x pairwise_state_dim + """ + if len(sequence_state.shape) != 3: + raise ValueError(f"`sequence_state` should be a 3d-tensor, got {len(sequence_state.shape)} dims.") + if len(pairwise_state.shape) != 4: + raise ValueError(f"`pairwise_state` should be a 4d-tensor, got {len(pairwise_state.shape)} dims.") + if mask is not None and len(mask.shape) != 2: + raise ValueError(f"`mask` should be a 2d-tensor, got {len(mask.shape)} dims.") + + batch_dim, seq_dim, sequence_state_dim = sequence_state.shape + pairwise_state_dim = pairwise_state.shape[3] + + if sequence_state_dim != self.config.sequence_state_dim: + raise ValueError( + "`sequence_state` last dimension should be equal to `self.sequence_state_dim`. Got " + f"{sequence_state_dim} != {self.config.sequence_state_dim}." + ) + if pairwise_state_dim != self.config.pairwise_state_dim: + raise ValueError( + "`pairwise_state` last dimension should be equal to `self.pairwise_state_dim`. Got " + f"{pairwise_state_dim} != {self.config.pairwise_state_dim}." + ) + if batch_dim != pairwise_state.shape[0]: + raise ValueError( + f"`sequence_state` and `pairwise_state` have inconsistent batch size: {batch_dim} != " + f"{pairwise_state.shape[0]}." + ) + if seq_dim != pairwise_state.shape[1] or seq_dim != pairwise_state.shape[2]: + raise ValueError( + f"`sequence_state` and `pairwise_state` have inconsistent sequence length: {seq_dim} != " + f"{pairwise_state.shape[1]} or {pairwise_state.shape[2]}." + ) + + # Update sequence state + bias = self.pair_to_sequence(pairwise_state) + + # Self attention with bias + mlp. + y = self.layernorm_1(sequence_state) + y, _ = self.seq_attention(y, mask=mask, bias=bias) + sequence_state = sequence_state + self.drop(y) + sequence_state = self.mlp_seq(sequence_state) + + # Update pairwise state + pairwise_state = pairwise_state + self.sequence_to_pair(sequence_state) + + # Axial attention with triangular bias. + tri_mask = mask.unsqueeze(2) * mask.unsqueeze(1) if mask is not None else None + pairwise_state = pairwise_state + self.row_drop(self.tri_mul_out(pairwise_state, mask=tri_mask)) + pairwise_state = pairwise_state + self.col_drop(self.tri_mul_in(pairwise_state, mask=tri_mask)) + pairwise_state = pairwise_state + self.row_drop( + self.tri_att_start(pairwise_state, mask=tri_mask, chunk_size=chunk_size) + ) + pairwise_state = pairwise_state + self.col_drop( + self.tri_att_end(pairwise_state, mask=tri_mask, chunk_size=chunk_size) + ) + + # MLP over pairs. + pairwise_state = self.mlp_pair(pairwise_state) + + return sequence_state, pairwise_state + + +class EsmCategoricalMixture: + def __init__(self, param, bins=50, start=0, end=1): + # All tensors are of shape ..., bins. + self.logits = param + bins = torch.linspace(start, end, bins + 1, device=self.logits.device, dtype=self.logits.dtype) + self.v_bins = (bins[:-1] + bins[1:]) / 2 + + def log_prob(self, true): + # Shapes are: + # self.probs: ... x bins + # true : ... + true_index = (true.unsqueeze(-1) - self.v_bins[[None] * true.ndim]).abs().argmin(-1) + nll = self.logits.log_softmax(-1) + return torch.take_along_dim(nll, true_index.unsqueeze(-1), dim=-1).squeeze(-1) + + def mean(self): + return (self.logits.softmax(-1) @ self.v_bins.unsqueeze(1)).squeeze(-1) + + +def categorical_lddt(logits, bins=50): + # Logits are ..., 37, bins. + return EsmCategoricalMixture(logits, bins=bins).mean() + + +def get_axial_mask(mask): + """ + Helper to convert B x L mask of valid positions to axial mask used in row column attentions. + + Input: + mask: B x L tensor of booleans + + Output: + mask: B x L x L tensor of booleans + """ + + if mask is None: + return None + + if len(mask.shape) != 2: + raise ValueError(f"`mask` should be a 2d-tensor, got {len(mask.shape)} dims.") + batch_dim, seq_dim = mask.shape + m = mask.unsqueeze(1).expand(batch_dim, seq_dim, seq_dim) + m = m.reshape(batch_dim * seq_dim, seq_dim) + return m + + +class EsmFoldRelativePosition(nn.Module): + def __init__(self, config): + super().__init__() + self.bins = config.position_bins + + # Note an additional offset is used so that the 0th position + # is reserved for masked pairs. + self.embedding = torch.nn.Embedding(2 * self.bins + 2, config.pairwise_state_dim) + + def forward(self, residue_index, mask=None): + """ + Input: + residue_index: B x L tensor of indices (dtype=torch.long) mask: B x L tensor of booleans + + Output: + pairwise_state: B x L x L x pairwise_state_dim tensor of embeddings + """ + if residue_index.dtype != torch.long: + raise ValueError(f"`residue_index` has dtype {residue_index.dtype}, it should be `torch.long`.") + if mask is not None and residue_index.shape != mask.shape: + raise ValueError( + f"`residue_index` and `mask` have inconsistent shapes: {residue_index.shape} != {mask.shape}." + ) + + diff = residue_index[:, None, :] - residue_index[:, :, None] + diff = diff.clamp(-self.bins, self.bins) + diff = diff + self.bins + 1 # Add 1 to adjust for padding index. + + if mask is not None: + mask = mask[:, None, :] * mask[:, :, None] + diff[mask == False] = 0 # noqa: E712 + + output = self.embedding(diff) + return output + + +class EsmFoldAngleResnetBlock(nn.Module): + def __init__(self, config): + super().__init__() + + self.linear_1 = EsmFoldLinear(config.resnet_dim, config.resnet_dim, init="relu") + self.linear_2 = EsmFoldLinear(config.resnet_dim, config.resnet_dim, init="final") + + self.relu = nn.ReLU() + + def forward(self, a: torch.Tensor) -> torch.Tensor: + s_initial = a + + a = self.relu(a) + a = self.linear_1(a) + a = self.relu(a) + a = self.linear_2(a) + + return a + s_initial + + +class EsmFoldAngleResnet(nn.Module): + """ + Implements Algorithm 20, lines 11-14 + """ + + def __init__(self, config): + super().__init__() + self.config = config + + self.linear_in = EsmFoldLinear(config.sequence_dim, config.resnet_dim) + self.linear_initial = EsmFoldLinear(config.sequence_dim, config.resnet_dim) + + self.layers = nn.ModuleList() + for _ in range(config.num_resnet_blocks): + layer = EsmFoldAngleResnetBlock(config) + self.layers.append(layer) + + self.linear_out = EsmFoldLinear(config.resnet_dim, config.num_angles * 2) + + self.relu = nn.ReLU() + + def forward(self, s: torch.Tensor, s_initial: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + s: + [*, C_hidden] single embedding + s_initial: + [*, C_hidden] single embedding as of the start of the StructureModule + Returns: + [*, no_angles, 2] predicted angles + """ + # NOTE: The ReLU's applied to the inputs are absent from the supplement + # pseudocode but present in the source. For maximal compatibility with + # the pretrained weights, I'm going with the source. + + # [*, C_hidden] + s_initial = self.relu(s_initial) + s_initial = self.linear_initial(s_initial) + s = self.relu(s) + s = self.linear_in(s) + s = s + s_initial + + for l in self.layers: + s = l(s) + + s = self.relu(s) + + # [*, no_angles * 2] + s = self.linear_out(s) + + # [*, no_angles, 2] + s = s.view(s.shape[:-1] + (-1, 2)) + + unnormalized_s = s + norm_denom = torch.sqrt( + torch.clamp( + torch.sum(s**2, dim=-1, keepdim=True), + min=self.config.epsilon, + ) + ) + s = s / norm_denom + + return unnormalized_s, s + + +class EsmFoldInvariantPointAttention(nn.Module): + """ + Implements Algorithm 22. + """ + + def __init__(self, config): + super().__init__() + self.config = config + + c_s = config.sequence_dim + c_z = config.pairwise_dim + self.hidden_dim = config.ipa_dim + self.num_heads = config.num_heads_ipa + self.num_qk_points = config.num_qk_points + self.num_v_points = config.num_v_points + + # These linear layers differ from their specifications in the + # supplement. There, they lack bias and use Glorot initialization. + # Here as in the official source, they have bias and use the default + # Lecun initialization. + hc = config.ipa_dim * config.num_heads_ipa + self.linear_q = EsmFoldLinear(c_s, hc) + self.linear_kv = EsmFoldLinear(c_s, 2 * hc) + + hpq = config.num_heads_ipa * config.num_qk_points * 3 + self.linear_q_points = EsmFoldLinear(c_s, hpq) + + hpkv = config.num_heads_ipa * (config.num_qk_points + config.num_v_points) * 3 + self.linear_kv_points = EsmFoldLinear(c_s, hpkv) + + self.linear_b = EsmFoldLinear(c_z, config.num_heads_ipa) + + self.head_weights = nn.Parameter(torch.zeros((config.num_heads_ipa))) + + concat_out_dim = config.num_heads_ipa * (c_z + config.ipa_dim + config.num_v_points * 4) + self.linear_out = EsmFoldLinear(concat_out_dim, c_s, init="final") + + self.softmax = nn.Softmax(dim=-1) + self.softplus = nn.Softplus() + + def forward( + self, + s: torch.Tensor, + z: Optional[torch.Tensor], + r: Rigid, + mask: torch.Tensor, + _offload_inference: bool = False, + _z_reference_list: Optional[Sequence[torch.Tensor]] = None, + ) -> torch.Tensor: + """ + Args: + s: + [*, N_res, C_s] single representation + z: + [*, N_res, N_res, C_z] pair representation + r: + [*, N_res] transformation object + mask: + [*, N_res] mask + Returns: + [*, N_res, C_s] single representation update + """ + z = [z] + + ####################################### + # Generate scalar and point activations + ####################################### + # [*, N_res, H * C_hidden] + q = self.linear_q(s) + kv = self.linear_kv(s) + + # [*, N_res, H, C_hidden] + q = q.view(q.shape[:-1] + (self.num_heads, -1)) + + # [*, N_res, H, 2 * C_hidden] + kv = kv.view(kv.shape[:-1] + (self.num_heads, -1)) + + # [*, N_res, H, C_hidden] + k, v = torch.split(kv, self.hidden_dim, dim=-1) + + # [*, N_res, H * P_q * 3] + q_pts = self.linear_q_points(s) + + # This is kind of clunky, but it's how the original does it + # [*, N_res, H * P_q, 3] + q_pts = torch.split(q_pts, q_pts.shape[-1] // 3, dim=-1) + q_pts = torch.stack(q_pts, dim=-1) + q_pts = r[..., None].apply(q_pts) + + # [*, N_res, H, P_q, 3] + q_pts = q_pts.view(q_pts.shape[:-2] + (self.num_heads, self.num_qk_points, 3)) + + # [*, N_res, H * (P_q + P_v) * 3] + kv_pts = self.linear_kv_points(s) + + # [*, N_res, H * (P_q + P_v), 3] + kv_pts = torch.split(kv_pts, kv_pts.shape[-1] // 3, dim=-1) + kv_pts = torch.stack(kv_pts, dim=-1) + kv_pts = r[..., None].apply(kv_pts) + + # [*, N_res, H, (P_q + P_v), 3] + kv_pts = kv_pts.view(kv_pts.shape[:-2] + (self.num_heads, -1, 3)) + + # [*, N_res, H, P_q/P_v, 3] + k_pts, v_pts = torch.split(kv_pts, [self.num_qk_points, self.num_v_points], dim=-2) + + ########################## + # Compute attention scores + ########################## + # [*, N_res, N_res, H] + b = self.linear_b(z[0]) + + if _offload_inference: + assert sys.getrefcount(z[0]) == 2 + z[0] = z[0].cpu() + + # [*, H, N_res, N_res] + if is_fp16_enabled(): + with torch.cuda.amp.autocast(enabled=False): + a = torch.matmul( + permute_final_dims(q.float(), (1, 0, 2)), # [*, H, N_res, C_hidden] + permute_final_dims(k.float(), (1, 2, 0)), # [*, H, C_hidden, N_res] + ) + else: + a = torch.matmul( + permute_final_dims(q, (1, 0, 2)), # [*, H, N_res, C_hidden] + permute_final_dims(k, (1, 2, 0)), # [*, H, C_hidden, N_res] + ) + + a *= math.sqrt(1.0 / (3 * self.hidden_dim)) + a += math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1)) + + # [*, N_res, N_res, H, P_q, 3] + pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5) + pt_att = pt_att**2 + + # [*, N_res, N_res, H, P_q] + pt_att = sum(torch.unbind(pt_att, dim=-1)) + head_weights = self.softplus(self.head_weights).view(*((1,) * len(pt_att.shape[:-2]) + (-1, 1))) + head_weights = head_weights * math.sqrt(1.0 / (3 * (self.num_qk_points * 9.0 / 2))) + pt_att = pt_att * head_weights + + # [*, N_res, N_res, H] + pt_att = torch.sum(pt_att, dim=-1) * (-0.5) + # [*, N_res, N_res] + square_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2) + square_mask = self.config.inf * (square_mask - 1) + + # [*, H, N_res, N_res] + pt_att = permute_final_dims(pt_att, (2, 0, 1)) + + a = a + pt_att + a = a + square_mask.unsqueeze(-3) + a = self.softmax(a) + + ################ + # Compute output + ################ + # [*, N_res, H, C_hidden] + o = torch.matmul(a, v.transpose(-2, -3).to(dtype=a.dtype)).transpose(-2, -3) + + # [*, N_res, H * C_hidden] + o = flatten_final_dims(o, 2) + + # [*, H, 3, N_res, P_v] + o_pt = torch.sum( + (a[..., None, :, :, None] * permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :]), + dim=-2, + ) + + # [*, N_res, H, P_v, 3] + o_pt = permute_final_dims(o_pt, (2, 0, 3, 1)) + o_pt = r[..., None, None].invert_apply(o_pt) + + # [*, N_res, H * P_v] + o_pt_norm = flatten_final_dims(torch.sqrt(torch.sum(o_pt**2, dim=-1) + self.config.epsilon), 2) + + # [*, N_res, H * P_v, 3] + o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3) + + if _offload_inference: + z[0] = z[0].to(o_pt.device) + + # [*, N_res, H, C_z] + o_pair = torch.matmul(a.transpose(-2, -3), z[0].to(dtype=a.dtype)) + + # [*, N_res, H * C_z] + o_pair = flatten_final_dims(o_pair, 2) + + # [*, N_res, C_s] + s = self.linear_out( + torch.cat((o, *torch.unbind(o_pt, dim=-1), o_pt_norm, o_pair), dim=-1).to(dtype=z[0].dtype) + ) + + return s + + +class EsmFoldBackboneUpdate(nn.Module): + """ + Implements part of Algorithm 23. + """ + + def __init__(self, config): + super().__init__() + + self.linear = EsmFoldLinear(config.sequence_dim, 6, init="final") + + def forward(self, s: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + [*, N_res, C_s] single representation + Returns: + [*, N_res, 6] update vector + """ + # [*, 6] + update = self.linear(s) + + return update + + +class EsmFoldStructureModuleTransitionLayer(nn.Module): + def __init__(self, config): + super().__init__() + + self.linear_1 = EsmFoldLinear(config.sequence_dim, config.sequence_dim, init="relu") + self.linear_2 = EsmFoldLinear(config.sequence_dim, config.sequence_dim, init="relu") + self.linear_3 = EsmFoldLinear(config.sequence_dim, config.sequence_dim, init="final") + + self.relu = nn.ReLU() + + def forward(self, s): + s_initial = s + s = self.linear_1(s) + s = self.relu(s) + s = self.linear_2(s) + s = self.relu(s) + s = self.linear_3(s) + + s = s + s_initial + + return s + + +class EsmFoldStructureModuleTransition(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + + self.layers = nn.ModuleList() + for _ in range(config.num_transition_layers): + l = EsmFoldStructureModuleTransitionLayer(config) + self.layers.append(l) + + self.dropout = nn.Dropout(config.dropout_rate) + self.layer_norm = LayerNorm(config.sequence_dim) + + def forward(self, s): + for l in self.layers: + s = l(s) + + s = self.dropout(s) + s = self.layer_norm(s) + + return s + + +class EsmFoldStructureModule(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + + # Buffers to be lazily initialized later + # self.default_frames + # self.group_idx + # self.atom_mask + # self.lit_positions + + self.layer_norm_s = LayerNorm(config.sequence_dim) + self.layer_norm_z = LayerNorm(config.pairwise_dim) + + self.linear_in = EsmFoldLinear(config.sequence_dim, config.sequence_dim) + + self.ipa = EsmFoldInvariantPointAttention(config) + + self.ipa_dropout = nn.Dropout(config.dropout_rate) + self.layer_norm_ipa = LayerNorm(config.sequence_dim) + + self.transition = EsmFoldStructureModuleTransition(config) + self.bb_update = EsmFoldBackboneUpdate(config) + self.angle_resnet = EsmFoldAngleResnet(config) + + def forward( + self, + evoformer_output_dict, + aatype, + mask=None, + _offload_inference=False, + ): + """ + Args: + evoformer_output_dict: + Dictionary containing: + "single": + [*, N_res, C_s] single representation + "pair": + [*, N_res, N_res, C_z] pair representation + aatype: + [*, N_res] amino acid indices + mask: + Optional [*, N_res] sequence mask + Returns: + A dictionary of outputs + """ + s = evoformer_output_dict["single"] + + if mask is None: + # [*, N] + mask = s.new_ones(s.shape[:-1]) + + # [*, N, C_s] + s = self.layer_norm_s(s) + + # [*, N, N, C_z] + z = self.layer_norm_z(evoformer_output_dict["pair"]) + + z_reference_list = None + if _offload_inference: + assert sys.getrefcount(evoformer_output_dict["pair"]) == 2 + evoformer_output_dict["pair"] = evoformer_output_dict["pair"].cpu() + z_reference_list = [z] + z = None + + # [*, N, C_s] + s_initial = s + s = self.linear_in(s) + + # [*, N] + rigids = Rigid.identity( + s.shape[:-1], + s.dtype, + s.device, + self.training, + fmt="quat", + ) + outputs = [] + for i in range(self.config.num_blocks): + # [*, N, C_s] + s = s + self.ipa( + s, + z, + rigids, + mask, + _offload_inference=_offload_inference, + _z_reference_list=z_reference_list, + ) + s = self.ipa_dropout(s) + s = self.layer_norm_ipa(s) + s = self.transition(s) + + # [*, N] + rigids = rigids.compose_q_update_vec(self.bb_update(s)) + + # To hew as closely as possible to AlphaFold, we convert our + # quaternion-based transformations to rotation-matrix ones + # here + backb_to_global = Rigid( + Rotation(rot_mats=rigids.get_rots().get_rot_mats(), quats=None), + rigids.get_trans(), + ) + + backb_to_global = backb_to_global.scale_translation(self.config.trans_scale_factor) + + # [*, N, 7, 2] + unnormalized_angles, angles = self.angle_resnet(s, s_initial) + + all_frames_to_global = self.torsion_angles_to_frames(backb_to_global, angles, aatype) + + pred_xyz = self.frames_and_literature_positions_to_atom14_pos(all_frames_to_global, aatype) + + scaled_rigids = rigids.scale_translation(self.config.trans_scale_factor) + + preds = { + "frames": scaled_rigids.to_tensor_7(), + "sidechain_frames": all_frames_to_global.to_tensor_4x4(), + "unnormalized_angles": unnormalized_angles, + "angles": angles, + "positions": pred_xyz, + "states": s, + } + + outputs.append(preds) + + rigids = rigids.stop_rot_gradient() + + del z, z_reference_list + + if _offload_inference: + evoformer_output_dict["pair"] = evoformer_output_dict["pair"].to(s.device) + + outputs = dict_multimap(torch.stack, outputs) + outputs["single"] = s + + return outputs + + def _init_residue_constants(self, float_dtype, device): + if not hasattr(self, "default_frames"): + self.register_buffer( + "default_frames", + torch.tensor( + residue_constants.restype_rigid_group_default_frame, + dtype=float_dtype, + device=device, + requires_grad=False, + ), + persistent=False, + ) + if not hasattr(self, "group_idx"): + self.register_buffer( + "group_idx", + torch.tensor( + residue_constants.restype_atom14_to_rigid_group, + device=device, + requires_grad=False, + ), + persistent=False, + ) + if not hasattr(self, "atom_mask"): + self.register_buffer( + "atom_mask", + torch.tensor( + residue_constants.restype_atom14_mask, + dtype=float_dtype, + device=device, + requires_grad=False, + ), + persistent=False, + ) + if not hasattr(self, "lit_positions"): + self.register_buffer( + "lit_positions", + torch.tensor( + residue_constants.restype_atom14_rigid_group_positions, + dtype=float_dtype, + device=device, + requires_grad=False, + ), + persistent=False, + ) + + def torsion_angles_to_frames(self, r, alpha, f): + # Lazily initialize the residue constants on the correct device + self._init_residue_constants(alpha.dtype, alpha.device) + # Separated purely to make testing less annoying + return torsion_angles_to_frames(r, alpha, f, self.default_frames) + + def frames_and_literature_positions_to_atom14_pos(self, r, f): # [*, N, 8] # [*, N] + # Lazily initialize the residue constants on the correct device + self._init_residue_constants(r.get_rots().dtype, r.get_rots().device) + return frames_and_literature_positions_to_atom14_pos( + r, + f, + self.default_frames, + self.group_idx, + self.atom_mask, + self.lit_positions, + ) + + +class EsmFoldingTrunk(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + + c_s = config.sequence_state_dim + c_z = config.pairwise_state_dim + + self.pairwise_positional_embedding = EsmFoldRelativePosition(config) + + self.blocks = nn.ModuleList([EsmFoldTriangularSelfAttentionBlock(config) for _ in range(config.num_blocks)]) + + self.recycle_bins = 15 + self.recycle_s_norm = nn.LayerNorm(c_s) + self.recycle_z_norm = nn.LayerNorm(c_z) + self.recycle_disto = nn.Embedding(self.recycle_bins, c_z) + self.recycle_disto.weight[0].detach().zero_() + + self.structure_module = EsmFoldStructureModule(config.structure_module) + self.trunk2sm_s = nn.Linear(c_s, config.structure_module.sequence_dim) + self.trunk2sm_z = nn.Linear(c_z, config.structure_module.pairwise_dim) + + self.chunk_size = config.chunk_size + + def set_chunk_size(self, chunk_size): + # This parameter means the axial attention will be computed + # in a chunked manner. This should make the memory used more or less O(L) instead of O(L^2). + # It's equivalent to running a for loop over chunks of the dimension we're iterative over, + # where the chunk_size is the size of the chunks, so 128 would mean to parse 128-length chunks. + self.chunk_size = chunk_size + + def forward(self, seq_feats, pair_feats, true_aa, residx, mask, no_recycles): + """ + Inputs: + seq_feats: B x L x C tensor of sequence features pair_feats: B x L x L x C tensor of pair features residx: B + x L long tensor giving the position in the sequence mask: B x L boolean tensor indicating valid residues + + Output: + predicted_structure: B x L x (num_atoms_per_residue * 3) tensor wrapped in a Coordinates object + """ + + device = seq_feats.device + s_s_0 = seq_feats + s_z_0 = pair_feats + + if no_recycles is None: + no_recycles = self.config.max_recycles + else: + if no_recycles < 0: + raise ValueError("Number of recycles must not be negative.") + no_recycles += 1 # First 'recycle' is just the standard forward pass through the model. + + def trunk_iter(s, z, residx, mask): + z = z + self.pairwise_positional_embedding(residx, mask=mask) + + for block in self.blocks: + s, z = block(s, z, mask=mask, residue_index=residx, chunk_size=self.chunk_size) + return s, z + + s_s = s_s_0 + s_z = s_z_0 + recycle_s = torch.zeros_like(s_s) + recycle_z = torch.zeros_like(s_z) + recycle_bins = torch.zeros(*s_z.shape[:-1], device=device, dtype=torch.int64) + + for recycle_idx in range(no_recycles): + with ContextManagers([] if recycle_idx == no_recycles - 1 else [torch.no_grad()]): + # === Recycling === + recycle_s = self.recycle_s_norm(recycle_s.detach()).to(device) + recycle_z = self.recycle_z_norm(recycle_z.detach()).to(device) + recycle_z += self.recycle_disto(recycle_bins.detach()).to(device) + + s_s, s_z = trunk_iter(s_s_0 + recycle_s, s_z_0 + recycle_z, residx, mask) + + # === Structure module === + structure = self.structure_module( + {"single": self.trunk2sm_s(s_s), "pair": self.trunk2sm_z(s_z)}, + true_aa, + mask.float(), + ) + + recycle_s = s_s + recycle_z = s_z + # Distogram needs the N, CA, C coordinates, and bin constants same as alphafold. + recycle_bins = EsmFoldingTrunk.distogram( + structure["positions"][-1][:, :, :3], + 3.375, + 21.375, + self.recycle_bins, + ) + + structure["s_s"] = s_s + structure["s_z"] = s_z + + return structure + + @staticmethod + def distogram(coords, min_bin, max_bin, num_bins): + # Coords are [... L x 3 x 3], where it's [N, CA, C] x 3 coordinates. + boundaries = torch.linspace( + min_bin, + max_bin, + num_bins - 1, + device=coords.device, + ) + boundaries = boundaries**2 + N, CA, C = [x.squeeze(-2) for x in coords.chunk(3, dim=-2)] + # Infer CB coordinates. + b = CA - N + c = C - CA + a = b.cross(c, dim=-1) + CB = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + CA + dists = (CB[..., None, :, :] - CB[..., :, None, :]).pow(2).sum(dim=-1, keepdims=True) + bins = torch.sum(dists > boundaries, dim=-1) # [..., L, L] + return bins + + +# TODO Add information to the docstring about any methods that convert to PDB format, or otherwise prepare +# the outputs for downstream use. + + +@add_start_docstrings( + """ + ESMForProteinFolding is the HuggingFace port of the original ESMFold model. It consists of an ESM-2 "stem" followed + by a protein folding "head", although unlike most other output heads, this "head" is similar in size and runtime to + the rest of the model combined! It outputs a dictionary containing predicted structural information about the input + protein(s). + """, + ESM_START_DOCSTRING, +) +class EsmForProteinFolding(EsmPreTrainedModel): + _no_split_modules = ["EsmFoldStructureModule", "EsmFoldTriangularSelfAttentionBlock"] + + def __init__(self, config): + super().__init__(config) + + self.config = config + + self.distogram_bins = 64 + + self.esm = EsmModel(config, add_pooling_layer=False) + + self.esm.requires_grad_(False) + if self.config.esmfold_config.fp16_esm: + self.esm.half() + + self.esm_feats = self.config.hidden_size + self.esm_attns = self.config.num_hidden_layers * self.config.num_attention_heads + self.esm_layers = self.config.num_hidden_layers + self.register_buffer("af2_to_esm", self._af2_to_esm_from_vocab_list(config.vocab_list)) + self.esm_s_combine = nn.Parameter(torch.zeros(self.esm_layers + 1)) + + trunk_config = self.config.esmfold_config.trunk + c_s = trunk_config.sequence_state_dim + c_z = trunk_config.pairwise_state_dim + self.esm_s_mlp = nn.Sequential( + LayerNorm(self.esm_feats), + nn.Linear(self.esm_feats, c_s), + nn.ReLU(), + nn.Linear(c_s, c_s), + ) + + # 0 is padding, N is unknown residues, N + 1 is mask. + self.n_tokens_embed = residue_constants.restype_num + 3 + self.pad_idx = 0 + self.unk_idx = self.n_tokens_embed - 2 + self.mask_idx = self.n_tokens_embed - 1 + self.esm_dict_cls_idx = self.config.vocab_list.index("") + self.esm_dict_mask_idx = self.config.vocab_list.index("") + self.esm_dict_eos_idx = self.config.vocab_list.index("") + self.esm_dict_padding_idx = self.config.vocab_list.index("") + if self.config.esmfold_config.embed_aa: + self.embedding = nn.Embedding(self.n_tokens_embed, c_s, padding_idx=0) + + self.trunk = EsmFoldingTrunk(trunk_config) + + self.distogram_head = nn.Linear(c_z, self.distogram_bins) + self.ptm_head = nn.Linear(c_z, self.distogram_bins) + self.lm_head = nn.Linear(c_s, self.n_tokens_embed) + self.lddt_bins = 50 + structure_module_config = trunk_config.structure_module + self.lddt_head = nn.Sequential( + nn.LayerNorm(structure_module_config.sequence_dim), + nn.Linear(structure_module_config.sequence_dim, self.config.esmfold_config.lddt_head_hid_dim), + nn.Linear(self.config.esmfold_config.lddt_head_hid_dim, self.config.esmfold_config.lddt_head_hid_dim), + nn.Linear(self.config.esmfold_config.lddt_head_hid_dim, 37 * self.lddt_bins), + ) + + @staticmethod + def _af2_to_esm_from_vocab_list(vocab_list: List[str]) -> torch.Tensor: + # Remember that t is shifted from residue_constants by 1 (0 is padding). + esm_reorder = [vocab_list.index("")] + [vocab_list.index(v) for v in residue_constants.restypes_with_x] + return torch.tensor(esm_reorder) + + @add_start_docstrings_to_model_forward(ESMFOLD_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=EsmForProteinFoldingOutput, config_class=EsmConfig) + def forward( + self, + input_ids: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + masking_pattern: Optional[torch.Tensor] = None, + num_recycles: Optional[int] = None, + ) -> EsmForProteinFoldingOutput: + r""" + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, EsmForProteinFolding + + >>> model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1") + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1") + >>> inputs = tokenizer(["MLKNVQVQLV"], return_tensors="pt", add_special_tokens=False) # A tiny random peptide + >>> outputs = model(**inputs) + >>> folded_positions = outputs.positions + ``` + + """ + cfg = self.config.esmfold_config + + aa = input_ids # B x L + B = aa.shape[0] + L = aa.shape[1] + device = input_ids.device + if attention_mask is None: + attention_mask = torch.ones_like(aa, device=device) + if position_ids is None: + position_ids = torch.arange(L, device=device).expand_as(input_ids) + + # === ESM === + esmaa = self.af2_idx_to_esm_idx(aa, attention_mask) + + if masking_pattern is not None: + masked_aa, esmaa, mlm_targets = self.bert_mask(aa, esmaa, attention_mask, masking_pattern) + else: + masked_aa = aa + mlm_targets = None + + # We get sequence and pair representations from whatever version of ESM / + # configuration we are using. The sequence representation esm_s is always + # present. The pair embedding esm_z may be present depending on the + # configuration of the model. If esm_z is not used by the model then it + # is returned as None here. + esm_s = self.compute_language_model_representations(esmaa) + + # Convert esm_s and esm_z, if present, to the precision used by the trunk and + # the structure module. These tensors may be a lower precision if, for example, + # we're running the language model in fp16 precision. + esm_s = esm_s.to(self.esm_s_combine.dtype) + + if cfg.esm_ablate_sequence: + esm_s = esm_s * 0 + + esm_s = esm_s.detach() + + # === preprocessing === + esm_s = (self.esm_s_combine.softmax(0).unsqueeze(0) @ esm_s).squeeze(2) + s_s_0 = self.esm_s_mlp(esm_s) + + s_z_0 = s_s_0.new_zeros(B, L, L, cfg.trunk.pairwise_state_dim) + + if self.config.esmfold_config.embed_aa: + s_s_0 += self.embedding(masked_aa) + + structure: dict = self.trunk(s_s_0, s_z_0, aa, position_ids, attention_mask, no_recycles=num_recycles) + # Documenting what we expect: + structure = { + k: v + for k, v in structure.items() + if k + in [ + "s_z", + "s_s", + "frames", + "sidechain_frames", + "unnormalized_angles", + "angles", + "positions", + "states", + ] + } + + # Add BERT mask for the loss to use, if available. + if mlm_targets: + structure["mlm_targets"] = mlm_targets + + disto_logits = self.distogram_head(structure["s_z"]) + disto_logits = (disto_logits + disto_logits.transpose(1, 2)) / 2 + structure["distogram_logits"] = disto_logits + + lm_logits = self.lm_head(structure["s_s"]) + structure["lm_logits"] = lm_logits + + structure["aatype"] = aa + make_atom14_masks(structure) + # Of course, this doesn't respect the true mask because it doesn't know about it... + # We're not going to properly mask change of index tensors: + # "residx_atom14_to_atom37", + # "residx_atom37_to_atom14", + for k in [ + "atom14_atom_exists", + "atom37_atom_exists", + ]: + structure[k] *= attention_mask.unsqueeze(-1) + structure["residue_index"] = position_ids + + lddt_head = self.lddt_head(structure["states"]).reshape(structure["states"].shape[0], B, L, -1, self.lddt_bins) + structure["lddt_head"] = lddt_head + plddt = categorical_lddt(lddt_head[-1], bins=self.lddt_bins) + structure["plddt"] = plddt + + ptm_logits = self.ptm_head(structure["s_z"]) + structure["ptm_logits"] = ptm_logits + structure["ptm"] = compute_tm(ptm_logits, max_bin=31, no_bins=self.distogram_bins) + structure.update(compute_predicted_aligned_error(ptm_logits, max_bin=31, no_bins=self.distogram_bins)) + + return EsmForProteinFoldingOutput(**structure) + + def af2_idx_to_esm_idx(self, aa, mask): + # avoid indexing on different devices + if self.af2_to_esm.device != aa.device: + self.af2_to_esm = self.af2_to_esm.to(aa.device) + aa = (aa + 1).masked_fill(mask != 1, 0) + return self.af2_to_esm[aa] + + def compute_language_model_representations(self, esmaa: torch.Tensor) -> torch.Tensor: + device = next(self.parameters()).device + B, L = esmaa.shape # B = batch size, L = sequence length. + + if self.config.esmfold_config.bypass_lm: + esm_s = torch.zeros(B, L, self.esm_s_combine.size[0], -1, self.esm_feats, device=device) + return esm_s + + bosi, eosi = self.esm_dict_cls_idx, self.esm_dict_eos_idx + bos = esmaa.new_full((B, 1), bosi) + eos = esmaa.new_full((B, 1), self.esm_dict_padding_idx) + esmaa = torch.cat([bos, esmaa, eos], dim=1) + # Use the first padding index as eos during inference. + esmaa[range(B), (esmaa != 1).sum(1)] = eosi + + # _, esm_z, esm_s = self.esm(esmaa, return_pairs=self.config.esmfold_config.use_esm_attn_map) + # Because we do not support use_esm_attn_map in the HF port as it is not used in any public models, + # esm_z is always None + esm_hidden_states = self.esm(esmaa, attention_mask=esmaa != 1, output_hidden_states=True)["hidden_states"] + esm_s = torch.stack(esm_hidden_states, dim=2) + + esm_s = esm_s[:, 1:-1] # B, L, nLayers, C + + return esm_s + + def bert_mask(self, aa, esmaa, mask, pattern): + new_aa = aa.clone() + target = aa.clone() + new_esmaa = esmaa.clone() + new_aa[pattern == 1] = self.mask_idx + target[pattern != 1] = 0 + new_esmaa[pattern == 1] = self.esm_dict_mask_idx + return new_aa, new_esmaa, target + + @torch.no_grad() + def infer( + self, + seqs: Union[str, List[str]], + position_ids=None, + ): + if isinstance(seqs, str): + lst = [seqs] + else: + lst = seqs + # Returns the raw outputs of the model given an input sequence. + device = next(self.parameters()).device + aatype = collate_dense_tensors( + [ + torch.from_numpy( + residue_constants.sequence_to_onehot( + sequence=seq, + mapping=residue_constants.restype_order_with_x, + map_unknown_to_x=True, + ) + ) + .to(device) + .argmax(dim=1) + for seq in lst + ] + ) # B=1 x L + mask = collate_dense_tensors([aatype.new_ones(len(seq)) for seq in lst]) + position_ids = ( + torch.arange(aatype.shape[1], device=device).expand(len(lst), -1) + if position_ids is None + else position_ids.to(device) + ) + if position_ids.ndim == 1: + position_ids = position_ids.unsqueeze(0) + return self.forward( + aatype, + mask, + position_ids=position_ids, + ) + + @staticmethod + def output_to_pdb(output: Dict) -> List[str]: + """Returns the pbd (file) string from the model given the model output.""" + output = {k: v.to("cpu").numpy() for k, v in output.items()} + pdbs = [] + final_atom_positions = atom14_to_atom37(output["positions"][-1], output) + final_atom_mask = output["atom37_atom_exists"] + for i in range(output["aatype"].shape[0]): + aa = output["aatype"][i] + pred_pos = final_atom_positions[i] + mask = final_atom_mask[i] + resid = output["residue_index"][i] + 1 + pred = OFProtein( + aatype=aa, + atom_positions=pred_pos, + atom_mask=mask, + residue_index=resid, + b_factors=output["plddt"][i], + ) + pdbs.append(to_pdb(pred)) + return pdbs + + def infer_pdb(self, seqs, *args, **kwargs) -> str: + """Returns the pdb (file) string from the model given an input sequence.""" + assert isinstance(seqs, str) + output = self.infer(seqs, *args, **kwargs) + return self.output_to_pdb(output)[0] + + def infer_pdbs(self, seqs: List[str], *args, **kwargs) -> List[str]: + """Returns the pdb (file) string from the model given an input sequence.""" + output = self.infer(seqs, *args, **kwargs) + return self.output_to_pdb(output) + + +__all__ = ["EsmForProteinFolding", "EsmFoldPreTrainedModel"] diff --git a/docs/transformers/build/lib/transformers/models/esm/modeling_tf_esm.py b/docs/transformers/build/lib/transformers/models/esm/modeling_tf_esm.py new file mode 100644 index 0000000000000000000000000000000000000000..71698486dab0adfece1feb14ce91b4500e24382c --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/esm/modeling_tf_esm.py @@ -0,0 +1,1575 @@ +# coding=utf-8 +# Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch ESM model.""" + +from __future__ import annotations + +import os +from typing import Optional, Tuple, Union + +import numpy as np +import tensorflow as tf + +from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward +from ...modeling_tf_outputs import ( + TFBaseModelOutputWithPastAndCrossAttentions, + TFBaseModelOutputWithPoolingAndCrossAttentions, + TFMaskedLMOutput, + TFSequenceClassifierOutput, + TFTokenClassifierOutput, +) +from ...modeling_tf_utils import ( + TFMaskedLanguageModelingLoss, + TFModelInputType, + TFPreTrainedModel, + TFSequenceClassificationLoss, + TFTokenClassificationLoss, + get_initializer, + keras, + shape_list, + unpack_inputs, +) +from ...tf_utils import check_embeddings_within_bounds, stable_softmax +from ...utils import logging +from .configuration_esm import EsmConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "facebook/esm2_t6_8M_UR50D" +_CONFIG_FOR_DOC = "EsmConfig" + + +def rotate_half(x): + x1, x2 = tf.split(x, 2, axis=-1) + return tf.concat((-x2, x1), axis=-1) + + +def apply_rotary_pos_emb(x, cos, sin): + cos = cos[:, :, : tf.shape(x)[-2], :] + sin = sin[:, :, : tf.shape(x)[-2], :] + + return (x * cos) + (rotate_half(x) * sin) + + +def symmetrize(x): + "Make layer symmetric in final two dimensions, used for contact prediction." + return x + tf.linalg.matrix_transpose(x) # Transposes last two dimensions only + + +def average_product_correct(x): + "Perform average product correct, used for contact prediction." + a1 = tf.reduce_sum(x, -1, keepdims=True) + a2 = tf.reduce_sum(x, -2, keepdims=True) + a12 = tf.reduce_sum(x, (-1, -2), keepdims=True) + + avg = a1 * a2 + avg = avg / a12 + normalized = x - avg + return normalized + + +class TFRotaryEmbedding(keras.layers.Layer): + """ + Rotary position embeddings based on those in + [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation + matrices which depend on their relative positions. + """ + + def __init__(self, dim: int, name=None): + super().__init__(name=name) + # Matt: The PyTorch version of this layer does a lot of work to cache values, but we just rely on TF compilation + # and/or XLA to sort out constants like that. It actually may not seem like this layer needs to be stateful at + # all when we benefit from TF compilation, but it does. The reason is that self.inv_freq is a buffer in the + # original implementation, but all the shared ESM checkpoints were trained with fp16 params. This means that + # the inv_freq tensor was stored as a float16, and we need to replicate those lower-precision values or our + # models give different outputs from the original. + self.dim = dim + + def build(self, input_shape): + super().build(input_shape) + self.inv_freq = self.add_weight( + "inv_freq", shape=(self.dim // 2,), dtype=tf.float32, initializer=get_initializer(1.0), trainable=False + ) + self.inv_freq.assign( + 1.0 / (10000 ** (tf.range(start=0, limit=self.dim, delta=2, dtype=tf.float32) / self.dim)) + ) + + def _compute_cos_sin(self, x, seq_dimension=2): + seq_len = tf.shape(x)[seq_dimension] + + t = tf.range(seq_len, dtype=self.inv_freq.dtype) + freqs = tf.einsum("i, j -> ij", t, self.inv_freq) # Outer multiplication + emb = tf.concat((freqs, freqs), axis=-1)[None, None, :, :] + + return tf.cos(emb), tf.sin(emb) + + def call(self, q: tf.Tensor, k: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]: + cos_emb, sin_emb = self._compute_cos_sin(k, seq_dimension=-2) + + return ( + apply_rotary_pos_emb(q, cos_emb, sin_emb), + apply_rotary_pos_emb(k, cos_emb, sin_emb), + ) + + +class TFEsmContactPredictionHead(keras.layers.Layer): + """Performs symmetrization, apc, and computes a logistic regression on the output features""" + + def __init__( + self, + in_features: int, + bias=True, + eos_idx: int = 2, + name=None, + ): + super().__init__(name=name) + self.eos_idx = eos_idx + self.in_features = in_features + self.regression = keras.layers.Dense(1, use_bias=bias, activation="sigmoid", name="regression") + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "regression", None) is not None: + with tf.name_scope(self.regression.name): + self.regression.build((None, self.in_features)) + + def call(self, tokens, attentions): + # remove eos token attentions + eos_mask = tf.cast(tokens != self.eos_idx, attentions.dtype) + eos_mask = tf.expand_dims(eos_mask, 1) * tf.expand_dims(eos_mask, 2) + attentions = attentions * eos_mask[:, None, None, :, :] + attentions = attentions[..., :-1, :-1] + # remove cls token attentions + attentions = attentions[..., 1:, 1:] + batch_size, layers, heads, seqlen, _ = shape_list(attentions) + attentions = tf.reshape(attentions, (batch_size, layers * heads, seqlen, seqlen)) + + # features: batch x channels x tokens x tokens (symmetric) + attentions = average_product_correct(symmetrize(attentions)) + attentions = tf.transpose(attentions, perm=(0, 2, 3, 1)) + return tf.squeeze(self.regression(attentions), 3) + + +class TFEsmEmbeddings(keras.layers.Layer): + """ + Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. + """ + + def __init__(self, config, name=None): + super().__init__(name=name) + self.word_embeddings = keras.layers.Embedding( + config.vocab_size, + config.hidden_size, + embeddings_initializer=get_initializer(config.initializer_range), + name="word_embeddings", + ) + self.position_embeddings = keras.layers.Embedding( + config.max_position_embeddings, + config.hidden_size, + embeddings_initializer=get_initializer(config.initializer_range), + name="position_embeddings", + ) + + if config.emb_layer_norm_before: + self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm") + else: + self.layer_norm = None + # Matt: I think this line was copied incorrectly from BERT, disabling for now + # self.dropout = Dropout(config.hidden_dropout_prob) + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + + self.position_ids = tf.range(config.max_position_embeddings)[None, :] + + self.padding_idx = config.pad_token_id + self.token_dropout = config.token_dropout + self.mask_token_id = config.mask_token_id + self.config = config + + def call( + self, input_ids=None, attention_mask=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 + ): + if position_ids is None: + if input_ids is not None: + # Create the position ids from the input token ids. Any padded tokens remain padded. + position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length) + else: + position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds) + + if inputs_embeds is None: + check_embeddings_within_bounds(input_ids, self.config.vocab_size) + inputs_embeds = self.word_embeddings(input_ids) + + # Note that if we want to support ESM-1 (not 1b!) in future then we need to support an + # embedding_scale factor here. + embeddings = inputs_embeds + + # Matt: ESM has the option to handle masking in MLM in a slightly unusual way. If the token_dropout + # flag is False then it is handled in the same was as BERT/RoBERTa. If it is set to True, however, + # masked tokens are treated as if they were selected for input dropout and zeroed out. + # This "mask-dropout" is compensated for when masked tokens are not present, by scaling embeddings by + # a factor of (fraction of unmasked tokens during training) / (fraction of unmasked tokens in sample). + # This is analogous to the way that dropout layers scale down outputs during evaluation when not + # actually dropping out values (or, equivalently, scale up their un-dropped outputs in training). + if self.token_dropout: + embeddings = tf.where((input_ids == self.mask_token_id)[:, :, None], 0.0, embeddings) + mask_ratio_train = 0.15 * 0.8 # Hardcoded as the ratio used in all ESM model training runs + src_lengths = tf.cast(tf.reduce_sum(attention_mask, axis=-1), tf.float32) + masked_tokens = input_ids == self.mask_token_id + mask_ratio_observed = tf.math.count_nonzero(masked_tokens, dtype=tf.float32, axis=-1) / src_lengths + embeddings = embeddings * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None] + + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + + if self.layer_norm is not None: + embeddings = self.layer_norm(embeddings) + if attention_mask is not None: + embeddings = embeddings * tf.cast(tf.expand_dims(attention_mask, -1), embeddings.dtype) + # Matt: I think this line was copied incorrectly from BERT, disabling it for now. + # embeddings = self.dropout(embeddings) + return embeddings + + def create_position_ids_from_inputs_embeds(self, inputs_embeds): + """ + We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids. + + Args: + inputs_embeds: tf.Tensor + + Returns: tf.Tensor + """ + input_shape = shape_list(inputs_embeds)[:-1] + sequence_length = input_shape[1] + + position_ids = tf.range( + start=self.padding_idx + 1, limit=sequence_length + self.padding_idx + 1, dtype=tf.int64 + ) + return tf.broadcast_to(tf.expand_dims(position_ids, 0), input_shape) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "word_embeddings", None) is not None: + with tf.name_scope(self.word_embeddings.name): + self.word_embeddings.build(None) + if getattr(self, "position_embeddings", None) is not None: + with tf.name_scope(self.position_embeddings.name): + self.position_embeddings.build(None) + if getattr(self, "layer_norm", None) is not None: + with tf.name_scope(self.layer_norm.name): + self.layer_norm.build([None, None, self.config.hidden_size]) + + +class TFEsmSelfAttention(keras.layers.Layer): + def __init__(self, config, position_embedding_type=None, name=None): + super().__init__(name=name) + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = keras.layers.Dense( + self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query" + ) + self.key = keras.layers.Dense( + self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key" + ) + self.value = keras.layers.Dense( + self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value" + ) + + self.dropout = keras.layers.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = position_embedding_type or getattr( + config, "position_embedding_type", "absolute" + ) + self.rotary_embeddings = None + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = keras.layers.Embedding( + 2 * config.max_position_embeddings - 1, + self.attention_head_size, + embeddings_initializer=get_initializer(config.initializer_range), + ) + elif self.position_embedding_type == "rotary": + self.rotary_embeddings = TFRotaryEmbedding(dim=self.attention_head_size, name="rotary_embeddings") + + self.is_decoder = config.is_decoder + self.config = config + + def transpose_for_scores(self, x: tf.Tensor) -> tf.Tensor: + new_x_shape = shape_list(x)[:-1] + [self.num_attention_heads, self.attention_head_size] + x = tf.reshape(x, new_x_shape) + return tf.transpose(x, perm=(0, 2, 1, 3)) + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: tf.Tensor | None = None, + head_mask: tf.Tensor | None = None, + encoder_hidden_states: tf.Tensor | None = None, + encoder_attention_mask: tf.Tensor | None = None, + past_key_value: Tuple[Tuple[tf.Tensor]] | None = None, + output_attentions: Optional[bool] = False, + training: bool = False, + ) -> Tuple[tf.Tensor]: + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_layer = past_key_value[0] + value_layer = past_key_value[1] + attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = tf.concat([past_key_value[0], key_layer], axis=2) + value_layer = tf.concat([past_key_value[1], value_layer], axis=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + # Matt: Our BERT model (which this code was derived from) scales attention logits down by sqrt(head_dim). + # ESM scales the query down by the same factor instead. Modulo numerical stability these are equivalent, + # but not when rotary embeddings get involved. Therefore, we scale the query here to match the original + # ESM code and fix rotary embeddings. + query_layer = query_layer * self.attention_head_size**-0.5 + + if self.is_decoder: + # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + if self.position_embedding_type == "rotary": + query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) + + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + seq_length = shape_list(hidden_states)[1] + position_ids_l = tf.expand_dims(tf.range(seq_length, dtype=tf.int64), -1) + position_ids_r = tf.expand_dims(tf.range(seq_length, dtype=tf.int64), 0) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) + positional_embedding = tf.cast(positional_embedding, query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = tf.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = tf.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = tf.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in EsmModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = stable_softmax(attention_scores, axis=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs, training=training) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = attention_probs @ value_layer + + context_layer = tf.transpose(context_layer, perm=(0, 2, 1, 3)) + new_context_layer_shape = shape_list(context_layer)[:-2] + [self.all_head_size] + context_layer = tf.reshape(context_layer, new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "query", None) is not None: + with tf.name_scope(self.query.name): + self.query.build([None, None, self.config.hidden_size]) + if getattr(self, "key", None) is not None: + with tf.name_scope(self.key.name): + self.key.build([None, None, self.config.hidden_size]) + if getattr(self, "value", None) is not None: + with tf.name_scope(self.value.name): + self.value.build([None, None, self.config.hidden_size]) + if getattr(self, "rotary_embeddings", None) is not None: + with tf.name_scope(self.rotary_embeddings.name): + self.rotary_embeddings.build(None) + + +class TFEsmSelfOutput(keras.layers.Layer): + def __init__(self, config, name=None): + super().__init__(name=name) + self.dense = keras.layers.Dense( + config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + self.dropout = keras.layers.Dropout(config.hidden_dropout_prob) + self.config = config + + def call(self, hidden_states, input_tensor, training=False): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states, training=training) + hidden_states += input_tensor + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + + +class TFEsmAttention(keras.layers.Layer): + def __init__(self, config, name=None): + super().__init__(name=name) + self.self = TFEsmSelfAttention(config, name="self") + self.output_layer = TFEsmSelfOutput(config, name="output") + self.pruned_heads = set() + self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") + self.config = config + + def prune_heads(self, heads): + raise NotImplementedError + + def call( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + training=False, + ): + hidden_states_ln = self.LayerNorm(hidden_states) + self_outputs = self.self( + hidden_states_ln, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + training, + ) + attention_output = self.output_layer(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "self", None) is not None: + with tf.name_scope(self.self.name): + self.self.build(None) + if getattr(self, "output_layer", None) is not None: + with tf.name_scope(self.output_layer.name): + self.output_layer.build(None) + if getattr(self, "LayerNorm", None) is not None: + with tf.name_scope(self.LayerNorm.name): + self.LayerNorm.build([None, None, self.config.hidden_size]) + + +class TFEsmIntermediate(keras.layers.Layer): + def __init__(self, config: EsmConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense( + units=config.intermediate_size, + kernel_initializer=get_initializer(config.initializer_range), + name="dense", + ) + self.config = config + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + hidden_states = self.dense(inputs=hidden_states) + hidden_states = tf.nn.gelu(hidden_states) + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + + +class TFEsmOutput(keras.layers.Layer): + def __init__(self, config, name=None): + super().__init__(name=name) + self.dense = keras.layers.Dense( + config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + self.dropout = keras.layers.Dropout(config.hidden_dropout_prob) + self.config = config + + def call(self, hidden_states, input_tensor, training=False): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states, training=training) + hidden_states += input_tensor + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.intermediate_size]) + + +class TFEsmLayer(keras.layers.Layer): + def __init__(self, config, name=None): + super().__init__(name=name) + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = TFEsmAttention(config, name="attention") + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + if not self.is_decoder: + raise RuntimeError(f"{self} should be used as a decoder model if cross attention is added") + self.crossattention = TFEsmAttention(config) + self.intermediate = TFEsmIntermediate(config, name="intermediate") + self.output_layer = TFEsmOutput(config, name="output") + self.LayerNorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") + self.config = config + + def call( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + training=False, + ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + training=training, + ) + attention_output = self_attention_outputs[0] + + # if decoder, the last output is tuple of self-attn cache + if self.is_decoder: + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + else: + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + cross_attn_present_key_value = None + if self.is_decoder and encoder_hidden_states is not None: + if not hasattr(self, "crossattention"): + raise AttributeError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated" + " with cross-attention layers by setting `config.add_cross_attention=True`" + ) + + # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + cross_attn_past_key_value, + output_attentions, + training=training, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + + # add cross-attn cache to positions 3,4 of present_key_value tuple + cross_attn_present_key_value = cross_attention_outputs[-1] + present_key_value = present_key_value + cross_attn_present_key_value + + layernorm_output = self.LayerNorm(attention_output) + intermediate_output = self.intermediate(hidden_states=layernorm_output) + layer_output = self.output_layer( + hidden_states=intermediate_output, input_tensor=attention_output, training=training + ) + outputs = (layer_output,) + outputs # add attentions if we output them + + # if decoder, return the attn key/values as the last output + if self.is_decoder: + outputs = outputs + (present_key_value,) + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "attention", None) is not None: + with tf.name_scope(self.attention.name): + self.attention.build(None) + if getattr(self, "intermediate", None) is not None: + with tf.name_scope(self.intermediate.name): + self.intermediate.build(None) + if getattr(self, "output_layer", None) is not None: + with tf.name_scope(self.output_layer.name): + self.output_layer.build(None) + if getattr(self, "LayerNorm", None) is not None: + with tf.name_scope(self.LayerNorm.name): + self.LayerNorm.build([None, None, self.config.hidden_size]) + + +class TFEsmEncoder(keras.layers.Layer): + def __init__(self, config, name=None): + super().__init__(name=name) + self.config = config + self.layer = [TFEsmLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)] + self.emb_layer_norm_after = keras.layers.LayerNormalization( + epsilon=config.layer_norm_eps, name="emb_layer_norm_after" + ) + + def call( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + training=False, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + next_decoder_cache = () if use_cache else None + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + training, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if self.emb_layer_norm_after: + hidden_states = self.emb_layer_norm_after(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return TFBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "emb_layer_norm_after", None) is not None: + with tf.name_scope(self.emb_layer_norm_after.name): + self.emb_layer_norm_after.build([None, None, self.config.hidden_size]) + if getattr(self, "layer", None) is not None: + for layer in self.layer: + with tf.name_scope(layer.name): + layer.build(None) + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertPooler with Bert->Esm +class TFEsmPooler(keras.layers.Layer): + def __init__(self, config: EsmConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense( + units=config.hidden_size, + kernel_initializer=get_initializer(config.initializer_range), + activation="tanh", + name="dense", + ) + self.config = config + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(inputs=first_token_tensor) + + return pooled_output + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + + +class TFEsmPreTrainedModel(TFPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = EsmConfig + base_model_prefix = "esm" + + +ESM_START_DOCSTRING = r""" + + This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a Keras [Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it as a + regular Keras model and refer to the TF/Keras documentation for all matters related to general usage and behavior. + + Parameters: + config ([`EsmConfig`]): Model configuration class with all the parameters of the + model. Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. +""" + +ESM_INPUTS_DOCSTRING = r""" + Args: + input_ids (`tf.Tensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`tf.Tensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`tf.Tensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`tf.Tensor` of shape `({0}, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare ESM Model transformer outputting raw hidden-states without any specific head on top.", + ESM_START_DOCSTRING, +) +class TFEsmMainLayer(keras.layers.Layer): + """ + + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in [Attention is + all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + + To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set + to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and + `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass. + """ + + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def __init__(self, config, add_pooling_layer=True, name=None, **kwargs): + super().__init__(name=name, **kwargs) + + self.config = config + self.is_decoder = config.is_decoder + + self.embeddings = TFEsmEmbeddings(config, name="embeddings") + self.encoder = TFEsmEncoder(config, name="encoder") + self.pooler = TFEsmPooler(config, name="pooler") if add_pooling_layer else None + + self.contact_head = TFEsmContactPredictionHead( + in_features=self.config.num_hidden_layers * self.config.num_attention_heads, bias=True, name="contact_head" + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "embeddings", None) is not None: + with tf.name_scope(self.embeddings.name): + self.embeddings.build(None) + if getattr(self, "encoder", None) is not None: + with tf.name_scope(self.encoder.name): + self.encoder.build(None) + if getattr(self, "pooler", None) is not None: + with tf.name_scope(self.pooler.name): + self.pooler.build(None) + if getattr(self, "contact_head", None) is not None: + with tf.name_scope(self.contact_head.name): + self.contact_head.build(None) + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value: tf.Variable): + self.embeddings.word_embeddings.weight = value + self.embeddings.vocab_size = shape_list(value)[0] + + def _prune_heads(self, heads_to_prune): + raise NotImplementedError + + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + encoder_hidden_states: np.ndarray | tf.Tensor | None = None, + encoder_attention_mask: np.ndarray | tf.Tensor | None = None, + past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: bool = False, + ) -> Union[TFBaseModelOutputWithPoolingAndCrossAttentions, Tuple[tf.Tensor]]: + if not self.config.is_decoder: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = shape_list(input_ids) + elif inputs_embeds is not None: + input_shape = shape_list(inputs_embeds)[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + + if past_key_values is None: + past_key_values_length = 0 + past_key_values = [None] * len(self.encoder.layer) + else: + past_key_values_length = shape_list(past_key_values[0][0])[-2] + + if attention_mask is None: + attention_mask = tf.fill(dims=(batch_size, seq_length + past_key_values_length), value=1) + + embedding_output = self.embeddings( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + training=training, + ) + + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask_shape = shape_list(attention_mask) + + mask_seq_length = seq_length + past_key_values_length + # Copied from `modeling_tf_t5.py` + # Provided a padding mask of dimensions [batch_size, mask_seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] + if self.is_decoder: + seq_ids = tf.range(mask_seq_length) + causal_mask = tf.less_equal( + tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)), + seq_ids[None, :, None], + ) + causal_mask = tf.cast(causal_mask, dtype=attention_mask.dtype) + extended_attention_mask = causal_mask * attention_mask[:, None, :] + attention_mask_shape = shape_list(extended_attention_mask) + extended_attention_mask = tf.reshape( + extended_attention_mask, (attention_mask_shape[0], 1, attention_mask_shape[1], attention_mask_shape[2]) + ) + if past_key_values[0] is not None: + # attention_mask needs to be sliced to the shape `[batch_size, 1, from_seq_length - cached_seq_length, to_seq_length] + extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :] + else: + extended_attention_mask = tf.reshape( + attention_mask, (attention_mask_shape[0], 1, 1, attention_mask_shape[1]) + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = tf.cast(extended_attention_mask, dtype=embedding_output.dtype) + one_cst = tf.constant(1.0, dtype=embedding_output.dtype) + ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype) + extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst) + + # Copied from `modeling_tf_t5.py` with -1e9 -> -10000 + if self.is_decoder and encoder_attention_mask is not None: + # If a 2D ou 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + encoder_attention_mask = tf.cast(encoder_attention_mask, dtype=extended_attention_mask.dtype) + num_dims_encoder_attention_mask = len(shape_list(encoder_attention_mask)) + if num_dims_encoder_attention_mask == 3: + encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] + if num_dims_encoder_attention_mask == 2: + encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] + + # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition + # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270 + # encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask, + # tf.transpose(encoder_extended_attention_mask, perm=(-1, -2))) + + encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0 + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + if head_mask is not None: + raise NotImplementedError + else: + head_mask = [None] * self.config.num_hidden_layers + + encoder_outputs = self.encoder( + hidden_states=embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(hidden_states=sequence_output) if self.pooler is not None else None + + if not return_dict: + return ( + sequence_output, + pooled_output, + ) + encoder_outputs[1:] + + return TFBaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + def predict_contacts(self, tokens, attention_mask): + attns = self(tokens, attention_mask=attention_mask, return_dict=True, output_attentions=True).attentions + attns = tf.stack(attns, axis=1) # Matches the original model layout + # In the original model, attentions for padding tokens are completely zeroed out. + # This makes no difference most of the time because the other tokens won't attend to them, + # but it does for the contact prediction task, which takes attentions as input, + # so we have to mimic that here. + attention_mask = tf.cast(attention_mask, attns.dtype) + attns *= attention_mask[:, None, None, None] + attns *= attention_mask[:, None, None, :, None] + return self.contact_head(tokens, attns) + + +@add_start_docstrings( + "The bare ESM Model transformer outputting raw hidden-states without any specific head on top.", + ESM_START_DOCSTRING, +) +class TFEsmModel(TFEsmPreTrainedModel): + def __init__(self, config: EsmConfig, add_pooling_layer=True, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.esm = TFEsmMainLayer(config, add_pooling_layer=add_pooling_layer, name="esm") + + @unpack_inputs + @add_start_docstrings_to_model_forward(ESM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFBaseModelOutputWithPoolingAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + encoder_hidden_states: np.ndarray | tf.Tensor | None = None, + encoder_attention_mask: np.ndarray | tf.Tensor | None = None, + past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: Optional[bool] = False, + ) -> Union[TFBaseModelOutputWithPoolingAndCrossAttentions, Tuple[tf.Tensor]]: + r""" + encoder_hidden_states (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + past_key_values (`Tuple[Tuple[tf.Tensor]]` of length `config.n_layers`) + contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*, defaults to `True`): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). Set to `False` during training, `True` during generation + """ + outputs = self.esm( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + return outputs + + def predict_contacts(self, tokens, attention_mask): + return self.esm.predict_contacts(tokens, attention_mask) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "esm", None) is not None: + with tf.name_scope(self.esm.name): + self.esm.build(None) + + +@add_start_docstrings("""ESM Model with a `language modeling` head on top.""", ESM_START_DOCSTRING) +class TFEsmForMaskedLM(TFEsmPreTrainedModel, TFMaskedLanguageModelingLoss): + _keys_to_ignore_on_load_missing = [r"position_ids"] + _keys_to_ignore_on_load_unexpected = [r"pooler"] + + def __init__(self, config): + super().__init__(config) + + if config.is_decoder: + logger.warning( + "If you want to use `EsmForMaskedLM` make sure `config.is_decoder=False` for " + "bi-directional self-attention." + ) + + self.esm = TFEsmMainLayer(config, add_pooling_layer=False, name="esm") + self.lm_head = TFEsmLMHead(config, name="lm_head") + if config.tie_word_embeddings: + # Ensure word embeddings are built so that we actually have something to tie + with tf.name_scope(os.path.join(self._name_scope(), "esm", "embeddings", "word_embeddings")): + self.esm.embeddings.word_embeddings.build((None, None)) + self.lm_head.decoder = self.esm.embeddings.word_embeddings.weights[0] + + def get_output_embeddings(self): + return self.lm_head.decoder + + def set_output_embeddings(self, new_embeddings): + self.lm_head.decoder = new_embeddings + + def get_lm_head(self): + return self.lm_head + + @unpack_inputs + @add_start_docstrings_to_model_forward(ESM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFMaskedLMOutput, + config_class=_CONFIG_FOR_DOC, + mask="", + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + encoder_hidden_states: np.ndarray | tf.Tensor | None = None, + encoder_attention_mask: np.ndarray | tf.Tensor | None = None, + labels: np.ndarray | tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: bool = False, + ) -> Union[TFMaskedLMOutput, Tuple[tf.Tensor]]: + r""" + labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + kwargs (`Dict[str, any]`, *optional*, defaults to `{}`): + Used to hide legacy arguments that have been deprecated. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.esm( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + sequence_output = outputs[0] + prediction_scores = self.lm_head(sequence_output) + + masked_lm_loss = None + if labels is not None: + masked_lm_loss = self.hf_compute_loss(labels=labels, logits=prediction_scores) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return TFMaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def predict_contacts(self, tokens, attention_mask): + return self.esm.predict_contacts(tokens, attention_mask) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "esm", None) is not None: + with tf.name_scope(self.esm.name): + self.esm.build(None) + if getattr(self, "lm_head", None) is not None: + with tf.name_scope(self.lm_head.name): + self.lm_head.build(None) + + +class TFEsmLMHead(keras.layers.Layer): + """ESM Head for masked language modeling.""" + + def __init__(self, config, name=None): + super().__init__(name=name) + self.dense = keras.layers.Dense( + config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + + self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm") + if config.tie_word_embeddings: + self.decoder = None + else: + self.decoder = keras.layers.Dense( + config.vocab_size, + kernel_initializer=get_initializer(config.initializer_range), + name="decoder", + use_bias=False, + ) + self.config = config + + def build(self, input_shape=None): + # Separate bias to match the PT model and allow weight cross-loading to work + # Put it in the build so it gets the right name when adding it as a weight + if self.built: + return + self.built = True + self.bias = self.add_weight("bias", shape=(self.config.vocab_size,), initializer="zeros", trainable=True) + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + if getattr(self, "layer_norm", None) is not None: + with tf.name_scope(self.layer_norm.name): + self.layer_norm.build([None, None, self.config.hidden_size]) + if getattr(self, "decoder", None) is not None and not self.config.tie_word_embeddings: + with tf.name_scope(self.decoder.name): + self.decoder.build([None, None, self.config.hidden_size]) + + def get_bias(self): + return {"bias": self.bias} + + def call(self, features): + x = self.dense(features) + x = tf.nn.gelu(x) + x = self.layer_norm(x) + + # project back to size of vocabulary with bias + if self.config.tie_word_embeddings: + x = tf.matmul(x, self.decoder, transpose_b=True) + self.bias + else: + x = self.decoder(x) + self.bias + return x + + +@add_start_docstrings( + """ + ESM Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled + output) e.g. for GLUE tasks. + """, + ESM_START_DOCSTRING, +) +class TFEsmForSequenceClassification(TFEsmPreTrainedModel, TFSequenceClassificationLoss): + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + + self.esm = TFEsmMainLayer(config, add_pooling_layer=False, name="esm") + self.classifier = TFEsmClassificationHead(config, name="classifier") + + @unpack_inputs + @add_start_docstrings_to_model_forward(ESM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFSequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + labels: np.ndarray | tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: bool = False, + ) -> Union[TFSequenceClassifierOutput, Tuple[tf.Tensor]]: + r""" + labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.esm( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + sequence_output = outputs[0] + logits = self.classifier(sequence_output) + + loss = None if labels is None else self.hf_compute_loss(labels, logits) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFSequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "esm", None) is not None: + with tf.name_scope(self.esm.name): + self.esm.build(None) + if getattr(self, "classifier", None) is not None: + with tf.name_scope(self.classifier.name): + self.classifier.build(None) + + +@add_start_docstrings( + """ + ESM Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + ESM_START_DOCSTRING, +) +class TFEsmForTokenClassification(TFEsmPreTrainedModel, TFTokenClassificationLoss): + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.esm = TFEsmMainLayer(config, add_pooling_layer=False, name="esm") + self.dropout = keras.layers.Dropout(config.hidden_dropout_prob) + self.classifier = keras.layers.Dense(config.num_labels, name="classifier") + self.config = config + + @unpack_inputs + @add_start_docstrings_to_model_forward(ESM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFTokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: TFModelInputType | None = None, + attention_mask: np.ndarray | tf.Tensor | None = None, + position_ids: np.ndarray | tf.Tensor | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + inputs_embeds: np.ndarray | tf.Tensor | None = None, + labels: np.ndarray | tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: bool = False, + ) -> Union[TFTokenClassifierOutput, Tuple[tf.Tensor]]: + r""" + labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.esm( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output, training=training) + logits = self.classifier(sequence_output) + + loss = None if labels is None else self.hf_compute_loss(labels, logits) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFTokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "esm", None) is not None: + with tf.name_scope(self.esm.name): + self.esm.build(None) + if getattr(self, "classifier", None) is not None: + with tf.name_scope(self.classifier.name): + self.classifier.build([None, None, self.config.hidden_size]) + + +class TFEsmClassificationHead(keras.layers.Layer): + """Head for sentence-level classification tasks.""" + + def __init__(self, config, name=None): + super().__init__(name=name) + self.dense = keras.layers.Dense( + config.hidden_size, + kernel_initializer=get_initializer(config.initializer_range), + activation="tanh", + name="dense", + ) + self.dropout = keras.layers.Dropout(config.hidden_dropout_prob) + self.out_proj = keras.layers.Dense( + config.num_labels, + kernel_initializer=get_initializer(config.initializer_range), + activation="linear", + name="out_proj", + ) + self.config = config + + def call(self, features, training=False): + x = features[:, 0, :] # take token (equiv. to [CLS]) + x = self.dropout(x, training=training) + x = self.dense(x) + x = self.dropout(x, training=training) + x = self.out_proj(x) + return x + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + if getattr(self, "out_proj", None) is not None: + with tf.name_scope(self.out_proj.name): + self.out_proj.build([None, None, self.config.hidden_size]) + + +def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0): + """ + Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols + are ignored. This is modified from fairseq's `utils.make_positions`. + + Args: + x: tf.Tensor x: + + Returns: tf.Tensor + """ + # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. + mask = tf.cast(input_ids != padding_idx, tf.int64) + incremental_indices = (tf.cumsum(mask, axis=1) + past_key_values_length) * mask + return incremental_indices + padding_idx + + +__all__ = [ + "TFEsmForMaskedLM", + "TFEsmForSequenceClassification", + "TFEsmForTokenClassification", + "TFEsmModel", + "TFEsmPreTrainedModel", +] diff --git a/docs/transformers/build/lib/transformers/models/esm/openfold_utils/__init__.py b/docs/transformers/build/lib/transformers/models/esm/openfold_utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..02a8c149ae320dd9b045edc5df31760a4eebefd9 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/esm/openfold_utils/__init__.py @@ -0,0 +1,8 @@ +from .chunk_utils import chunk_layer +from .data_transforms import make_atom14_masks +from .feats import atom14_to_atom37, frames_and_literature_positions_to_atom14_pos, torsion_angles_to_frames +from .loss import compute_predicted_aligned_error, compute_tm +from .protein import Protein as OFProtein +from .protein import to_pdb +from .rigid_utils import Rigid, Rotation +from .tensor_utils import dict_multimap, flatten_final_dims, permute_final_dims diff --git a/docs/transformers/build/lib/transformers/models/esm/openfold_utils/chunk_utils.py b/docs/transformers/build/lib/transformers/models/esm/openfold_utils/chunk_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..51ff6b74d6c3f59ac23b68e2be0999b5f8dbb4ef --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/esm/openfold_utils/chunk_utils.py @@ -0,0 +1,397 @@ +# Copyright 2021 AlQuraishi Laboratory +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import math +from functools import partial +from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union + +import torch + +from .tensor_utils import tensor_tree_map, tree_map + + +def _fetch_dims(tree: Union[dict, list, tuple, torch.Tensor]) -> List[Tuple[int, ...]]: + shapes = [] + if isinstance(tree, dict): + for v in tree.values(): + shapes.extend(_fetch_dims(v)) + elif isinstance(tree, (list, tuple)): + for t in tree: + shapes.extend(_fetch_dims(t)) + elif isinstance(tree, torch.Tensor): + shapes.append(tree.shape) + else: + raise TypeError("Not supported") + + return shapes + + +@torch.jit.ignore +def _flat_idx_to_idx(flat_idx: int, dims: Tuple[int, ...]) -> Tuple[int, ...]: + idx = [] + for d in reversed(dims): + idx.append(flat_idx % d) + flat_idx = flat_idx // d + + return tuple(reversed(idx)) + + +@torch.jit.ignore +def _get_minimal_slice_set( + start: Sequence[int], + end: Sequence[int], + dims: Sequence[int], + start_edges: Optional[Sequence[bool]] = None, + end_edges: Optional[Sequence[bool]] = None, +) -> List[Tuple[slice, ...]]: + """ + Produces an ordered sequence of tensor slices that, when used in sequence on a tensor with shape dims, yields + tensors that contain every leaf in the contiguous range [start, end]. Care is taken to yield a short sequence of + slices, and perhaps even the shortest possible (I'm pretty sure it's the latter). + + end is INCLUSIVE. + """ + + # start_edges and end_edges both indicate whether, starting from any given + # dimension, the start/end index is at the top/bottom edge of the + # corresponding tensor, modeled as a tree + def reduce_edge_list(l: List[bool]) -> None: + tally = True + for i in range(len(l)): + reversed_idx = -1 * (i + 1) + l[reversed_idx] &= tally + tally = l[reversed_idx] + + if start_edges is None: + start_edges = [s == 0 for s in start] + reduce_edge_list(start_edges) + if end_edges is None: + end_edges = [e == (d - 1) for e, d in zip(end, dims)] + reduce_edge_list(end_edges) + + # Base cases. Either start/end are empty and we're done, or the final, + # one-dimensional tensor can be simply sliced + if len(start) == 0: + return [()] + elif len(start) == 1: + return [(slice(start[0], end[0] + 1),)] + + slices: List[Tuple[slice, ...]] = [] + path_list: List[slice] = [] + + # Dimensions common to start and end can be selected directly + for s, e in zip(start, end): + if s == e: + path_list.append(slice(s, s + 1)) + else: + break + + path: Tuple[slice, ...] = tuple(path_list) + divergence_idx = len(path) + + # start == end, and we're done + if divergence_idx == len(dims): + return [path] + + def upper() -> Tuple[Tuple[slice, ...], ...]: + assert start_edges is not None + assert end_edges is not None + + sdi = start[divergence_idx] + return tuple( + path + (slice(sdi, sdi + 1),) + s + for s in _get_minimal_slice_set( + start[divergence_idx + 1 :], + [d - 1 for d in dims[divergence_idx + 1 :]], + dims[divergence_idx + 1 :], + start_edges=start_edges[divergence_idx + 1 :], + end_edges=[True for _ in end_edges[divergence_idx + 1 :]], + ) + ) + + def lower() -> Tuple[Tuple[slice, ...], ...]: + assert start_edges is not None + assert end_edges is not None + + edi = end[divergence_idx] + return tuple( + path + (slice(edi, edi + 1),) + s + for s in _get_minimal_slice_set( + [0 for _ in start[divergence_idx + 1 :]], + end[divergence_idx + 1 :], + dims[divergence_idx + 1 :], + start_edges=[True for _ in start_edges[divergence_idx + 1 :]], + end_edges=end_edges[divergence_idx + 1 :], + ) + ) + + # If both start and end are at the edges of the subtree rooted at + # divergence_idx, we can just select the whole subtree at once + if start_edges[divergence_idx] and end_edges[divergence_idx]: + slices.append(path + (slice(start[divergence_idx], end[divergence_idx] + 1),)) + # If just start is at the edge, we can grab almost all of the subtree, + # treating only the ragged bottom edge as an edge case + elif start_edges[divergence_idx]: + slices.append(path + (slice(start[divergence_idx], end[divergence_idx]),)) + slices.extend(lower()) + # Analogous to the previous case, but the top is ragged this time + elif end_edges[divergence_idx]: + slices.extend(upper()) + slices.append(path + (slice(start[divergence_idx] + 1, end[divergence_idx] + 1),)) + # If both sides of the range are ragged, we need to handle both sides + # separately. If there's contiguous meat in between them, we can index it + # in one big chunk + else: + slices.extend(upper()) + middle_ground = end[divergence_idx] - start[divergence_idx] + if middle_ground > 1: + slices.append(path + (slice(start[divergence_idx] + 1, end[divergence_idx]),)) + slices.extend(lower()) + + return slices + + +@torch.jit.ignore +def _chunk_slice(t: torch.Tensor, flat_start: int, flat_end: int, no_batch_dims: int) -> torch.Tensor: + """ + Equivalent to + + t.reshape((-1,) + t.shape[no_batch_dims:])[flat_start:flat_end] + + but without the need for the initial reshape call, which can be memory-intensive in certain situations. The only + reshape operations in this function are performed on sub-tensors that scale with (flat_end - flat_start), the chunk + size. + """ + + batch_dims = t.shape[:no_batch_dims] + start_idx = list(_flat_idx_to_idx(flat_start, batch_dims)) + # _get_minimal_slice_set is inclusive + end_idx = list(_flat_idx_to_idx(flat_end - 1, batch_dims)) + + # Get an ordered list of slices to perform + slices = _get_minimal_slice_set( + start_idx, + end_idx, + batch_dims, + ) + + sliced_tensors = [t[s] for s in slices] + + return torch.cat([s.view((-1,) + t.shape[no_batch_dims:]) for s in sliced_tensors]) + + +def chunk_layer( + layer: Callable, + inputs: Dict[str, Any], + chunk_size: int, + no_batch_dims: int, + low_mem: bool = False, + _out: Any = None, + _add_into_out: bool = False, +) -> Any: + """ + Implements the "chunking" procedure described in section 1.11.8. + + Layer outputs and inputs are assumed to be simple "pytrees," consisting only of (arbitrarily nested) lists, tuples, + and dicts with torch.Tensor leaves. + + Args: + layer: + The layer to be applied chunk-wise + inputs: + A (non-nested) dictionary of keyworded inputs. All leaves must be tensors and must share the same batch + dimensions. + chunk_size: + The number of sub-batches per chunk. If multiple batch dimensions are specified, a "sub-batch" is defined + as a single indexing of all batch dimensions simultaneously (s.t. the number of sub-batches is the product + of the batch dimensions). + no_batch_dims: + How many of the initial dimensions of each input tensor can be considered batch dimensions. + low_mem: + Avoids flattening potentially large input tensors. Unnecessary in most cases, and is ever so slightly + slower than the default setting. + Returns: + The reassembled output of the layer on the inputs. + """ + if not (len(inputs) > 0): + raise ValueError("Must provide at least one input") + + initial_dims = [shape[:no_batch_dims] for shape in _fetch_dims(inputs)] + orig_batch_dims = tuple([max(s) for s in zip(*initial_dims)]) + + def _prep_inputs(t: torch.Tensor) -> torch.Tensor: + if not low_mem: + if not sum(t.shape[:no_batch_dims]) == no_batch_dims: + t = t.expand(orig_batch_dims + t.shape[no_batch_dims:]) + t = t.reshape(-1, *t.shape[no_batch_dims:]) + else: + t = t.expand(orig_batch_dims + t.shape[no_batch_dims:]) + return t + + prepped_inputs: Dict[str, Any] = tensor_tree_map(_prep_inputs, inputs) + prepped_outputs = None + if _out is not None: + prepped_outputs = tensor_tree_map(lambda t: t.view([-1] + list(t.shape[no_batch_dims:])), _out) + + flat_batch_dim = 1 + for d in orig_batch_dims: + flat_batch_dim *= d + + no_chunks = flat_batch_dim // chunk_size + (flat_batch_dim % chunk_size != 0) + + def _select_chunk(t: torch.Tensor) -> torch.Tensor: + return t[i : i + chunk_size] if t.shape[0] != 1 else t + + i = 0 + out = prepped_outputs + for _ in range(no_chunks): + # Chunk the input + if not low_mem: + select_chunk = _select_chunk + else: + select_chunk = partial( + _chunk_slice, + flat_start=i, + flat_end=min(flat_batch_dim, i + chunk_size), + no_batch_dims=len(orig_batch_dims), + ) + + chunks: Dict[str, Any] = tensor_tree_map(select_chunk, prepped_inputs) + + # Run the layer on the chunk + output_chunk = layer(**chunks) + + # Allocate space for the output + if out is None: + out = tensor_tree_map(lambda t: t.new_zeros((flat_batch_dim,) + t.shape[1:]), output_chunk) + + # Put the chunk in its pre-allocated space + if isinstance(output_chunk, dict): + + def assign(d1: dict, d2: dict) -> None: + for k, v in d1.items(): + if isinstance(v, dict): + assign(v, d2[k]) + else: + if _add_into_out: + v[i : i + chunk_size] += d2[k] + else: + v[i : i + chunk_size] = d2[k] + + assign(out, output_chunk) + elif isinstance(output_chunk, tuple): + for x1, x2 in zip(out, output_chunk): + if _add_into_out: + x1[i : i + chunk_size] += x2 + else: + x1[i : i + chunk_size] = x2 + elif isinstance(output_chunk, torch.Tensor): + if _add_into_out: + out[i : i + chunk_size] += output_chunk + else: + out[i : i + chunk_size] = output_chunk + else: + raise TypeError("Not supported") + + i += chunk_size + + out = tensor_tree_map(lambda t: t.view(orig_batch_dims + t.shape[1:]), out) + + return out + + +class ChunkSizeTuner: + def __init__( + self, + # Heuristically, runtimes for most of the modules in the network + # plateau earlier than this on all GPUs I've run the model on. + max_chunk_size: int = 512, + ): + self.max_chunk_size = max_chunk_size + self.cached_chunk_size: Optional[int] = None + self.cached_arg_data: Optional[tuple] = None + + def _determine_favorable_chunk_size(self, fn: Callable, args: tuple, min_chunk_size: int) -> int: + logging.info("Tuning chunk size...") + + if min_chunk_size >= self.max_chunk_size: + return min_chunk_size + + candidates: List[int] = [2**l for l in range(int(math.log(self.max_chunk_size, 2)) + 1)] + candidates = [c for c in candidates if c > min_chunk_size] + candidates = [min_chunk_size] + candidates + candidates[-1] += 4 + + def test_chunk_size(chunk_size: int) -> bool: + try: + with torch.no_grad(): + fn(*args, chunk_size=chunk_size) + return True + except RuntimeError: + return False + + min_viable_chunk_size_index = 0 + i = len(candidates) - 1 + while i > min_viable_chunk_size_index: + viable = test_chunk_size(candidates[i]) + if not viable: + i = (min_viable_chunk_size_index + i) // 2 + else: + min_viable_chunk_size_index = i + i = (i + len(candidates) - 1) // 2 + + return candidates[min_viable_chunk_size_index] + + def _compare_arg_caches(self, ac1: Iterable, ac2: Iterable) -> bool: + consistent = True + for a1, a2 in zip(ac1, ac2): + assert type(ac1) is type(ac2) + if isinstance(ac1, (list, tuple)): + consistent &= self._compare_arg_caches(a1, a2) + elif isinstance(ac1, dict): + a1_items = [v for _, v in sorted(a1.items(), key=lambda x: x[0])] + a2_items = [v for _, v in sorted(a2.items(), key=lambda x: x[0])] + consistent &= self._compare_arg_caches(a1_items, a2_items) + else: + consistent &= a1 == a2 + + return consistent + + def tune_chunk_size( + self, + representative_fn: Callable, + args: tuple, + min_chunk_size: int, + ) -> int: + consistent = True + arg_data: tuple = tree_map(lambda a: a.shape if isinstance(a, torch.Tensor) else a, args, object) + if self.cached_arg_data is not None: + # If args have changed shape/value, we need to re-tune + assert len(self.cached_arg_data) == len(arg_data) + consistent = self._compare_arg_caches(self.cached_arg_data, arg_data) + else: + # Otherwise, we can reuse the precomputed value + consistent = False + + if not consistent: + self.cached_chunk_size = self._determine_favorable_chunk_size( + representative_fn, + args, + min_chunk_size, + ) + self.cached_arg_data = arg_data + + assert self.cached_chunk_size is not None + + return self.cached_chunk_size diff --git a/docs/transformers/build/lib/transformers/models/esm/openfold_utils/data_transforms.py b/docs/transformers/build/lib/transformers/models/esm/openfold_utils/data_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..8d4c17589ae66df2a8fd0ccfe8d6e335004eed9a --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/esm/openfold_utils/data_transforms.py @@ -0,0 +1,93 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict + +import numpy as np +import torch + +from . import residue_constants as rc +from .tensor_utils import tensor_tree_map, tree_map + + +def make_atom14_masks(protein: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """Construct denser atom positions (14 dimensions instead of 37).""" + restype_atom14_to_atom37_list = [] + restype_atom37_to_atom14_list = [] + restype_atom14_mask_list = [] + + for rt in rc.restypes: + atom_names = rc.restype_name_to_atom14_names[rc.restype_1to3[rt]] + restype_atom14_to_atom37_list.append([(rc.atom_order[name] if name else 0) for name in atom_names]) + atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)} + restype_atom37_to_atom14_list.append( + [(atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0) for name in rc.atom_types] + ) + + restype_atom14_mask_list.append([(1.0 if name else 0.0) for name in atom_names]) + + # Add dummy mapping for restype 'UNK' + restype_atom14_to_atom37_list.append([0] * 14) + restype_atom37_to_atom14_list.append([0] * 37) + restype_atom14_mask_list.append([0.0] * 14) + + restype_atom14_to_atom37 = torch.tensor( + restype_atom14_to_atom37_list, + dtype=torch.int32, + device=protein["aatype"].device, + ) + restype_atom37_to_atom14 = torch.tensor( + restype_atom37_to_atom14_list, + dtype=torch.int32, + device=protein["aatype"].device, + ) + restype_atom14_mask = torch.tensor( + restype_atom14_mask_list, + dtype=torch.float32, + device=protein["aatype"].device, + ) + protein_aatype = protein["aatype"].to(torch.long) + + # create the mapping for (residx, atom14) --> atom37, i.e. an array + # with shape (num_res, 14) containing the atom37 indices for this protein + residx_atom14_to_atom37 = restype_atom14_to_atom37[protein_aatype] + residx_atom14_mask = restype_atom14_mask[protein_aatype] + + protein["atom14_atom_exists"] = residx_atom14_mask + protein["residx_atom14_to_atom37"] = residx_atom14_to_atom37.long() + + # create the gather indices for mapping back + residx_atom37_to_atom14 = restype_atom37_to_atom14[protein_aatype] + protein["residx_atom37_to_atom14"] = residx_atom37_to_atom14.long() + + # create the corresponding mask + restype_atom37_mask = torch.zeros([21, 37], dtype=torch.float32, device=protein["aatype"].device) + for restype, restype_letter in enumerate(rc.restypes): + restype_name = rc.restype_1to3[restype_letter] + atom_names = rc.residue_atoms[restype_name] + for atom_name in atom_names: + atom_type = rc.atom_order[atom_name] + restype_atom37_mask[restype, atom_type] = 1 + + residx_atom37_mask = restype_atom37_mask[protein_aatype] + protein["atom37_atom_exists"] = residx_atom37_mask + + return protein + + +def make_atom14_masks_np(batch: Dict[str, torch.Tensor]) -> Dict[str, np.ndarray]: + batch = tree_map(lambda n: torch.tensor(n, device=batch["aatype"].device), batch, np.ndarray) + out = tensor_tree_map(lambda t: np.array(t), make_atom14_masks(batch)) + return out diff --git a/docs/transformers/build/lib/transformers/models/esm/openfold_utils/feats.py b/docs/transformers/build/lib/transformers/models/esm/openfold_utils/feats.py new file mode 100644 index 0000000000000000000000000000000000000000..ac7b90dfe79b24b852cb26fca998bda831f36a6f --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/esm/openfold_utils/feats.py @@ -0,0 +1,253 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Tuple, overload + +import torch +import torch.types +from torch import nn + +from . import residue_constants as rc +from .rigid_utils import Rigid, Rotation +from .tensor_utils import batched_gather + + +@overload +def pseudo_beta_fn(aatype: torch.Tensor, all_atom_positions: torch.Tensor, all_atom_masks: None) -> torch.Tensor: ... + + +@overload +def pseudo_beta_fn( + aatype: torch.Tensor, all_atom_positions: torch.Tensor, all_atom_masks: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: ... + + +def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks): + is_gly = aatype == rc.restype_order["G"] + ca_idx = rc.atom_order["CA"] + cb_idx = rc.atom_order["CB"] + pseudo_beta = torch.where( + is_gly[..., None].expand(*((-1,) * len(is_gly.shape)), 3), + all_atom_positions[..., ca_idx, :], + all_atom_positions[..., cb_idx, :], + ) + + if all_atom_masks is not None: + pseudo_beta_mask = torch.where( + is_gly, + all_atom_masks[..., ca_idx], + all_atom_masks[..., cb_idx], + ) + return pseudo_beta, pseudo_beta_mask + else: + return pseudo_beta + + +def atom14_to_atom37(atom14: torch.Tensor, batch: Dict[str, torch.Tensor]) -> torch.Tensor: + atom37_data = batched_gather( + atom14, + batch["residx_atom37_to_atom14"], + dim=-2, + no_batch_dims=len(atom14.shape[:-2]), + ) + + atom37_data = atom37_data * batch["atom37_atom_exists"][..., None] + + return atom37_data + + +def build_template_angle_feat(template_feats: Dict[str, torch.Tensor]) -> torch.Tensor: + template_aatype = template_feats["template_aatype"] + torsion_angles_sin_cos = template_feats["template_torsion_angles_sin_cos"] + alt_torsion_angles_sin_cos = template_feats["template_alt_torsion_angles_sin_cos"] + torsion_angles_mask = template_feats["template_torsion_angles_mask"] + template_angle_feat = torch.cat( + [ + nn.functional.one_hot(template_aatype, 22), + torsion_angles_sin_cos.reshape(*torsion_angles_sin_cos.shape[:-2], 14), + alt_torsion_angles_sin_cos.reshape(*alt_torsion_angles_sin_cos.shape[:-2], 14), + torsion_angles_mask, + ], + dim=-1, + ) + + return template_angle_feat + + +def build_template_pair_feat( + batch: Dict[str, torch.Tensor], + min_bin: torch.types.Number, + max_bin: torch.types.Number, + no_bins: int, + use_unit_vector: bool = False, + eps: float = 1e-20, + inf: float = 1e8, +) -> torch.Tensor: + template_mask = batch["template_pseudo_beta_mask"] + template_mask_2d = template_mask[..., None] * template_mask[..., None, :] + + # Compute distogram (this seems to differ slightly from Alg. 5) + tpb = batch["template_pseudo_beta"] + dgram = torch.sum((tpb[..., None, :] - tpb[..., None, :, :]) ** 2, dim=-1, keepdim=True) + lower = torch.linspace(min_bin, max_bin, no_bins, device=tpb.device) ** 2 + upper = torch.cat([lower[1:], lower.new_tensor([inf])], dim=-1) + dgram = ((dgram > lower) * (dgram < upper)).type(dgram.dtype) + + to_concat = [dgram, template_mask_2d[..., None]] + + aatype_one_hot: torch.LongTensor = nn.functional.one_hot( + batch["template_aatype"], + rc.restype_num + 2, + ) + + n_res = batch["template_aatype"].shape[-1] + to_concat.append(aatype_one_hot[..., None, :, :].expand(*aatype_one_hot.shape[:-2], n_res, -1, -1)) + to_concat.append(aatype_one_hot[..., None, :].expand(*aatype_one_hot.shape[:-2], -1, n_res, -1)) + + n, ca, c = [rc.atom_order[a] for a in ["N", "CA", "C"]] + rigids = Rigid.make_transform_from_reference( + n_xyz=batch["template_all_atom_positions"][..., n, :], + ca_xyz=batch["template_all_atom_positions"][..., ca, :], + c_xyz=batch["template_all_atom_positions"][..., c, :], + eps=eps, + ) + points = rigids.get_trans()[..., None, :, :] + rigid_vec = rigids[..., None].invert_apply(points) + + inv_distance_scalar = torch.rsqrt(eps + torch.sum(rigid_vec**2, dim=-1)) + + t_aa_masks = batch["template_all_atom_mask"] + template_mask = t_aa_masks[..., n] * t_aa_masks[..., ca] * t_aa_masks[..., c] + template_mask_2d = template_mask[..., None] * template_mask[..., None, :] + + inv_distance_scalar = inv_distance_scalar * template_mask_2d + unit_vector = rigid_vec * inv_distance_scalar[..., None] + + if not use_unit_vector: + unit_vector = unit_vector * 0.0 + + to_concat.extend(torch.unbind(unit_vector[..., None, :], dim=-1)) + to_concat.append(template_mask_2d[..., None]) + + act = torch.cat(to_concat, dim=-1) + act = act * template_mask_2d[..., None] + + return act + + +def build_extra_msa_feat(batch: Dict[str, torch.Tensor]) -> torch.Tensor: + msa_1hot: torch.LongTensor = nn.functional.one_hot(batch["extra_msa"], 23) + msa_feat = [ + msa_1hot, + batch["extra_has_deletion"].unsqueeze(-1), + batch["extra_deletion_value"].unsqueeze(-1), + ] + return torch.cat(msa_feat, dim=-1) + + +def torsion_angles_to_frames( + r: Rigid, + alpha: torch.Tensor, + aatype: torch.Tensor, + rrgdf: torch.Tensor, +) -> Rigid: + # [*, N, 8, 4, 4] + default_4x4 = rrgdf[aatype, ...] + + # [*, N, 8] transformations, i.e. + # One [*, N, 8, 3, 3] rotation matrix and + # One [*, N, 8, 3] translation matrix + default_r = r.from_tensor_4x4(default_4x4) + + bb_rot = alpha.new_zeros((*((1,) * len(alpha.shape[:-1])), 2)) + bb_rot[..., 1] = 1 + + # [*, N, 8, 2] + alpha = torch.cat([bb_rot.expand(*alpha.shape[:-2], -1, -1), alpha], dim=-2) + + # [*, N, 8, 3, 3] + # Produces rotation matrices of the form: + # [ + # [1, 0 , 0 ], + # [0, a_2,-a_1], + # [0, a_1, a_2] + # ] + # This follows the original code rather than the supplement, which uses + # different indices. + + all_rots = alpha.new_zeros(default_r.get_rots().get_rot_mats().shape) + all_rots[..., 0, 0] = 1 + all_rots[..., 1, 1] = alpha[..., 1] + all_rots[..., 1, 2] = -alpha[..., 0] + all_rots[..., 2, 1:] = alpha + + all_frames = default_r.compose(Rigid(Rotation(rot_mats=all_rots), None)) + + chi2_frame_to_frame = all_frames[..., 5] + chi3_frame_to_frame = all_frames[..., 6] + chi4_frame_to_frame = all_frames[..., 7] + + chi1_frame_to_bb = all_frames[..., 4] + chi2_frame_to_bb = chi1_frame_to_bb.compose(chi2_frame_to_frame) + chi3_frame_to_bb = chi2_frame_to_bb.compose(chi3_frame_to_frame) + chi4_frame_to_bb = chi3_frame_to_bb.compose(chi4_frame_to_frame) + + all_frames_to_bb = Rigid.cat( + [ + all_frames[..., :5], + chi2_frame_to_bb.unsqueeze(-1), + chi3_frame_to_bb.unsqueeze(-1), + chi4_frame_to_bb.unsqueeze(-1), + ], + dim=-1, + ) + + all_frames_to_global = r[..., None].compose(all_frames_to_bb) + + return all_frames_to_global + + +def frames_and_literature_positions_to_atom14_pos( + r: Rigid, + aatype: torch.Tensor, + default_frames: torch.Tensor, + group_idx: torch.Tensor, + atom_mask: torch.Tensor, + lit_positions: torch.Tensor, +) -> torch.Tensor: + # [*, N, 14] + group_mask = group_idx[aatype, ...] + + # [*, N, 14, 8] + group_mask_one_hot: torch.LongTensor = nn.functional.one_hot( + group_mask, + num_classes=default_frames.shape[-3], + ) + + # [*, N, 14, 8] + t_atoms_to_global = r[..., None, :] * group_mask_one_hot + + # [*, N, 14] + t_atoms_to_global = t_atoms_to_global.map_tensor_fn(lambda x: torch.sum(x, dim=-1)) + + # [*, N, 14, 1] + atom_mask = atom_mask[aatype, ...].unsqueeze(-1) + + # [*, N, 14, 3] + lit_positions = lit_positions[aatype, ...] + pred_positions = t_atoms_to_global.apply(lit_positions) + pred_positions = pred_positions * atom_mask + + return pred_positions diff --git a/docs/transformers/build/lib/transformers/models/esm/openfold_utils/loss.py b/docs/transformers/build/lib/transformers/models/esm/openfold_utils/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..8c442786dc82ba2ebe243923509ed76a40de2a01 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/esm/openfold_utils/loss.py @@ -0,0 +1,105 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Optional, Tuple + +import torch + + +def _calculate_bin_centers(boundaries: torch.Tensor) -> torch.Tensor: + step = boundaries[1] - boundaries[0] + bin_centers = boundaries + step / 2 + bin_centers = torch.cat([bin_centers, (bin_centers[-1] + step).unsqueeze(-1)], dim=0) + return bin_centers + + +def _calculate_expected_aligned_error( + alignment_confidence_breaks: torch.Tensor, + aligned_distance_error_probs: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + bin_centers = _calculate_bin_centers(alignment_confidence_breaks) + return ( + torch.sum(aligned_distance_error_probs * bin_centers, dim=-1), + bin_centers[-1], + ) + + +def compute_predicted_aligned_error( + logits: torch.Tensor, + max_bin: int = 31, + no_bins: int = 64, + **kwargs, +) -> Dict[str, torch.Tensor]: + """Computes aligned confidence metrics from logits. + + Args: + logits: [*, num_res, num_res, num_bins] the logits output from + PredictedAlignedErrorHead. + max_bin: Maximum bin value + no_bins: Number of bins + Returns: + aligned_confidence_probs: [*, num_res, num_res, num_bins] the predicted + aligned error probabilities over bins for each residue pair. + predicted_aligned_error: [*, num_res, num_res] the expected aligned distance + error for each pair of residues. + max_predicted_aligned_error: [*] the maximum predicted error possible. + """ + boundaries = torch.linspace(0, max_bin, steps=(no_bins - 1), device=logits.device) + + aligned_confidence_probs = torch.nn.functional.softmax(logits, dim=-1) + predicted_aligned_error, max_predicted_aligned_error = _calculate_expected_aligned_error( + alignment_confidence_breaks=boundaries, + aligned_distance_error_probs=aligned_confidence_probs, + ) + + return { + "aligned_confidence_probs": aligned_confidence_probs, + "predicted_aligned_error": predicted_aligned_error, + "max_predicted_aligned_error": max_predicted_aligned_error, + } + + +def compute_tm( + logits: torch.Tensor, + residue_weights: Optional[torch.Tensor] = None, + max_bin: int = 31, + no_bins: int = 64, + eps: float = 1e-8, + **kwargs, +) -> torch.Tensor: + if residue_weights is None: + residue_weights = logits.new_ones(logits.shape[-2]) + + boundaries = torch.linspace(0, max_bin, steps=(no_bins - 1), device=logits.device) + + bin_centers = _calculate_bin_centers(boundaries) + torch.sum(residue_weights) + n = logits.shape[-2] + clipped_n = max(n, 19) + + d0 = 1.24 * (clipped_n - 15) ** (1.0 / 3) - 1.8 + + probs = torch.nn.functional.softmax(logits, dim=-1) + + tm_per_bin = 1.0 / (1 + (bin_centers**2) / (d0**2)) + predicted_tm_term = torch.sum(probs * tm_per_bin, dim=-1) + + normed_residue_mask = residue_weights / (eps + residue_weights.sum()) + per_alignment = torch.sum(predicted_tm_term * normed_residue_mask, dim=-1) + + weighted = per_alignment * residue_weights + + argmax = (weighted == torch.max(weighted)).nonzero()[0] + return per_alignment[tuple(argmax)] diff --git a/docs/transformers/build/lib/transformers/models/esm/openfold_utils/protein.py b/docs/transformers/build/lib/transformers/models/esm/openfold_utils/protein.py new file mode 100644 index 0000000000000000000000000000000000000000..ae9d8c13277bd8c5e7dd152f50d549b6f7286af3 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/esm/openfold_utils/protein.py @@ -0,0 +1,330 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Protein data type.""" + +import dataclasses +import re +import string +from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence, Tuple + +import numpy as np + +from . import residue_constants + + +FeatureDict = Mapping[str, np.ndarray] +ModelOutput = Mapping[str, Any] # Is a nested dict. +PICO_TO_ANGSTROM = 0.01 + + +@dataclasses.dataclass(frozen=True) +class Protein: + """Protein structure representation.""" + + # Cartesian coordinates of atoms in angstroms. The atom types correspond to + # residue_constants.atom_types, i.e. the first three are N, CA, CB. + atom_positions: np.ndarray # [num_res, num_atom_type, 3] + + # Amino-acid type for each residue represented as an integer between 0 and + # 20, where 20 is 'X'. + aatype: np.ndarray # [num_res] + + # Binary float mask to indicate presence of a particular atom. 1.0 if an atom + # is present and 0.0 if not. This should be used for loss masking. + atom_mask: np.ndarray # [num_res, num_atom_type] + + # Residue index as used in PDB. It is not necessarily continuous or 0-indexed. + residue_index: np.ndarray # [num_res] + + # B-factors, or temperature factors, of each residue (in sq. angstroms units), + # representing the displacement of the residue from its ground truth mean + # value. + b_factors: np.ndarray # [num_res, num_atom_type] + + # Chain indices for multi-chain predictions + chain_index: Optional[np.ndarray] = None + + # Optional remark about the protein. Included as a comment in output PDB + # files + remark: Optional[str] = None + + # Templates used to generate this protein (prediction-only) + parents: Optional[Sequence[str]] = None + + # Chain corresponding to each parent + parents_chain_index: Optional[Sequence[int]] = None + + +def from_proteinnet_string(proteinnet_str: str) -> Protein: + tag_re = r"(\[[A-Z]+\]\n)" + tags: List[str] = [tag.strip() for tag in re.split(tag_re, proteinnet_str) if len(tag) > 0] + groups: Iterator[Tuple[str, List[str]]] = zip(tags[0::2], [l.split("\n") for l in tags[1::2]]) + + atoms: List[str] = ["N", "CA", "C"] + aatype = None + atom_positions = None + atom_mask = None + for g in groups: + if "[PRIMARY]" == g[0]: + seq = g[1][0].strip() + for i in range(len(seq)): + if seq[i] not in residue_constants.restypes: + seq[i] = "X" # FIXME: strings are immutable + aatype = np.array( + [residue_constants.restype_order.get(res_symbol, residue_constants.restype_num) for res_symbol in seq] + ) + elif "[TERTIARY]" == g[0]: + tertiary: List[List[float]] = [] + for axis in range(3): + tertiary.append(list(map(float, g[1][axis].split()))) + tertiary_np = np.array(tertiary) + atom_positions = np.zeros((len(tertiary[0]) // 3, residue_constants.atom_type_num, 3)).astype(np.float32) + for i, atom in enumerate(atoms): + atom_positions[:, residue_constants.atom_order[atom], :] = np.transpose(tertiary_np[:, i::3]) + atom_positions *= PICO_TO_ANGSTROM + elif "[MASK]" == g[0]: + mask = np.array(list(map({"-": 0, "+": 1}.get, g[1][0].strip()))) + atom_mask = np.zeros( + ( + len(mask), + residue_constants.atom_type_num, + ) + ).astype(np.float32) + for i, atom in enumerate(atoms): + atom_mask[:, residue_constants.atom_order[atom]] = 1 + atom_mask *= mask[..., None] + + assert aatype is not None + + return Protein( + atom_positions=atom_positions, + atom_mask=atom_mask, + aatype=aatype, + residue_index=np.arange(len(aatype)), + b_factors=None, + ) + + +def get_pdb_headers(prot: Protein, chain_id: int = 0) -> List[str]: + pdb_headers: List[str] = [] + + remark = prot.remark + if remark is not None: + pdb_headers.append(f"REMARK {remark}") + + parents = prot.parents + parents_chain_index = prot.parents_chain_index + if parents is not None and parents_chain_index is not None: + parents = [p for i, p in zip(parents_chain_index, parents) if i == chain_id] + + if parents is None or len(parents) == 0: + parents = ["N/A"] + + pdb_headers.append(f"PARENT {' '.join(parents)}") + + return pdb_headers + + +def add_pdb_headers(prot: Protein, pdb_str: str) -> str: + """Add pdb headers to an existing PDB string. Useful during multi-chain + recycling + """ + out_pdb_lines: List[str] = [] + lines = pdb_str.split("\n") + + remark = prot.remark + if remark is not None: + out_pdb_lines.append(f"REMARK {remark}") + + parents_per_chain: List[List[str]] + if prot.parents is not None and len(prot.parents) > 0: + parents_per_chain = [] + if prot.parents_chain_index is not None: + parent_dict: Dict[str, List[str]] = {} + for p, i in zip(prot.parents, prot.parents_chain_index): + parent_dict.setdefault(str(i), []) + parent_dict[str(i)].append(p) + + max_idx = max([int(chain_idx) for chain_idx in parent_dict]) + for i in range(max_idx + 1): + chain_parents = parent_dict.get(str(i), ["N/A"]) + parents_per_chain.append(chain_parents) + else: + parents_per_chain.append(list(prot.parents)) + else: + parents_per_chain = [["N/A"]] + + def make_parent_line(p: Sequence[str]) -> str: + return f"PARENT {' '.join(p)}" + + out_pdb_lines.append(make_parent_line(parents_per_chain[0])) + + chain_counter = 0 + for i, l in enumerate(lines): + if "PARENT" not in l and "REMARK" not in l: + out_pdb_lines.append(l) + if "TER" in l and "END" not in lines[i + 1]: + chain_counter += 1 + if not chain_counter >= len(parents_per_chain): + chain_parents = parents_per_chain[chain_counter] + else: + chain_parents = ["N/A"] + + out_pdb_lines.append(make_parent_line(chain_parents)) + + return "\n".join(out_pdb_lines) + + +def to_pdb(prot: Protein) -> str: + """Converts a `Protein` instance to a PDB string. + + Args: + prot: The protein to convert to PDB. + + Returns: + PDB string. + """ + restypes = residue_constants.restypes + ["X"] + + def res_1to3(r: int) -> str: + return residue_constants.restype_1to3.get(restypes[r], "UNK") + + atom_types = residue_constants.atom_types + + pdb_lines: List[str] = [] + + atom_mask = prot.atom_mask + aatype = prot.aatype + atom_positions = prot.atom_positions + residue_index = prot.residue_index.astype(np.int32) + b_factors = prot.b_factors + chain_index = prot.chain_index + + if np.any(aatype > residue_constants.restype_num): + raise ValueError("Invalid aatypes.") + + headers = get_pdb_headers(prot) + if len(headers) > 0: + pdb_lines.extend(headers) + + n = aatype.shape[0] + atom_index = 1 + prev_chain_index = 0 + chain_tags = string.ascii_uppercase + chain_tag = None + # Add all atom sites. + for i in range(n): + res_name_3 = res_1to3(aatype[i]) + for atom_name, pos, mask, b_factor in zip(atom_types, atom_positions[i], atom_mask[i], b_factors[i]): + if mask < 0.5: + continue + + record_type = "ATOM" + name = atom_name if len(atom_name) == 4 else f" {atom_name}" + alt_loc = "" + insertion_code = "" + occupancy = 1.00 + element = atom_name[0] # Protein supports only C, N, O, S, this works. + charge = "" + + chain_tag = "A" + if chain_index is not None: + chain_tag = chain_tags[chain_index[i]] + + # PDB is a columnar format, every space matters here! + atom_line = ( + f"{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}" + f"{res_name_3:>3} {chain_tag:>1}" + f"{residue_index[i]:>4}{insertion_code:>1} " + f"{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}" + f"{occupancy:>6.2f}{b_factor:>6.2f} " + f"{element:>2}{charge:>2}" + ) + pdb_lines.append(atom_line) + atom_index += 1 + + should_terminate = i == n - 1 + if chain_index is not None: + if i != n - 1 and chain_index[i + 1] != prev_chain_index: + should_terminate = True + prev_chain_index = chain_index[i + 1] + + if should_terminate: + # Close the chain. + chain_end = "TER" + chain_termination_line = ( + f"{chain_end:<6}{atom_index:>5} {res_1to3(aatype[i]):>3} {chain_tag:>1}{residue_index[i]:>4}" + ) + pdb_lines.append(chain_termination_line) + atom_index += 1 + + if i != n - 1: + # "prev" is a misnomer here. This happens at the beginning of + # each new chain. + pdb_lines.extend(get_pdb_headers(prot, prev_chain_index)) + + pdb_lines.append("END") + pdb_lines.append("") + return "\n".join(pdb_lines) + + +def ideal_atom_mask(prot: Protein) -> np.ndarray: + """Computes an ideal atom mask. + + `Protein.atom_mask` typically is defined according to the atoms that are reported in the PDB. This function + computes a mask according to heavy atoms that should be present in the given sequence of amino acids. + + Args: + prot: `Protein` whose fields are `numpy.ndarray` objects. + + Returns: + An ideal atom mask. + """ + return residue_constants.STANDARD_ATOM_MASK[prot.aatype] + + +def from_prediction( + features: FeatureDict, + result: ModelOutput, + b_factors: Optional[np.ndarray] = None, + chain_index: Optional[np.ndarray] = None, + remark: Optional[str] = None, + parents: Optional[Sequence[str]] = None, + parents_chain_index: Optional[Sequence[int]] = None, +) -> Protein: + """Assembles a protein from a prediction. + + Args: + features: Dictionary holding model inputs. + result: Dictionary holding model outputs. + b_factors: (Optional) B-factors to use for the protein. + chain_index: (Optional) Chain indices for multi-chain predictions + remark: (Optional) Remark about the prediction + parents: (Optional) List of template names + Returns: + A protein instance. + """ + return Protein( + aatype=features["aatype"], + atom_positions=result["final_atom_positions"], + atom_mask=result["final_atom_mask"], + residue_index=features["residue_index"] + 1, + b_factors=b_factors if b_factors is not None else np.zeros_like(result["final_atom_mask"]), + chain_index=chain_index, + remark=remark, + parents=parents, + parents_chain_index=parents_chain_index, + ) diff --git a/docs/transformers/build/lib/transformers/models/esm/openfold_utils/residue_constants.py b/docs/transformers/build/lib/transformers/models/esm/openfold_utils/residue_constants.py new file mode 100644 index 0000000000000000000000000000000000000000..b05a603fb29f51e2b870c73fa28899cd91936bdb --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/esm/openfold_utils/residue_constants.py @@ -0,0 +1,981 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Constants used in AlphaFold.""" + +import collections +import copy +import functools +from importlib import resources +from typing import Dict, List, Mapping, Sequence, Tuple + +import numpy as np + + +# Internal import (35fd). + + +# Distance from one CA to next CA [trans configuration: omega = 180]. +ca_ca = 3.80209737096 + +# Format: The list for each AA type contains chi1, chi2, chi3, chi4 in +# this order (or a relevant subset from chi1 onwards). ALA and GLY don't have +# chi angles so their chi angle lists are empty. +chi_angles_atoms: Dict[str, List[List[str]]] = { + "ALA": [], + # Chi5 in arginine is always 0 +- 5 degrees, so ignore it. + "ARG": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD"], ["CB", "CG", "CD", "NE"], ["CG", "CD", "NE", "CZ"]], + "ASN": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "OD1"]], + "ASP": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "OD1"]], + "CYS": [["N", "CA", "CB", "SG"]], + "GLN": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD"], ["CB", "CG", "CD", "OE1"]], + "GLU": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD"], ["CB", "CG", "CD", "OE1"]], + "GLY": [], + "HIS": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "ND1"]], + "ILE": [["N", "CA", "CB", "CG1"], ["CA", "CB", "CG1", "CD1"]], + "LEU": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]], + "LYS": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD"], ["CB", "CG", "CD", "CE"], ["CG", "CD", "CE", "NZ"]], + "MET": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "SD"], ["CB", "CG", "SD", "CE"]], + "PHE": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]], + "PRO": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD"]], + "SER": [["N", "CA", "CB", "OG"]], + "THR": [["N", "CA", "CB", "OG1"]], + "TRP": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]], + "TYR": [["N", "CA", "CB", "CG"], ["CA", "CB", "CG", "CD1"]], + "VAL": [["N", "CA", "CB", "CG1"]], +} + +# If chi angles given in fixed-length array, this matrix determines how to mask +# them for each AA type. The order is as per restype_order (see below). +chi_angles_mask: List[List[float]] = [ + [0.0, 0.0, 0.0, 0.0], # ALA + [1.0, 1.0, 1.0, 1.0], # ARG + [1.0, 1.0, 0.0, 0.0], # ASN + [1.0, 1.0, 0.0, 0.0], # ASP + [1.0, 0.0, 0.0, 0.0], # CYS + [1.0, 1.0, 1.0, 0.0], # GLN + [1.0, 1.0, 1.0, 0.0], # GLU + [0.0, 0.0, 0.0, 0.0], # GLY + [1.0, 1.0, 0.0, 0.0], # HIS + [1.0, 1.0, 0.0, 0.0], # ILE + [1.0, 1.0, 0.0, 0.0], # LEU + [1.0, 1.0, 1.0, 1.0], # LYS + [1.0, 1.0, 1.0, 0.0], # MET + [1.0, 1.0, 0.0, 0.0], # PHE + [1.0, 1.0, 0.0, 0.0], # PRO + [1.0, 0.0, 0.0, 0.0], # SER + [1.0, 0.0, 0.0, 0.0], # THR + [1.0, 1.0, 0.0, 0.0], # TRP + [1.0, 1.0, 0.0, 0.0], # TYR + [1.0, 0.0, 0.0, 0.0], # VAL +] + +# The following chi angles are pi periodic: they can be rotated by a multiple +# of pi without affecting the structure. +chi_pi_periodic: List[List[float]] = [ + [0.0, 0.0, 0.0, 0.0], # ALA + [0.0, 0.0, 0.0, 0.0], # ARG + [0.0, 0.0, 0.0, 0.0], # ASN + [0.0, 1.0, 0.0, 0.0], # ASP + [0.0, 0.0, 0.0, 0.0], # CYS + [0.0, 0.0, 0.0, 0.0], # GLN + [0.0, 0.0, 1.0, 0.0], # GLU + [0.0, 0.0, 0.0, 0.0], # GLY + [0.0, 0.0, 0.0, 0.0], # HIS + [0.0, 0.0, 0.0, 0.0], # ILE + [0.0, 0.0, 0.0, 0.0], # LEU + [0.0, 0.0, 0.0, 0.0], # LYS + [0.0, 0.0, 0.0, 0.0], # MET + [0.0, 1.0, 0.0, 0.0], # PHE + [0.0, 0.0, 0.0, 0.0], # PRO + [0.0, 0.0, 0.0, 0.0], # SER + [0.0, 0.0, 0.0, 0.0], # THR + [0.0, 0.0, 0.0, 0.0], # TRP + [0.0, 1.0, 0.0, 0.0], # TYR + [0.0, 0.0, 0.0, 0.0], # VAL + [0.0, 0.0, 0.0, 0.0], # UNK +] + +# Atoms positions relative to the 8 rigid groups, defined by the pre-omega, phi, +# psi and chi angles: +# 0: 'backbone group', +# 1: 'pre-omega-group', (empty) +# 2: 'phi-group', (currently empty, because it defines only hydrogens) +# 3: 'psi-group', +# 4,5,6,7: 'chi1,2,3,4-group' +# The atom positions are relative to the axis-end-atom of the corresponding +# rotation axis. The x-axis is in direction of the rotation axis, and the y-axis +# is defined such that the dihedral-angle-definiting atom (the last entry in +# chi_angles_atoms above) is in the xy-plane (with a positive y-coordinate). +# format: [atomname, group_idx, rel_position] +rigid_group_atom_positions: Dict[str, List[Tuple[str, int, Tuple[float, float, float]]]] = { + "ALA": [ + ("N", 0, (-0.525, 1.363, 0.000)), + ("CA", 0, (0.000, 0.000, 0.000)), + ("C", 0, (1.526, -0.000, -0.000)), + ("CB", 0, (-0.529, -0.774, -1.205)), + ("O", 3, (0.627, 1.062, 0.000)), + ], + "ARG": [ + ("N", 0, (-0.524, 1.362, -0.000)), + ("CA", 0, (0.000, 0.000, 0.000)), + ("C", 0, (1.525, -0.000, -0.000)), + ("CB", 0, (-0.524, -0.778, -1.209)), + ("O", 3, (0.626, 1.062, 0.000)), + ("CG", 4, (0.616, 1.390, -0.000)), + ("CD", 5, (0.564, 1.414, 0.000)), + ("NE", 6, (0.539, 1.357, -0.000)), + ("NH1", 7, (0.206, 2.301, 0.000)), + ("NH2", 7, (2.078, 0.978, -0.000)), + ("CZ", 7, (0.758, 1.093, -0.000)), + ], + "ASN": [ + ("N", 0, (-0.536, 1.357, 0.000)), + ("CA", 0, (0.000, 0.000, 0.000)), + ("C", 0, (1.526, -0.000, -0.000)), + ("CB", 0, (-0.531, -0.787, -1.200)), + ("O", 3, (0.625, 1.062, 0.000)), + ("CG", 4, (0.584, 1.399, 0.000)), + ("ND2", 5, (0.593, -1.188, 0.001)), + ("OD1", 5, (0.633, 1.059, 0.000)), + ], + "ASP": [ + ("N", 0, (-0.525, 1.362, -0.000)), + ("CA", 0, (0.000, 0.000, 0.000)), + ("C", 0, (1.527, 0.000, -0.000)), + ("CB", 0, (-0.526, -0.778, -1.208)), + ("O", 3, (0.626, 1.062, -0.000)), + ("CG", 4, (0.593, 1.398, -0.000)), + ("OD1", 5, (0.610, 1.091, 0.000)), + ("OD2", 5, (0.592, -1.101, -0.003)), + ], + "CYS": [ + ("N", 0, (-0.522, 1.362, -0.000)), + ("CA", 0, (0.000, 0.000, 0.000)), + ("C", 0, (1.524, 0.000, 0.000)), + ("CB", 0, (-0.519, -0.773, -1.212)), + ("O", 3, (0.625, 1.062, -0.000)), + ("SG", 4, (0.728, 1.653, 0.000)), + ], + "GLN": [ + ("N", 0, (-0.526, 1.361, -0.000)), + ("CA", 0, (0.000, 0.000, 0.000)), + ("C", 0, (1.526, 0.000, 0.000)), + ("CB", 0, (-0.525, -0.779, -1.207)), + ("O", 3, (0.626, 1.062, -0.000)), + ("CG", 4, (0.615, 1.393, 0.000)), + ("CD", 5, (0.587, 1.399, -0.000)), + ("NE2", 6, (0.593, -1.189, -0.001)), + ("OE1", 6, (0.634, 1.060, 0.000)), + ], + "GLU": [ + ("N", 0, (-0.528, 1.361, 0.000)), + ("CA", 0, (0.000, 0.000, 0.000)), + ("C", 0, (1.526, -0.000, -0.000)), + ("CB", 0, (-0.526, -0.781, -1.207)), + ("O", 3, (0.626, 1.062, 0.000)), + ("CG", 4, (0.615, 1.392, 0.000)), + ("CD", 5, (0.600, 1.397, 0.000)), + ("OE1", 6, (0.607, 1.095, -0.000)), + ("OE2", 6, (0.589, -1.104, -0.001)), + ], + "GLY": [ + ("N", 0, (-0.572, 1.337, 0.000)), + ("CA", 0, (0.000, 0.000, 0.000)), + ("C", 0, (1.517, -0.000, -0.000)), + ("O", 3, (0.626, 1.062, -0.000)), + ], + "HIS": [ + ("N", 0, (-0.527, 1.360, 0.000)), + ("CA", 0, (0.000, 0.000, 0.000)), + ("C", 0, (1.525, 0.000, 0.000)), + ("CB", 0, (-0.525, -0.778, -1.208)), + ("O", 3, (0.625, 1.063, 0.000)), + ("CG", 4, (0.600, 1.370, -0.000)), + ("CD2", 5, (0.889, -1.021, 0.003)), + ("ND1", 5, (0.744, 1.160, -0.000)), + ("CE1", 5, (2.030, 0.851, 0.002)), + ("NE2", 5, (2.145, -0.466, 0.004)), + ], + "ILE": [ + ("N", 0, (-0.493, 1.373, -0.000)), + ("CA", 0, (0.000, 0.000, 0.000)), + ("C", 0, (1.527, -0.000, -0.000)), + ("CB", 0, (-0.536, -0.793, -1.213)), + ("O", 3, (0.627, 1.062, -0.000)), + ("CG1", 4, (0.534, 1.437, -0.000)), + ("CG2", 4, (0.540, -0.785, -1.199)), + ("CD1", 5, (0.619, 1.391, 0.000)), + ], + "LEU": [ + ("N", 0, (-0.520, 1.363, 0.000)), + ("CA", 0, (0.000, 0.000, 0.000)), + ("C", 0, (1.525, -0.000, -0.000)), + ("CB", 0, (-0.522, -0.773, -1.214)), + ("O", 3, (0.625, 1.063, -0.000)), + ("CG", 4, (0.678, 1.371, 0.000)), + ("CD1", 5, (0.530, 1.430, -0.000)), + ("CD2", 5, (0.535, -0.774, 1.200)), + ], + "LYS": [ + ("N", 0, (-0.526, 1.362, -0.000)), + ("CA", 0, (0.000, 0.000, 0.000)), + ("C", 0, (1.526, 0.000, 0.000)), + ("CB", 0, (-0.524, -0.778, -1.208)), + ("O", 3, (0.626, 1.062, -0.000)), + ("CG", 4, (0.619, 1.390, 0.000)), + ("CD", 5, (0.559, 1.417, 0.000)), + ("CE", 6, (0.560, 1.416, 0.000)), + ("NZ", 7, (0.554, 1.387, 0.000)), + ], + "MET": [ + ("N", 0, (-0.521, 1.364, -0.000)), + ("CA", 0, (0.000, 0.000, 0.000)), + ("C", 0, (1.525, 0.000, 0.000)), + ("CB", 0, (-0.523, -0.776, -1.210)), + ("O", 3, (0.625, 1.062, -0.000)), + ("CG", 4, (0.613, 1.391, -0.000)), + ("SD", 5, (0.703, 1.695, 0.000)), + ("CE", 6, (0.320, 1.786, -0.000)), + ], + "PHE": [ + ("N", 0, (-0.518, 1.363, 0.000)), + ("CA", 0, (0.000, 0.000, 0.000)), + ("C", 0, (1.524, 0.000, -0.000)), + ("CB", 0, (-0.525, -0.776, -1.212)), + ("O", 3, (0.626, 1.062, -0.000)), + ("CG", 4, (0.607, 1.377, 0.000)), + ("CD1", 5, (0.709, 1.195, -0.000)), + ("CD2", 5, (0.706, -1.196, 0.000)), + ("CE1", 5, (2.102, 1.198, -0.000)), + ("CE2", 5, (2.098, -1.201, -0.000)), + ("CZ", 5, (2.794, -0.003, -0.001)), + ], + "PRO": [ + ("N", 0, (-0.566, 1.351, -0.000)), + ("CA", 0, (0.000, 0.000, 0.000)), + ("C", 0, (1.527, -0.000, 0.000)), + ("CB", 0, (-0.546, -0.611, -1.293)), + ("O", 3, (0.621, 1.066, 0.000)), + ("CG", 4, (0.382, 1.445, 0.0)), + # ('CD', 5, (0.427, 1.440, 0.0)), + ("CD", 5, (0.477, 1.424, 0.0)), # manually made angle 2 degrees larger + ], + "SER": [ + ("N", 0, (-0.529, 1.360, -0.000)), + ("CA", 0, (0.000, 0.000, 0.000)), + ("C", 0, (1.525, -0.000, -0.000)), + ("CB", 0, (-0.518, -0.777, -1.211)), + ("O", 3, (0.626, 1.062, -0.000)), + ("OG", 4, (0.503, 1.325, 0.000)), + ], + "THR": [ + ("N", 0, (-0.517, 1.364, 0.000)), + ("CA", 0, (0.000, 0.000, 0.000)), + ("C", 0, (1.526, 0.000, -0.000)), + ("CB", 0, (-0.516, -0.793, -1.215)), + ("O", 3, (0.626, 1.062, 0.000)), + ("CG2", 4, (0.550, -0.718, -1.228)), + ("OG1", 4, (0.472, 1.353, 0.000)), + ], + "TRP": [ + ("N", 0, (-0.521, 1.363, 0.000)), + ("CA", 0, (0.000, 0.000, 0.000)), + ("C", 0, (1.525, -0.000, 0.000)), + ("CB", 0, (-0.523, -0.776, -1.212)), + ("O", 3, (0.627, 1.062, 0.000)), + ("CG", 4, (0.609, 1.370, -0.000)), + ("CD1", 5, (0.824, 1.091, 0.000)), + ("CD2", 5, (0.854, -1.148, -0.005)), + ("CE2", 5, (2.186, -0.678, -0.007)), + ("CE3", 5, (0.622, -2.530, -0.007)), + ("NE1", 5, (2.140, 0.690, -0.004)), + ("CH2", 5, (3.028, -2.890, -0.013)), + ("CZ2", 5, (3.283, -1.543, -0.011)), + ("CZ3", 5, (1.715, -3.389, -0.011)), + ], + "TYR": [ + ("N", 0, (-0.522, 1.362, 0.000)), + ("CA", 0, (0.000, 0.000, 0.000)), + ("C", 0, (1.524, -0.000, -0.000)), + ("CB", 0, (-0.522, -0.776, -1.213)), + ("O", 3, (0.627, 1.062, -0.000)), + ("CG", 4, (0.607, 1.382, -0.000)), + ("CD1", 5, (0.716, 1.195, -0.000)), + ("CD2", 5, (0.713, -1.194, -0.001)), + ("CE1", 5, (2.107, 1.200, -0.002)), + ("CE2", 5, (2.104, -1.201, -0.003)), + ("OH", 5, (4.168, -0.002, -0.005)), + ("CZ", 5, (2.791, -0.001, -0.003)), + ], + "VAL": [ + ("N", 0, (-0.494, 1.373, -0.000)), + ("CA", 0, (0.000, 0.000, 0.000)), + ("C", 0, (1.527, -0.000, -0.000)), + ("CB", 0, (-0.533, -0.795, -1.213)), + ("O", 3, (0.627, 1.062, -0.000)), + ("CG1", 4, (0.540, 1.429, -0.000)), + ("CG2", 4, (0.533, -0.776, 1.203)), + ], +} + +# A list of atoms (excluding hydrogen) for each AA type. PDB naming convention. +residue_atoms: Dict[str, List[str]] = { + "ALA": ["C", "CA", "CB", "N", "O"], + "ARG": ["C", "CA", "CB", "CG", "CD", "CZ", "N", "NE", "O", "NH1", "NH2"], + "ASP": ["C", "CA", "CB", "CG", "N", "O", "OD1", "OD2"], + "ASN": ["C", "CA", "CB", "CG", "N", "ND2", "O", "OD1"], + "CYS": ["C", "CA", "CB", "N", "O", "SG"], + "GLU": ["C", "CA", "CB", "CG", "CD", "N", "O", "OE1", "OE2"], + "GLN": ["C", "CA", "CB", "CG", "CD", "N", "NE2", "O", "OE1"], + "GLY": ["C", "CA", "N", "O"], + "HIS": ["C", "CA", "CB", "CG", "CD2", "CE1", "N", "ND1", "NE2", "O"], + "ILE": ["C", "CA", "CB", "CG1", "CG2", "CD1", "N", "O"], + "LEU": ["C", "CA", "CB", "CG", "CD1", "CD2", "N", "O"], + "LYS": ["C", "CA", "CB", "CG", "CD", "CE", "N", "NZ", "O"], + "MET": ["C", "CA", "CB", "CG", "CE", "N", "O", "SD"], + "PHE": ["C", "CA", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "N", "O"], + "PRO": ["C", "CA", "CB", "CG", "CD", "N", "O"], + "SER": ["C", "CA", "CB", "N", "O", "OG"], + "THR": ["C", "CA", "CB", "CG2", "N", "O", "OG1"], + "TRP": ["C", "CA", "CB", "CG", "CD1", "CD2", "CE2", "CE3", "CZ2", "CZ3", "CH2", "N", "NE1", "O"], + "TYR": ["C", "CA", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "N", "O", "OH"], + "VAL": ["C", "CA", "CB", "CG1", "CG2", "N", "O"], +} + +# Naming swaps for ambiguous atom names. +# Due to symmetries in the amino acids the naming of atoms is ambiguous in +# 4 of the 20 amino acids. +# (The LDDT paper lists 7 amino acids as ambiguous, but the naming ambiguities +# in LEU, VAL and ARG can be resolved by using the 3d constellations of +# the 'ambiguous' atoms and their neighbours) +# TODO: ^ interpret this +residue_atom_renaming_swaps: Dict[str, Dict[str, str]] = { + "ASP": {"OD1": "OD2"}, + "GLU": {"OE1": "OE2"}, + "PHE": {"CD1": "CD2", "CE1": "CE2"}, + "TYR": {"CD1": "CD2", "CE1": "CE2"}, +} + +# Van der Waals radii [Angstroem] of the atoms (from Wikipedia) +van_der_waals_radius: Dict[str, float] = { + "C": 1.7, + "N": 1.55, + "O": 1.52, + "S": 1.8, +} + +Bond = collections.namedtuple("Bond", ["atom1_name", "atom2_name", "length", "stddev"]) +BondAngle = collections.namedtuple( + "BondAngle", + ["atom1_name", "atom2_name", "atom3name", "angle_rad", "stddev"], +) + + +def map_structure_with_atom_order(in_list: list, first_call: bool = True) -> list: + # Maps strings in a nested list structure to their corresponding index in atom_order + if first_call: + in_list = copy.deepcopy(in_list) + for i in range(len(in_list)): + if isinstance(in_list[i], list): + in_list[i] = map_structure_with_atom_order(in_list[i], first_call=False) + elif isinstance(in_list[i], str): + in_list[i] = atom_order[in_list[i]] + else: + raise TypeError("Unexpected type when mapping nested lists!") + return in_list + + +@functools.lru_cache(maxsize=None) +def load_stereo_chemical_props() -> Tuple[ + Mapping[str, List[Bond]], + Mapping[str, List[Bond]], + Mapping[str, List[BondAngle]], +]: + """Load stereo_chemical_props.txt into a nice structure. + + Load literature values for bond lengths and bond angles and translate bond angles into the length of the opposite + edge of the triangle ("residue_virtual_bonds"). + + Returns: + residue_bonds: dict that maps resname --> list of Bond tuples residue_virtual_bonds: dict that maps resname --> + list of Bond tuples residue_bond_angles: dict that maps resname --> list of BondAngle tuples + """ + # TODO: this file should be downloaded in a setup script + stereo_chemical_props = resources.read_text("openfold.resources", "stereo_chemical_props.txt") + + lines_iter = iter(stereo_chemical_props.splitlines()) + # Load bond lengths. + residue_bonds: Dict[str, List[Bond]] = {} + next(lines_iter) # Skip header line. + for line in lines_iter: + if line.strip() == "-": + break + bond, resname, bond_length, stddev = line.split() + atom1, atom2 = bond.split("-") + if resname not in residue_bonds: + residue_bonds[resname] = [] + residue_bonds[resname].append(Bond(atom1, atom2, float(bond_length), float(stddev))) + residue_bonds["UNK"] = [] + + # Load bond angles. + residue_bond_angles: Dict[str, List[BondAngle]] = {} + next(lines_iter) # Skip empty line. + next(lines_iter) # Skip header line. + for line in lines_iter: + if line.strip() == "-": + break + bond, resname, angle_degree, stddev_degree = line.split() + atom1, atom2, atom3 = bond.split("-") + if resname not in residue_bond_angles: + residue_bond_angles[resname] = [] + residue_bond_angles[resname].append( + BondAngle( + atom1, + atom2, + atom3, + float(angle_degree) / 180.0 * np.pi, + float(stddev_degree) / 180.0 * np.pi, + ) + ) + residue_bond_angles["UNK"] = [] + + def make_bond_key(atom1_name: str, atom2_name: str) -> str: + """Unique key to lookup bonds.""" + return "-".join(sorted([atom1_name, atom2_name])) + + # Translate bond angles into distances ("virtual bonds"). + residue_virtual_bonds: Dict[str, List[Bond]] = {} + for resname, bond_angles in residue_bond_angles.items(): + # Create a fast lookup dict for bond lengths. + bond_cache: Dict[str, Bond] = {} + for b in residue_bonds[resname]: + bond_cache[make_bond_key(b.atom1_name, b.atom2_name)] = b + residue_virtual_bonds[resname] = [] + for ba in bond_angles: + bond1 = bond_cache[make_bond_key(ba.atom1_name, ba.atom2_name)] + bond2 = bond_cache[make_bond_key(ba.atom2_name, ba.atom3name)] + + # Compute distance between atom1 and atom3 using the law of cosines + # c^2 = a^2 + b^2 - 2ab*cos(gamma). + gamma = ba.angle_rad + length = np.sqrt(bond1.length**2 + bond2.length**2 - 2 * bond1.length * bond2.length * np.cos(gamma)) + + # Propagation of uncertainty assuming uncorrelated errors. + dl_outer = 0.5 / length + dl_dgamma = (2 * bond1.length * bond2.length * np.sin(gamma)) * dl_outer + dl_db1 = (2 * bond1.length - 2 * bond2.length * np.cos(gamma)) * dl_outer + dl_db2 = (2 * bond2.length - 2 * bond1.length * np.cos(gamma)) * dl_outer + stddev = np.sqrt( + (dl_dgamma * ba.stddev) ** 2 + (dl_db1 * bond1.stddev) ** 2 + (dl_db2 * bond2.stddev) ** 2 + ) + residue_virtual_bonds[resname].append(Bond(ba.atom1_name, ba.atom3name, length, stddev)) + + return (residue_bonds, residue_virtual_bonds, residue_bond_angles) + + +# Between-residue bond lengths for general bonds (first element) and for Proline +# (second element). +between_res_bond_length_c_n: Tuple[float, float] = (1.329, 1.341) +between_res_bond_length_stddev_c_n: Tuple[float, float] = (0.014, 0.016) + +# Between-residue cos_angles. +between_res_cos_angles_c_n_ca: Tuple[float, float] = (-0.5203, 0.0353) # degrees: 121.352 +- 2.315 +between_res_cos_angles_ca_c_n: Tuple[float, float] = (-0.4473, 0.0311) # degrees: 116.568 +- 1.995 + +# This mapping is used when we need to store atom data in a format that requires +# fixed atom data size for every residue (e.g. a numpy array). +atom_types: List[str] = [ + "N", + "CA", + "C", + "CB", + "O", + "CG", + "CG1", + "CG2", + "OG", + "OG1", + "SG", + "CD", + "CD1", + "CD2", + "ND1", + "ND2", + "OD1", + "OD2", + "SD", + "CE", + "CE1", + "CE2", + "CE3", + "NE", + "NE1", + "NE2", + "OE1", + "OE2", + "CH2", + "NH1", + "NH2", + "OH", + "CZ", + "CZ2", + "CZ3", + "NZ", + "OXT", +] +atom_order: Dict[str, int] = {atom_type: i for i, atom_type in enumerate(atom_types)} +atom_type_num = len(atom_types) # := 37. + +# A compact atom encoding with 14 columns +# pylint: disable=line-too-long +# pylint: disable=bad-whitespace +restype_name_to_atom14_names: Dict[str, List[str]] = { + "ALA": ["N", "CA", "C", "O", "CB", "", "", "", "", "", "", "", "", ""], + "ARG": ["N", "CA", "C", "O", "CB", "CG", "CD", "NE", "CZ", "NH1", "NH2", "", "", ""], + "ASN": ["N", "CA", "C", "O", "CB", "CG", "OD1", "ND2", "", "", "", "", "", ""], + "ASP": ["N", "CA", "C", "O", "CB", "CG", "OD1", "OD2", "", "", "", "", "", ""], + "CYS": ["N", "CA", "C", "O", "CB", "SG", "", "", "", "", "", "", "", ""], + "GLN": ["N", "CA", "C", "O", "CB", "CG", "CD", "OE1", "NE2", "", "", "", "", ""], + "GLU": ["N", "CA", "C", "O", "CB", "CG", "CD", "OE1", "OE2", "", "", "", "", ""], + "GLY": ["N", "CA", "C", "O", "", "", "", "", "", "", "", "", "", ""], + "HIS": ["N", "CA", "C", "O", "CB", "CG", "ND1", "CD2", "CE1", "NE2", "", "", "", ""], + "ILE": ["N", "CA", "C", "O", "CB", "CG1", "CG2", "CD1", "", "", "", "", "", ""], + "LEU": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "", "", "", "", "", ""], + "LYS": ["N", "CA", "C", "O", "CB", "CG", "CD", "CE", "NZ", "", "", "", "", ""], + "MET": ["N", "CA", "C", "O", "CB", "CG", "SD", "CE", "", "", "", "", "", ""], + "PHE": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "", "", ""], + "PRO": ["N", "CA", "C", "O", "CB", "CG", "CD", "", "", "", "", "", "", ""], + "SER": ["N", "CA", "C", "O", "CB", "OG", "", "", "", "", "", "", "", ""], + "THR": ["N", "CA", "C", "O", "CB", "OG1", "CG2", "", "", "", "", "", "", ""], + "TRP": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "NE1", "CE2", "CE3", "CZ2", "CZ3", "CH2"], + "TYR": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "OH", "", ""], + "VAL": ["N", "CA", "C", "O", "CB", "CG1", "CG2", "", "", "", "", "", "", ""], + "UNK": ["", "", "", "", "", "", "", "", "", "", "", "", "", ""], +} +# pylint: enable=line-too-long +# pylint: enable=bad-whitespace + + +# This is the standard residue order when coding AA type as a number. +# Reproduce it by taking 3-letter AA codes and sorting them alphabetically. +restypes: List[str] = [ + "A", + "R", + "N", + "D", + "C", + "Q", + "E", + "G", + "H", + "I", + "L", + "K", + "M", + "F", + "P", + "S", + "T", + "W", + "Y", + "V", +] +restype_order: Dict[str, int] = {restype: i for i, restype in enumerate(restypes)} +restype_num = len(restypes) # := 20. +unk_restype_index = restype_num # Catch-all index for unknown restypes. + +restypes_with_x: List[str] = restypes + ["X"] +restype_order_with_x: Dict[str, int] = {restype: i for i, restype in enumerate(restypes_with_x)} + + +def sequence_to_onehot(sequence: str, mapping: Mapping[str, int], map_unknown_to_x: bool = False) -> np.ndarray: + """Maps the given sequence into a one-hot encoded matrix. + + Args: + sequence: An amino acid sequence. + mapping: A dictionary mapping amino acids to integers. + map_unknown_to_x: If True, any amino acid that is not in the mapping will be + mapped to the unknown amino acid 'X'. If the mapping doesn't contain amino acid 'X', an error will be thrown. + If False, any amino acid not in the mapping will throw an error. + + Returns: + A numpy array of shape (seq_len, num_unique_aas) with one-hot encoding of the sequence. + + Raises: + ValueError: If the mapping doesn't contain values from 0 to + num_unique_aas - 1 without any gaps. + """ + num_entries = max(mapping.values()) + 1 + + if sorted(set(mapping.values())) != list(range(num_entries)): + raise ValueError( + "The mapping must have values from 0 to num_unique_aas-1 without any gaps. Got: %s" + % sorted(mapping.values()) + ) + + one_hot_arr = np.zeros((len(sequence), num_entries), dtype=np.int32) + + for aa_index, aa_type in enumerate(sequence): + if map_unknown_to_x: + if aa_type.isalpha() and aa_type.isupper(): + aa_id = mapping.get(aa_type, mapping["X"]) + else: + raise ValueError(f"Invalid character in the sequence: {aa_type}") + else: + aa_id = mapping[aa_type] + one_hot_arr[aa_index, aa_id] = 1 + + return one_hot_arr + + +restype_1to3: Dict[str, str] = { + "A": "ALA", + "R": "ARG", + "N": "ASN", + "D": "ASP", + "C": "CYS", + "Q": "GLN", + "E": "GLU", + "G": "GLY", + "H": "HIS", + "I": "ILE", + "L": "LEU", + "K": "LYS", + "M": "MET", + "F": "PHE", + "P": "PRO", + "S": "SER", + "T": "THR", + "W": "TRP", + "Y": "TYR", + "V": "VAL", +} + + +# NB: restype_3to1 differs from Bio.PDB.protein_letters_3to1 by being a simple +# 1-to-1 mapping of 3 letter names to one letter names. The latter contains +# many more, and less common, three letter names as keys and maps many of these +# to the same one letter name (including 'X' and 'U' which we don't use here). +restype_3to1: Dict[str, str] = {v: k for k, v in restype_1to3.items()} + +# Define a restype name for all unknown residues. +unk_restype = "UNK" + +resnames: List[str] = [restype_1to3[r] for r in restypes] + [unk_restype] +resname_to_idx: Dict[str, int] = {resname: i for i, resname in enumerate(resnames)} + + +# The mapping here uses hhblits convention, so that B is mapped to D, J and O +# are mapped to X, U is mapped to C, and Z is mapped to E. Other than that the +# remaining 20 amino acids are kept in alphabetical order. +# There are 2 non-amino acid codes, X (representing any amino acid) and +# "-" representing a missing amino acid in an alignment. The id for these +# codes is put at the end (20 and 21) so that they can easily be ignored if +# desired. +HHBLITS_AA_TO_ID: Dict[str, int] = { + "A": 0, + "B": 2, + "C": 1, + "D": 2, + "E": 3, + "F": 4, + "G": 5, + "H": 6, + "I": 7, + "J": 20, + "K": 8, + "L": 9, + "M": 10, + "N": 11, + "O": 20, + "P": 12, + "Q": 13, + "R": 14, + "S": 15, + "T": 16, + "U": 1, + "V": 17, + "W": 18, + "X": 20, + "Y": 19, + "Z": 3, + "-": 21, +} + +# Partial inversion of HHBLITS_AA_TO_ID. +ID_TO_HHBLITS_AA: Dict[int, str] = { + 0: "A", + 1: "C", # Also U. + 2: "D", # Also B. + 3: "E", # Also Z. + 4: "F", + 5: "G", + 6: "H", + 7: "I", + 8: "K", + 9: "L", + 10: "M", + 11: "N", + 12: "P", + 13: "Q", + 14: "R", + 15: "S", + 16: "T", + 17: "V", + 18: "W", + 19: "Y", + 20: "X", # Includes J and O. + 21: "-", +} + +restypes_with_x_and_gap: List[str] = restypes + ["X", "-"] +MAP_HHBLITS_AATYPE_TO_OUR_AATYPE: Tuple[int, ...] = tuple( + restypes_with_x_and_gap.index(ID_TO_HHBLITS_AA[i]) for i in range(len(restypes_with_x_and_gap)) +) + + +def _make_standard_atom_mask() -> np.ndarray: + """Returns [num_res_types, num_atom_types] mask array.""" + # +1 to account for unknown (all 0s). + mask = np.zeros([restype_num + 1, atom_type_num], dtype=np.int32) + for restype, restype_letter in enumerate(restypes): + restype_name = restype_1to3[restype_letter] + atom_names = residue_atoms[restype_name] + for atom_name in atom_names: + atom_type = atom_order[atom_name] + mask[restype, atom_type] = 1 + return mask + + +STANDARD_ATOM_MASK = _make_standard_atom_mask() + + +# A one hot representation for the first and second atoms defining the axis +# of rotation for each chi-angle in each residue. +def chi_angle_atom(atom_index: int) -> np.ndarray: + """Define chi-angle rigid groups via one-hot representations.""" + chi_angles_index = {} + one_hots = [] + + for k, v in chi_angles_atoms.items(): + indices = [atom_types.index(s[atom_index]) for s in v] + indices.extend([-1] * (4 - len(indices))) + chi_angles_index[k] = indices + + for r in restypes: + res3 = restype_1to3[r] + one_hot = np.eye(atom_type_num)[chi_angles_index[res3]] + one_hots.append(one_hot) + + one_hots.append(np.zeros([4, atom_type_num])) # Add zeros for residue `X`. + one_hot = np.stack(one_hots, axis=0) + one_hot = np.transpose(one_hot, [0, 2, 1]) + + return one_hot + + +chi_atom_1_one_hot = chi_angle_atom(1) +chi_atom_2_one_hot = chi_angle_atom(2) + +# An array like chi_angles_atoms but using indices rather than names. +chi_angles_atom_indices_list: List[List[List[str]]] = [chi_angles_atoms[restype_1to3[r]] for r in restypes] +chi_angles_atom_indices_ours: list = map_structure_with_atom_order(chi_angles_atom_indices_list) +chi_angles_atom_indices = np.array( + [chi_atoms + ([[0, 0, 0, 0]] * (4 - len(chi_atoms))) for chi_atoms in chi_angles_atom_indices_list] +) + +# Mapping from (res_name, atom_name) pairs to the atom's chi group index +# and atom index within that group. +chi_groups_for_atom: Dict[Tuple[str, str], List[Tuple[int, int]]] = collections.defaultdict(list) +for res_name, chi_angle_atoms_for_res in chi_angles_atoms.items(): + for chi_group_i, chi_group in enumerate(chi_angle_atoms_for_res): + for atom_i, atom in enumerate(chi_group): + chi_groups_for_atom[(res_name, atom)].append((chi_group_i, atom_i)) +chi_groups_for_atom = dict(chi_groups_for_atom) + + +def _make_rigid_transformation_4x4(ex: np.ndarray, ey: np.ndarray, translation: np.ndarray) -> np.ndarray: + """Create a rigid 4x4 transformation matrix from two axes and transl.""" + # Normalize ex. + ex_normalized = ex / np.linalg.norm(ex) + + # make ey perpendicular to ex + ey_normalized = ey - np.dot(ey, ex_normalized) * ex_normalized + ey_normalized /= np.linalg.norm(ey_normalized) + + # compute ez as cross product + eznorm = np.cross(ex_normalized, ey_normalized) + m = np.stack([ex_normalized, ey_normalized, eznorm, translation]).transpose() + m = np.concatenate([m, [[0.0, 0.0, 0.0, 1.0]]], axis=0) + return m + + +# create an array with (restype, atomtype) --> rigid_group_idx +# and an array with (restype, atomtype, coord) for the atom positions +# and compute affine transformation matrices (4,4) from one rigid group to the +# previous group +restype_atom37_to_rigid_group = np.zeros([21, 37], dtype=int) +restype_atom37_mask = np.zeros([21, 37], dtype=np.float32) +restype_atom37_rigid_group_positions = np.zeros([21, 37, 3], dtype=np.float32) +restype_atom14_to_rigid_group = np.zeros([21, 14], dtype=int) +restype_atom14_mask = np.zeros([21, 14], dtype=np.float32) +restype_atom14_rigid_group_positions = np.zeros([21, 14, 3], dtype=np.float32) +restype_rigid_group_default_frame = np.zeros([21, 8, 4, 4], dtype=np.float32) + + +def _make_rigid_group_constants() -> None: + """Fill the arrays above.""" + for restype, restype_letter in enumerate(restypes): + resname = restype_1to3[restype_letter] + for atomname, group_idx, atom_position in rigid_group_atom_positions[resname]: + atomtype = atom_order[atomname] + restype_atom37_to_rigid_group[restype, atomtype] = group_idx + restype_atom37_mask[restype, atomtype] = 1 + restype_atom37_rigid_group_positions[restype, atomtype, :] = atom_position + + atom14idx = restype_name_to_atom14_names[resname].index(atomname) + restype_atom14_to_rigid_group[restype, atom14idx] = group_idx + restype_atom14_mask[restype, atom14idx] = 1 + restype_atom14_rigid_group_positions[restype, atom14idx, :] = atom_position + + for restype, restype_letter in enumerate(restypes): + resname = restype_1to3[restype_letter] + atom_positions: Dict[str, np.ndarray] = { + name: np.array(pos) for name, _, pos in rigid_group_atom_positions[resname] + } + + # backbone to backbone is the identity transform + restype_rigid_group_default_frame[restype, 0, :, :] = np.eye(4) + + # pre-omega-frame to backbone (currently dummy identity matrix) + restype_rigid_group_default_frame[restype, 1, :, :] = np.eye(4) + + # phi-frame to backbone + mat = _make_rigid_transformation_4x4( + ex=atom_positions["N"] - atom_positions["CA"], + ey=np.array([1.0, 0.0, 0.0]), + translation=atom_positions["N"], + ) + restype_rigid_group_default_frame[restype, 2, :, :] = mat + + # psi-frame to backbone + mat = _make_rigid_transformation_4x4( + ex=atom_positions["C"] - atom_positions["CA"], + ey=atom_positions["CA"] - atom_positions["N"], + translation=atom_positions["C"], + ) + restype_rigid_group_default_frame[restype, 3, :, :] = mat + + # chi1-frame to backbone + if chi_angles_mask[restype][0]: + base_atom_names = chi_angles_atoms[resname][0] + base_atom_positions = [atom_positions[name] for name in base_atom_names] + mat = _make_rigid_transformation_4x4( + ex=base_atom_positions[2] - base_atom_positions[1], + ey=base_atom_positions[0] - base_atom_positions[1], + translation=base_atom_positions[2], + ) + restype_rigid_group_default_frame[restype, 4, :, :] = mat + + # chi2-frame to chi1-frame + # chi3-frame to chi2-frame + # chi4-frame to chi3-frame + # luckily all rotation axes for the next frame start at (0,0,0) of the + # previous frame + for chi_idx in range(1, 4): + if chi_angles_mask[restype][chi_idx]: + axis_end_atom_name = chi_angles_atoms[resname][chi_idx][2] + axis_end_atom_position = atom_positions[axis_end_atom_name] + mat = _make_rigid_transformation_4x4( + ex=axis_end_atom_position, + ey=np.array([-1.0, 0.0, 0.0]), + translation=axis_end_atom_position, + ) + restype_rigid_group_default_frame[restype, 4 + chi_idx, :, :] = mat + + +_make_rigid_group_constants() + + +def make_atom14_dists_bounds( + overlap_tolerance: float = 1.5, + bond_length_tolerance_factor: int = 15, +) -> Dict[str, np.ndarray]: + """compute upper and lower bounds for bonds to assess violations.""" + restype_atom14_bond_lower_bound = np.zeros([21, 14, 14], np.float32) + restype_atom14_bond_upper_bound = np.zeros([21, 14, 14], np.float32) + restype_atom14_bond_stddev = np.zeros([21, 14, 14], np.float32) + residue_bonds, residue_virtual_bonds, _ = load_stereo_chemical_props() + for restype, restype_letter in enumerate(restypes): + resname = restype_1to3[restype_letter] + atom_list = restype_name_to_atom14_names[resname] + + # create lower and upper bounds for clashes + for atom1_idx, atom1_name in enumerate(atom_list): + if not atom1_name: + continue + atom1_radius = van_der_waals_radius[atom1_name[0]] + for atom2_idx, atom2_name in enumerate(atom_list): + if (not atom2_name) or atom1_idx == atom2_idx: + continue + atom2_radius = van_der_waals_radius[atom2_name[0]] + lower = atom1_radius + atom2_radius - overlap_tolerance + upper = 1e10 + restype_atom14_bond_lower_bound[restype, atom1_idx, atom2_idx] = lower + restype_atom14_bond_lower_bound[restype, atom2_idx, atom1_idx] = lower + restype_atom14_bond_upper_bound[restype, atom1_idx, atom2_idx] = upper + restype_atom14_bond_upper_bound[restype, atom2_idx, atom1_idx] = upper + + # overwrite lower and upper bounds for bonds and angles + for b in residue_bonds[resname] + residue_virtual_bonds[resname]: + atom1_idx = atom_list.index(b.atom1_name) + atom2_idx = atom_list.index(b.atom2_name) + lower = b.length - bond_length_tolerance_factor * b.stddev + upper = b.length + bond_length_tolerance_factor * b.stddev + restype_atom14_bond_lower_bound[restype, atom1_idx, atom2_idx] = lower + restype_atom14_bond_lower_bound[restype, atom2_idx, atom1_idx] = lower + restype_atom14_bond_upper_bound[restype, atom1_idx, atom2_idx] = upper + restype_atom14_bond_upper_bound[restype, atom2_idx, atom1_idx] = upper + restype_atom14_bond_stddev[restype, atom1_idx, atom2_idx] = b.stddev + restype_atom14_bond_stddev[restype, atom2_idx, atom1_idx] = b.stddev + return { + "lower_bound": restype_atom14_bond_lower_bound, # shape (21,14,14) + "upper_bound": restype_atom14_bond_upper_bound, # shape (21,14,14) + "stddev": restype_atom14_bond_stddev, # shape (21,14,14) + } + + +restype_atom14_ambiguous_atoms = np.zeros((21, 14), dtype=np.float32) +restype_atom14_ambiguous_atoms_swap_idx: np.ndarray = np.tile(np.arange(14, dtype=int), (21, 1)) + + +def _make_atom14_ambiguity_feats() -> None: + for res, pairs in residue_atom_renaming_swaps.items(): + res_idx = restype_order[restype_3to1[res]] + for atom1, atom2 in pairs.items(): + atom1_idx = restype_name_to_atom14_names[res].index(atom1) + atom2_idx = restype_name_to_atom14_names[res].index(atom2) + restype_atom14_ambiguous_atoms[res_idx, atom1_idx] = 1 + restype_atom14_ambiguous_atoms[res_idx, atom2_idx] = 1 + restype_atom14_ambiguous_atoms_swap_idx[res_idx, atom1_idx] = atom2_idx + restype_atom14_ambiguous_atoms_swap_idx[res_idx, atom2_idx] = atom1_idx + + +_make_atom14_ambiguity_feats() + + +def aatype_to_str_sequence(aatype: Sequence[int]) -> str: + return "".join([restypes_with_x[aatype[i]] for i in range(len(aatype))]) diff --git a/docs/transformers/build/lib/transformers/models/esm/openfold_utils/tensor_utils.py b/docs/transformers/build/lib/transformers/models/esm/openfold_utils/tensor_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..efe72e4905b81faee5fe44131f1ba1b75856bd1c --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/esm/openfold_utils/tensor_utils.py @@ -0,0 +1,140 @@ +# Copyright 2021 AlQuraishi Laboratory +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +from typing import Any, Callable, Dict, List, Type, TypeVar, Union, overload + +import torch +import torch.nn as nn +import torch.types + + +def add(m1: torch.Tensor, m2: torch.Tensor, inplace: bool) -> torch.Tensor: + # The first operation in a checkpoint can't be in-place, but it's + # nice to have in-place addition during inference. Thus... + if not inplace: + m1 = m1 + m2 + else: + m1 += m2 + + return m1 + + +def permute_final_dims(tensor: torch.Tensor, inds: List[int]) -> torch.Tensor: + zero_index = -1 * len(inds) + first_inds = list(range(len(tensor.shape[:zero_index]))) + return tensor.permute(first_inds + [zero_index + i for i in inds]) + + +def flatten_final_dims(t: torch.Tensor, no_dims: int) -> torch.Tensor: + return t.reshape(t.shape[:-no_dims] + (-1,)) + + +def masked_mean(mask: torch.Tensor, value: torch.Tensor, dim: int, eps: float = 1e-4) -> torch.Tensor: + mask = mask.expand(*value.shape) + return torch.sum(mask * value, dim=dim) / (eps + torch.sum(mask, dim=dim)) + + +def pts_to_distogram( + pts: torch.Tensor, min_bin: torch.types.Number = 2.3125, max_bin: torch.types.Number = 21.6875, no_bins: int = 64 +) -> torch.Tensor: + boundaries = torch.linspace(min_bin, max_bin, no_bins - 1, device=pts.device) + dists = torch.sqrt(torch.sum((pts.unsqueeze(-2) - pts.unsqueeze(-3)) ** 2, dim=-1)) + return torch.bucketize(dists, boundaries) + + +def dict_multimap(fn: Callable[[list], Any], dicts: List[dict]) -> dict: + first = dicts[0] + new_dict = {} + for k, v in first.items(): + all_v = [d[k] for d in dicts] + if isinstance(v, dict): + new_dict[k] = dict_multimap(fn, all_v) + else: + new_dict[k] = fn(all_v) + + return new_dict + + +def one_hot(x: torch.Tensor, v_bins: torch.Tensor) -> torch.Tensor: + reshaped_bins = v_bins.view(((1,) * len(x.shape)) + (len(v_bins),)) + diffs = x[..., None] - reshaped_bins + am = torch.argmin(torch.abs(diffs), dim=-1) + return nn.functional.one_hot(am, num_classes=len(v_bins)).float() + + +def batched_gather(data: torch.Tensor, inds: torch.Tensor, dim: int = 0, no_batch_dims: int = 0) -> torch.Tensor: + ranges: List[Union[slice, torch.Tensor]] = [] + for i, s in enumerate(data.shape[:no_batch_dims]): + r = torch.arange(s) + r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1)))) + ranges.append(r) + + remaining_dims: List[Union[slice, torch.Tensor]] = [slice(None) for _ in range(len(data.shape) - no_batch_dims)] + remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds + ranges.extend(remaining_dims) + # Matt note: Editing this to get around the behaviour of using a list as an array index changing + # in recent Numpy versions + return data[tuple(ranges)] + + +T = TypeVar("T") + + +# With tree_map, a poor man's JAX tree_map +def dict_map( + fn: Callable[[T], Any], dic: Dict[Any, Union[dict, list, tuple, T]], leaf_type: Type[T] +) -> Dict[Any, Union[dict, list, tuple, Any]]: + new_dict: Dict[Any, Union[dict, list, tuple, Any]] = {} + for k, v in dic.items(): + if isinstance(v, dict): + new_dict[k] = dict_map(fn, v, leaf_type) + else: + new_dict[k] = tree_map(fn, v, leaf_type) + + return new_dict + + +@overload +def tree_map(fn: Callable[[T], Any], tree: T, leaf_type: Type[T]) -> Any: ... + + +@overload +def tree_map(fn: Callable[[T], Any], tree: dict, leaf_type: Type[T]) -> dict: ... + + +@overload +def tree_map(fn: Callable[[T], Any], tree: list, leaf_type: Type[T]) -> list: ... + + +@overload +def tree_map(fn: Callable[[T], Any], tree: tuple, leaf_type: Type[T]) -> tuple: ... + + +def tree_map(fn, tree, leaf_type): + if isinstance(tree, dict): + return dict_map(fn, tree, leaf_type) + elif isinstance(tree, list): + return [tree_map(fn, x, leaf_type) for x in tree] + elif isinstance(tree, tuple): + return tuple(tree_map(fn, x, leaf_type) for x in tree) + elif isinstance(tree, leaf_type): + return fn(tree) + else: + print(type(tree)) + raise TypeError("Not supported") + + +tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor) diff --git a/docs/transformers/build/lib/transformers/models/esm/tokenization_esm.py b/docs/transformers/build/lib/transformers/models/esm/tokenization_esm.py new file mode 100644 index 0000000000000000000000000000000000000000..4bc433e350e13c33fc779092baf52197d8aa5e0d --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/esm/tokenization_esm.py @@ -0,0 +1,147 @@ +# coding=utf-8 +# Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for ESM.""" + +import os +from typing import List, Optional + +from ...tokenization_utils import PreTrainedTokenizer +from ...utils import logging + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} + + +def load_vocab_file(vocab_file): + with open(vocab_file, "r") as f: + lines = f.read().splitlines() + return [l.strip() for l in lines] + + +class EsmTokenizer(PreTrainedTokenizer): + """ + Constructs an ESM tokenizer. + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + unk_token="", + cls_token="", + pad_token="", + mask_token="", + eos_token="", + **kwargs, + ): + self.all_tokens = load_vocab_file(vocab_file) + self._id_to_token = dict(enumerate(self.all_tokens)) + self._token_to_id = {tok: ind for ind, tok in enumerate(self.all_tokens)} + super().__init__( + unk_token=unk_token, + cls_token=cls_token, + pad_token=pad_token, + mask_token=mask_token, + eos_token=eos_token, + **kwargs, + ) + + # TODO, all the tokens are added? But they are also part of the vocab... bit strange. + # none of them are special, but they all need special splitting. + + self.unique_no_split_tokens = self.all_tokens + self._update_trie(self.unique_no_split_tokens) + + def _convert_id_to_token(self, index: int) -> str: + return self._id_to_token.get(index, self.unk_token) + + def _convert_token_to_id(self, token: str) -> int: + return self._token_to_id.get(token, self._token_to_id.get(self.unk_token)) + + def _tokenize(self, text, **kwargs): + return text.split() + + def get_vocab(self): + base_vocab = self._token_to_id.copy() + base_vocab.update(self.added_tokens_encoder) + return base_vocab + + def token_to_id(self, token: str) -> int: + return self._token_to_id.get(token, self._token_to_id.get(self.unk_token)) + + def id_to_token(self, index: int) -> str: + return self._id_to_token.get(index, self.unk_token) + + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + cls = [self.cls_token_id] + sep = [self.eos_token_id] # No sep token in ESM vocabulary + if token_ids_1 is None: + if self.eos_token_id is None: + return cls + token_ids_0 + else: + return cls + token_ids_0 + sep + elif self.eos_token_id is None: + raise ValueError("Cannot tokenize multiple sequences when EOS token is not set!") + return cls + token_ids_0 + sep + token_ids_1 + sep # Multiple inputs always have an EOS token + + def get_special_tokens_mask( + self, token_ids_0: List, token_ids_1: Optional[List] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods. + + Args: + token_ids_0 (`List[int]`): + List of ids of the first sequence. + token_ids_1 (`List[int]`, *optional*): + List of ids of the second sequence. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + if token_ids_1 is not None: + raise ValueError( + "You should not supply a second sequence if the provided sequence of " + "ids is already formatted with special tokens for the model." + ) + + return [1 if token in self.all_special_ids else 0 for token in token_ids_0] + mask = [1] + ([0] * len(token_ids_0)) + [1] + if token_ids_1 is not None: + mask += [0] * len(token_ids_1) + [1] + return mask + + def save_vocabulary(self, save_directory, filename_prefix): + vocab_file = os.path.join(save_directory, (filename_prefix + "-" if filename_prefix else "") + "vocab.txt") + with open(vocab_file, "w") as f: + f.write("\n".join(self.all_tokens)) + return (vocab_file,) + + @property + def vocab_size(self) -> int: + return len(self.all_tokens) + + +__all__ = ["EsmTokenizer"] diff --git a/docs/transformers/build/lib/transformers/models/falcon/__init__.py b/docs/transformers/build/lib/transformers/models/falcon/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f9789767f11402264660b5dec0b5cae2466ee9d8 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/falcon/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_falcon import * + from .modeling_falcon import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)