Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- docs/transformers/build/lib/transformers/models/visual_bert/configuration_visual_bert.py +135 -0
- docs/transformers/build/lib/transformers/models/visual_bert/modeling_visual_bert.py +1597 -0
- docs/transformers/build/lib/transformers/models/vit/__init__.py +32 -0
- docs/transformers/build/lib/transformers/models/vit/configuration_vit.py +151 -0
- docs/transformers/build/lib/transformers/models/vit/convert_dino_to_pytorch.py +218 -0
- docs/transformers/build/lib/transformers/models/vit/convert_vit_timm_to_pytorch.py +254 -0
- docs/transformers/build/lib/transformers/models/vit/feature_extraction_vit.py +38 -0
- docs/transformers/build/lib/transformers/models/vit/image_processing_vit.py +288 -0
- docs/transformers/build/lib/transformers/models/vit/image_processing_vit_fast.py +45 -0
- docs/transformers/build/lib/transformers/models/vit/modeling_flax_vit.py +677 -0
- docs/transformers/build/lib/transformers/models/vit/modeling_tf_vit.py +907 -0
- docs/transformers/build/lib/transformers/models/vit/modeling_vit.py +883 -0
- docs/transformers/build/lib/transformers/models/vit_mae/__init__.py +28 -0
- docs/transformers/build/lib/transformers/models/vit_mae/configuration_vit_mae.py +140 -0
- docs/transformers/build/lib/transformers/models/vit_mae/convert_vit_mae_to_pytorch.py +178 -0
- docs/transformers/build/lib/transformers/models/vit_mae/modeling_tf_vit_mae.py +1375 -0
- docs/transformers/build/lib/transformers/models/vit_mae/modeling_vit_mae.py +1163 -0
- docs/transformers/build/lib/transformers/models/vit_msn/__init__.py +27 -0
- docs/transformers/build/lib/transformers/models/vit_msn/configuration_vit_msn.py +115 -0
- docs/transformers/build/lib/transformers/models/vit_msn/convert_msn_to_pytorch.py +245 -0
- docs/transformers/build/lib/transformers/models/vit_msn/modeling_vit_msn.py +741 -0
- docs/transformers/build/lib/transformers/models/vitdet/__init__.py +27 -0
- docs/transformers/build/lib/transformers/models/vitdet/configuration_vitdet.py +156 -0
- docs/transformers/build/lib/transformers/models/vitdet/modeling_vitdet.py +883 -0
- docs/transformers/build/lib/transformers/models/vitmatte/__init__.py +28 -0
- docs/transformers/build/lib/transformers/models/vitmatte/configuration_vitmatte.py +136 -0
- docs/transformers/build/lib/transformers/models/vitmatte/convert_vitmatte_to_hf.py +170 -0
- docs/transformers/build/lib/transformers/models/vitmatte/image_processing_vitmatte.py +272 -0
- docs/transformers/build/lib/transformers/models/vitmatte/modeling_vitmatte.py +341 -0
- docs/transformers/build/lib/transformers/models/vitpose/__init__.py +28 -0
- docs/transformers/build/lib/transformers/models/vitpose/configuration_vitpose.py +126 -0
- docs/transformers/build/lib/transformers/models/vitpose/convert_vitpose_to_hf.py +428 -0
- docs/transformers/build/lib/transformers/models/vitpose/image_processing_vitpose.py +684 -0
- docs/transformers/build/lib/transformers/models/vitpose/modeling_vitpose.py +340 -0
- docs/transformers/build/lib/transformers/models/vitpose_backbone/__init__.py +17 -0
- docs/transformers/build/lib/transformers/models/vitpose_backbone/configuration_vitpose_backbone.py +139 -0
- docs/transformers/build/lib/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py +579 -0
- docs/transformers/build/lib/transformers/models/vits/__init__.py +28 -0
- docs/transformers/build/lib/transformers/models/vits/configuration_vits.py +253 -0
- docs/transformers/build/lib/transformers/models/vits/convert_original_checkpoint.py +390 -0
- docs/transformers/build/lib/transformers/models/vits/modeling_vits.py +1493 -0
- docs/transformers/build/lib/transformers/models/vits/tokenization_vits.py +246 -0
- docs/transformers/build/lib/transformers/models/vivit/__init__.py +28 -0
- docs/transformers/build/lib/transformers/models/vivit/configuration_vivit.py +119 -0
- docs/transformers/build/lib/transformers/models/vivit/convert_vivit_flax_to_pytorch.py +231 -0
- docs/transformers/build/lib/transformers/models/vivit/image_processing_vivit.py +407 -0
- docs/transformers/build/lib/transformers/models/vivit/modeling_vivit.py +844 -0
- docs/transformers/build/lib/transformers/models/wav2vec2/__init__.py +32 -0
- docs/transformers/build/lib/transformers/models/wav2vec2/configuration_wav2vec2.py +347 -0
- docs/transformers/build/lib/transformers/models/wav2vec2/convert_wav2vec2_original_pytorch_checkpoint_to_pytorch.py +385 -0
docs/transformers/build/lib/transformers/models/visual_bert/configuration_visual_bert.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""VisualBERT model configuration"""
|
| 16 |
+
|
| 17 |
+
from ...configuration_utils import PretrainedConfig
|
| 18 |
+
from ...utils import logging
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
logger = logging.get_logger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class VisualBertConfig(PretrainedConfig):
|
| 25 |
+
r"""
|
| 26 |
+
This is the configuration class to store the configuration of a [`VisualBertModel`]. It is used to instantiate an
|
| 27 |
+
VisualBERT model according to the specified arguments, defining the model architecture. Instantiating a
|
| 28 |
+
configuration with the defaults will yield a similar configuration to that of the VisualBERT
|
| 29 |
+
[uclanlp/visualbert-vqa-coco-pre](https://huggingface.co/uclanlp/visualbert-vqa-coco-pre) architecture.
|
| 30 |
+
|
| 31 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 32 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
vocab_size (`int`, *optional*, defaults to 30522):
|
| 37 |
+
Vocabulary size of the VisualBERT model. Defines the number of different tokens that can be represented by
|
| 38 |
+
the `inputs_ids` passed when calling [`VisualBertModel`]. Vocabulary size of the model. Defines the
|
| 39 |
+
different tokens that can be represented by the `inputs_ids` passed to the forward method of
|
| 40 |
+
[`VisualBertModel`].
|
| 41 |
+
hidden_size (`int`, *optional*, defaults to 768):
|
| 42 |
+
Dimensionality of the encoder layers and the pooler layer.
|
| 43 |
+
visual_embedding_dim (`int`, *optional*, defaults to 512):
|
| 44 |
+
Dimensionality of the visual embeddings to be passed to the model.
|
| 45 |
+
num_hidden_layers (`int`, *optional*, defaults to 12):
|
| 46 |
+
Number of hidden layers in the Transformer encoder.
|
| 47 |
+
num_attention_heads (`int`, *optional*, defaults to 12):
|
| 48 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 49 |
+
intermediate_size (`int`, *optional*, defaults to 3072):
|
| 50 |
+
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
| 51 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
| 52 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 53 |
+
`"relu"`, `"selu"` and `"gelu_new"` are supported.
|
| 54 |
+
hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
|
| 55 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
| 56 |
+
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
|
| 57 |
+
The dropout ratio for the attention probabilities.
|
| 58 |
+
max_position_embeddings (`int`, *optional*, defaults to 512):
|
| 59 |
+
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
| 60 |
+
just in case (e.g., 512 or 1024 or 2048).
|
| 61 |
+
type_vocab_size (`int`, *optional*, defaults to 2):
|
| 62 |
+
The vocabulary size of the `token_type_ids` passed when calling [`VisualBertModel`].
|
| 63 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 64 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 65 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
| 66 |
+
The epsilon used by the layer normalization layers.
|
| 67 |
+
bypass_transformer (`bool`, *optional*, defaults to `False`):
|
| 68 |
+
Whether or not the model should bypass the transformer for the visual embeddings. If set to `True`, the
|
| 69 |
+
model directly concatenates the visual embeddings from [`VisualBertEmbeddings`] with text output from
|
| 70 |
+
transformers, and then pass it to a self-attention layer.
|
| 71 |
+
special_visual_initialize (`bool`, *optional*, defaults to `True`):
|
| 72 |
+
Whether or not the visual token type and position type embedding weights should be initialized the same as
|
| 73 |
+
the textual token type and positive type embeddings. When set to `True`, the weights of the textual token
|
| 74 |
+
type and position type embeddings are copied to the respective visual embedding layers.
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
Example:
|
| 78 |
+
|
| 79 |
+
```python
|
| 80 |
+
>>> from transformers import VisualBertConfig, VisualBertModel
|
| 81 |
+
|
| 82 |
+
>>> # Initializing a VisualBERT visualbert-vqa-coco-pre style configuration
|
| 83 |
+
>>> configuration = VisualBertConfig.from_pretrained("uclanlp/visualbert-vqa-coco-pre")
|
| 84 |
+
|
| 85 |
+
>>> # Initializing a model (with random weights) from the visualbert-vqa-coco-pre style configuration
|
| 86 |
+
>>> model = VisualBertModel(configuration)
|
| 87 |
+
|
| 88 |
+
>>> # Accessing the model configuration
|
| 89 |
+
>>> configuration = model.config
|
| 90 |
+
```"""
|
| 91 |
+
|
| 92 |
+
model_type = "visual_bert"
|
| 93 |
+
|
| 94 |
+
def __init__(
|
| 95 |
+
self,
|
| 96 |
+
vocab_size=30522,
|
| 97 |
+
hidden_size=768,
|
| 98 |
+
visual_embedding_dim=512,
|
| 99 |
+
num_hidden_layers=12,
|
| 100 |
+
num_attention_heads=12,
|
| 101 |
+
intermediate_size=3072,
|
| 102 |
+
hidden_act="gelu",
|
| 103 |
+
hidden_dropout_prob=0.1,
|
| 104 |
+
attention_probs_dropout_prob=0.1,
|
| 105 |
+
max_position_embeddings=512,
|
| 106 |
+
type_vocab_size=2,
|
| 107 |
+
initializer_range=0.02,
|
| 108 |
+
layer_norm_eps=1e-12,
|
| 109 |
+
bypass_transformer=False,
|
| 110 |
+
special_visual_initialize=True,
|
| 111 |
+
pad_token_id=1,
|
| 112 |
+
bos_token_id=0,
|
| 113 |
+
eos_token_id=2,
|
| 114 |
+
**kwargs,
|
| 115 |
+
):
|
| 116 |
+
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
| 117 |
+
|
| 118 |
+
self.vocab_size = vocab_size
|
| 119 |
+
self.max_position_embeddings = max_position_embeddings
|
| 120 |
+
self.hidden_size = hidden_size
|
| 121 |
+
self.visual_embedding_dim = visual_embedding_dim
|
| 122 |
+
self.num_hidden_layers = num_hidden_layers
|
| 123 |
+
self.num_attention_heads = num_attention_heads
|
| 124 |
+
self.intermediate_size = intermediate_size
|
| 125 |
+
self.hidden_act = hidden_act
|
| 126 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
| 127 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
| 128 |
+
self.initializer_range = initializer_range
|
| 129 |
+
self.type_vocab_size = type_vocab_size
|
| 130 |
+
self.layer_norm_eps = layer_norm_eps
|
| 131 |
+
self.bypass_transformer = bypass_transformer
|
| 132 |
+
self.special_visual_initialize = special_visual_initialize
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
__all__ = ["VisualBertConfig"]
|
docs/transformers/build/lib/transformers/models/visual_bert/modeling_visual_bert.py
ADDED
|
@@ -0,0 +1,1597 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2021 The UCLA NLP Authors and The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""PyTorch VisualBERT model."""
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
from typing import Optional, Tuple, Union
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import torch.utils.checkpoint
|
| 23 |
+
from torch import nn
|
| 24 |
+
from torch.nn import CrossEntropyLoss, KLDivLoss, LogSoftmax
|
| 25 |
+
|
| 26 |
+
from ...activations import ACT2FN
|
| 27 |
+
from ...modeling_outputs import (
|
| 28 |
+
BaseModelOutput,
|
| 29 |
+
BaseModelOutputWithPooling,
|
| 30 |
+
MultipleChoiceModelOutput,
|
| 31 |
+
SequenceClassifierOutput,
|
| 32 |
+
)
|
| 33 |
+
from ...modeling_utils import PreTrainedModel
|
| 34 |
+
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
| 35 |
+
from ...utils import (
|
| 36 |
+
ModelOutput,
|
| 37 |
+
add_start_docstrings,
|
| 38 |
+
add_start_docstrings_to_model_forward,
|
| 39 |
+
logging,
|
| 40 |
+
replace_return_docstrings,
|
| 41 |
+
)
|
| 42 |
+
from .configuration_visual_bert import VisualBertConfig
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
logger = logging.get_logger(__name__)
|
| 46 |
+
|
| 47 |
+
_CONFIG_FOR_DOC = "VisualBertConfig"
|
| 48 |
+
_CHECKPOINT_FOR_DOC = "uclanlp/visualbert-vqa-coco-pre"
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class VisualBertEmbeddings(nn.Module):
|
| 52 |
+
"""Construct the embeddings from word, position and token_type embeddings and visual embeddings."""
|
| 53 |
+
|
| 54 |
+
def __init__(self, config):
|
| 55 |
+
super().__init__()
|
| 56 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
| 57 |
+
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
| 58 |
+
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
|
| 59 |
+
|
| 60 |
+
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
| 61 |
+
# any TensorFlow checkpoint file
|
| 62 |
+
|
| 63 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 64 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 65 |
+
|
| 66 |
+
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
| 67 |
+
self.register_buffer(
|
| 68 |
+
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
# For Visual Features
|
| 72 |
+
# Token type and position embedding for image features
|
| 73 |
+
self.visual_token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
|
| 74 |
+
self.visual_position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
| 75 |
+
|
| 76 |
+
if config.special_visual_initialize:
|
| 77 |
+
self.visual_token_type_embeddings.weight.data = nn.Parameter(
|
| 78 |
+
self.token_type_embeddings.weight.data.clone(), requires_grad=True
|
| 79 |
+
)
|
| 80 |
+
self.visual_position_embeddings.weight.data = nn.Parameter(
|
| 81 |
+
self.position_embeddings.weight.data.clone(), requires_grad=True
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
self.visual_projection = nn.Linear(config.visual_embedding_dim, config.hidden_size)
|
| 85 |
+
|
| 86 |
+
def forward(
|
| 87 |
+
self,
|
| 88 |
+
input_ids=None,
|
| 89 |
+
token_type_ids=None,
|
| 90 |
+
position_ids=None,
|
| 91 |
+
inputs_embeds=None,
|
| 92 |
+
visual_embeds=None,
|
| 93 |
+
visual_token_type_ids=None,
|
| 94 |
+
image_text_alignment=None,
|
| 95 |
+
):
|
| 96 |
+
if input_ids is not None:
|
| 97 |
+
input_shape = input_ids.size()
|
| 98 |
+
else:
|
| 99 |
+
input_shape = inputs_embeds.size()[:-1]
|
| 100 |
+
|
| 101 |
+
seq_length = input_shape[1]
|
| 102 |
+
|
| 103 |
+
if position_ids is None:
|
| 104 |
+
position_ids = self.position_ids[:, :seq_length]
|
| 105 |
+
|
| 106 |
+
if inputs_embeds is None:
|
| 107 |
+
inputs_embeds = self.word_embeddings(input_ids)
|
| 108 |
+
|
| 109 |
+
if token_type_ids is None:
|
| 110 |
+
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
|
| 111 |
+
|
| 112 |
+
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
| 113 |
+
|
| 114 |
+
embeddings = inputs_embeds + token_type_embeddings
|
| 115 |
+
|
| 116 |
+
# Absolute Position Embeddings
|
| 117 |
+
position_embeddings = self.position_embeddings(position_ids)
|
| 118 |
+
embeddings += position_embeddings
|
| 119 |
+
|
| 120 |
+
if visual_embeds is not None:
|
| 121 |
+
if visual_token_type_ids is None:
|
| 122 |
+
visual_token_type_ids = torch.ones(
|
| 123 |
+
visual_embeds.size()[:-1], dtype=torch.long, device=self.position_ids.device
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
visual_embeds = self.visual_projection(visual_embeds)
|
| 127 |
+
visual_token_type_embeddings = self.visual_token_type_embeddings(visual_token_type_ids)
|
| 128 |
+
|
| 129 |
+
if image_text_alignment is not None:
|
| 130 |
+
# image_text_alignment = Batch x image_length x alignment_number.
|
| 131 |
+
# Each element denotes the position of the word corresponding to the image feature. -1 is the padding value.
|
| 132 |
+
|
| 133 |
+
dtype = token_type_embeddings.dtype
|
| 134 |
+
image_text_alignment_mask = (image_text_alignment != -1).long()
|
| 135 |
+
# Get rid of the -1.
|
| 136 |
+
image_text_alignment = image_text_alignment_mask * image_text_alignment
|
| 137 |
+
|
| 138 |
+
# Batch x image_length x alignment length x dim
|
| 139 |
+
visual_position_embeddings = self.position_embeddings(image_text_alignment)
|
| 140 |
+
visual_position_embeddings *= image_text_alignment_mask.to(dtype=dtype).unsqueeze(-1)
|
| 141 |
+
visual_position_embeddings = visual_position_embeddings.sum(2)
|
| 142 |
+
|
| 143 |
+
# We want to averge along the alignment_number dimension.
|
| 144 |
+
image_text_alignment_mask = image_text_alignment_mask.to(dtype=dtype).sum(2)
|
| 145 |
+
|
| 146 |
+
if (image_text_alignment_mask == 0).sum() != 0:
|
| 147 |
+
image_text_alignment_mask[image_text_alignment_mask == 0] = 1 # Avoid divide by zero error
|
| 148 |
+
logger.warning(
|
| 149 |
+
"Found 0 values in `image_text_alignment_mask`. Setting them to 1 to avoid divide-by-zero"
|
| 150 |
+
" error."
|
| 151 |
+
)
|
| 152 |
+
visual_position_embeddings = visual_position_embeddings / image_text_alignment_mask.unsqueeze(-1)
|
| 153 |
+
|
| 154 |
+
visual_position_ids = torch.zeros(
|
| 155 |
+
*visual_embeds.size()[:-1], dtype=torch.long, device=visual_embeds.device
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
# When fine-tuning the detector , the image_text_alignment is sometimes padded too long.
|
| 159 |
+
if visual_position_embeddings.size(1) != visual_embeds.size(1):
|
| 160 |
+
if visual_position_embeddings.size(1) < visual_embeds.size(1):
|
| 161 |
+
raise ValueError(
|
| 162 |
+
f"Visual position embeddings length: {visual_position_embeddings.size(1)} "
|
| 163 |
+
f"should be the same as `visual_embeds` length: {visual_embeds.size(1)}"
|
| 164 |
+
)
|
| 165 |
+
visual_position_embeddings = visual_position_embeddings[:, : visual_embeds.size(1), :]
|
| 166 |
+
|
| 167 |
+
visual_position_embeddings = visual_position_embeddings + self.visual_position_embeddings(
|
| 168 |
+
visual_position_ids
|
| 169 |
+
)
|
| 170 |
+
else:
|
| 171 |
+
visual_position_ids = torch.zeros(
|
| 172 |
+
*visual_embeds.size()[:-1], dtype=torch.long, device=visual_embeds.device
|
| 173 |
+
)
|
| 174 |
+
visual_position_embeddings = self.visual_position_embeddings(visual_position_ids)
|
| 175 |
+
|
| 176 |
+
visual_embeddings = visual_embeds + visual_position_embeddings + visual_token_type_embeddings
|
| 177 |
+
|
| 178 |
+
embeddings = torch.cat((embeddings, visual_embeddings), dim=1)
|
| 179 |
+
|
| 180 |
+
embeddings = self.LayerNorm(embeddings)
|
| 181 |
+
embeddings = self.dropout(embeddings)
|
| 182 |
+
return embeddings
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
class VisualBertSelfAttention(nn.Module):
|
| 186 |
+
def __init__(self, config):
|
| 187 |
+
super().__init__()
|
| 188 |
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
| 189 |
+
raise ValueError(
|
| 190 |
+
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
|
| 191 |
+
f"heads ({config.num_attention_heads})"
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
self.num_attention_heads = config.num_attention_heads
|
| 195 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
| 196 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
| 197 |
+
|
| 198 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
| 199 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
| 200 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
| 201 |
+
|
| 202 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
| 203 |
+
|
| 204 |
+
def transpose_for_scores(self, x):
|
| 205 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
| 206 |
+
x = x.view(*new_x_shape)
|
| 207 |
+
return x.permute(0, 2, 1, 3)
|
| 208 |
+
|
| 209 |
+
def forward(
|
| 210 |
+
self,
|
| 211 |
+
hidden_states,
|
| 212 |
+
attention_mask=None,
|
| 213 |
+
head_mask=None,
|
| 214 |
+
output_attentions=False,
|
| 215 |
+
):
|
| 216 |
+
mixed_query_layer = self.query(hidden_states)
|
| 217 |
+
|
| 218 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
| 219 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
| 220 |
+
|
| 221 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
| 222 |
+
|
| 223 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
| 224 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
| 225 |
+
|
| 226 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
| 227 |
+
if attention_mask is not None:
|
| 228 |
+
# Apply the attention mask is (precomputed for all layers in VisualBertSelfAttentionModel forward() function)
|
| 229 |
+
attention_scores = attention_scores + attention_mask
|
| 230 |
+
|
| 231 |
+
# Normalize the attention scores to probabilities.
|
| 232 |
+
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
| 233 |
+
|
| 234 |
+
# This is actually dropping out entire tokens to attend to, which might
|
| 235 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
| 236 |
+
attention_probs = self.dropout(attention_probs)
|
| 237 |
+
|
| 238 |
+
# Mask heads if we want to
|
| 239 |
+
if head_mask is not None:
|
| 240 |
+
attention_probs = attention_probs * head_mask
|
| 241 |
+
|
| 242 |
+
context_layer = torch.matmul(attention_probs, value_layer)
|
| 243 |
+
|
| 244 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
| 245 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
| 246 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
| 247 |
+
|
| 248 |
+
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
| 249 |
+
|
| 250 |
+
return outputs
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
# Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->VisualBert
|
| 254 |
+
class VisualBertSelfOutput(nn.Module):
|
| 255 |
+
def __init__(self, config):
|
| 256 |
+
super().__init__()
|
| 257 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 258 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 259 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 260 |
+
|
| 261 |
+
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
| 262 |
+
hidden_states = self.dense(hidden_states)
|
| 263 |
+
hidden_states = self.dropout(hidden_states)
|
| 264 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
| 265 |
+
return hidden_states
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
class VisualBertAttention(nn.Module):
|
| 269 |
+
def __init__(self, config):
|
| 270 |
+
super().__init__()
|
| 271 |
+
self.self = VisualBertSelfAttention(config)
|
| 272 |
+
self.output = VisualBertSelfOutput(config)
|
| 273 |
+
self.pruned_heads = set()
|
| 274 |
+
|
| 275 |
+
def prune_heads(self, heads):
|
| 276 |
+
if len(heads) == 0:
|
| 277 |
+
return
|
| 278 |
+
heads, index = find_pruneable_heads_and_indices(
|
| 279 |
+
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
# Prune linear layers
|
| 283 |
+
self.self.query = prune_linear_layer(self.self.query, index)
|
| 284 |
+
self.self.key = prune_linear_layer(self.self.key, index)
|
| 285 |
+
self.self.value = prune_linear_layer(self.self.value, index)
|
| 286 |
+
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
| 287 |
+
|
| 288 |
+
# Update hyper params and store pruned heads
|
| 289 |
+
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
| 290 |
+
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
| 291 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
| 292 |
+
|
| 293 |
+
def forward(
|
| 294 |
+
self,
|
| 295 |
+
hidden_states,
|
| 296 |
+
attention_mask=None,
|
| 297 |
+
head_mask=None,
|
| 298 |
+
output_attentions=False,
|
| 299 |
+
):
|
| 300 |
+
self_outputs = self.self(
|
| 301 |
+
hidden_states,
|
| 302 |
+
attention_mask,
|
| 303 |
+
head_mask,
|
| 304 |
+
output_attentions,
|
| 305 |
+
)
|
| 306 |
+
attention_output = self.output(self_outputs[0], hidden_states)
|
| 307 |
+
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
| 308 |
+
return outputs
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->VisualBert
|
| 312 |
+
class VisualBertIntermediate(nn.Module):
|
| 313 |
+
def __init__(self, config):
|
| 314 |
+
super().__init__()
|
| 315 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
| 316 |
+
if isinstance(config.hidden_act, str):
|
| 317 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
| 318 |
+
else:
|
| 319 |
+
self.intermediate_act_fn = config.hidden_act
|
| 320 |
+
|
| 321 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 322 |
+
hidden_states = self.dense(hidden_states)
|
| 323 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
| 324 |
+
return hidden_states
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
# Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->VisualBert
|
| 328 |
+
class VisualBertOutput(nn.Module):
|
| 329 |
+
def __init__(self, config):
|
| 330 |
+
super().__init__()
|
| 331 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
| 332 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 333 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 334 |
+
|
| 335 |
+
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
| 336 |
+
hidden_states = self.dense(hidden_states)
|
| 337 |
+
hidden_states = self.dropout(hidden_states)
|
| 338 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
| 339 |
+
return hidden_states
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
class VisualBertLayer(nn.Module):
|
| 343 |
+
def __init__(self, config):
|
| 344 |
+
super().__init__()
|
| 345 |
+
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
| 346 |
+
self.seq_len_dim = 1
|
| 347 |
+
self.attention = VisualBertAttention(config)
|
| 348 |
+
self.intermediate = VisualBertIntermediate(config)
|
| 349 |
+
self.output = VisualBertOutput(config)
|
| 350 |
+
|
| 351 |
+
def forward(
|
| 352 |
+
self,
|
| 353 |
+
hidden_states,
|
| 354 |
+
attention_mask=None,
|
| 355 |
+
head_mask=None,
|
| 356 |
+
output_attentions=False,
|
| 357 |
+
):
|
| 358 |
+
self_attention_outputs = self.attention(
|
| 359 |
+
hidden_states,
|
| 360 |
+
attention_mask,
|
| 361 |
+
head_mask,
|
| 362 |
+
output_attentions=output_attentions,
|
| 363 |
+
)
|
| 364 |
+
attention_output = self_attention_outputs[0]
|
| 365 |
+
|
| 366 |
+
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
| 367 |
+
|
| 368 |
+
layer_output = apply_chunking_to_forward(
|
| 369 |
+
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
| 370 |
+
)
|
| 371 |
+
outputs = (layer_output,) + outputs
|
| 372 |
+
|
| 373 |
+
return outputs
|
| 374 |
+
|
| 375 |
+
def feed_forward_chunk(self, attention_output):
|
| 376 |
+
intermediate_output = self.intermediate(attention_output)
|
| 377 |
+
layer_output = self.output(intermediate_output, attention_output)
|
| 378 |
+
return layer_output
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
class VisualBertEncoder(nn.Module):
|
| 382 |
+
def __init__(self, config):
|
| 383 |
+
super().__init__()
|
| 384 |
+
self.config = config
|
| 385 |
+
self.layer = nn.ModuleList([VisualBertLayer(config) for _ in range(config.num_hidden_layers)])
|
| 386 |
+
self.gradient_checkpointing = False
|
| 387 |
+
|
| 388 |
+
def forward(
|
| 389 |
+
self,
|
| 390 |
+
hidden_states,
|
| 391 |
+
attention_mask=None,
|
| 392 |
+
head_mask=None,
|
| 393 |
+
output_attentions=False,
|
| 394 |
+
output_hidden_states=False,
|
| 395 |
+
return_dict=True,
|
| 396 |
+
):
|
| 397 |
+
all_hidden_states = () if output_hidden_states else None
|
| 398 |
+
all_self_attentions = () if output_attentions else None
|
| 399 |
+
|
| 400 |
+
for i, layer_module in enumerate(self.layer):
|
| 401 |
+
if output_hidden_states:
|
| 402 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 403 |
+
|
| 404 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
| 405 |
+
|
| 406 |
+
if self.gradient_checkpointing and self.training:
|
| 407 |
+
layer_outputs = self._gradient_checkpointing_func(
|
| 408 |
+
layer_module.__call__,
|
| 409 |
+
hidden_states,
|
| 410 |
+
attention_mask,
|
| 411 |
+
layer_head_mask,
|
| 412 |
+
output_attentions,
|
| 413 |
+
)
|
| 414 |
+
else:
|
| 415 |
+
layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions)
|
| 416 |
+
|
| 417 |
+
hidden_states = layer_outputs[0]
|
| 418 |
+
if output_attentions:
|
| 419 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
| 420 |
+
|
| 421 |
+
if output_hidden_states:
|
| 422 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 423 |
+
|
| 424 |
+
if not return_dict:
|
| 425 |
+
return tuple(
|
| 426 |
+
v
|
| 427 |
+
for v in [
|
| 428 |
+
hidden_states,
|
| 429 |
+
all_hidden_states,
|
| 430 |
+
all_self_attentions,
|
| 431 |
+
]
|
| 432 |
+
if v is not None
|
| 433 |
+
)
|
| 434 |
+
return BaseModelOutput(
|
| 435 |
+
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
# Copied from transformers.models.bert.modeling_bert.BertPooler with Bert->VisualBert
|
| 440 |
+
class VisualBertPooler(nn.Module):
|
| 441 |
+
def __init__(self, config):
|
| 442 |
+
super().__init__()
|
| 443 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 444 |
+
self.activation = nn.Tanh()
|
| 445 |
+
|
| 446 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 447 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
| 448 |
+
# to the first token.
|
| 449 |
+
first_token_tensor = hidden_states[:, 0]
|
| 450 |
+
pooled_output = self.dense(first_token_tensor)
|
| 451 |
+
pooled_output = self.activation(pooled_output)
|
| 452 |
+
return pooled_output
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
# Copied from transformers.models.bert.modeling_bert.BertPredictionHeadTransform with Bert->VisualBert
|
| 456 |
+
class VisualBertPredictionHeadTransform(nn.Module):
|
| 457 |
+
def __init__(self, config):
|
| 458 |
+
super().__init__()
|
| 459 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 460 |
+
if isinstance(config.hidden_act, str):
|
| 461 |
+
self.transform_act_fn = ACT2FN[config.hidden_act]
|
| 462 |
+
else:
|
| 463 |
+
self.transform_act_fn = config.hidden_act
|
| 464 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 465 |
+
|
| 466 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 467 |
+
hidden_states = self.dense(hidden_states)
|
| 468 |
+
hidden_states = self.transform_act_fn(hidden_states)
|
| 469 |
+
hidden_states = self.LayerNorm(hidden_states)
|
| 470 |
+
return hidden_states
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
# Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead with Bert->VisualBert
|
| 474 |
+
class VisualBertLMPredictionHead(nn.Module):
|
| 475 |
+
def __init__(self, config):
|
| 476 |
+
super().__init__()
|
| 477 |
+
self.transform = VisualBertPredictionHeadTransform(config)
|
| 478 |
+
|
| 479 |
+
# The output weights are the same as the input embeddings, but there is
|
| 480 |
+
# an output-only bias for each token.
|
| 481 |
+
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 482 |
+
|
| 483 |
+
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
| 484 |
+
|
| 485 |
+
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
| 486 |
+
self.decoder.bias = self.bias
|
| 487 |
+
|
| 488 |
+
def _tie_weights(self):
|
| 489 |
+
self.decoder.bias = self.bias
|
| 490 |
+
|
| 491 |
+
def forward(self, hidden_states):
|
| 492 |
+
hidden_states = self.transform(hidden_states)
|
| 493 |
+
hidden_states = self.decoder(hidden_states)
|
| 494 |
+
return hidden_states
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
# Copied from transformers.models.bert.modeling_bert.BertPreTrainingHeads with Bert->VisualBert
|
| 498 |
+
class VisualBertPreTrainingHeads(nn.Module):
|
| 499 |
+
def __init__(self, config):
|
| 500 |
+
super().__init__()
|
| 501 |
+
self.predictions = VisualBertLMPredictionHead(config)
|
| 502 |
+
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
| 503 |
+
|
| 504 |
+
def forward(self, sequence_output, pooled_output):
|
| 505 |
+
prediction_scores = self.predictions(sequence_output)
|
| 506 |
+
seq_relationship_score = self.seq_relationship(pooled_output)
|
| 507 |
+
return prediction_scores, seq_relationship_score
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
class VisualBertPreTrainedModel(PreTrainedModel):
|
| 511 |
+
"""
|
| 512 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 513 |
+
models.
|
| 514 |
+
"""
|
| 515 |
+
|
| 516 |
+
config_class = VisualBertConfig
|
| 517 |
+
base_model_prefix = "visual_bert"
|
| 518 |
+
supports_gradient_checkpointing = True
|
| 519 |
+
|
| 520 |
+
def _init_weights(self, module):
|
| 521 |
+
"""Initialize the weights"""
|
| 522 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
| 523 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
| 524 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
| 525 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 526 |
+
if hasattr(module, "bias") and module.bias is not None:
|
| 527 |
+
module.bias.data.zero_()
|
| 528 |
+
elif isinstance(module, nn.LayerNorm):
|
| 529 |
+
module.bias.data.zero_()
|
| 530 |
+
module.weight.data.fill_(1.0)
|
| 531 |
+
elif isinstance(module, VisualBertLMPredictionHead):
|
| 532 |
+
module.bias.data.zero_()
|
| 533 |
+
|
| 534 |
+
|
| 535 |
+
@dataclass
|
| 536 |
+
class VisualBertForPreTrainingOutput(ModelOutput):
|
| 537 |
+
"""
|
| 538 |
+
Output type of [`VisualBertForPreTraining`].
|
| 539 |
+
|
| 540 |
+
Args:
|
| 541 |
+
loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
|
| 542 |
+
Total loss as the sum of the masked language modeling loss and the sentence-image prediction
|
| 543 |
+
(classification) loss.
|
| 544 |
+
prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
| 545 |
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
| 546 |
+
seq_relationship_logits (`torch.FloatTensor` of shape `(batch_size, 2)`):
|
| 547 |
+
Prediction scores of the sentence-image prediction (classification) head (scores of True/False continuation
|
| 548 |
+
before SoftMax).
|
| 549 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 550 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
|
| 551 |
+
shape `(batch_size, sequence_length, hidden_size)`.
|
| 552 |
+
|
| 553 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
| 554 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 555 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 556 |
+
sequence_length)`.
|
| 557 |
+
|
| 558 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 559 |
+
heads.
|
| 560 |
+
"""
|
| 561 |
+
|
| 562 |
+
loss: Optional[torch.FloatTensor] = None
|
| 563 |
+
prediction_logits: Optional[torch.FloatTensor] = None
|
| 564 |
+
seq_relationship_logits: Optional[torch.FloatTensor] = None
|
| 565 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 566 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
| 567 |
+
|
| 568 |
+
|
| 569 |
+
VISUAL_BERT_START_DOCSTRING = r"""
|
| 570 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 571 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| 572 |
+
etc.)
|
| 573 |
+
|
| 574 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
| 575 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
| 576 |
+
and behavior.
|
| 577 |
+
|
| 578 |
+
Parameters:
|
| 579 |
+
config ([`VisualBertConfig`]): Model configuration class with all the parameters of the model.
|
| 580 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 581 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 582 |
+
"""
|
| 583 |
+
|
| 584 |
+
VISUAL_BERT_INPUTS_DOCSTRING = r"""
|
| 585 |
+
Args:
|
| 586 |
+
input_ids (`torch.LongTensor` of shape `({0})`):
|
| 587 |
+
Indices of input sequence tokens in the vocabulary.
|
| 588 |
+
|
| 589 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 590 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 591 |
+
|
| 592 |
+
[What are input IDs?](../glossary#input-ids)
|
| 593 |
+
attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
|
| 594 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 595 |
+
|
| 596 |
+
- 1 for tokens that are **not masked**,
|
| 597 |
+
- 0 for tokens that are **masked**.
|
| 598 |
+
|
| 599 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 600 |
+
token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
|
| 601 |
+
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
|
| 602 |
+
1]`:
|
| 603 |
+
|
| 604 |
+
- 0 corresponds to a *sentence A* token,
|
| 605 |
+
- 1 corresponds to a *sentence B* token.
|
| 606 |
+
|
| 607 |
+
[What are token type IDs?](../glossary#token-type-ids)
|
| 608 |
+
position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
|
| 609 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
| 610 |
+
config.max_position_embeddings - 1]`.
|
| 611 |
+
|
| 612 |
+
[What are position IDs?](../glossary#position-ids)
|
| 613 |
+
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
| 614 |
+
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
| 615 |
+
|
| 616 |
+
- 1 indicates the head is **not masked**,
|
| 617 |
+
- 0 indicates the head is **masked**.
|
| 618 |
+
|
| 619 |
+
inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
|
| 620 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
| 621 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
| 622 |
+
model's internal embedding lookup matrix.
|
| 623 |
+
|
| 624 |
+
visual_embeds (`torch.FloatTensor` of shape `(batch_size, visual_seq_length, visual_embedding_dim)`, *optional*):
|
| 625 |
+
The embedded representation of the visual inputs, generally derived using using an object detector.
|
| 626 |
+
|
| 627 |
+
visual_attention_mask (`torch.FloatTensor` of shape `(batch_size, visual_seq_length)`, *optional*):
|
| 628 |
+
Mask to avoid performing attention on visual embeddings. Mask values selected in `[0, 1]`:
|
| 629 |
+
|
| 630 |
+
- 1 for tokens that are **not masked**,
|
| 631 |
+
- 0 for tokens that are **masked**.
|
| 632 |
+
|
| 633 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 634 |
+
visual_token_type_ids (`torch.LongTensor` of shape `(batch_size, visual_seq_length)`, *optional*):
|
| 635 |
+
Segment token indices to indicate different portions of the visual embeds.
|
| 636 |
+
|
| 637 |
+
[What are token type IDs?](../glossary#token-type-ids) The authors of VisualBERT set the
|
| 638 |
+
*visual_token_type_ids* to *1* for all tokens.
|
| 639 |
+
|
| 640 |
+
image_text_alignment (`torch.LongTensor` of shape `(batch_size, visual_seq_length, alignment_number)`, *optional*):
|
| 641 |
+
Image-Text alignment uses to decide the position IDs of the visual embeddings.
|
| 642 |
+
|
| 643 |
+
output_attentions (`bool`, *optional*):
|
| 644 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 645 |
+
tensors for more detail.
|
| 646 |
+
output_hidden_states (`bool`, *optional*):
|
| 647 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 648 |
+
more detail.
|
| 649 |
+
return_dict (`bool`, *optional*):
|
| 650 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 651 |
+
"""
|
| 652 |
+
|
| 653 |
+
|
| 654 |
+
@add_start_docstrings(
|
| 655 |
+
"The bare VisualBert Model transformer outputting raw hidden-states without any specific head on top.",
|
| 656 |
+
VISUAL_BERT_START_DOCSTRING,
|
| 657 |
+
)
|
| 658 |
+
class VisualBertModel(VisualBertPreTrainedModel):
|
| 659 |
+
"""
|
| 660 |
+
|
| 661 |
+
The model can behave as an encoder (with only self-attention) following the architecture described in [Attention is
|
| 662 |
+
all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
|
| 663 |
+
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
| 664 |
+
"""
|
| 665 |
+
|
| 666 |
+
def __init__(self, config, add_pooling_layer=True):
|
| 667 |
+
super().__init__(config)
|
| 668 |
+
self.config = config
|
| 669 |
+
|
| 670 |
+
self.embeddings = VisualBertEmbeddings(config)
|
| 671 |
+
self.encoder = VisualBertEncoder(config)
|
| 672 |
+
|
| 673 |
+
self.pooler = VisualBertPooler(config) if add_pooling_layer else None
|
| 674 |
+
|
| 675 |
+
self.bypass_transformer = config.bypass_transformer
|
| 676 |
+
|
| 677 |
+
if self.bypass_transformer:
|
| 678 |
+
self.additional_layer = VisualBertLayer(config)
|
| 679 |
+
|
| 680 |
+
# Initialize weights and apply final processing
|
| 681 |
+
self.post_init()
|
| 682 |
+
|
| 683 |
+
def get_input_embeddings(self):
|
| 684 |
+
return self.embeddings.word_embeddings
|
| 685 |
+
|
| 686 |
+
def set_input_embeddings(self, value):
|
| 687 |
+
self.embeddings.word_embeddings = value
|
| 688 |
+
|
| 689 |
+
def _prune_heads(self, heads_to_prune):
|
| 690 |
+
"""
|
| 691 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
| 692 |
+
class PreTrainedModel
|
| 693 |
+
"""
|
| 694 |
+
for layer, heads in heads_to_prune.items():
|
| 695 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
| 696 |
+
|
| 697 |
+
@add_start_docstrings_to_model_forward(VISUAL_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 698 |
+
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
|
| 699 |
+
def forward(
|
| 700 |
+
self,
|
| 701 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 702 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 703 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 704 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 705 |
+
head_mask: Optional[torch.LongTensor] = None,
|
| 706 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 707 |
+
visual_embeds: Optional[torch.FloatTensor] = None,
|
| 708 |
+
visual_attention_mask: Optional[torch.LongTensor] = None,
|
| 709 |
+
visual_token_type_ids: Optional[torch.LongTensor] = None,
|
| 710 |
+
image_text_alignment: Optional[torch.LongTensor] = None,
|
| 711 |
+
output_attentions: Optional[bool] = None,
|
| 712 |
+
output_hidden_states: Optional[bool] = None,
|
| 713 |
+
return_dict: Optional[bool] = None,
|
| 714 |
+
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPooling]:
|
| 715 |
+
r"""
|
| 716 |
+
|
| 717 |
+
Returns:
|
| 718 |
+
|
| 719 |
+
Example:
|
| 720 |
+
|
| 721 |
+
```python
|
| 722 |
+
# Assumption: *get_visual_embeddings(image)* gets the visual embeddings of the image.
|
| 723 |
+
from transformers import AutoTokenizer, VisualBertModel
|
| 724 |
+
import torch
|
| 725 |
+
|
| 726 |
+
tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
|
| 727 |
+
model = VisualBertModel.from_pretrained("uclanlp/visualbert-vqa-coco-pre")
|
| 728 |
+
|
| 729 |
+
inputs = tokenizer("The capital of France is Paris.", return_tensors="pt")
|
| 730 |
+
visual_embeds = get_visual_embeddings(image).unsqueeze(0)
|
| 731 |
+
visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long)
|
| 732 |
+
visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)
|
| 733 |
+
|
| 734 |
+
inputs.update(
|
| 735 |
+
{
|
| 736 |
+
"visual_embeds": visual_embeds,
|
| 737 |
+
"visual_token_type_ids": visual_token_type_ids,
|
| 738 |
+
"visual_attention_mask": visual_attention_mask,
|
| 739 |
+
}
|
| 740 |
+
)
|
| 741 |
+
|
| 742 |
+
outputs = model(**inputs)
|
| 743 |
+
|
| 744 |
+
last_hidden_states = outputs.last_hidden_state
|
| 745 |
+
```"""
|
| 746 |
+
|
| 747 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 748 |
+
output_hidden_states = (
|
| 749 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 750 |
+
)
|
| 751 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 752 |
+
|
| 753 |
+
if input_ids is not None and inputs_embeds is not None:
|
| 754 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
| 755 |
+
elif input_ids is not None:
|
| 756 |
+
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
| 757 |
+
input_shape = input_ids.size()
|
| 758 |
+
elif inputs_embeds is not None:
|
| 759 |
+
input_shape = inputs_embeds.size()[:-1]
|
| 760 |
+
else:
|
| 761 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
| 762 |
+
|
| 763 |
+
batch_size, seq_length = input_shape
|
| 764 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
| 765 |
+
|
| 766 |
+
if visual_embeds is not None:
|
| 767 |
+
visual_input_shape = visual_embeds.size()[:-1]
|
| 768 |
+
|
| 769 |
+
if attention_mask is None:
|
| 770 |
+
attention_mask = torch.ones(input_shape, device=device)
|
| 771 |
+
|
| 772 |
+
if visual_embeds is not None and visual_attention_mask is None:
|
| 773 |
+
visual_attention_mask = torch.ones(visual_input_shape, device=device)
|
| 774 |
+
|
| 775 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
| 776 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
| 777 |
+
if visual_embeds is not None:
|
| 778 |
+
combined_attention_mask = torch.cat((attention_mask, visual_attention_mask), dim=-1)
|
| 779 |
+
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
|
| 780 |
+
combined_attention_mask, (batch_size, input_shape + visual_input_shape)
|
| 781 |
+
)
|
| 782 |
+
|
| 783 |
+
else:
|
| 784 |
+
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
|
| 785 |
+
attention_mask, (batch_size, input_shape)
|
| 786 |
+
)
|
| 787 |
+
|
| 788 |
+
# Prepare head mask if needed
|
| 789 |
+
# 1.0 in head_mask indicate we keep the head
|
| 790 |
+
# attention_probs has shape bsz x n_heads x N x N
|
| 791 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
| 792 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
| 793 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
| 794 |
+
|
| 795 |
+
embedding_output = self.embeddings(
|
| 796 |
+
input_ids=input_ids,
|
| 797 |
+
position_ids=position_ids,
|
| 798 |
+
token_type_ids=token_type_ids,
|
| 799 |
+
inputs_embeds=inputs_embeds,
|
| 800 |
+
visual_embeds=visual_embeds,
|
| 801 |
+
visual_token_type_ids=visual_token_type_ids,
|
| 802 |
+
image_text_alignment=image_text_alignment,
|
| 803 |
+
)
|
| 804 |
+
|
| 805 |
+
if self.bypass_transformer and visual_embeds is not None:
|
| 806 |
+
text_length = input_ids.size(1)
|
| 807 |
+
text_embedding_output = embedding_output[:, :text_length, :]
|
| 808 |
+
visual_embedding_output = embedding_output[:, text_length:, :]
|
| 809 |
+
|
| 810 |
+
text_extended_attention_mask = extended_attention_mask[:, :, text_length, :text_length]
|
| 811 |
+
|
| 812 |
+
encoded_outputs = self.encoder(
|
| 813 |
+
text_embedding_output,
|
| 814 |
+
attention_mask=text_extended_attention_mask,
|
| 815 |
+
output_attentions=output_attentions,
|
| 816 |
+
output_hidden_states=output_hidden_states,
|
| 817 |
+
return_dict=return_dict,
|
| 818 |
+
)
|
| 819 |
+
sequence_output = encoded_outputs[0]
|
| 820 |
+
concatenated_input = torch.cat((sequence_output, visual_embedding_output), dim=1)
|
| 821 |
+
sequence_output = self.additional_layer(concatenated_input, extended_attention_mask)
|
| 822 |
+
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
| 823 |
+
|
| 824 |
+
else:
|
| 825 |
+
encoder_outputs = self.encoder(
|
| 826 |
+
embedding_output,
|
| 827 |
+
attention_mask=extended_attention_mask,
|
| 828 |
+
head_mask=head_mask,
|
| 829 |
+
output_attentions=output_attentions,
|
| 830 |
+
output_hidden_states=output_hidden_states,
|
| 831 |
+
return_dict=return_dict,
|
| 832 |
+
)
|
| 833 |
+
sequence_output = encoder_outputs[0]
|
| 834 |
+
|
| 835 |
+
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
| 836 |
+
|
| 837 |
+
if not return_dict:
|
| 838 |
+
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
| 839 |
+
|
| 840 |
+
return BaseModelOutputWithPooling(
|
| 841 |
+
last_hidden_state=sequence_output,
|
| 842 |
+
pooler_output=pooled_output,
|
| 843 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 844 |
+
attentions=encoder_outputs.attentions,
|
| 845 |
+
)
|
| 846 |
+
|
| 847 |
+
|
| 848 |
+
@add_start_docstrings(
|
| 849 |
+
"""
|
| 850 |
+
VisualBert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a
|
| 851 |
+
`sentence-image prediction (classification)` head.
|
| 852 |
+
""",
|
| 853 |
+
VISUAL_BERT_START_DOCSTRING,
|
| 854 |
+
)
|
| 855 |
+
class VisualBertForPreTraining(VisualBertPreTrainedModel):
|
| 856 |
+
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
|
| 857 |
+
|
| 858 |
+
def __init__(self, config):
|
| 859 |
+
super().__init__(config)
|
| 860 |
+
|
| 861 |
+
self.visual_bert = VisualBertModel(config)
|
| 862 |
+
self.cls = VisualBertPreTrainingHeads(config)
|
| 863 |
+
|
| 864 |
+
# Initialize weights and apply final processing
|
| 865 |
+
self.post_init()
|
| 866 |
+
|
| 867 |
+
def get_output_embeddings(self):
|
| 868 |
+
return self.cls.predictions.decoder
|
| 869 |
+
|
| 870 |
+
def set_output_embeddings(self, new_embeddings):
|
| 871 |
+
self.cls.predictions.decoder = new_embeddings
|
| 872 |
+
self.cls.predictions.bias = new_embeddings.bias
|
| 873 |
+
|
| 874 |
+
@add_start_docstrings_to_model_forward(VISUAL_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 875 |
+
@replace_return_docstrings(output_type=VisualBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
|
| 876 |
+
def forward(
|
| 877 |
+
self,
|
| 878 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 879 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 880 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 881 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 882 |
+
head_mask: Optional[torch.LongTensor] = None,
|
| 883 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 884 |
+
visual_embeds: Optional[torch.FloatTensor] = None,
|
| 885 |
+
visual_attention_mask: Optional[torch.LongTensor] = None,
|
| 886 |
+
visual_token_type_ids: Optional[torch.LongTensor] = None,
|
| 887 |
+
image_text_alignment: Optional[torch.LongTensor] = None,
|
| 888 |
+
output_attentions: Optional[bool] = None,
|
| 889 |
+
output_hidden_states: Optional[bool] = None,
|
| 890 |
+
return_dict: Optional[bool] = None,
|
| 891 |
+
labels: Optional[torch.LongTensor] = None,
|
| 892 |
+
sentence_image_labels: Optional[torch.LongTensor] = None,
|
| 893 |
+
) -> Union[Tuple[torch.Tensor], VisualBertForPreTrainingOutput]:
|
| 894 |
+
r"""
|
| 895 |
+
labels (`torch.LongTensor` of shape `(batch_size, total_sequence_length)`, *optional*):
|
| 896 |
+
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
|
| 897 |
+
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
|
| 898 |
+
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
|
| 899 |
+
sentence_image_labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 900 |
+
Labels for computing the sentence-image prediction (classification) loss. Input should be a sequence pair
|
| 901 |
+
(see `input_ids` docstring) Indices should be in `[0, 1]`:
|
| 902 |
+
|
| 903 |
+
- 0 indicates sequence B is a matching pair of sequence A for the given image,
|
| 904 |
+
- 1 indicates sequence B is a random sequence w.r.t A for the given image.
|
| 905 |
+
|
| 906 |
+
Returns:
|
| 907 |
+
|
| 908 |
+
Example:
|
| 909 |
+
|
| 910 |
+
```python
|
| 911 |
+
# Assumption: *get_visual_embeddings(image)* gets the visual embeddings of the image in the batch.
|
| 912 |
+
from transformers import AutoTokenizer, VisualBertForPreTraining
|
| 913 |
+
|
| 914 |
+
tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
|
| 915 |
+
model = VisualBertForPreTraining.from_pretrained("uclanlp/visualbert-vqa-coco-pre")
|
| 916 |
+
|
| 917 |
+
inputs = tokenizer("The capital of France is [MASK].", return_tensors="pt")
|
| 918 |
+
visual_embeds = get_visual_embeddings(image).unsqueeze(0)
|
| 919 |
+
visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long)
|
| 920 |
+
visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)
|
| 921 |
+
|
| 922 |
+
inputs.update(
|
| 923 |
+
{
|
| 924 |
+
"visual_embeds": visual_embeds,
|
| 925 |
+
"visual_token_type_ids": visual_token_type_ids,
|
| 926 |
+
"visual_attention_mask": visual_attention_mask,
|
| 927 |
+
}
|
| 928 |
+
)
|
| 929 |
+
max_length = inputs["input_ids"].shape[-1] + visual_embeds.shape[-2]
|
| 930 |
+
labels = tokenizer(
|
| 931 |
+
"The capital of France is Paris.", return_tensors="pt", padding="max_length", max_length=max_length
|
| 932 |
+
)["input_ids"]
|
| 933 |
+
sentence_image_labels = torch.tensor(1).unsqueeze(0) # Batch_size
|
| 934 |
+
|
| 935 |
+
|
| 936 |
+
outputs = model(**inputs, labels=labels, sentence_image_labels=sentence_image_labels)
|
| 937 |
+
loss = outputs.loss
|
| 938 |
+
prediction_logits = outputs.prediction_logits
|
| 939 |
+
seq_relationship_logits = outputs.seq_relationship_logits
|
| 940 |
+
```"""
|
| 941 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 942 |
+
|
| 943 |
+
if labels is not None:
|
| 944 |
+
total_size = attention_mask.size(-1) + visual_attention_mask.size(-1)
|
| 945 |
+
if labels.size(-1) != total_size:
|
| 946 |
+
raise ValueError(
|
| 947 |
+
"The labels provided should have same sequence length as total attention mask. "
|
| 948 |
+
f"Found labels with sequence length {labels.size(-1)}, expected {total_size}."
|
| 949 |
+
)
|
| 950 |
+
|
| 951 |
+
outputs = self.visual_bert(
|
| 952 |
+
input_ids,
|
| 953 |
+
attention_mask=attention_mask,
|
| 954 |
+
token_type_ids=token_type_ids,
|
| 955 |
+
position_ids=position_ids,
|
| 956 |
+
head_mask=head_mask,
|
| 957 |
+
inputs_embeds=inputs_embeds,
|
| 958 |
+
visual_embeds=visual_embeds,
|
| 959 |
+
visual_attention_mask=visual_attention_mask,
|
| 960 |
+
visual_token_type_ids=visual_token_type_ids,
|
| 961 |
+
image_text_alignment=image_text_alignment,
|
| 962 |
+
output_attentions=output_attentions,
|
| 963 |
+
output_hidden_states=output_hidden_states,
|
| 964 |
+
return_dict=return_dict,
|
| 965 |
+
)
|
| 966 |
+
|
| 967 |
+
sequence_output, pooled_output = outputs[:2]
|
| 968 |
+
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
|
| 969 |
+
|
| 970 |
+
total_loss = None
|
| 971 |
+
if labels is not None and sentence_image_labels is not None:
|
| 972 |
+
loss_fct = CrossEntropyLoss()
|
| 973 |
+
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
| 974 |
+
sentence_image_loss = loss_fct(seq_relationship_score.view(-1, 2), sentence_image_labels.view(-1))
|
| 975 |
+
total_loss = masked_lm_loss + sentence_image_loss
|
| 976 |
+
|
| 977 |
+
elif labels is not None:
|
| 978 |
+
loss_fct = CrossEntropyLoss()
|
| 979 |
+
total_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
| 980 |
+
|
| 981 |
+
if not return_dict:
|
| 982 |
+
output = (prediction_scores, seq_relationship_score) + outputs[2:]
|
| 983 |
+
return ((total_loss,) + output) if total_loss is not None else output
|
| 984 |
+
|
| 985 |
+
return VisualBertForPreTrainingOutput(
|
| 986 |
+
loss=total_loss,
|
| 987 |
+
prediction_logits=prediction_scores,
|
| 988 |
+
seq_relationship_logits=seq_relationship_score,
|
| 989 |
+
hidden_states=outputs.hidden_states,
|
| 990 |
+
attentions=outputs.attentions,
|
| 991 |
+
)
|
| 992 |
+
|
| 993 |
+
|
| 994 |
+
@add_start_docstrings(
|
| 995 |
+
"""
|
| 996 |
+
VisualBert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and
|
| 997 |
+
a softmax) e.g. for VCR tasks.
|
| 998 |
+
""",
|
| 999 |
+
VISUAL_BERT_START_DOCSTRING,
|
| 1000 |
+
)
|
| 1001 |
+
class VisualBertForMultipleChoice(VisualBertPreTrainedModel):
|
| 1002 |
+
def __init__(self, config):
|
| 1003 |
+
super().__init__(config)
|
| 1004 |
+
|
| 1005 |
+
self.visual_bert = VisualBertModel(config)
|
| 1006 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 1007 |
+
self.cls = nn.Linear(config.hidden_size, 1)
|
| 1008 |
+
|
| 1009 |
+
# Initialize weights and apply final processing
|
| 1010 |
+
self.post_init()
|
| 1011 |
+
|
| 1012 |
+
@add_start_docstrings_to_model_forward(
|
| 1013 |
+
VISUAL_BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
|
| 1014 |
+
)
|
| 1015 |
+
@replace_return_docstrings(output_type=MultipleChoiceModelOutput, config_class=_CONFIG_FOR_DOC)
|
| 1016 |
+
def forward(
|
| 1017 |
+
self,
|
| 1018 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1019 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 1020 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 1021 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1022 |
+
head_mask: Optional[torch.LongTensor] = None,
|
| 1023 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1024 |
+
visual_embeds: Optional[torch.FloatTensor] = None,
|
| 1025 |
+
visual_attention_mask: Optional[torch.LongTensor] = None,
|
| 1026 |
+
visual_token_type_ids: Optional[torch.LongTensor] = None,
|
| 1027 |
+
image_text_alignment: Optional[torch.LongTensor] = None,
|
| 1028 |
+
output_attentions: Optional[bool] = None,
|
| 1029 |
+
output_hidden_states: Optional[bool] = None,
|
| 1030 |
+
return_dict: Optional[bool] = None,
|
| 1031 |
+
labels: Optional[torch.LongTensor] = None,
|
| 1032 |
+
) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]:
|
| 1033 |
+
r"""
|
| 1034 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1035 |
+
Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
|
| 1036 |
+
num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
|
| 1037 |
+
`input_ids` above)
|
| 1038 |
+
|
| 1039 |
+
Returns:
|
| 1040 |
+
|
| 1041 |
+
Example:
|
| 1042 |
+
|
| 1043 |
+
```python
|
| 1044 |
+
# Assumption: *get_visual_embeddings(image)* gets the visual embeddings of the image in the batch.
|
| 1045 |
+
from transformers import AutoTokenizer, VisualBertForMultipleChoice
|
| 1046 |
+
import torch
|
| 1047 |
+
|
| 1048 |
+
tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
|
| 1049 |
+
model = VisualBertForMultipleChoice.from_pretrained("uclanlp/visualbert-vcr")
|
| 1050 |
+
|
| 1051 |
+
prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
|
| 1052 |
+
choice0 = "It is eaten with a fork and a knife."
|
| 1053 |
+
choice1 = "It is eaten while held in the hand."
|
| 1054 |
+
|
| 1055 |
+
visual_embeds = get_visual_embeddings(image)
|
| 1056 |
+
# (batch_size, num_choices, visual_seq_length, visual_embedding_dim)
|
| 1057 |
+
visual_embeds = visual_embeds.expand(1, 2, *visual_embeds.shape)
|
| 1058 |
+
visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long)
|
| 1059 |
+
visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)
|
| 1060 |
+
|
| 1061 |
+
labels = torch.tensor(0).unsqueeze(0) # choice0 is correct (according to Wikipedia ;)), batch size 1
|
| 1062 |
+
|
| 1063 |
+
encoding = tokenizer([[prompt, prompt], [choice0, choice1]], return_tensors="pt", padding=True)
|
| 1064 |
+
# batch size is 1
|
| 1065 |
+
inputs_dict = {k: v.unsqueeze(0) for k, v in encoding.items()}
|
| 1066 |
+
inputs_dict.update(
|
| 1067 |
+
{
|
| 1068 |
+
"visual_embeds": visual_embeds,
|
| 1069 |
+
"visual_attention_mask": visual_attention_mask,
|
| 1070 |
+
"visual_token_type_ids": visual_token_type_ids,
|
| 1071 |
+
"labels": labels,
|
| 1072 |
+
}
|
| 1073 |
+
)
|
| 1074 |
+
outputs = model(**inputs_dict)
|
| 1075 |
+
|
| 1076 |
+
loss = outputs.loss
|
| 1077 |
+
logits = outputs.logits
|
| 1078 |
+
```"""
|
| 1079 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1080 |
+
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
|
| 1081 |
+
|
| 1082 |
+
input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
|
| 1083 |
+
attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
|
| 1084 |
+
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
|
| 1085 |
+
position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
|
| 1086 |
+
inputs_embeds = (
|
| 1087 |
+
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
|
| 1088 |
+
if inputs_embeds is not None
|
| 1089 |
+
else None
|
| 1090 |
+
)
|
| 1091 |
+
|
| 1092 |
+
visual_embeds = (
|
| 1093 |
+
visual_embeds.view(-1, visual_embeds.size(-2), visual_embeds.size(-1))
|
| 1094 |
+
if visual_embeds is not None
|
| 1095 |
+
else None
|
| 1096 |
+
)
|
| 1097 |
+
visual_attention_mask = (
|
| 1098 |
+
visual_attention_mask.view(-1, visual_attention_mask.size(-1))
|
| 1099 |
+
if visual_attention_mask is not None
|
| 1100 |
+
else None
|
| 1101 |
+
)
|
| 1102 |
+
visual_token_type_ids = (
|
| 1103 |
+
visual_token_type_ids.view(-1, visual_token_type_ids.size(-1))
|
| 1104 |
+
if visual_token_type_ids is not None
|
| 1105 |
+
else None
|
| 1106 |
+
)
|
| 1107 |
+
|
| 1108 |
+
outputs = self.visual_bert(
|
| 1109 |
+
input_ids,
|
| 1110 |
+
attention_mask=attention_mask,
|
| 1111 |
+
token_type_ids=token_type_ids,
|
| 1112 |
+
position_ids=position_ids,
|
| 1113 |
+
head_mask=head_mask,
|
| 1114 |
+
inputs_embeds=inputs_embeds,
|
| 1115 |
+
visual_embeds=visual_embeds,
|
| 1116 |
+
visual_attention_mask=visual_attention_mask,
|
| 1117 |
+
visual_token_type_ids=visual_token_type_ids,
|
| 1118 |
+
image_text_alignment=image_text_alignment,
|
| 1119 |
+
output_attentions=output_attentions,
|
| 1120 |
+
output_hidden_states=output_hidden_states,
|
| 1121 |
+
return_dict=return_dict,
|
| 1122 |
+
)
|
| 1123 |
+
|
| 1124 |
+
_, pooled_output = outputs[0], outputs[1]
|
| 1125 |
+
|
| 1126 |
+
pooled_output = self.dropout(pooled_output)
|
| 1127 |
+
logits = self.cls(pooled_output)
|
| 1128 |
+
reshaped_logits = logits.view(-1, num_choices)
|
| 1129 |
+
|
| 1130 |
+
loss = None
|
| 1131 |
+
if labels is not None:
|
| 1132 |
+
loss_fct = CrossEntropyLoss()
|
| 1133 |
+
loss = loss_fct(reshaped_logits, labels)
|
| 1134 |
+
|
| 1135 |
+
if not return_dict:
|
| 1136 |
+
output = (reshaped_logits,) + outputs[2:]
|
| 1137 |
+
return ((loss,) + output) if loss is not None else output
|
| 1138 |
+
|
| 1139 |
+
return MultipleChoiceModelOutput(
|
| 1140 |
+
loss=loss,
|
| 1141 |
+
logits=reshaped_logits,
|
| 1142 |
+
hidden_states=outputs.hidden_states,
|
| 1143 |
+
attentions=outputs.attentions,
|
| 1144 |
+
)
|
| 1145 |
+
|
| 1146 |
+
|
| 1147 |
+
@add_start_docstrings(
|
| 1148 |
+
"""
|
| 1149 |
+
VisualBert Model with a classification/regression head on top (a dropout and a linear layer on top of the pooled
|
| 1150 |
+
output) for VQA.
|
| 1151 |
+
""",
|
| 1152 |
+
VISUAL_BERT_START_DOCSTRING,
|
| 1153 |
+
)
|
| 1154 |
+
class VisualBertForQuestionAnswering(VisualBertPreTrainedModel):
|
| 1155 |
+
def __init__(self, config):
|
| 1156 |
+
super().__init__(config)
|
| 1157 |
+
self.num_labels = config.num_labels
|
| 1158 |
+
|
| 1159 |
+
self.visual_bert = VisualBertModel(config)
|
| 1160 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 1161 |
+
self.cls = nn.Linear(config.hidden_size, config.num_labels)
|
| 1162 |
+
|
| 1163 |
+
# Initialize weights and apply final processing
|
| 1164 |
+
self.post_init()
|
| 1165 |
+
|
| 1166 |
+
@add_start_docstrings_to_model_forward(VISUAL_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 1167 |
+
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
|
| 1168 |
+
def forward(
|
| 1169 |
+
self,
|
| 1170 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1171 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 1172 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 1173 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1174 |
+
head_mask: Optional[torch.LongTensor] = None,
|
| 1175 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1176 |
+
visual_embeds: Optional[torch.FloatTensor] = None,
|
| 1177 |
+
visual_attention_mask: Optional[torch.LongTensor] = None,
|
| 1178 |
+
visual_token_type_ids: Optional[torch.LongTensor] = None,
|
| 1179 |
+
image_text_alignment: Optional[torch.LongTensor] = None,
|
| 1180 |
+
output_attentions: Optional[bool] = None,
|
| 1181 |
+
output_hidden_states: Optional[bool] = None,
|
| 1182 |
+
return_dict: Optional[bool] = None,
|
| 1183 |
+
labels: Optional[torch.LongTensor] = None,
|
| 1184 |
+
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
|
| 1185 |
+
r"""
|
| 1186 |
+
labels (`torch.LongTensor` of shape `(batch_size, total_sequence_length)`, *optional*):
|
| 1187 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 1188 |
+
config.num_labels - 1]`. A KLDivLoss is computed between the labels and the returned logits.
|
| 1189 |
+
|
| 1190 |
+
Returns:
|
| 1191 |
+
|
| 1192 |
+
Example:
|
| 1193 |
+
|
| 1194 |
+
```python
|
| 1195 |
+
# Assumption: *get_visual_embeddings(image)* gets the visual embeddings of the image in the batch.
|
| 1196 |
+
from transformers import AutoTokenizer, VisualBertForQuestionAnswering
|
| 1197 |
+
import torch
|
| 1198 |
+
|
| 1199 |
+
tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
|
| 1200 |
+
model = VisualBertForQuestionAnswering.from_pretrained("uclanlp/visualbert-vqa")
|
| 1201 |
+
|
| 1202 |
+
text = "Who is eating the apple?"
|
| 1203 |
+
inputs = tokenizer(text, return_tensors="pt")
|
| 1204 |
+
visual_embeds = get_visual_embeddings(image).unsqueeze(0)
|
| 1205 |
+
visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long)
|
| 1206 |
+
visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)
|
| 1207 |
+
|
| 1208 |
+
inputs.update(
|
| 1209 |
+
{
|
| 1210 |
+
"visual_embeds": visual_embeds,
|
| 1211 |
+
"visual_token_type_ids": visual_token_type_ids,
|
| 1212 |
+
"visual_attention_mask": visual_attention_mask,
|
| 1213 |
+
}
|
| 1214 |
+
)
|
| 1215 |
+
|
| 1216 |
+
labels = torch.tensor([[0.0, 1.0]]).unsqueeze(0) # Batch size 1, Num labels 2
|
| 1217 |
+
|
| 1218 |
+
outputs = model(**inputs, labels=labels)
|
| 1219 |
+
loss = outputs.loss
|
| 1220 |
+
scores = outputs.logits
|
| 1221 |
+
```"""
|
| 1222 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1223 |
+
|
| 1224 |
+
# Get the index of the last text token
|
| 1225 |
+
index_to_gather = attention_mask.sum(1) - 2 # as in original code
|
| 1226 |
+
|
| 1227 |
+
outputs = self.visual_bert(
|
| 1228 |
+
input_ids,
|
| 1229 |
+
attention_mask=attention_mask,
|
| 1230 |
+
token_type_ids=token_type_ids,
|
| 1231 |
+
position_ids=position_ids,
|
| 1232 |
+
head_mask=head_mask,
|
| 1233 |
+
inputs_embeds=inputs_embeds,
|
| 1234 |
+
visual_embeds=visual_embeds,
|
| 1235 |
+
visual_attention_mask=visual_attention_mask,
|
| 1236 |
+
visual_token_type_ids=visual_token_type_ids,
|
| 1237 |
+
image_text_alignment=image_text_alignment,
|
| 1238 |
+
output_attentions=output_attentions,
|
| 1239 |
+
output_hidden_states=output_hidden_states,
|
| 1240 |
+
return_dict=return_dict,
|
| 1241 |
+
)
|
| 1242 |
+
|
| 1243 |
+
sequence_output = outputs[0]
|
| 1244 |
+
|
| 1245 |
+
# TO-CHECK: From the original code
|
| 1246 |
+
index_to_gather = (
|
| 1247 |
+
index_to_gather.unsqueeze(-1).unsqueeze(-1).expand(index_to_gather.size(0), 1, sequence_output.size(-1))
|
| 1248 |
+
)
|
| 1249 |
+
pooled_output = torch.gather(sequence_output, 1, index_to_gather)
|
| 1250 |
+
|
| 1251 |
+
pooled_output = self.dropout(pooled_output)
|
| 1252 |
+
logits = self.cls(pooled_output)
|
| 1253 |
+
reshaped_logits = logits.view(-1, self.num_labels)
|
| 1254 |
+
|
| 1255 |
+
loss = None
|
| 1256 |
+
if labels is not None:
|
| 1257 |
+
loss_fct = nn.KLDivLoss(reduction="batchmean")
|
| 1258 |
+
log_softmax = nn.LogSoftmax(dim=-1)
|
| 1259 |
+
reshaped_logits = log_softmax(reshaped_logits)
|
| 1260 |
+
loss = loss_fct(reshaped_logits, labels.contiguous())
|
| 1261 |
+
if not return_dict:
|
| 1262 |
+
output = (reshaped_logits,) + outputs[2:]
|
| 1263 |
+
return ((loss,) + output) if loss is not None else output
|
| 1264 |
+
|
| 1265 |
+
return SequenceClassifierOutput(
|
| 1266 |
+
loss=loss,
|
| 1267 |
+
logits=reshaped_logits,
|
| 1268 |
+
hidden_states=outputs.hidden_states,
|
| 1269 |
+
attentions=outputs.attentions,
|
| 1270 |
+
)
|
| 1271 |
+
|
| 1272 |
+
|
| 1273 |
+
@add_start_docstrings(
|
| 1274 |
+
"""
|
| 1275 |
+
VisualBert Model with a sequence classification head on top (a dropout and a linear layer on top of the pooled
|
| 1276 |
+
output) for Visual Reasoning e.g. for NLVR task.
|
| 1277 |
+
""",
|
| 1278 |
+
VISUAL_BERT_START_DOCSTRING,
|
| 1279 |
+
)
|
| 1280 |
+
class VisualBertForVisualReasoning(VisualBertPreTrainedModel):
|
| 1281 |
+
def __init__(self, config):
|
| 1282 |
+
super().__init__(config)
|
| 1283 |
+
self.num_labels = config.num_labels
|
| 1284 |
+
|
| 1285 |
+
self.visual_bert = VisualBertModel(config)
|
| 1286 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 1287 |
+
self.cls = nn.Linear(config.hidden_size, config.num_labels) # 2
|
| 1288 |
+
|
| 1289 |
+
# Initialize weights and apply final processing
|
| 1290 |
+
self.post_init()
|
| 1291 |
+
|
| 1292 |
+
@add_start_docstrings_to_model_forward(VISUAL_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 1293 |
+
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
|
| 1294 |
+
def forward(
|
| 1295 |
+
self,
|
| 1296 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1297 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 1298 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 1299 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1300 |
+
head_mask: Optional[torch.LongTensor] = None,
|
| 1301 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1302 |
+
visual_embeds: Optional[torch.FloatTensor] = None,
|
| 1303 |
+
visual_attention_mask: Optional[torch.LongTensor] = None,
|
| 1304 |
+
visual_token_type_ids: Optional[torch.LongTensor] = None,
|
| 1305 |
+
image_text_alignment: Optional[torch.LongTensor] = None,
|
| 1306 |
+
output_attentions: Optional[bool] = None,
|
| 1307 |
+
output_hidden_states: Optional[bool] = None,
|
| 1308 |
+
return_dict: Optional[bool] = None,
|
| 1309 |
+
labels: Optional[torch.LongTensor] = None,
|
| 1310 |
+
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
|
| 1311 |
+
r"""
|
| 1312 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1313 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 1314 |
+
config.num_labels - 1]`. A classification loss is computed (Cross-Entropy) against these labels.
|
| 1315 |
+
|
| 1316 |
+
Returns:
|
| 1317 |
+
|
| 1318 |
+
Example:
|
| 1319 |
+
|
| 1320 |
+
```python
|
| 1321 |
+
# Assumption: *get_visual_embeddings(image)* gets the visual embeddings of the image in the batch.
|
| 1322 |
+
from transformers import AutoTokenizer, VisualBertForVisualReasoning
|
| 1323 |
+
import torch
|
| 1324 |
+
|
| 1325 |
+
tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
|
| 1326 |
+
model = VisualBertForVisualReasoning.from_pretrained("uclanlp/visualbert-nlvr2")
|
| 1327 |
+
|
| 1328 |
+
text = "Who is eating the apple?"
|
| 1329 |
+
inputs = tokenizer(text, return_tensors="pt")
|
| 1330 |
+
visual_embeds = get_visual_embeddings(image).unsqueeze(0)
|
| 1331 |
+
visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long)
|
| 1332 |
+
visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)
|
| 1333 |
+
|
| 1334 |
+
inputs.update(
|
| 1335 |
+
{
|
| 1336 |
+
"visual_embeds": visual_embeds,
|
| 1337 |
+
"visual_token_type_ids": visual_token_type_ids,
|
| 1338 |
+
"visual_attention_mask": visual_attention_mask,
|
| 1339 |
+
}
|
| 1340 |
+
)
|
| 1341 |
+
|
| 1342 |
+
labels = torch.tensor(1).unsqueeze(0) # Batch size 1, Num choices 2
|
| 1343 |
+
|
| 1344 |
+
outputs = model(**inputs, labels=labels)
|
| 1345 |
+
loss = outputs.loss
|
| 1346 |
+
scores = outputs.logits
|
| 1347 |
+
```"""
|
| 1348 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1349 |
+
|
| 1350 |
+
outputs = self.visual_bert(
|
| 1351 |
+
input_ids,
|
| 1352 |
+
attention_mask=attention_mask,
|
| 1353 |
+
token_type_ids=token_type_ids,
|
| 1354 |
+
position_ids=position_ids,
|
| 1355 |
+
head_mask=head_mask,
|
| 1356 |
+
inputs_embeds=inputs_embeds,
|
| 1357 |
+
visual_embeds=visual_embeds,
|
| 1358 |
+
visual_attention_mask=visual_attention_mask,
|
| 1359 |
+
visual_token_type_ids=visual_token_type_ids,
|
| 1360 |
+
image_text_alignment=image_text_alignment,
|
| 1361 |
+
output_attentions=output_attentions,
|
| 1362 |
+
output_hidden_states=output_hidden_states,
|
| 1363 |
+
return_dict=return_dict,
|
| 1364 |
+
)
|
| 1365 |
+
|
| 1366 |
+
# sequence_output = outputs[0]
|
| 1367 |
+
pooled_output = outputs[1]
|
| 1368 |
+
pooled_output = self.dropout(pooled_output)
|
| 1369 |
+
logits = self.cls(pooled_output)
|
| 1370 |
+
reshaped_logits = logits.contiguous()
|
| 1371 |
+
|
| 1372 |
+
loss = None
|
| 1373 |
+
if labels is not None:
|
| 1374 |
+
loss_fct = CrossEntropyLoss()
|
| 1375 |
+
loss = loss_fct(reshaped_logits, labels.view(-1))
|
| 1376 |
+
|
| 1377 |
+
if not return_dict:
|
| 1378 |
+
output = (logits,) + outputs[2:]
|
| 1379 |
+
return ((loss,) + output) if loss is not None else output
|
| 1380 |
+
|
| 1381 |
+
return SequenceClassifierOutput(
|
| 1382 |
+
loss=loss,
|
| 1383 |
+
logits=reshaped_logits,
|
| 1384 |
+
hidden_states=outputs.hidden_states,
|
| 1385 |
+
attentions=outputs.attentions,
|
| 1386 |
+
)
|
| 1387 |
+
|
| 1388 |
+
|
| 1389 |
+
class VisualBertRegionToPhraseAttention(nn.Module):
|
| 1390 |
+
def __init__(self, config):
|
| 1391 |
+
super().__init__()
|
| 1392 |
+
if config.hidden_size % config.num_attention_heads != 0:
|
| 1393 |
+
raise ValueError(
|
| 1394 |
+
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
|
| 1395 |
+
f"heads ({config.num_attention_heads})"
|
| 1396 |
+
)
|
| 1397 |
+
self.num_attention_heads = 1 # config.num_attention_heads
|
| 1398 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
| 1399 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
| 1400 |
+
|
| 1401 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
| 1402 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
| 1403 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
| 1404 |
+
|
| 1405 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
| 1406 |
+
|
| 1407 |
+
def transpose_for_scores(self, x):
|
| 1408 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
| 1409 |
+
x = x.view(*new_x_shape)
|
| 1410 |
+
return x.permute(0, 2, 1, 3)
|
| 1411 |
+
|
| 1412 |
+
def forward(self, query, key, attention_mask):
|
| 1413 |
+
attention_mask = attention_mask.to(query.dtype)
|
| 1414 |
+
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
| 1415 |
+
attention_mask = (1.0 - attention_mask) * torch.finfo(query.dtype).min
|
| 1416 |
+
|
| 1417 |
+
mixed_query_layer = self.query(query)
|
| 1418 |
+
mixed_key_layer = self.key(key)
|
| 1419 |
+
|
| 1420 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
| 1421 |
+
key_layer = self.transpose_for_scores(mixed_key_layer)
|
| 1422 |
+
|
| 1423 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
| 1424 |
+
|
| 1425 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
| 1426 |
+
|
| 1427 |
+
attention_scores = attention_scores + attention_mask
|
| 1428 |
+
|
| 1429 |
+
attention_scores = attention_scores.squeeze(1)
|
| 1430 |
+
return attention_scores
|
| 1431 |
+
|
| 1432 |
+
|
| 1433 |
+
@add_start_docstrings(
|
| 1434 |
+
"""
|
| 1435 |
+
VisualBert Model with a Masked Language Modeling head and an attention layer on top for Region-to-Phrase Alignment
|
| 1436 |
+
e.g. for Flickr30 Entities task.
|
| 1437 |
+
""",
|
| 1438 |
+
VISUAL_BERT_START_DOCSTRING,
|
| 1439 |
+
)
|
| 1440 |
+
class VisualBertForRegionToPhraseAlignment(VisualBertPreTrainedModel):
|
| 1441 |
+
_tied_weights_keys = ["cls.predictions.decoder.bias"]
|
| 1442 |
+
|
| 1443 |
+
def __init__(self, config):
|
| 1444 |
+
super().__init__(config)
|
| 1445 |
+
|
| 1446 |
+
self.visual_bert = VisualBertModel(config)
|
| 1447 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 1448 |
+
self.cls = VisualBertPreTrainingHeads(config)
|
| 1449 |
+
self.attention = VisualBertRegionToPhraseAttention(config)
|
| 1450 |
+
|
| 1451 |
+
# Initialize weights and apply final processing
|
| 1452 |
+
self.post_init()
|
| 1453 |
+
|
| 1454 |
+
@add_start_docstrings_to_model_forward(VISUAL_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 1455 |
+
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
|
| 1456 |
+
def forward(
|
| 1457 |
+
self,
|
| 1458 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1459 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 1460 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 1461 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1462 |
+
head_mask: Optional[torch.LongTensor] = None,
|
| 1463 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1464 |
+
visual_embeds: Optional[torch.FloatTensor] = None,
|
| 1465 |
+
visual_attention_mask: Optional[torch.LongTensor] = None,
|
| 1466 |
+
visual_token_type_ids: Optional[torch.LongTensor] = None,
|
| 1467 |
+
image_text_alignment: Optional[torch.LongTensor] = None,
|
| 1468 |
+
output_attentions: Optional[bool] = None,
|
| 1469 |
+
output_hidden_states: Optional[bool] = None,
|
| 1470 |
+
return_dict: Optional[bool] = None,
|
| 1471 |
+
region_to_phrase_position: Optional[torch.LongTensor] = None,
|
| 1472 |
+
labels: Optional[torch.LongTensor] = None,
|
| 1473 |
+
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
|
| 1474 |
+
r"""
|
| 1475 |
+
region_to_phrase_position (`torch.LongTensor` of shape `(batch_size, total_sequence_length)`, *optional*):
|
| 1476 |
+
The positions depicting the position of the image embedding corresponding to the textual tokens.
|
| 1477 |
+
|
| 1478 |
+
labels (`torch.LongTensor` of shape `(batch_size, total_sequence_length, visual_sequence_length)`, *optional*):
|
| 1479 |
+
Labels for computing the masked language modeling loss. KLDivLoss is computed against these labels and the
|
| 1480 |
+
outputs from the attention layer.
|
| 1481 |
+
|
| 1482 |
+
Returns:
|
| 1483 |
+
|
| 1484 |
+
Example:
|
| 1485 |
+
|
| 1486 |
+
```python
|
| 1487 |
+
# Assumption: *get_visual_embeddings(image)* gets the visual embeddings of the image in the batch.
|
| 1488 |
+
from transformers import AutoTokenizer, VisualBertForRegionToPhraseAlignment
|
| 1489 |
+
import torch
|
| 1490 |
+
|
| 1491 |
+
tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
|
| 1492 |
+
model = VisualBertForRegionToPhraseAlignment.from_pretrained("uclanlp/visualbert-vqa-coco-pre")
|
| 1493 |
+
|
| 1494 |
+
text = "Who is eating the apple?"
|
| 1495 |
+
inputs = tokenizer(text, return_tensors="pt")
|
| 1496 |
+
visual_embeds = get_visual_embeddings(image).unsqueeze(0)
|
| 1497 |
+
visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long)
|
| 1498 |
+
visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)
|
| 1499 |
+
region_to_phrase_position = torch.ones((1, inputs["input_ids"].shape[-1] + visual_embeds.shape[-2]))
|
| 1500 |
+
|
| 1501 |
+
inputs.update(
|
| 1502 |
+
{
|
| 1503 |
+
"region_to_phrase_position": region_to_phrase_position,
|
| 1504 |
+
"visual_embeds": visual_embeds,
|
| 1505 |
+
"visual_token_type_ids": visual_token_type_ids,
|
| 1506 |
+
"visual_attention_mask": visual_attention_mask,
|
| 1507 |
+
}
|
| 1508 |
+
)
|
| 1509 |
+
|
| 1510 |
+
labels = torch.ones(
|
| 1511 |
+
(1, inputs["input_ids"].shape[-1] + visual_embeds.shape[-2], visual_embeds.shape[-2])
|
| 1512 |
+
) # Batch size 1
|
| 1513 |
+
|
| 1514 |
+
outputs = model(**inputs, labels=labels)
|
| 1515 |
+
loss = outputs.loss
|
| 1516 |
+
scores = outputs.logits
|
| 1517 |
+
```"""
|
| 1518 |
+
if region_to_phrase_position is None:
|
| 1519 |
+
raise ValueError("`region_to_phrase_position` should not be None when using Flickr Model.")
|
| 1520 |
+
|
| 1521 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1522 |
+
|
| 1523 |
+
outputs = self.visual_bert(
|
| 1524 |
+
input_ids,
|
| 1525 |
+
attention_mask=attention_mask,
|
| 1526 |
+
token_type_ids=token_type_ids,
|
| 1527 |
+
position_ids=position_ids,
|
| 1528 |
+
head_mask=head_mask,
|
| 1529 |
+
inputs_embeds=inputs_embeds,
|
| 1530 |
+
visual_embeds=visual_embeds,
|
| 1531 |
+
visual_attention_mask=visual_attention_mask,
|
| 1532 |
+
visual_token_type_ids=visual_token_type_ids,
|
| 1533 |
+
image_text_alignment=image_text_alignment,
|
| 1534 |
+
output_attentions=output_attentions,
|
| 1535 |
+
output_hidden_states=output_hidden_states,
|
| 1536 |
+
return_dict=return_dict,
|
| 1537 |
+
)
|
| 1538 |
+
|
| 1539 |
+
sequence_output = outputs[0]
|
| 1540 |
+
|
| 1541 |
+
region_to_phrase_position_mask = (region_to_phrase_position != -1).long()
|
| 1542 |
+
|
| 1543 |
+
# Make the -1 become 0
|
| 1544 |
+
region_to_phrase_position = region_to_phrase_position * region_to_phrase_position_mask
|
| 1545 |
+
|
| 1546 |
+
# Selected_positions = batch x selected position x dim
|
| 1547 |
+
expanded_region_to_phrase_positions = region_to_phrase_position.unsqueeze(2).expand(
|
| 1548 |
+
region_to_phrase_position.size(0), region_to_phrase_position.size(1), sequence_output.size(2)
|
| 1549 |
+
)
|
| 1550 |
+
selected_positions = sequence_output.gather(1, expanded_region_to_phrase_positions)
|
| 1551 |
+
|
| 1552 |
+
# Visual Features = batch x visual_feature_length x dim
|
| 1553 |
+
# This will need separate image and visual masks.
|
| 1554 |
+
visual_features = sequence_output[:, attention_mask.size(1) :]
|
| 1555 |
+
|
| 1556 |
+
if visual_features.size(1) != visual_attention_mask.size(1):
|
| 1557 |
+
raise ValueError(
|
| 1558 |
+
f"Visual features length :{visual_features.size(1)} should be the same"
|
| 1559 |
+
f" as visual attention mask length: {visual_attention_mask.size(1)}."
|
| 1560 |
+
)
|
| 1561 |
+
|
| 1562 |
+
logits = self.attention(selected_positions, visual_features, visual_attention_mask)
|
| 1563 |
+
|
| 1564 |
+
loss = None
|
| 1565 |
+
|
| 1566 |
+
if labels is not None:
|
| 1567 |
+
# scores = batch x selected position x visual_feature
|
| 1568 |
+
# scores = selected_positions.bmm(visual_features.transpose(1,2))
|
| 1569 |
+
# label = batch x selected_postion x needed position
|
| 1570 |
+
loss_fct = KLDivLoss(reduction="batchmean")
|
| 1571 |
+
log_softmax = LogSoftmax(dim=-1)
|
| 1572 |
+
scores = log_softmax(logits)
|
| 1573 |
+
labels = labels.contiguous()
|
| 1574 |
+
loss = loss_fct(scores, labels)
|
| 1575 |
+
|
| 1576 |
+
if not return_dict:
|
| 1577 |
+
output = (logits,) + outputs[2:]
|
| 1578 |
+
return ((loss,) + output) if loss is not None else output
|
| 1579 |
+
|
| 1580 |
+
return SequenceClassifierOutput(
|
| 1581 |
+
loss=loss,
|
| 1582 |
+
logits=logits,
|
| 1583 |
+
hidden_states=outputs.hidden_states,
|
| 1584 |
+
attentions=outputs.attentions,
|
| 1585 |
+
)
|
| 1586 |
+
|
| 1587 |
+
|
| 1588 |
+
__all__ = [
|
| 1589 |
+
"VisualBertForMultipleChoice",
|
| 1590 |
+
"VisualBertForPreTraining",
|
| 1591 |
+
"VisualBertForQuestionAnswering",
|
| 1592 |
+
"VisualBertForRegionToPhraseAlignment",
|
| 1593 |
+
"VisualBertForVisualReasoning",
|
| 1594 |
+
"VisualBertLayer",
|
| 1595 |
+
"VisualBertModel",
|
| 1596 |
+
"VisualBertPreTrainedModel",
|
| 1597 |
+
]
|
docs/transformers/build/lib/transformers/models/vit/__init__.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_vit import *
|
| 22 |
+
from .feature_extraction_vit import *
|
| 23 |
+
from .image_processing_vit import *
|
| 24 |
+
from .image_processing_vit_fast import *
|
| 25 |
+
from .modeling_flax_vit import *
|
| 26 |
+
from .modeling_tf_vit import *
|
| 27 |
+
from .modeling_vit import *
|
| 28 |
+
else:
|
| 29 |
+
import sys
|
| 30 |
+
|
| 31 |
+
_file = globals()["__file__"]
|
| 32 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
docs/transformers/build/lib/transformers/models/vit/configuration_vit.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2021 Google AI and The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""ViT model configuration"""
|
| 16 |
+
|
| 17 |
+
from collections import OrderedDict
|
| 18 |
+
from typing import Mapping
|
| 19 |
+
|
| 20 |
+
from packaging import version
|
| 21 |
+
|
| 22 |
+
from ...configuration_utils import PretrainedConfig
|
| 23 |
+
from ...onnx import OnnxConfig
|
| 24 |
+
from ...utils import logging
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
logger = logging.get_logger(__name__)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class ViTConfig(PretrainedConfig):
|
| 31 |
+
r"""
|
| 32 |
+
This is the configuration class to store the configuration of a [`ViTModel`]. It is used to instantiate an ViT
|
| 33 |
+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
| 34 |
+
defaults will yield a similar configuration to that of the ViT
|
| 35 |
+
[google/vit-base-patch16-224](https://huggingface.co/google/vit-base-patch16-224) architecture.
|
| 36 |
+
|
| 37 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 38 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
hidden_size (`int`, *optional*, defaults to 768):
|
| 43 |
+
Dimensionality of the encoder layers and the pooler layer.
|
| 44 |
+
num_hidden_layers (`int`, *optional*, defaults to 12):
|
| 45 |
+
Number of hidden layers in the Transformer encoder.
|
| 46 |
+
num_attention_heads (`int`, *optional*, defaults to 12):
|
| 47 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 48 |
+
intermediate_size (`int`, *optional*, defaults to 3072):
|
| 49 |
+
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
| 50 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
| 51 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 52 |
+
`"relu"`, `"selu"` and `"gelu_new"` are supported.
|
| 53 |
+
hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
|
| 54 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
| 55 |
+
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
|
| 56 |
+
The dropout ratio for the attention probabilities.
|
| 57 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 58 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 59 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
| 60 |
+
The epsilon used by the layer normalization layers.
|
| 61 |
+
image_size (`int`, *optional*, defaults to 224):
|
| 62 |
+
The size (resolution) of each image.
|
| 63 |
+
patch_size (`int`, *optional*, defaults to 16):
|
| 64 |
+
The size (resolution) of each patch.
|
| 65 |
+
num_channels (`int`, *optional*, defaults to 3):
|
| 66 |
+
The number of input channels.
|
| 67 |
+
qkv_bias (`bool`, *optional*, defaults to `True`):
|
| 68 |
+
Whether to add a bias to the queries, keys and values.
|
| 69 |
+
encoder_stride (`int`, *optional*, defaults to 16):
|
| 70 |
+
Factor to increase the spatial resolution by in the decoder head for masked image modeling.
|
| 71 |
+
pooler_output_size (`int`, *optional*):
|
| 72 |
+
Dimensionality of the pooler layer. If None, defaults to `hidden_size`.
|
| 73 |
+
pooler_act (`str`, *optional*, defaults to `"tanh"`):
|
| 74 |
+
The activation function to be used by the pooler. Keys of ACT2FN are supported for Flax and
|
| 75 |
+
Pytorch, and elements of https://www.tensorflow.org/api_docs/python/tf/keras/activations are
|
| 76 |
+
supported for Tensorflow.
|
| 77 |
+
|
| 78 |
+
Example:
|
| 79 |
+
|
| 80 |
+
```python
|
| 81 |
+
>>> from transformers import ViTConfig, ViTModel
|
| 82 |
+
|
| 83 |
+
>>> # Initializing a ViT vit-base-patch16-224 style configuration
|
| 84 |
+
>>> configuration = ViTConfig()
|
| 85 |
+
|
| 86 |
+
>>> # Initializing a model (with random weights) from the vit-base-patch16-224 style configuration
|
| 87 |
+
>>> model = ViTModel(configuration)
|
| 88 |
+
|
| 89 |
+
>>> # Accessing the model configuration
|
| 90 |
+
>>> configuration = model.config
|
| 91 |
+
```"""
|
| 92 |
+
|
| 93 |
+
model_type = "vit"
|
| 94 |
+
|
| 95 |
+
def __init__(
|
| 96 |
+
self,
|
| 97 |
+
hidden_size=768,
|
| 98 |
+
num_hidden_layers=12,
|
| 99 |
+
num_attention_heads=12,
|
| 100 |
+
intermediate_size=3072,
|
| 101 |
+
hidden_act="gelu",
|
| 102 |
+
hidden_dropout_prob=0.0,
|
| 103 |
+
attention_probs_dropout_prob=0.0,
|
| 104 |
+
initializer_range=0.02,
|
| 105 |
+
layer_norm_eps=1e-12,
|
| 106 |
+
image_size=224,
|
| 107 |
+
patch_size=16,
|
| 108 |
+
num_channels=3,
|
| 109 |
+
qkv_bias=True,
|
| 110 |
+
encoder_stride=16,
|
| 111 |
+
pooler_output_size=None,
|
| 112 |
+
pooler_act="tanh",
|
| 113 |
+
**kwargs,
|
| 114 |
+
):
|
| 115 |
+
super().__init__(**kwargs)
|
| 116 |
+
|
| 117 |
+
self.hidden_size = hidden_size
|
| 118 |
+
self.num_hidden_layers = num_hidden_layers
|
| 119 |
+
self.num_attention_heads = num_attention_heads
|
| 120 |
+
self.intermediate_size = intermediate_size
|
| 121 |
+
self.hidden_act = hidden_act
|
| 122 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
| 123 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
| 124 |
+
self.initializer_range = initializer_range
|
| 125 |
+
self.layer_norm_eps = layer_norm_eps
|
| 126 |
+
self.image_size = image_size
|
| 127 |
+
self.patch_size = patch_size
|
| 128 |
+
self.num_channels = num_channels
|
| 129 |
+
self.qkv_bias = qkv_bias
|
| 130 |
+
self.encoder_stride = encoder_stride
|
| 131 |
+
self.pooler_output_size = pooler_output_size if pooler_output_size else hidden_size
|
| 132 |
+
self.pooler_act = pooler_act
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class ViTOnnxConfig(OnnxConfig):
|
| 136 |
+
torch_onnx_minimum_version = version.parse("1.11")
|
| 137 |
+
|
| 138 |
+
@property
|
| 139 |
+
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
| 140 |
+
return OrderedDict(
|
| 141 |
+
[
|
| 142 |
+
("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
|
| 143 |
+
]
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
@property
|
| 147 |
+
def atol_for_validation(self) -> float:
|
| 148 |
+
return 1e-4
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
__all__ = ["ViTConfig", "ViTOnnxConfig"]
|
docs/transformers/build/lib/transformers/models/vit/convert_dino_to_pytorch.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2021 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Convert ViT checkpoints trained with the DINO method."""
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import json
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
|
| 21 |
+
import requests
|
| 22 |
+
import torch
|
| 23 |
+
from huggingface_hub import hf_hub_download
|
| 24 |
+
from PIL import Image
|
| 25 |
+
|
| 26 |
+
from transformers import ViTConfig, ViTForImageClassification, ViTImageProcessor, ViTModel
|
| 27 |
+
from transformers.utils import logging
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
logging.set_verbosity_info()
|
| 31 |
+
logger = logging.get_logger(__name__)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# here we list all keys to be renamed (original name on the left, our name on the right)
|
| 35 |
+
def create_rename_keys(config, base_model=False):
|
| 36 |
+
rename_keys = []
|
| 37 |
+
for i in range(config.num_hidden_layers):
|
| 38 |
+
# encoder layers: output projection, 2 feedforward neural networks and 2 layernorms
|
| 39 |
+
rename_keys.append((f"blocks.{i}.norm1.weight", f"vit.encoder.layer.{i}.layernorm_before.weight"))
|
| 40 |
+
rename_keys.append((f"blocks.{i}.norm1.bias", f"vit.encoder.layer.{i}.layernorm_before.bias"))
|
| 41 |
+
rename_keys.append((f"blocks.{i}.attn.proj.weight", f"vit.encoder.layer.{i}.attention.output.dense.weight"))
|
| 42 |
+
rename_keys.append((f"blocks.{i}.attn.proj.bias", f"vit.encoder.layer.{i}.attention.output.dense.bias"))
|
| 43 |
+
rename_keys.append((f"blocks.{i}.norm2.weight", f"vit.encoder.layer.{i}.layernorm_after.weight"))
|
| 44 |
+
rename_keys.append((f"blocks.{i}.norm2.bias", f"vit.encoder.layer.{i}.layernorm_after.bias"))
|
| 45 |
+
rename_keys.append((f"blocks.{i}.mlp.fc1.weight", f"vit.encoder.layer.{i}.intermediate.dense.weight"))
|
| 46 |
+
rename_keys.append((f"blocks.{i}.mlp.fc1.bias", f"vit.encoder.layer.{i}.intermediate.dense.bias"))
|
| 47 |
+
rename_keys.append((f"blocks.{i}.mlp.fc2.weight", f"vit.encoder.layer.{i}.output.dense.weight"))
|
| 48 |
+
rename_keys.append((f"blocks.{i}.mlp.fc2.bias", f"vit.encoder.layer.{i}.output.dense.bias"))
|
| 49 |
+
|
| 50 |
+
# projection layer + position embeddings
|
| 51 |
+
rename_keys.extend(
|
| 52 |
+
[
|
| 53 |
+
("cls_token", "vit.embeddings.cls_token"),
|
| 54 |
+
("patch_embed.proj.weight", "vit.embeddings.patch_embeddings.projection.weight"),
|
| 55 |
+
("patch_embed.proj.bias", "vit.embeddings.patch_embeddings.projection.bias"),
|
| 56 |
+
("pos_embed", "vit.embeddings.position_embeddings"),
|
| 57 |
+
]
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
if base_model:
|
| 61 |
+
# layernorm + pooler
|
| 62 |
+
rename_keys.extend(
|
| 63 |
+
[
|
| 64 |
+
("norm.weight", "layernorm.weight"),
|
| 65 |
+
("norm.bias", "layernorm.bias"),
|
| 66 |
+
]
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
# if just the base model, we should remove "vit" from all keys that start with "vit"
|
| 70 |
+
rename_keys = [(pair[0], pair[1][4:]) if pair[1].startswith("vit") else pair for pair in rename_keys]
|
| 71 |
+
else:
|
| 72 |
+
# layernorm + classification head
|
| 73 |
+
rename_keys.extend(
|
| 74 |
+
[
|
| 75 |
+
("norm.weight", "vit.layernorm.weight"),
|
| 76 |
+
("norm.bias", "vit.layernorm.bias"),
|
| 77 |
+
("head.weight", "classifier.weight"),
|
| 78 |
+
("head.bias", "classifier.bias"),
|
| 79 |
+
]
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
return rename_keys
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
# we split up the matrix of each encoder layer into queries, keys and values
|
| 86 |
+
def read_in_q_k_v(state_dict, config, base_model=False):
|
| 87 |
+
for i in range(config.num_hidden_layers):
|
| 88 |
+
if base_model:
|
| 89 |
+
prefix = ""
|
| 90 |
+
else:
|
| 91 |
+
prefix = "vit."
|
| 92 |
+
# read in weights + bias of input projection layer (in timm, this is a single matrix + bias)
|
| 93 |
+
in_proj_weight = state_dict.pop(f"blocks.{i}.attn.qkv.weight")
|
| 94 |
+
in_proj_bias = state_dict.pop(f"blocks.{i}.attn.qkv.bias")
|
| 95 |
+
# next, add query, keys and values (in that order) to the state dict
|
| 96 |
+
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[
|
| 97 |
+
: config.hidden_size, :
|
| 98 |
+
]
|
| 99 |
+
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.bias"] = in_proj_bias[: config.hidden_size]
|
| 100 |
+
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[
|
| 101 |
+
config.hidden_size : config.hidden_size * 2, :
|
| 102 |
+
]
|
| 103 |
+
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.bias"] = in_proj_bias[
|
| 104 |
+
config.hidden_size : config.hidden_size * 2
|
| 105 |
+
]
|
| 106 |
+
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[
|
| 107 |
+
-config.hidden_size :, :
|
| 108 |
+
]
|
| 109 |
+
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.bias"] = in_proj_bias[-config.hidden_size :]
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def remove_classification_head_(state_dict):
|
| 113 |
+
ignore_keys = ["head.weight", "head.bias"]
|
| 114 |
+
for k in ignore_keys:
|
| 115 |
+
state_dict.pop(k, None)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def rename_key(dct, old, new):
|
| 119 |
+
val = dct.pop(old)
|
| 120 |
+
dct[new] = val
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
# We will verify our results on an image of cute cats
|
| 124 |
+
def prepare_img():
|
| 125 |
+
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 126 |
+
im = Image.open(requests.get(url, stream=True).raw)
|
| 127 |
+
return im
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
@torch.no_grad()
|
| 131 |
+
def convert_vit_checkpoint(model_name, pytorch_dump_folder_path, base_model=True):
|
| 132 |
+
"""
|
| 133 |
+
Copy/paste/tweak model's weights to our ViT structure.
|
| 134 |
+
"""
|
| 135 |
+
|
| 136 |
+
# define default ViT configuration
|
| 137 |
+
config = ViTConfig()
|
| 138 |
+
# patch_size
|
| 139 |
+
if model_name[-1] == "8":
|
| 140 |
+
config.patch_size = 8
|
| 141 |
+
# set labels if required
|
| 142 |
+
if not base_model:
|
| 143 |
+
config.num_labels = 1000
|
| 144 |
+
repo_id = "huggingface/label-files"
|
| 145 |
+
filename = "imagenet-1k-id2label.json"
|
| 146 |
+
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
|
| 147 |
+
id2label = {int(k): v for k, v in id2label.items()}
|
| 148 |
+
config.id2label = id2label
|
| 149 |
+
config.label2id = {v: k for k, v in id2label.items()}
|
| 150 |
+
# size of the architecture
|
| 151 |
+
if model_name in ["dino_vits8", "dino_vits16"]:
|
| 152 |
+
config.hidden_size = 384
|
| 153 |
+
config.intermediate_size = 1536
|
| 154 |
+
config.num_hidden_layers = 12
|
| 155 |
+
config.num_attention_heads = 6
|
| 156 |
+
|
| 157 |
+
# load original model from torch hub
|
| 158 |
+
original_model = torch.hub.load("facebookresearch/dino:main", model_name)
|
| 159 |
+
original_model.eval()
|
| 160 |
+
|
| 161 |
+
# load state_dict of original model, remove and rename some keys
|
| 162 |
+
state_dict = original_model.state_dict()
|
| 163 |
+
if base_model:
|
| 164 |
+
remove_classification_head_(state_dict)
|
| 165 |
+
rename_keys = create_rename_keys(config, base_model=base_model)
|
| 166 |
+
for src, dest in rename_keys:
|
| 167 |
+
rename_key(state_dict, src, dest)
|
| 168 |
+
read_in_q_k_v(state_dict, config, base_model)
|
| 169 |
+
|
| 170 |
+
# load HuggingFace model
|
| 171 |
+
if base_model:
|
| 172 |
+
model = ViTModel(config, add_pooling_layer=False).eval()
|
| 173 |
+
else:
|
| 174 |
+
model = ViTForImageClassification(config).eval()
|
| 175 |
+
model.load_state_dict(state_dict)
|
| 176 |
+
|
| 177 |
+
# Check outputs on an image, prepared by ViTImageProcessor
|
| 178 |
+
image_processor = ViTImageProcessor()
|
| 179 |
+
encoding = image_processor(images=prepare_img(), return_tensors="pt")
|
| 180 |
+
pixel_values = encoding["pixel_values"]
|
| 181 |
+
outputs = model(pixel_values)
|
| 182 |
+
|
| 183 |
+
if base_model:
|
| 184 |
+
final_hidden_state_cls_token = original_model(pixel_values)
|
| 185 |
+
assert torch.allclose(final_hidden_state_cls_token, outputs.last_hidden_state[:, 0, :], atol=1e-1)
|
| 186 |
+
else:
|
| 187 |
+
logits = original_model(pixel_values)
|
| 188 |
+
assert logits.shape == outputs.logits.shape
|
| 189 |
+
assert torch.allclose(logits, outputs.logits, atol=1e-3)
|
| 190 |
+
|
| 191 |
+
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
| 192 |
+
print(f"Saving model {model_name} to {pytorch_dump_folder_path}")
|
| 193 |
+
model.save_pretrained(pytorch_dump_folder_path)
|
| 194 |
+
print(f"Saving image processor to {pytorch_dump_folder_path}")
|
| 195 |
+
image_processor.save_pretrained(pytorch_dump_folder_path)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
if __name__ == "__main__":
|
| 199 |
+
parser = argparse.ArgumentParser()
|
| 200 |
+
# Required parameters
|
| 201 |
+
parser.add_argument(
|
| 202 |
+
"--model_name",
|
| 203 |
+
default="dino_vitb16",
|
| 204 |
+
type=str,
|
| 205 |
+
help="Name of the model trained with DINO you'd like to convert.",
|
| 206 |
+
)
|
| 207 |
+
parser.add_argument(
|
| 208 |
+
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
|
| 209 |
+
)
|
| 210 |
+
parser.add_argument(
|
| 211 |
+
"--base_model",
|
| 212 |
+
action="store_true",
|
| 213 |
+
help="Whether to only convert the base model (no projection head weights).",
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
parser.set_defaults(base_model=True)
|
| 217 |
+
args = parser.parse_args()
|
| 218 |
+
convert_vit_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.base_model)
|
docs/transformers/build/lib/transformers/models/vit/convert_vit_timm_to_pytorch.py
ADDED
|
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2021 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Convert ViT and non-distilled DeiT checkpoints from the timm library."""
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
|
| 20 |
+
import requests
|
| 21 |
+
import timm
|
| 22 |
+
import torch
|
| 23 |
+
from PIL import Image
|
| 24 |
+
from timm.data import ImageNetInfo, infer_imagenet_subset
|
| 25 |
+
|
| 26 |
+
from transformers import DeiTImageProcessor, ViTConfig, ViTForImageClassification, ViTImageProcessor, ViTModel
|
| 27 |
+
from transformers.utils import logging
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
logging.set_verbosity_info()
|
| 31 |
+
logger = logging.get_logger(__name__)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# here we list all keys to be renamed (original name on the left, our name on the right)
|
| 35 |
+
def create_rename_keys(config, base_model=False):
|
| 36 |
+
rename_keys = []
|
| 37 |
+
for i in range(config.num_hidden_layers):
|
| 38 |
+
# encoder layers: output projection, 2 feedforward neural networks and 2 layernorms
|
| 39 |
+
rename_keys.append((f"blocks.{i}.norm1.weight", f"vit.encoder.layer.{i}.layernorm_before.weight"))
|
| 40 |
+
rename_keys.append((f"blocks.{i}.norm1.bias", f"vit.encoder.layer.{i}.layernorm_before.bias"))
|
| 41 |
+
rename_keys.append((f"blocks.{i}.attn.proj.weight", f"vit.encoder.layer.{i}.attention.output.dense.weight"))
|
| 42 |
+
rename_keys.append((f"blocks.{i}.attn.proj.bias", f"vit.encoder.layer.{i}.attention.output.dense.bias"))
|
| 43 |
+
rename_keys.append((f"blocks.{i}.norm2.weight", f"vit.encoder.layer.{i}.layernorm_after.weight"))
|
| 44 |
+
rename_keys.append((f"blocks.{i}.norm2.bias", f"vit.encoder.layer.{i}.layernorm_after.bias"))
|
| 45 |
+
rename_keys.append((f"blocks.{i}.mlp.fc1.weight", f"vit.encoder.layer.{i}.intermediate.dense.weight"))
|
| 46 |
+
rename_keys.append((f"blocks.{i}.mlp.fc1.bias", f"vit.encoder.layer.{i}.intermediate.dense.bias"))
|
| 47 |
+
rename_keys.append((f"blocks.{i}.mlp.fc2.weight", f"vit.encoder.layer.{i}.output.dense.weight"))
|
| 48 |
+
rename_keys.append((f"blocks.{i}.mlp.fc2.bias", f"vit.encoder.layer.{i}.output.dense.bias"))
|
| 49 |
+
|
| 50 |
+
# projection layer + position embeddings
|
| 51 |
+
rename_keys.extend(
|
| 52 |
+
[
|
| 53 |
+
("cls_token", "vit.embeddings.cls_token"),
|
| 54 |
+
("patch_embed.proj.weight", "vit.embeddings.patch_embeddings.projection.weight"),
|
| 55 |
+
("patch_embed.proj.bias", "vit.embeddings.patch_embeddings.projection.bias"),
|
| 56 |
+
("pos_embed", "vit.embeddings.position_embeddings"),
|
| 57 |
+
]
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
if base_model:
|
| 61 |
+
# layernorm
|
| 62 |
+
rename_keys.extend(
|
| 63 |
+
[
|
| 64 |
+
("norm.weight", "layernorm.weight"),
|
| 65 |
+
("norm.bias", "layernorm.bias"),
|
| 66 |
+
]
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
# if just the base model, we should remove "vit" from all keys that start with "vit"
|
| 70 |
+
rename_keys = [(pair[0], pair[1][4:]) if pair[1].startswith("vit") else pair for pair in rename_keys]
|
| 71 |
+
else:
|
| 72 |
+
# layernorm + classification head
|
| 73 |
+
rename_keys.extend(
|
| 74 |
+
[
|
| 75 |
+
("norm.weight", "vit.layernorm.weight"),
|
| 76 |
+
("norm.bias", "vit.layernorm.bias"),
|
| 77 |
+
("head.weight", "classifier.weight"),
|
| 78 |
+
("head.bias", "classifier.bias"),
|
| 79 |
+
]
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
return rename_keys
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
# we split up the matrix of each encoder layer into queries, keys and values
|
| 86 |
+
def read_in_q_k_v(state_dict, config, base_model=False):
|
| 87 |
+
for i in range(config.num_hidden_layers):
|
| 88 |
+
if base_model:
|
| 89 |
+
prefix = ""
|
| 90 |
+
else:
|
| 91 |
+
prefix = "vit."
|
| 92 |
+
# read in weights + bias of input projection layer (in timm, this is a single matrix + bias)
|
| 93 |
+
in_proj_weight = state_dict.pop(f"blocks.{i}.attn.qkv.weight")
|
| 94 |
+
in_proj_bias = state_dict.pop(f"blocks.{i}.attn.qkv.bias")
|
| 95 |
+
# next, add query, keys and values (in that order) to the state dict
|
| 96 |
+
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[
|
| 97 |
+
: config.hidden_size, :
|
| 98 |
+
]
|
| 99 |
+
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.bias"] = in_proj_bias[: config.hidden_size]
|
| 100 |
+
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[
|
| 101 |
+
config.hidden_size : config.hidden_size * 2, :
|
| 102 |
+
]
|
| 103 |
+
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.bias"] = in_proj_bias[
|
| 104 |
+
config.hidden_size : config.hidden_size * 2
|
| 105 |
+
]
|
| 106 |
+
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[
|
| 107 |
+
-config.hidden_size :, :
|
| 108 |
+
]
|
| 109 |
+
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.bias"] = in_proj_bias[-config.hidden_size :]
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def remove_classification_head_(state_dict):
|
| 113 |
+
ignore_keys = ["head.weight", "head.bias"]
|
| 114 |
+
for k in ignore_keys:
|
| 115 |
+
state_dict.pop(k, None)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def rename_key(dct, old, new):
|
| 119 |
+
val = dct.pop(old)
|
| 120 |
+
dct[new] = val
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
# We will verify our results on an image of cute cats
|
| 124 |
+
def prepare_img():
|
| 125 |
+
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 126 |
+
im = Image.open(requests.get(url, stream=True).raw)
|
| 127 |
+
return im
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
@torch.no_grad()
|
| 131 |
+
def convert_vit_checkpoint(vit_name, pytorch_dump_folder_path):
|
| 132 |
+
"""
|
| 133 |
+
Copy/paste/tweak model's weights to our ViT structure.
|
| 134 |
+
"""
|
| 135 |
+
|
| 136 |
+
# define default ViT configuration
|
| 137 |
+
config = ViTConfig()
|
| 138 |
+
base_model = False
|
| 139 |
+
|
| 140 |
+
# load original model from timm
|
| 141 |
+
timm_model = timm.create_model(vit_name, pretrained=True)
|
| 142 |
+
timm_model.eval()
|
| 143 |
+
|
| 144 |
+
# detect unsupported ViT models in transformers
|
| 145 |
+
# fc_norm is present
|
| 146 |
+
if not isinstance(getattr(timm_model, "fc_norm", None), torch.nn.Identity):
|
| 147 |
+
raise ValueError(f"{vit_name} is not supported in transformers because of the presence of fc_norm.")
|
| 148 |
+
|
| 149 |
+
# use of global average pooling in combination (or without) class token
|
| 150 |
+
if getattr(timm_model, "global_pool", None) == "avg":
|
| 151 |
+
raise ValueError(f"{vit_name} is not supported in transformers because of use of global average pooling.")
|
| 152 |
+
|
| 153 |
+
# CLIP style vit with norm_pre layer present
|
| 154 |
+
if "clip" in vit_name and not isinstance(getattr(timm_model, "norm_pre", None), torch.nn.Identity):
|
| 155 |
+
raise ValueError(
|
| 156 |
+
f"{vit_name} is not supported in transformers because it's a CLIP style ViT with norm_pre layer."
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
# SigLIP style vit with attn_pool layer present
|
| 160 |
+
if "siglip" in vit_name and getattr(timm_model, "global_pool", None) == "map":
|
| 161 |
+
raise ValueError(
|
| 162 |
+
f"{vit_name} is not supported in transformers because it's a SigLIP style ViT with attn_pool."
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
# use of layer scale in ViT model blocks
|
| 166 |
+
if not isinstance(getattr(timm_model.blocks[0], "ls1", None), torch.nn.Identity) or not isinstance(
|
| 167 |
+
getattr(timm_model.blocks[0], "ls2", None), torch.nn.Identity
|
| 168 |
+
):
|
| 169 |
+
raise ValueError(f"{vit_name} is not supported in transformers because it uses a layer scale in its blocks.")
|
| 170 |
+
|
| 171 |
+
# Hybrid ResNet-ViTs
|
| 172 |
+
if not isinstance(timm_model.patch_embed, timm.layers.PatchEmbed):
|
| 173 |
+
raise ValueError(f"{vit_name} is not supported in transformers because it is a hybrid ResNet-ViT.")
|
| 174 |
+
|
| 175 |
+
# get patch size and image size from the patch embedding submodule
|
| 176 |
+
config.patch_size = timm_model.patch_embed.patch_size[0]
|
| 177 |
+
config.image_size = timm_model.patch_embed.img_size[0]
|
| 178 |
+
|
| 179 |
+
# retrieve architecture-specific parameters from the timm model
|
| 180 |
+
config.hidden_size = timm_model.embed_dim
|
| 181 |
+
config.intermediate_size = timm_model.blocks[0].mlp.fc1.out_features
|
| 182 |
+
config.num_hidden_layers = len(timm_model.blocks)
|
| 183 |
+
config.num_attention_heads = timm_model.blocks[0].attn.num_heads
|
| 184 |
+
|
| 185 |
+
# check whether the model has a classification head or not
|
| 186 |
+
if timm_model.num_classes != 0:
|
| 187 |
+
config.num_labels = timm_model.num_classes
|
| 188 |
+
# infer ImageNet subset from timm model
|
| 189 |
+
imagenet_subset = infer_imagenet_subset(timm_model)
|
| 190 |
+
dataset_info = ImageNetInfo(imagenet_subset)
|
| 191 |
+
config.id2label = {i: dataset_info.index_to_label_name(i) for i in range(dataset_info.num_classes())}
|
| 192 |
+
config.label2id = {v: k for k, v in config.id2label.items()}
|
| 193 |
+
else:
|
| 194 |
+
print(f"{vit_name} is going to be converted as a feature extractor only.")
|
| 195 |
+
base_model = True
|
| 196 |
+
|
| 197 |
+
# load state_dict of original model
|
| 198 |
+
state_dict = timm_model.state_dict()
|
| 199 |
+
|
| 200 |
+
# remove and rename some keys in the state dict
|
| 201 |
+
if base_model:
|
| 202 |
+
remove_classification_head_(state_dict)
|
| 203 |
+
rename_keys = create_rename_keys(config, base_model)
|
| 204 |
+
for src, dest in rename_keys:
|
| 205 |
+
rename_key(state_dict, src, dest)
|
| 206 |
+
read_in_q_k_v(state_dict, config, base_model)
|
| 207 |
+
|
| 208 |
+
# load HuggingFace model
|
| 209 |
+
if base_model:
|
| 210 |
+
model = ViTModel(config, add_pooling_layer=False).eval()
|
| 211 |
+
else:
|
| 212 |
+
model = ViTForImageClassification(config).eval()
|
| 213 |
+
model.load_state_dict(state_dict)
|
| 214 |
+
|
| 215 |
+
# Check outputs on an image, prepared by ViTImageProcessor/DeiTImageProcessor
|
| 216 |
+
if "deit" in vit_name:
|
| 217 |
+
image_processor = DeiTImageProcessor(size=config.image_size)
|
| 218 |
+
else:
|
| 219 |
+
image_processor = ViTImageProcessor(size=config.image_size)
|
| 220 |
+
encoding = image_processor(images=prepare_img(), return_tensors="pt")
|
| 221 |
+
pixel_values = encoding["pixel_values"]
|
| 222 |
+
outputs = model(pixel_values)
|
| 223 |
+
|
| 224 |
+
if base_model:
|
| 225 |
+
timm_pooled_output = timm_model.forward_features(pixel_values)
|
| 226 |
+
assert timm_pooled_output.shape == outputs.last_hidden_state.shape
|
| 227 |
+
assert torch.allclose(timm_pooled_output, outputs.last_hidden_state, atol=1e-1)
|
| 228 |
+
else:
|
| 229 |
+
timm_logits = timm_model(pixel_values)
|
| 230 |
+
assert timm_logits.shape == outputs.logits.shape
|
| 231 |
+
assert torch.allclose(timm_logits, outputs.logits, atol=1e-3)
|
| 232 |
+
|
| 233 |
+
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
|
| 234 |
+
print(f"Saving model {vit_name} to {pytorch_dump_folder_path}")
|
| 235 |
+
model.save_pretrained(pytorch_dump_folder_path)
|
| 236 |
+
print(f"Saving image processor to {pytorch_dump_folder_path}")
|
| 237 |
+
image_processor.save_pretrained(pytorch_dump_folder_path)
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
if __name__ == "__main__":
|
| 241 |
+
parser = argparse.ArgumentParser()
|
| 242 |
+
# Required parameters
|
| 243 |
+
parser.add_argument(
|
| 244 |
+
"--vit_name",
|
| 245 |
+
default="vit_base_patch16_224",
|
| 246 |
+
type=str,
|
| 247 |
+
help="Name of the ViT timm model you'd like to convert.",
|
| 248 |
+
)
|
| 249 |
+
parser.add_argument(
|
| 250 |
+
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
args = parser.parse_args()
|
| 254 |
+
convert_vit_checkpoint(args.vit_name, args.pytorch_dump_folder_path)
|
docs/transformers/build/lib/transformers/models/vit/feature_extraction_vit.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Feature extractor class for ViT."""
|
| 16 |
+
|
| 17 |
+
import warnings
|
| 18 |
+
|
| 19 |
+
from ...utils import logging
|
| 20 |
+
from ...utils.import_utils import requires
|
| 21 |
+
from .image_processing_vit import ViTImageProcessor
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
logger = logging.get_logger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@requires(backends=("vision",))
|
| 28 |
+
class ViTFeatureExtractor(ViTImageProcessor):
|
| 29 |
+
def __init__(self, *args, **kwargs) -> None:
|
| 30 |
+
warnings.warn(
|
| 31 |
+
"The class ViTFeatureExtractor is deprecated and will be removed in version 5 of Transformers. Please"
|
| 32 |
+
" use ViTImageProcessor instead.",
|
| 33 |
+
FutureWarning,
|
| 34 |
+
)
|
| 35 |
+
super().__init__(*args, **kwargs)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
__all__ = ["ViTFeatureExtractor"]
|
docs/transformers/build/lib/transformers/models/vit/image_processing_vit.py
ADDED
|
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Image processor class for ViT."""
|
| 16 |
+
|
| 17 |
+
from typing import Dict, List, Optional, Union
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
|
| 21 |
+
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
| 22 |
+
from ...image_transforms import convert_to_rgb, resize, to_channel_dimension_format
|
| 23 |
+
from ...image_utils import (
|
| 24 |
+
IMAGENET_STANDARD_MEAN,
|
| 25 |
+
IMAGENET_STANDARD_STD,
|
| 26 |
+
ChannelDimension,
|
| 27 |
+
ImageInput,
|
| 28 |
+
PILImageResampling,
|
| 29 |
+
infer_channel_dimension_format,
|
| 30 |
+
is_scaled_image,
|
| 31 |
+
make_list_of_images,
|
| 32 |
+
to_numpy_array,
|
| 33 |
+
valid_images,
|
| 34 |
+
validate_preprocess_arguments,
|
| 35 |
+
)
|
| 36 |
+
from ...utils import TensorType, filter_out_non_signature_kwargs, logging
|
| 37 |
+
from ...utils.import_utils import requires
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
logger = logging.get_logger(__name__)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@requires(backends=("vision",))
|
| 44 |
+
class ViTImageProcessor(BaseImageProcessor):
|
| 45 |
+
r"""
|
| 46 |
+
Constructs a ViT image processor.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
do_resize (`bool`, *optional*, defaults to `True`):
|
| 50 |
+
Whether to resize the image's (height, width) dimensions to the specified `(size["height"],
|
| 51 |
+
size["width"])`. Can be overridden by the `do_resize` parameter in the `preprocess` method.
|
| 52 |
+
size (`dict`, *optional*, defaults to `{"height": 224, "width": 224}`):
|
| 53 |
+
Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
|
| 54 |
+
method.
|
| 55 |
+
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
|
| 56 |
+
Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
|
| 57 |
+
`preprocess` method.
|
| 58 |
+
do_rescale (`bool`, *optional*, defaults to `True`):
|
| 59 |
+
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
|
| 60 |
+
parameter in the `preprocess` method.
|
| 61 |
+
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
| 62 |
+
Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
|
| 63 |
+
`preprocess` method.
|
| 64 |
+
do_normalize (`bool`, *optional*, defaults to `True`):
|
| 65 |
+
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
|
| 66 |
+
method.
|
| 67 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
|
| 68 |
+
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
|
| 69 |
+
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
|
| 70 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
|
| 71 |
+
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
|
| 72 |
+
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
| 73 |
+
do_convert_rgb (`bool`, *optional*):
|
| 74 |
+
Whether to convert the image to RGB.
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
model_input_names = ["pixel_values"]
|
| 78 |
+
|
| 79 |
+
def __init__(
|
| 80 |
+
self,
|
| 81 |
+
do_resize: bool = True,
|
| 82 |
+
size: Optional[Dict[str, int]] = None,
|
| 83 |
+
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
| 84 |
+
do_rescale: bool = True,
|
| 85 |
+
rescale_factor: Union[int, float] = 1 / 255,
|
| 86 |
+
do_normalize: bool = True,
|
| 87 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 88 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 89 |
+
do_convert_rgb: Optional[bool] = None,
|
| 90 |
+
**kwargs,
|
| 91 |
+
) -> None:
|
| 92 |
+
super().__init__(**kwargs)
|
| 93 |
+
size = size if size is not None else {"height": 224, "width": 224}
|
| 94 |
+
size = get_size_dict(size)
|
| 95 |
+
self.do_resize = do_resize
|
| 96 |
+
self.do_rescale = do_rescale
|
| 97 |
+
self.do_normalize = do_normalize
|
| 98 |
+
self.size = size
|
| 99 |
+
self.resample = resample
|
| 100 |
+
self.rescale_factor = rescale_factor
|
| 101 |
+
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
|
| 102 |
+
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
|
| 103 |
+
self.do_convert_rgb = do_convert_rgb
|
| 104 |
+
|
| 105 |
+
def resize(
|
| 106 |
+
self,
|
| 107 |
+
image: np.ndarray,
|
| 108 |
+
size: Dict[str, int],
|
| 109 |
+
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
| 110 |
+
data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 111 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 112 |
+
**kwargs,
|
| 113 |
+
) -> np.ndarray:
|
| 114 |
+
"""
|
| 115 |
+
Resize an image to `(size["height"], size["width"])`.
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
image (`np.ndarray`):
|
| 119 |
+
Image to resize.
|
| 120 |
+
size (`Dict[str, int]`):
|
| 121 |
+
Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
|
| 122 |
+
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
|
| 123 |
+
`PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`.
|
| 124 |
+
data_format (`ChannelDimension` or `str`, *optional*):
|
| 125 |
+
The channel dimension format for the output image. If unset, the channel dimension format of the input
|
| 126 |
+
image is used. Can be one of:
|
| 127 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 128 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 129 |
+
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
| 130 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 131 |
+
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
| 132 |
+
from the input image. Can be one of:
|
| 133 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 134 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 135 |
+
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
| 136 |
+
|
| 137 |
+
Returns:
|
| 138 |
+
`np.ndarray`: The resized image.
|
| 139 |
+
"""
|
| 140 |
+
size = get_size_dict(size)
|
| 141 |
+
if "height" not in size or "width" not in size:
|
| 142 |
+
raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
|
| 143 |
+
output_size = (size["height"], size["width"])
|
| 144 |
+
return resize(
|
| 145 |
+
image,
|
| 146 |
+
size=output_size,
|
| 147 |
+
resample=resample,
|
| 148 |
+
data_format=data_format,
|
| 149 |
+
input_data_format=input_data_format,
|
| 150 |
+
**kwargs,
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
@filter_out_non_signature_kwargs()
|
| 154 |
+
def preprocess(
|
| 155 |
+
self,
|
| 156 |
+
images: ImageInput,
|
| 157 |
+
do_resize: Optional[bool] = None,
|
| 158 |
+
size: Dict[str, int] = None,
|
| 159 |
+
resample: PILImageResampling = None,
|
| 160 |
+
do_rescale: Optional[bool] = None,
|
| 161 |
+
rescale_factor: Optional[float] = None,
|
| 162 |
+
do_normalize: Optional[bool] = None,
|
| 163 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 164 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 165 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 166 |
+
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
|
| 167 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 168 |
+
do_convert_rgb: Optional[bool] = None,
|
| 169 |
+
):
|
| 170 |
+
"""
|
| 171 |
+
Preprocess an image or batch of images.
|
| 172 |
+
|
| 173 |
+
Args:
|
| 174 |
+
images (`ImageInput`):
|
| 175 |
+
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
|
| 176 |
+
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
| 177 |
+
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
| 178 |
+
Whether to resize the image.
|
| 179 |
+
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
| 180 |
+
Dictionary in the format `{"height": h, "width": w}` specifying the size of the output image after
|
| 181 |
+
resizing.
|
| 182 |
+
resample (`PILImageResampling` filter, *optional*, defaults to `self.resample`):
|
| 183 |
+
`PILImageResampling` filter to use if resizing the image e.g. `PILImageResampling.BILINEAR`. Only has
|
| 184 |
+
an effect if `do_resize` is set to `True`.
|
| 185 |
+
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
| 186 |
+
Whether to rescale the image values between [0 - 1].
|
| 187 |
+
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
| 188 |
+
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
| 189 |
+
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
| 190 |
+
Whether to normalize the image.
|
| 191 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
| 192 |
+
Image mean to use if `do_normalize` is set to `True`.
|
| 193 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
| 194 |
+
Image standard deviation to use if `do_normalize` is set to `True`.
|
| 195 |
+
return_tensors (`str` or `TensorType`, *optional*):
|
| 196 |
+
The type of tensors to return. Can be one of:
|
| 197 |
+
- Unset: Return a list of `np.ndarray`.
|
| 198 |
+
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
| 199 |
+
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
| 200 |
+
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
| 201 |
+
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
| 202 |
+
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
| 203 |
+
The channel dimension format for the output image. Can be one of:
|
| 204 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 205 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 206 |
+
- Unset: Use the channel dimension format of the input image.
|
| 207 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 208 |
+
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
| 209 |
+
from the input image. Can be one of:
|
| 210 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 211 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 212 |
+
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
| 213 |
+
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
| 214 |
+
Whether to convert the image to RGB.
|
| 215 |
+
"""
|
| 216 |
+
do_resize = do_resize if do_resize is not None else self.do_resize
|
| 217 |
+
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
| 218 |
+
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
| 219 |
+
resample = resample if resample is not None else self.resample
|
| 220 |
+
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
| 221 |
+
image_mean = image_mean if image_mean is not None else self.image_mean
|
| 222 |
+
image_std = image_std if image_std is not None else self.image_std
|
| 223 |
+
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
| 224 |
+
|
| 225 |
+
size = size if size is not None else self.size
|
| 226 |
+
size_dict = get_size_dict(size)
|
| 227 |
+
|
| 228 |
+
images = make_list_of_images(images)
|
| 229 |
+
|
| 230 |
+
if not valid_images(images):
|
| 231 |
+
raise ValueError(
|
| 232 |
+
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
| 233 |
+
"torch.Tensor, tf.Tensor or jax.ndarray."
|
| 234 |
+
)
|
| 235 |
+
validate_preprocess_arguments(
|
| 236 |
+
do_rescale=do_rescale,
|
| 237 |
+
rescale_factor=rescale_factor,
|
| 238 |
+
do_normalize=do_normalize,
|
| 239 |
+
image_mean=image_mean,
|
| 240 |
+
image_std=image_std,
|
| 241 |
+
do_resize=do_resize,
|
| 242 |
+
size=size,
|
| 243 |
+
resample=resample,
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
if do_convert_rgb:
|
| 247 |
+
images = [convert_to_rgb(image) for image in images]
|
| 248 |
+
|
| 249 |
+
# All transformations expect numpy arrays.
|
| 250 |
+
images = [to_numpy_array(image) for image in images]
|
| 251 |
+
|
| 252 |
+
if do_rescale and is_scaled_image(images[0]):
|
| 253 |
+
logger.warning_once(
|
| 254 |
+
"It looks like you are trying to rescale already rescaled images. If the input"
|
| 255 |
+
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
if input_data_format is None:
|
| 259 |
+
# We assume that all images have the same channel dimension format.
|
| 260 |
+
input_data_format = infer_channel_dimension_format(images[0])
|
| 261 |
+
|
| 262 |
+
if do_resize:
|
| 263 |
+
images = [
|
| 264 |
+
self.resize(image=image, size=size_dict, resample=resample, input_data_format=input_data_format)
|
| 265 |
+
for image in images
|
| 266 |
+
]
|
| 267 |
+
|
| 268 |
+
if do_rescale:
|
| 269 |
+
images = [
|
| 270 |
+
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
| 271 |
+
for image in images
|
| 272 |
+
]
|
| 273 |
+
|
| 274 |
+
if do_normalize:
|
| 275 |
+
images = [
|
| 276 |
+
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
|
| 277 |
+
for image in images
|
| 278 |
+
]
|
| 279 |
+
|
| 280 |
+
images = [
|
| 281 |
+
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
|
| 282 |
+
]
|
| 283 |
+
|
| 284 |
+
data = {"pixel_values": images}
|
| 285 |
+
return BatchFeature(data=data, tensor_type=return_tensors)
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
__all__ = ["ViTImageProcessor"]
|
docs/transformers/build/lib/transformers/models/vit/image_processing_vit_fast.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Fast Image processor class for ViT."""
|
| 16 |
+
|
| 17 |
+
from ...image_processing_utils_fast import (
|
| 18 |
+
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
|
| 19 |
+
BaseImageProcessorFast,
|
| 20 |
+
)
|
| 21 |
+
from ...image_utils import (
|
| 22 |
+
IMAGENET_STANDARD_MEAN,
|
| 23 |
+
IMAGENET_STANDARD_STD,
|
| 24 |
+
PILImageResampling,
|
| 25 |
+
)
|
| 26 |
+
from ...utils import (
|
| 27 |
+
add_start_docstrings,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@add_start_docstrings(
|
| 32 |
+
"Constructs a fast ViT image processor.",
|
| 33 |
+
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
|
| 34 |
+
)
|
| 35 |
+
class ViTImageProcessorFast(BaseImageProcessorFast):
|
| 36 |
+
resample = PILImageResampling.BILINEAR
|
| 37 |
+
image_mean = IMAGENET_STANDARD_MEAN
|
| 38 |
+
image_std = IMAGENET_STANDARD_STD
|
| 39 |
+
size = {"height": 224, "width": 224}
|
| 40 |
+
do_resize = True
|
| 41 |
+
do_rescale = True
|
| 42 |
+
do_normalize = True
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
__all__ = ["ViTImageProcessorFast"]
|
docs/transformers/build/lib/transformers/models/vit/modeling_flax_vit.py
ADDED
|
@@ -0,0 +1,677 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2021 The Google Flax Team Authors and The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from typing import Optional, Tuple
|
| 17 |
+
|
| 18 |
+
import flax.linen as nn
|
| 19 |
+
import jax
|
| 20 |
+
import jax.numpy as jnp
|
| 21 |
+
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
| 22 |
+
from flax.linen.attention import dot_product_attention_weights
|
| 23 |
+
from flax.traverse_util import flatten_dict, unflatten_dict
|
| 24 |
+
|
| 25 |
+
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxBaseModelOutputWithPooling, FlaxSequenceClassifierOutput
|
| 26 |
+
from ...modeling_flax_utils import (
|
| 27 |
+
ACT2FN,
|
| 28 |
+
FlaxPreTrainedModel,
|
| 29 |
+
append_replace_return_docstrings,
|
| 30 |
+
overwrite_call_docstring,
|
| 31 |
+
)
|
| 32 |
+
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward
|
| 33 |
+
from .configuration_vit import ViTConfig
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
VIT_START_DOCSTRING = r"""
|
| 37 |
+
|
| 38 |
+
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 39 |
+
library implements for all its model (such as downloading, saving and converting weights from PyTorch models)
|
| 40 |
+
|
| 41 |
+
This model is also a
|
| 42 |
+
[flax.linen.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) subclass. Use it as
|
| 43 |
+
a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and
|
| 44 |
+
behavior.
|
| 45 |
+
|
| 46 |
+
Finally, this model supports inherent JAX features such as:
|
| 47 |
+
|
| 48 |
+
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
|
| 49 |
+
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
|
| 50 |
+
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
|
| 51 |
+
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
|
| 52 |
+
|
| 53 |
+
Parameters:
|
| 54 |
+
config ([`ViTConfig`]): Model configuration class with all the parameters of the model.
|
| 55 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 56 |
+
configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
|
| 57 |
+
dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
|
| 58 |
+
The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
|
| 59 |
+
`jax.numpy.bfloat16` (on TPUs).
|
| 60 |
+
|
| 61 |
+
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
|
| 62 |
+
specified all the computation will be performed with the given `dtype`.
|
| 63 |
+
|
| 64 |
+
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
|
| 65 |
+
parameters.**
|
| 66 |
+
|
| 67 |
+
If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
|
| 68 |
+
[`~FlaxPreTrainedModel.to_bf16`].
|
| 69 |
+
"""
|
| 70 |
+
|
| 71 |
+
VIT_INPUTS_DOCSTRING = r"""
|
| 72 |
+
Args:
|
| 73 |
+
pixel_values (`numpy.ndarray` of shape `(batch_size, num_channels, height, width)`):
|
| 74 |
+
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`]
|
| 75 |
+
for details.
|
| 76 |
+
|
| 77 |
+
output_attentions (`bool`, *optional*):
|
| 78 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 79 |
+
tensors for more detail.
|
| 80 |
+
output_hidden_states (`bool`, *optional*):
|
| 81 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 82 |
+
more detail.
|
| 83 |
+
return_dict (`bool`, *optional*):
|
| 84 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class FlaxViTPatchEmbeddings(nn.Module):
|
| 89 |
+
config: ViTConfig
|
| 90 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 91 |
+
|
| 92 |
+
def setup(self):
|
| 93 |
+
image_size = self.config.image_size
|
| 94 |
+
patch_size = self.config.patch_size
|
| 95 |
+
num_patches = (image_size // patch_size) * (image_size // patch_size)
|
| 96 |
+
self.num_patches = num_patches
|
| 97 |
+
self.num_channels = self.config.num_channels
|
| 98 |
+
self.projection = nn.Conv(
|
| 99 |
+
self.config.hidden_size,
|
| 100 |
+
kernel_size=(patch_size, patch_size),
|
| 101 |
+
strides=(patch_size, patch_size),
|
| 102 |
+
padding="VALID",
|
| 103 |
+
dtype=self.dtype,
|
| 104 |
+
kernel_init=jax.nn.initializers.variance_scaling(
|
| 105 |
+
self.config.initializer_range**2, "fan_in", "truncated_normal"
|
| 106 |
+
),
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
def __call__(self, pixel_values):
|
| 110 |
+
num_channels = pixel_values.shape[-1]
|
| 111 |
+
if num_channels != self.num_channels:
|
| 112 |
+
raise ValueError(
|
| 113 |
+
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
| 114 |
+
)
|
| 115 |
+
embeddings = self.projection(pixel_values)
|
| 116 |
+
batch_size, _, _, channels = embeddings.shape
|
| 117 |
+
return jnp.reshape(embeddings, (batch_size, -1, channels))
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class FlaxViTEmbeddings(nn.Module):
|
| 121 |
+
"""Construct the CLS token, position and patch embeddings."""
|
| 122 |
+
|
| 123 |
+
config: ViTConfig
|
| 124 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 125 |
+
|
| 126 |
+
def setup(self):
|
| 127 |
+
self.cls_token = self.param(
|
| 128 |
+
"cls_token",
|
| 129 |
+
jax.nn.initializers.variance_scaling(self.config.initializer_range**2, "fan_in", "truncated_normal"),
|
| 130 |
+
(1, 1, self.config.hidden_size),
|
| 131 |
+
)
|
| 132 |
+
self.patch_embeddings = FlaxViTPatchEmbeddings(self.config, dtype=self.dtype)
|
| 133 |
+
num_patches = self.patch_embeddings.num_patches
|
| 134 |
+
self.position_embeddings = self.param(
|
| 135 |
+
"position_embeddings",
|
| 136 |
+
jax.nn.initializers.variance_scaling(self.config.initializer_range**2, "fan_in", "truncated_normal"),
|
| 137 |
+
(1, num_patches + 1, self.config.hidden_size),
|
| 138 |
+
)
|
| 139 |
+
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
| 140 |
+
|
| 141 |
+
def __call__(self, pixel_values, deterministic=True):
|
| 142 |
+
batch_size = pixel_values.shape[0]
|
| 143 |
+
|
| 144 |
+
embeddings = self.patch_embeddings(pixel_values)
|
| 145 |
+
|
| 146 |
+
cls_tokens = jnp.broadcast_to(self.cls_token, (batch_size, 1, self.config.hidden_size))
|
| 147 |
+
embeddings = jnp.concatenate((cls_tokens, embeddings), axis=1)
|
| 148 |
+
embeddings = embeddings + self.position_embeddings
|
| 149 |
+
embeddings = self.dropout(embeddings, deterministic=deterministic)
|
| 150 |
+
return embeddings
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class FlaxViTSelfAttention(nn.Module):
|
| 154 |
+
config: ViTConfig
|
| 155 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 156 |
+
|
| 157 |
+
def setup(self):
|
| 158 |
+
if self.config.hidden_size % self.config.num_attention_heads != 0:
|
| 159 |
+
raise ValueError(
|
| 160 |
+
"`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`:"
|
| 161 |
+
" {self.config.num_attention_heads}"
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
self.query = nn.Dense(
|
| 165 |
+
self.config.hidden_size,
|
| 166 |
+
dtype=self.dtype,
|
| 167 |
+
kernel_init=jax.nn.initializers.variance_scaling(
|
| 168 |
+
self.config.initializer_range**2, mode="fan_in", distribution="truncated_normal"
|
| 169 |
+
),
|
| 170 |
+
use_bias=self.config.qkv_bias,
|
| 171 |
+
)
|
| 172 |
+
self.key = nn.Dense(
|
| 173 |
+
self.config.hidden_size,
|
| 174 |
+
dtype=self.dtype,
|
| 175 |
+
kernel_init=jax.nn.initializers.variance_scaling(
|
| 176 |
+
self.config.initializer_range**2, mode="fan_in", distribution="truncated_normal"
|
| 177 |
+
),
|
| 178 |
+
use_bias=self.config.qkv_bias,
|
| 179 |
+
)
|
| 180 |
+
self.value = nn.Dense(
|
| 181 |
+
self.config.hidden_size,
|
| 182 |
+
dtype=self.dtype,
|
| 183 |
+
kernel_init=jax.nn.initializers.variance_scaling(
|
| 184 |
+
self.config.initializer_range**2, mode="fan_in", distribution="truncated_normal"
|
| 185 |
+
),
|
| 186 |
+
use_bias=self.config.qkv_bias,
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
def __call__(self, hidden_states, deterministic: bool = True, output_attentions: bool = False):
|
| 190 |
+
head_dim = self.config.hidden_size // self.config.num_attention_heads
|
| 191 |
+
|
| 192 |
+
query_states = self.query(hidden_states).reshape(
|
| 193 |
+
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
|
| 194 |
+
)
|
| 195 |
+
value_states = self.value(hidden_states).reshape(
|
| 196 |
+
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
|
| 197 |
+
)
|
| 198 |
+
key_states = self.key(hidden_states).reshape(
|
| 199 |
+
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
dropout_rng = None
|
| 203 |
+
if not deterministic and self.config.attention_probs_dropout_prob > 0.0:
|
| 204 |
+
dropout_rng = self.make_rng("dropout")
|
| 205 |
+
|
| 206 |
+
attn_weights = dot_product_attention_weights(
|
| 207 |
+
query_states,
|
| 208 |
+
key_states,
|
| 209 |
+
dropout_rng=dropout_rng,
|
| 210 |
+
dropout_rate=self.config.attention_probs_dropout_prob,
|
| 211 |
+
broadcast_dropout=True,
|
| 212 |
+
deterministic=deterministic,
|
| 213 |
+
dtype=self.dtype,
|
| 214 |
+
precision=None,
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
|
| 218 |
+
attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,))
|
| 219 |
+
|
| 220 |
+
outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
|
| 221 |
+
return outputs
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
class FlaxViTSelfOutput(nn.Module):
|
| 225 |
+
config: ViTConfig
|
| 226 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 227 |
+
|
| 228 |
+
def setup(self):
|
| 229 |
+
self.dense = nn.Dense(
|
| 230 |
+
self.config.hidden_size,
|
| 231 |
+
kernel_init=jax.nn.initializers.variance_scaling(
|
| 232 |
+
self.config.initializer_range**2, "fan_in", "truncated_normal"
|
| 233 |
+
),
|
| 234 |
+
dtype=self.dtype,
|
| 235 |
+
)
|
| 236 |
+
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
| 237 |
+
|
| 238 |
+
def __call__(self, hidden_states, input_tensor, deterministic: bool = True):
|
| 239 |
+
hidden_states = self.dense(hidden_states)
|
| 240 |
+
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
| 241 |
+
return hidden_states
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
class FlaxViTAttention(nn.Module):
|
| 245 |
+
config: ViTConfig
|
| 246 |
+
dtype: jnp.dtype = jnp.float32
|
| 247 |
+
|
| 248 |
+
def setup(self):
|
| 249 |
+
self.attention = FlaxViTSelfAttention(self.config, dtype=self.dtype)
|
| 250 |
+
self.output = FlaxViTSelfOutput(self.config, dtype=self.dtype)
|
| 251 |
+
|
| 252 |
+
def __call__(self, hidden_states, deterministic=True, output_attentions: bool = False):
|
| 253 |
+
attn_outputs = self.attention(hidden_states, deterministic=deterministic, output_attentions=output_attentions)
|
| 254 |
+
attn_output = attn_outputs[0]
|
| 255 |
+
hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic)
|
| 256 |
+
|
| 257 |
+
outputs = (hidden_states,)
|
| 258 |
+
|
| 259 |
+
if output_attentions:
|
| 260 |
+
outputs += (attn_outputs[1],)
|
| 261 |
+
|
| 262 |
+
return outputs
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
class FlaxViTIntermediate(nn.Module):
|
| 266 |
+
config: ViTConfig
|
| 267 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 268 |
+
|
| 269 |
+
def setup(self):
|
| 270 |
+
self.dense = nn.Dense(
|
| 271 |
+
self.config.intermediate_size,
|
| 272 |
+
kernel_init=jax.nn.initializers.variance_scaling(
|
| 273 |
+
self.config.initializer_range**2, "fan_in", "truncated_normal"
|
| 274 |
+
),
|
| 275 |
+
dtype=self.dtype,
|
| 276 |
+
)
|
| 277 |
+
self.activation = ACT2FN[self.config.hidden_act]
|
| 278 |
+
|
| 279 |
+
def __call__(self, hidden_states):
|
| 280 |
+
hidden_states = self.dense(hidden_states)
|
| 281 |
+
hidden_states = self.activation(hidden_states)
|
| 282 |
+
return hidden_states
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
class FlaxViTOutput(nn.Module):
|
| 286 |
+
config: ViTConfig
|
| 287 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 288 |
+
|
| 289 |
+
def setup(self):
|
| 290 |
+
self.dense = nn.Dense(
|
| 291 |
+
self.config.hidden_size,
|
| 292 |
+
kernel_init=jax.nn.initializers.variance_scaling(
|
| 293 |
+
self.config.initializer_range**2, "fan_in", "truncated_normal"
|
| 294 |
+
),
|
| 295 |
+
dtype=self.dtype,
|
| 296 |
+
)
|
| 297 |
+
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
|
| 298 |
+
|
| 299 |
+
def __call__(self, hidden_states, attention_output, deterministic: bool = True):
|
| 300 |
+
hidden_states = self.dense(hidden_states)
|
| 301 |
+
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
|
| 302 |
+
hidden_states = hidden_states + attention_output
|
| 303 |
+
return hidden_states
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
class FlaxViTLayer(nn.Module):
|
| 307 |
+
config: ViTConfig
|
| 308 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 309 |
+
|
| 310 |
+
def setup(self):
|
| 311 |
+
self.attention = FlaxViTAttention(self.config, dtype=self.dtype)
|
| 312 |
+
self.intermediate = FlaxViTIntermediate(self.config, dtype=self.dtype)
|
| 313 |
+
self.output = FlaxViTOutput(self.config, dtype=self.dtype)
|
| 314 |
+
self.layernorm_before = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
| 315 |
+
self.layernorm_after = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
| 316 |
+
|
| 317 |
+
def __call__(self, hidden_states, deterministic: bool = True, output_attentions: bool = False):
|
| 318 |
+
attention_outputs = self.attention(
|
| 319 |
+
self.layernorm_before(hidden_states), # in ViT, layernorm is applied before self-attention
|
| 320 |
+
deterministic=deterministic,
|
| 321 |
+
output_attentions=output_attentions,
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
attention_output = attention_outputs[0]
|
| 325 |
+
|
| 326 |
+
# first residual connection
|
| 327 |
+
attention_output = attention_output + hidden_states
|
| 328 |
+
|
| 329 |
+
# in ViT, layernorm is also applied after self-attention
|
| 330 |
+
layer_output = self.layernorm_after(attention_output)
|
| 331 |
+
|
| 332 |
+
hidden_states = self.intermediate(layer_output)
|
| 333 |
+
hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic)
|
| 334 |
+
|
| 335 |
+
outputs = (hidden_states,)
|
| 336 |
+
|
| 337 |
+
if output_attentions:
|
| 338 |
+
outputs += (attention_outputs[1],)
|
| 339 |
+
return outputs
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
class FlaxViTLayerCollection(nn.Module):
|
| 343 |
+
config: ViTConfig
|
| 344 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 345 |
+
|
| 346 |
+
def setup(self):
|
| 347 |
+
self.layers = [
|
| 348 |
+
FlaxViTLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers)
|
| 349 |
+
]
|
| 350 |
+
|
| 351 |
+
def __call__(
|
| 352 |
+
self,
|
| 353 |
+
hidden_states,
|
| 354 |
+
deterministic: bool = True,
|
| 355 |
+
output_attentions: bool = False,
|
| 356 |
+
output_hidden_states: bool = False,
|
| 357 |
+
return_dict: bool = True,
|
| 358 |
+
):
|
| 359 |
+
all_attentions = () if output_attentions else None
|
| 360 |
+
all_hidden_states = () if output_hidden_states else None
|
| 361 |
+
|
| 362 |
+
for i, layer in enumerate(self.layers):
|
| 363 |
+
if output_hidden_states:
|
| 364 |
+
all_hidden_states += (hidden_states,)
|
| 365 |
+
|
| 366 |
+
layer_outputs = layer(hidden_states, deterministic=deterministic, output_attentions=output_attentions)
|
| 367 |
+
|
| 368 |
+
hidden_states = layer_outputs[0]
|
| 369 |
+
|
| 370 |
+
if output_attentions:
|
| 371 |
+
all_attentions += (layer_outputs[1],)
|
| 372 |
+
|
| 373 |
+
if output_hidden_states:
|
| 374 |
+
all_hidden_states += (hidden_states,)
|
| 375 |
+
|
| 376 |
+
outputs = (hidden_states,)
|
| 377 |
+
if not return_dict:
|
| 378 |
+
return tuple(v for v in outputs if v is not None)
|
| 379 |
+
|
| 380 |
+
return FlaxBaseModelOutput(
|
| 381 |
+
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
class FlaxViTEncoder(nn.Module):
|
| 386 |
+
config: ViTConfig
|
| 387 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 388 |
+
|
| 389 |
+
def setup(self):
|
| 390 |
+
self.layer = FlaxViTLayerCollection(self.config, dtype=self.dtype)
|
| 391 |
+
|
| 392 |
+
def __call__(
|
| 393 |
+
self,
|
| 394 |
+
hidden_states,
|
| 395 |
+
deterministic: bool = True,
|
| 396 |
+
output_attentions: bool = False,
|
| 397 |
+
output_hidden_states: bool = False,
|
| 398 |
+
return_dict: bool = True,
|
| 399 |
+
):
|
| 400 |
+
return self.layer(
|
| 401 |
+
hidden_states,
|
| 402 |
+
deterministic=deterministic,
|
| 403 |
+
output_attentions=output_attentions,
|
| 404 |
+
output_hidden_states=output_hidden_states,
|
| 405 |
+
return_dict=return_dict,
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
class FlaxViTPooler(nn.Module):
|
| 410 |
+
config: ViTConfig
|
| 411 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 412 |
+
|
| 413 |
+
def setup(self):
|
| 414 |
+
self.dense = nn.Dense(
|
| 415 |
+
self.config.pooler_output_size,
|
| 416 |
+
kernel_init=jax.nn.initializers.variance_scaling(
|
| 417 |
+
self.config.initializer_range**2, "fan_in", "truncated_normal"
|
| 418 |
+
),
|
| 419 |
+
dtype=self.dtype,
|
| 420 |
+
)
|
| 421 |
+
self.activation = ACT2FN[self.config.pooler_act]
|
| 422 |
+
|
| 423 |
+
def __call__(self, hidden_states):
|
| 424 |
+
cls_hidden_state = hidden_states[:, 0]
|
| 425 |
+
cls_hidden_state = self.dense(cls_hidden_state)
|
| 426 |
+
return self.activation(cls_hidden_state)
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
class FlaxViTPreTrainedModel(FlaxPreTrainedModel):
|
| 430 |
+
"""
|
| 431 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 432 |
+
models.
|
| 433 |
+
"""
|
| 434 |
+
|
| 435 |
+
config_class = ViTConfig
|
| 436 |
+
base_model_prefix = "vit"
|
| 437 |
+
main_input_name = "pixel_values"
|
| 438 |
+
module_class: nn.Module = None
|
| 439 |
+
|
| 440 |
+
def __init__(
|
| 441 |
+
self,
|
| 442 |
+
config: ViTConfig,
|
| 443 |
+
input_shape=None,
|
| 444 |
+
seed: int = 0,
|
| 445 |
+
dtype: jnp.dtype = jnp.float32,
|
| 446 |
+
_do_init: bool = True,
|
| 447 |
+
**kwargs,
|
| 448 |
+
):
|
| 449 |
+
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
| 450 |
+
if input_shape is None:
|
| 451 |
+
input_shape = (1, config.image_size, config.image_size, config.num_channels)
|
| 452 |
+
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
| 453 |
+
|
| 454 |
+
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
| 455 |
+
# init input tensors
|
| 456 |
+
pixel_values = jnp.zeros(input_shape, dtype=self.dtype)
|
| 457 |
+
|
| 458 |
+
params_rng, dropout_rng = jax.random.split(rng)
|
| 459 |
+
rngs = {"params": params_rng, "dropout": dropout_rng}
|
| 460 |
+
|
| 461 |
+
random_params = self.module.init(rngs, pixel_values, return_dict=False)["params"]
|
| 462 |
+
|
| 463 |
+
if params is not None:
|
| 464 |
+
random_params = flatten_dict(unfreeze(random_params))
|
| 465 |
+
params = flatten_dict(unfreeze(params))
|
| 466 |
+
for missing_key in self._missing_keys:
|
| 467 |
+
params[missing_key] = random_params[missing_key]
|
| 468 |
+
self._missing_keys = set()
|
| 469 |
+
return freeze(unflatten_dict(params))
|
| 470 |
+
else:
|
| 471 |
+
return random_params
|
| 472 |
+
|
| 473 |
+
@add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 474 |
+
def __call__(
|
| 475 |
+
self,
|
| 476 |
+
pixel_values,
|
| 477 |
+
params: dict = None,
|
| 478 |
+
dropout_rng: jax.random.PRNGKey = None,
|
| 479 |
+
train: bool = False,
|
| 480 |
+
output_attentions: Optional[bool] = None,
|
| 481 |
+
output_hidden_states: Optional[bool] = None,
|
| 482 |
+
return_dict: Optional[bool] = None,
|
| 483 |
+
):
|
| 484 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 485 |
+
output_hidden_states = (
|
| 486 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 487 |
+
)
|
| 488 |
+
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
| 489 |
+
|
| 490 |
+
pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1))
|
| 491 |
+
# Handle any PRNG if needed
|
| 492 |
+
rngs = {}
|
| 493 |
+
if dropout_rng is not None:
|
| 494 |
+
rngs["dropout"] = dropout_rng
|
| 495 |
+
|
| 496 |
+
return self.module.apply(
|
| 497 |
+
{"params": params or self.params},
|
| 498 |
+
jnp.array(pixel_values, dtype=jnp.float32),
|
| 499 |
+
not train,
|
| 500 |
+
output_attentions,
|
| 501 |
+
output_hidden_states,
|
| 502 |
+
return_dict,
|
| 503 |
+
rngs=rngs,
|
| 504 |
+
)
|
| 505 |
+
|
| 506 |
+
|
| 507 |
+
class FlaxViTModule(nn.Module):
|
| 508 |
+
config: ViTConfig
|
| 509 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 510 |
+
add_pooling_layer: bool = True
|
| 511 |
+
|
| 512 |
+
def setup(self):
|
| 513 |
+
self.embeddings = FlaxViTEmbeddings(self.config, dtype=self.dtype)
|
| 514 |
+
self.encoder = FlaxViTEncoder(self.config, dtype=self.dtype)
|
| 515 |
+
self.layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
|
| 516 |
+
self.pooler = FlaxViTPooler(self.config, dtype=self.dtype) if self.add_pooling_layer else None
|
| 517 |
+
|
| 518 |
+
def __call__(
|
| 519 |
+
self,
|
| 520 |
+
pixel_values,
|
| 521 |
+
deterministic: bool = True,
|
| 522 |
+
output_attentions: bool = False,
|
| 523 |
+
output_hidden_states: bool = False,
|
| 524 |
+
return_dict: bool = True,
|
| 525 |
+
):
|
| 526 |
+
hidden_states = self.embeddings(pixel_values, deterministic=deterministic)
|
| 527 |
+
|
| 528 |
+
outputs = self.encoder(
|
| 529 |
+
hidden_states,
|
| 530 |
+
deterministic=deterministic,
|
| 531 |
+
output_attentions=output_attentions,
|
| 532 |
+
output_hidden_states=output_hidden_states,
|
| 533 |
+
return_dict=return_dict,
|
| 534 |
+
)
|
| 535 |
+
hidden_states = outputs[0]
|
| 536 |
+
hidden_states = self.layernorm(hidden_states)
|
| 537 |
+
pooled = self.pooler(hidden_states) if self.add_pooling_layer else None
|
| 538 |
+
|
| 539 |
+
if not return_dict:
|
| 540 |
+
# if pooled is None, don't return it
|
| 541 |
+
if pooled is None:
|
| 542 |
+
return (hidden_states,) + outputs[1:]
|
| 543 |
+
return (hidden_states, pooled) + outputs[1:]
|
| 544 |
+
|
| 545 |
+
return FlaxBaseModelOutputWithPooling(
|
| 546 |
+
last_hidden_state=hidden_states,
|
| 547 |
+
pooler_output=pooled,
|
| 548 |
+
hidden_states=outputs.hidden_states,
|
| 549 |
+
attentions=outputs.attentions,
|
| 550 |
+
)
|
| 551 |
+
|
| 552 |
+
|
| 553 |
+
@add_start_docstrings(
|
| 554 |
+
"The bare ViT Model transformer outputting raw hidden-states without any specific head on top.",
|
| 555 |
+
VIT_START_DOCSTRING,
|
| 556 |
+
)
|
| 557 |
+
class FlaxViTModel(FlaxViTPreTrainedModel):
|
| 558 |
+
module_class = FlaxViTModule
|
| 559 |
+
|
| 560 |
+
|
| 561 |
+
FLAX_VISION_MODEL_DOCSTRING = """
|
| 562 |
+
Returns:
|
| 563 |
+
|
| 564 |
+
Examples:
|
| 565 |
+
|
| 566 |
+
```python
|
| 567 |
+
>>> from transformers import AutoImageProcessor, FlaxViTModel
|
| 568 |
+
>>> from PIL import Image
|
| 569 |
+
>>> import requests
|
| 570 |
+
|
| 571 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 572 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 573 |
+
|
| 574 |
+
>>> image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
|
| 575 |
+
>>> model = FlaxViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
|
| 576 |
+
|
| 577 |
+
>>> inputs = image_processor(images=image, return_tensors="np")
|
| 578 |
+
>>> outputs = model(**inputs)
|
| 579 |
+
>>> last_hidden_states = outputs.last_hidden_state
|
| 580 |
+
```
|
| 581 |
+
"""
|
| 582 |
+
|
| 583 |
+
overwrite_call_docstring(FlaxViTModel, FLAX_VISION_MODEL_DOCSTRING)
|
| 584 |
+
append_replace_return_docstrings(FlaxViTModel, output_type=FlaxBaseModelOutputWithPooling, config_class=ViTConfig)
|
| 585 |
+
|
| 586 |
+
|
| 587 |
+
class FlaxViTForImageClassificationModule(nn.Module):
|
| 588 |
+
config: ViTConfig
|
| 589 |
+
dtype: jnp.dtype = jnp.float32
|
| 590 |
+
|
| 591 |
+
def setup(self):
|
| 592 |
+
self.vit = FlaxViTModule(config=self.config, dtype=self.dtype, add_pooling_layer=False)
|
| 593 |
+
self.classifier = nn.Dense(
|
| 594 |
+
self.config.num_labels,
|
| 595 |
+
dtype=self.dtype,
|
| 596 |
+
kernel_init=jax.nn.initializers.variance_scaling(
|
| 597 |
+
self.config.initializer_range**2, "fan_in", "truncated_normal"
|
| 598 |
+
),
|
| 599 |
+
)
|
| 600 |
+
|
| 601 |
+
def __call__(
|
| 602 |
+
self,
|
| 603 |
+
pixel_values=None,
|
| 604 |
+
deterministic: bool = True,
|
| 605 |
+
output_attentions=None,
|
| 606 |
+
output_hidden_states=None,
|
| 607 |
+
return_dict=None,
|
| 608 |
+
):
|
| 609 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 610 |
+
|
| 611 |
+
outputs = self.vit(
|
| 612 |
+
pixel_values,
|
| 613 |
+
deterministic=deterministic,
|
| 614 |
+
output_attentions=output_attentions,
|
| 615 |
+
output_hidden_states=output_hidden_states,
|
| 616 |
+
return_dict=return_dict,
|
| 617 |
+
)
|
| 618 |
+
|
| 619 |
+
hidden_states = outputs[0]
|
| 620 |
+
logits = self.classifier(hidden_states[:, 0, :])
|
| 621 |
+
|
| 622 |
+
if not return_dict:
|
| 623 |
+
output = (logits,) + outputs[2:]
|
| 624 |
+
return output
|
| 625 |
+
|
| 626 |
+
return FlaxSequenceClassifierOutput(
|
| 627 |
+
logits=logits,
|
| 628 |
+
hidden_states=outputs.hidden_states,
|
| 629 |
+
attentions=outputs.attentions,
|
| 630 |
+
)
|
| 631 |
+
|
| 632 |
+
|
| 633 |
+
@add_start_docstrings(
|
| 634 |
+
"""
|
| 635 |
+
ViT Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
|
| 636 |
+
the [CLS] token) e.g. for ImageNet.
|
| 637 |
+
""",
|
| 638 |
+
VIT_START_DOCSTRING,
|
| 639 |
+
)
|
| 640 |
+
class FlaxViTForImageClassification(FlaxViTPreTrainedModel):
|
| 641 |
+
module_class = FlaxViTForImageClassificationModule
|
| 642 |
+
|
| 643 |
+
|
| 644 |
+
FLAX_VISION_CLASSIF_DOCSTRING = """
|
| 645 |
+
Returns:
|
| 646 |
+
|
| 647 |
+
Example:
|
| 648 |
+
|
| 649 |
+
```python
|
| 650 |
+
>>> from transformers import AutoImageProcessor, FlaxViTForImageClassification
|
| 651 |
+
>>> from PIL import Image
|
| 652 |
+
>>> import jax
|
| 653 |
+
>>> import requests
|
| 654 |
+
|
| 655 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 656 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 657 |
+
|
| 658 |
+
>>> image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")
|
| 659 |
+
>>> model = FlaxViTForImageClassification.from_pretrained("google/vit-base-patch16-224")
|
| 660 |
+
|
| 661 |
+
>>> inputs = image_processor(images=image, return_tensors="np")
|
| 662 |
+
>>> outputs = model(**inputs)
|
| 663 |
+
>>> logits = outputs.logits
|
| 664 |
+
|
| 665 |
+
>>> # model predicts one of the 1000 ImageNet classes
|
| 666 |
+
>>> predicted_class_idx = jax.numpy.argmax(logits, axis=-1)
|
| 667 |
+
>>> print("Predicted class:", model.config.id2label[predicted_class_idx.item()])
|
| 668 |
+
```
|
| 669 |
+
"""
|
| 670 |
+
|
| 671 |
+
overwrite_call_docstring(FlaxViTForImageClassification, FLAX_VISION_CLASSIF_DOCSTRING)
|
| 672 |
+
append_replace_return_docstrings(
|
| 673 |
+
FlaxViTForImageClassification, output_type=FlaxSequenceClassifierOutput, config_class=ViTConfig
|
| 674 |
+
)
|
| 675 |
+
|
| 676 |
+
|
| 677 |
+
__all__ = ["FlaxViTForImageClassification", "FlaxViTModel", "FlaxViTPreTrainedModel"]
|
docs/transformers/build/lib/transformers/models/vit/modeling_tf_vit.py
ADDED
|
@@ -0,0 +1,907 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2021 Google AI, Ross Wightman, The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""TF 2.0 ViT model."""
|
| 16 |
+
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
import collections.abc
|
| 20 |
+
import math
|
| 21 |
+
from typing import Optional, Tuple, Union
|
| 22 |
+
|
| 23 |
+
import numpy as np
|
| 24 |
+
import tensorflow as tf
|
| 25 |
+
|
| 26 |
+
from ...activations_tf import get_tf_activation
|
| 27 |
+
from ...modeling_tf_outputs import TFBaseModelOutput, TFBaseModelOutputWithPooling, TFSequenceClassifierOutput
|
| 28 |
+
from ...modeling_tf_utils import (
|
| 29 |
+
TFModelInputType,
|
| 30 |
+
TFPreTrainedModel,
|
| 31 |
+
TFSequenceClassificationLoss,
|
| 32 |
+
get_initializer,
|
| 33 |
+
keras,
|
| 34 |
+
keras_serializable,
|
| 35 |
+
unpack_inputs,
|
| 36 |
+
)
|
| 37 |
+
from ...tf_utils import shape_list, stable_softmax
|
| 38 |
+
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
| 39 |
+
from .configuration_vit import ViTConfig
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
logger = logging.get_logger(__name__)
|
| 43 |
+
|
| 44 |
+
# General docstring
|
| 45 |
+
_CONFIG_FOR_DOC = "ViTConfig"
|
| 46 |
+
|
| 47 |
+
# Base docstring
|
| 48 |
+
_CHECKPOINT_FOR_DOC = "google/vit-base-patch16-224-in21k"
|
| 49 |
+
_EXPECTED_OUTPUT_SHAPE = [1, 197, 768]
|
| 50 |
+
|
| 51 |
+
# Image classification docstring
|
| 52 |
+
_IMAGE_CLASS_CHECKPOINT = "google/vit-base-patch16-224"
|
| 53 |
+
_IMAGE_CLASS_EXPECTED_OUTPUT = "Egyptian cat"
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class TFViTEmbeddings(keras.layers.Layer):
|
| 57 |
+
"""
|
| 58 |
+
Construct the CLS token, position and patch embeddings.
|
| 59 |
+
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
def __init__(self, config: ViTConfig, **kwargs):
|
| 63 |
+
super().__init__(**kwargs)
|
| 64 |
+
|
| 65 |
+
self.patch_embeddings = TFViTPatchEmbeddings(config, name="patch_embeddings")
|
| 66 |
+
self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
|
| 67 |
+
self.config = config
|
| 68 |
+
|
| 69 |
+
def build(self, input_shape=None):
|
| 70 |
+
num_patches = self.patch_embeddings.num_patches
|
| 71 |
+
self.cls_token = self.add_weight(
|
| 72 |
+
shape=(1, 1, self.config.hidden_size),
|
| 73 |
+
initializer=get_initializer(self.config.initializer_range),
|
| 74 |
+
trainable=True,
|
| 75 |
+
name="cls_token",
|
| 76 |
+
)
|
| 77 |
+
self.position_embeddings = self.add_weight(
|
| 78 |
+
shape=(1, num_patches + 1, self.config.hidden_size),
|
| 79 |
+
initializer=get_initializer(self.config.initializer_range),
|
| 80 |
+
trainable=True,
|
| 81 |
+
name="position_embeddings",
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
if self.built:
|
| 85 |
+
return
|
| 86 |
+
self.built = True
|
| 87 |
+
if getattr(self, "patch_embeddings", None) is not None:
|
| 88 |
+
with tf.name_scope(self.patch_embeddings.name):
|
| 89 |
+
self.patch_embeddings.build(None)
|
| 90 |
+
|
| 91 |
+
def interpolate_pos_encoding(self, embeddings, height, width) -> tf.Tensor:
|
| 92 |
+
"""
|
| 93 |
+
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
|
| 94 |
+
resolution images.
|
| 95 |
+
|
| 96 |
+
Source:
|
| 97 |
+
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
| 98 |
+
"""
|
| 99 |
+
|
| 100 |
+
batch_size, seq_len, dim = shape_list(embeddings)
|
| 101 |
+
num_patches = seq_len - 1
|
| 102 |
+
|
| 103 |
+
_, num_positions, _ = shape_list(self.position_embeddings)
|
| 104 |
+
num_positions -= 1
|
| 105 |
+
|
| 106 |
+
if num_patches == num_positions and height == width:
|
| 107 |
+
return self.position_embeddings
|
| 108 |
+
class_pos_embed = self.position_embeddings[:, :1]
|
| 109 |
+
patch_pos_embed = self.position_embeddings[:, 1:]
|
| 110 |
+
h0 = height // self.config.patch_size
|
| 111 |
+
w0 = width // self.config.patch_size
|
| 112 |
+
patch_pos_embed = tf.image.resize(
|
| 113 |
+
images=tf.reshape(
|
| 114 |
+
patch_pos_embed, shape=(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
|
| 115 |
+
),
|
| 116 |
+
size=(h0, w0),
|
| 117 |
+
method="bicubic",
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
shape = shape_list(patch_pos_embed)
|
| 121 |
+
assert h0 == shape[-3] and w0 == shape[-2]
|
| 122 |
+
patch_pos_embed = tf.reshape(tensor=patch_pos_embed, shape=(1, -1, dim))
|
| 123 |
+
return tf.concat(values=(class_pos_embed, patch_pos_embed), axis=1)
|
| 124 |
+
|
| 125 |
+
def call(
|
| 126 |
+
self, pixel_values: tf.Tensor, interpolate_pos_encoding: bool = False, training: bool = False
|
| 127 |
+
) -> tf.Tensor:
|
| 128 |
+
batch_size, num_channels, height, width = shape_list(pixel_values)
|
| 129 |
+
embeddings = self.patch_embeddings(
|
| 130 |
+
pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, training=training
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
# add the [CLS] token to the embedded patch tokens
|
| 134 |
+
cls_tokens = tf.repeat(self.cls_token, repeats=batch_size, axis=0)
|
| 135 |
+
embeddings = tf.concat((cls_tokens, embeddings), axis=1)
|
| 136 |
+
|
| 137 |
+
# add positional encoding to each token
|
| 138 |
+
if interpolate_pos_encoding:
|
| 139 |
+
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
|
| 140 |
+
else:
|
| 141 |
+
embeddings = embeddings + self.position_embeddings
|
| 142 |
+
|
| 143 |
+
embeddings = self.dropout(embeddings, training=training)
|
| 144 |
+
|
| 145 |
+
return embeddings
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
# Based on timm implementation, which can be found here:
|
| 149 |
+
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
| 150 |
+
class TFViTPatchEmbeddings(keras.layers.Layer):
|
| 151 |
+
"""
|
| 152 |
+
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
|
| 153 |
+
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
|
| 154 |
+
Transformer.
|
| 155 |
+
"""
|
| 156 |
+
|
| 157 |
+
def __init__(self, config: ViTConfig, **kwargs):
|
| 158 |
+
super().__init__(**kwargs)
|
| 159 |
+
image_size, patch_size = config.image_size, config.patch_size
|
| 160 |
+
num_channels, hidden_size = config.num_channels, config.hidden_size
|
| 161 |
+
|
| 162 |
+
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
|
| 163 |
+
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
|
| 164 |
+
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
| 165 |
+
self.image_size = image_size
|
| 166 |
+
self.patch_size = patch_size
|
| 167 |
+
self.num_patches = num_patches
|
| 168 |
+
self.num_channels = num_channels
|
| 169 |
+
self.config = config
|
| 170 |
+
|
| 171 |
+
self.projection = keras.layers.Conv2D(
|
| 172 |
+
filters=hidden_size,
|
| 173 |
+
kernel_size=patch_size,
|
| 174 |
+
strides=patch_size,
|
| 175 |
+
padding="valid",
|
| 176 |
+
data_format="channels_last",
|
| 177 |
+
use_bias=True,
|
| 178 |
+
kernel_initializer=get_initializer(self.config.initializer_range),
|
| 179 |
+
bias_initializer="zeros",
|
| 180 |
+
name="projection",
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
def call(
|
| 184 |
+
self, pixel_values: tf.Tensor, interpolate_pos_encoding: bool = False, training: bool = False
|
| 185 |
+
) -> tf.Tensor:
|
| 186 |
+
batch_size, num_channels, height, width = shape_list(pixel_values)
|
| 187 |
+
if tf.executing_eagerly() and num_channels != self.num_channels:
|
| 188 |
+
raise ValueError(
|
| 189 |
+
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
| 190 |
+
)
|
| 191 |
+
if not interpolate_pos_encoding:
|
| 192 |
+
if tf.executing_eagerly():
|
| 193 |
+
if height != self.image_size[0] or width != self.image_size[1]:
|
| 194 |
+
raise ValueError(
|
| 195 |
+
f"Input image size ({height}*{width}) doesn't match model"
|
| 196 |
+
f" ({self.image_size[0]}*{self.image_size[1]})."
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
# When running on CPU, `keras.layers.Conv2D` doesn't support `NCHW` format.
|
| 200 |
+
# So change the input format from `NCHW` to `NHWC`.
|
| 201 |
+
# shape = (batch_size, in_height, in_width, in_channels=num_channels)
|
| 202 |
+
pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1))
|
| 203 |
+
|
| 204 |
+
projection = self.projection(pixel_values)
|
| 205 |
+
|
| 206 |
+
# Change the 2D spatial dimensions to a single temporal dimension.
|
| 207 |
+
# shape = (batch_size, num_patches, out_channels=embed_dim)
|
| 208 |
+
num_patches = (width // self.patch_size[1]) * (height // self.patch_size[0])
|
| 209 |
+
embeddings = tf.reshape(tensor=projection, shape=(batch_size, num_patches, -1))
|
| 210 |
+
|
| 211 |
+
return embeddings
|
| 212 |
+
|
| 213 |
+
def build(self, input_shape=None):
|
| 214 |
+
if self.built:
|
| 215 |
+
return
|
| 216 |
+
self.built = True
|
| 217 |
+
if getattr(self, "projection", None) is not None:
|
| 218 |
+
with tf.name_scope(self.projection.name):
|
| 219 |
+
self.projection.build([None, None, None, self.num_channels])
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
class TFViTSelfAttention(keras.layers.Layer):
|
| 223 |
+
def __init__(self, config: ViTConfig, **kwargs):
|
| 224 |
+
super().__init__(**kwargs)
|
| 225 |
+
|
| 226 |
+
if config.hidden_size % config.num_attention_heads != 0:
|
| 227 |
+
raise ValueError(
|
| 228 |
+
f"The hidden size ({config.hidden_size}) is not a multiple of the number "
|
| 229 |
+
f"of attention heads ({config.num_attention_heads})"
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
self.num_attention_heads = config.num_attention_heads
|
| 233 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
| 234 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
| 235 |
+
self.sqrt_att_head_size = math.sqrt(self.attention_head_size)
|
| 236 |
+
|
| 237 |
+
self.query = keras.layers.Dense(
|
| 238 |
+
units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
|
| 239 |
+
)
|
| 240 |
+
self.key = keras.layers.Dense(
|
| 241 |
+
units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
|
| 242 |
+
)
|
| 243 |
+
self.value = keras.layers.Dense(
|
| 244 |
+
units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
|
| 245 |
+
)
|
| 246 |
+
self.dropout = keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
|
| 247 |
+
self.config = config
|
| 248 |
+
|
| 249 |
+
def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
|
| 250 |
+
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
|
| 251 |
+
tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
|
| 252 |
+
|
| 253 |
+
# Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size]
|
| 254 |
+
return tf.transpose(tensor, perm=[0, 2, 1, 3])
|
| 255 |
+
|
| 256 |
+
def call(
|
| 257 |
+
self,
|
| 258 |
+
hidden_states: tf.Tensor,
|
| 259 |
+
head_mask: tf.Tensor,
|
| 260 |
+
output_attentions: bool,
|
| 261 |
+
training: bool = False,
|
| 262 |
+
) -> Tuple[tf.Tensor]:
|
| 263 |
+
batch_size = shape_list(hidden_states)[0]
|
| 264 |
+
mixed_query_layer = self.query(inputs=hidden_states)
|
| 265 |
+
mixed_key_layer = self.key(inputs=hidden_states)
|
| 266 |
+
mixed_value_layer = self.value(inputs=hidden_states)
|
| 267 |
+
query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
|
| 268 |
+
key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
|
| 269 |
+
value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
|
| 270 |
+
|
| 271 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
| 272 |
+
# (batch size, num_heads, seq_len_q, seq_len_k)
|
| 273 |
+
attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
|
| 274 |
+
dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype)
|
| 275 |
+
attention_scores = tf.divide(attention_scores, dk)
|
| 276 |
+
|
| 277 |
+
# Normalize the attention scores to probabilities.
|
| 278 |
+
attention_probs = stable_softmax(logits=attention_scores, axis=-1)
|
| 279 |
+
|
| 280 |
+
# This is actually dropping out entire tokens to attend to, which might
|
| 281 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
| 282 |
+
attention_probs = self.dropout(inputs=attention_probs, training=training)
|
| 283 |
+
|
| 284 |
+
# Mask heads if we want to
|
| 285 |
+
if head_mask is not None:
|
| 286 |
+
attention_probs = tf.multiply(attention_probs, head_mask)
|
| 287 |
+
|
| 288 |
+
attention_output = tf.matmul(attention_probs, value_layer)
|
| 289 |
+
attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3])
|
| 290 |
+
|
| 291 |
+
# (batch_size, seq_len_q, all_head_size)
|
| 292 |
+
attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size))
|
| 293 |
+
outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
|
| 294 |
+
|
| 295 |
+
return outputs
|
| 296 |
+
|
| 297 |
+
def build(self, input_shape=None):
|
| 298 |
+
if self.built:
|
| 299 |
+
return
|
| 300 |
+
self.built = True
|
| 301 |
+
if getattr(self, "query", None) is not None:
|
| 302 |
+
with tf.name_scope(self.query.name):
|
| 303 |
+
self.query.build([None, None, self.config.hidden_size])
|
| 304 |
+
if getattr(self, "key", None) is not None:
|
| 305 |
+
with tf.name_scope(self.key.name):
|
| 306 |
+
self.key.build([None, None, self.config.hidden_size])
|
| 307 |
+
if getattr(self, "value", None) is not None:
|
| 308 |
+
with tf.name_scope(self.value.name):
|
| 309 |
+
self.value.build([None, None, self.config.hidden_size])
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
class TFViTSelfOutput(keras.layers.Layer):
|
| 313 |
+
"""
|
| 314 |
+
The residual connection is defined in TFViTLayer instead of here (as is the case with other models), due to the
|
| 315 |
+
layernorm applied before each block.
|
| 316 |
+
"""
|
| 317 |
+
|
| 318 |
+
def __init__(self, config: ViTConfig, **kwargs):
|
| 319 |
+
super().__init__(**kwargs)
|
| 320 |
+
|
| 321 |
+
self.dense = keras.layers.Dense(
|
| 322 |
+
units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
| 323 |
+
)
|
| 324 |
+
self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
|
| 325 |
+
self.config = config
|
| 326 |
+
|
| 327 |
+
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
|
| 328 |
+
hidden_states = self.dense(inputs=hidden_states)
|
| 329 |
+
hidden_states = self.dropout(inputs=hidden_states, training=training)
|
| 330 |
+
|
| 331 |
+
return hidden_states
|
| 332 |
+
|
| 333 |
+
def build(self, input_shape=None):
|
| 334 |
+
if self.built:
|
| 335 |
+
return
|
| 336 |
+
self.built = True
|
| 337 |
+
if getattr(self, "dense", None) is not None:
|
| 338 |
+
with tf.name_scope(self.dense.name):
|
| 339 |
+
self.dense.build([None, None, self.config.hidden_size])
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
class TFViTAttention(keras.layers.Layer):
|
| 343 |
+
def __init__(self, config: ViTConfig, **kwargs):
|
| 344 |
+
super().__init__(**kwargs)
|
| 345 |
+
|
| 346 |
+
self.self_attention = TFViTSelfAttention(config, name="attention")
|
| 347 |
+
self.dense_output = TFViTSelfOutput(config, name="output")
|
| 348 |
+
|
| 349 |
+
def prune_heads(self, heads):
|
| 350 |
+
raise NotImplementedError
|
| 351 |
+
|
| 352 |
+
def call(
|
| 353 |
+
self,
|
| 354 |
+
input_tensor: tf.Tensor,
|
| 355 |
+
head_mask: tf.Tensor,
|
| 356 |
+
output_attentions: bool,
|
| 357 |
+
training: bool = False,
|
| 358 |
+
) -> Tuple[tf.Tensor]:
|
| 359 |
+
self_outputs = self.self_attention(
|
| 360 |
+
hidden_states=input_tensor, head_mask=head_mask, output_attentions=output_attentions, training=training
|
| 361 |
+
)
|
| 362 |
+
attention_output = self.dense_output(
|
| 363 |
+
hidden_states=self_outputs[0], input_tensor=input_tensor, training=training
|
| 364 |
+
)
|
| 365 |
+
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
| 366 |
+
|
| 367 |
+
return outputs
|
| 368 |
+
|
| 369 |
+
def build(self, input_shape=None):
|
| 370 |
+
if self.built:
|
| 371 |
+
return
|
| 372 |
+
self.built = True
|
| 373 |
+
if getattr(self, "self_attention", None) is not None:
|
| 374 |
+
with tf.name_scope(self.self_attention.name):
|
| 375 |
+
self.self_attention.build(None)
|
| 376 |
+
if getattr(self, "dense_output", None) is not None:
|
| 377 |
+
with tf.name_scope(self.dense_output.name):
|
| 378 |
+
self.dense_output.build(None)
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
class TFViTIntermediate(keras.layers.Layer):
|
| 382 |
+
def __init__(self, config: ViTConfig, **kwargs):
|
| 383 |
+
super().__init__(**kwargs)
|
| 384 |
+
|
| 385 |
+
self.dense = keras.layers.Dense(
|
| 386 |
+
units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
if isinstance(config.hidden_act, str):
|
| 390 |
+
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
|
| 391 |
+
else:
|
| 392 |
+
self.intermediate_act_fn = config.hidden_act
|
| 393 |
+
self.config = config
|
| 394 |
+
|
| 395 |
+
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
|
| 396 |
+
hidden_states = self.dense(inputs=hidden_states)
|
| 397 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
| 398 |
+
|
| 399 |
+
return hidden_states
|
| 400 |
+
|
| 401 |
+
def build(self, input_shape=None):
|
| 402 |
+
if self.built:
|
| 403 |
+
return
|
| 404 |
+
self.built = True
|
| 405 |
+
if getattr(self, "dense", None) is not None:
|
| 406 |
+
with tf.name_scope(self.dense.name):
|
| 407 |
+
self.dense.build([None, None, self.config.hidden_size])
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
class TFViTOutput(keras.layers.Layer):
|
| 411 |
+
def __init__(self, config: ViTConfig, **kwargs):
|
| 412 |
+
super().__init__(**kwargs)
|
| 413 |
+
|
| 414 |
+
self.dense = keras.layers.Dense(
|
| 415 |
+
units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
| 416 |
+
)
|
| 417 |
+
self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
|
| 418 |
+
self.config = config
|
| 419 |
+
|
| 420 |
+
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
|
| 421 |
+
hidden_states = self.dense(inputs=hidden_states)
|
| 422 |
+
hidden_states = self.dropout(inputs=hidden_states, training=training)
|
| 423 |
+
hidden_states = hidden_states + input_tensor
|
| 424 |
+
|
| 425 |
+
return hidden_states
|
| 426 |
+
|
| 427 |
+
def build(self, input_shape=None):
|
| 428 |
+
if self.built:
|
| 429 |
+
return
|
| 430 |
+
self.built = True
|
| 431 |
+
if getattr(self, "dense", None) is not None:
|
| 432 |
+
with tf.name_scope(self.dense.name):
|
| 433 |
+
self.dense.build([None, None, self.config.intermediate_size])
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
class TFViTLayer(keras.layers.Layer):
|
| 437 |
+
"""This corresponds to the Block class in the timm implementation."""
|
| 438 |
+
|
| 439 |
+
def __init__(self, config: ViTConfig, **kwargs):
|
| 440 |
+
super().__init__(**kwargs)
|
| 441 |
+
|
| 442 |
+
self.attention = TFViTAttention(config, name="attention")
|
| 443 |
+
self.intermediate = TFViTIntermediate(config, name="intermediate")
|
| 444 |
+
self.vit_output = TFViTOutput(config, name="output")
|
| 445 |
+
|
| 446 |
+
self.layernorm_before = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_before")
|
| 447 |
+
self.layernorm_after = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_after")
|
| 448 |
+
self.config = config
|
| 449 |
+
|
| 450 |
+
def call(
|
| 451 |
+
self,
|
| 452 |
+
hidden_states: tf.Tensor,
|
| 453 |
+
head_mask: tf.Tensor,
|
| 454 |
+
output_attentions: bool,
|
| 455 |
+
training: bool = False,
|
| 456 |
+
) -> Tuple[tf.Tensor]:
|
| 457 |
+
attention_outputs = self.attention(
|
| 458 |
+
# in ViT, layernorm is applied before self-attention
|
| 459 |
+
input_tensor=self.layernorm_before(inputs=hidden_states),
|
| 460 |
+
head_mask=head_mask,
|
| 461 |
+
output_attentions=output_attentions,
|
| 462 |
+
training=training,
|
| 463 |
+
)
|
| 464 |
+
attention_output = attention_outputs[0]
|
| 465 |
+
|
| 466 |
+
# first residual connection
|
| 467 |
+
hidden_states = attention_output + hidden_states
|
| 468 |
+
|
| 469 |
+
# in ViT, layernorm is also applied after self-attention
|
| 470 |
+
layer_output = self.layernorm_after(inputs=hidden_states)
|
| 471 |
+
|
| 472 |
+
intermediate_output = self.intermediate(hidden_states=layer_output)
|
| 473 |
+
|
| 474 |
+
# second residual connection is done here
|
| 475 |
+
layer_output = self.vit_output(
|
| 476 |
+
hidden_states=intermediate_output, input_tensor=hidden_states, training=training
|
| 477 |
+
)
|
| 478 |
+
outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
|
| 479 |
+
|
| 480 |
+
return outputs
|
| 481 |
+
|
| 482 |
+
def build(self, input_shape=None):
|
| 483 |
+
if self.built:
|
| 484 |
+
return
|
| 485 |
+
self.built = True
|
| 486 |
+
if getattr(self, "attention", None) is not None:
|
| 487 |
+
with tf.name_scope(self.attention.name):
|
| 488 |
+
self.attention.build(None)
|
| 489 |
+
if getattr(self, "intermediate", None) is not None:
|
| 490 |
+
with tf.name_scope(self.intermediate.name):
|
| 491 |
+
self.intermediate.build(None)
|
| 492 |
+
if getattr(self, "vit_output", None) is not None:
|
| 493 |
+
with tf.name_scope(self.vit_output.name):
|
| 494 |
+
self.vit_output.build(None)
|
| 495 |
+
if getattr(self, "layernorm_before", None) is not None:
|
| 496 |
+
with tf.name_scope(self.layernorm_before.name):
|
| 497 |
+
self.layernorm_before.build([None, None, self.config.hidden_size])
|
| 498 |
+
if getattr(self, "layernorm_after", None) is not None:
|
| 499 |
+
with tf.name_scope(self.layernorm_after.name):
|
| 500 |
+
self.layernorm_after.build([None, None, self.config.hidden_size])
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
class TFViTEncoder(keras.layers.Layer):
|
| 504 |
+
def __init__(self, config: ViTConfig, **kwargs):
|
| 505 |
+
super().__init__(**kwargs)
|
| 506 |
+
|
| 507 |
+
self.layer = [TFViTLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)]
|
| 508 |
+
|
| 509 |
+
def call(
|
| 510 |
+
self,
|
| 511 |
+
hidden_states: tf.Tensor,
|
| 512 |
+
head_mask: tf.Tensor,
|
| 513 |
+
output_attentions: bool,
|
| 514 |
+
output_hidden_states: bool,
|
| 515 |
+
return_dict: bool,
|
| 516 |
+
training: bool = False,
|
| 517 |
+
) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
|
| 518 |
+
all_hidden_states = () if output_hidden_states else None
|
| 519 |
+
all_attentions = () if output_attentions else None
|
| 520 |
+
|
| 521 |
+
for i, layer_module in enumerate(self.layer):
|
| 522 |
+
if output_hidden_states:
|
| 523 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 524 |
+
|
| 525 |
+
layer_outputs = layer_module(
|
| 526 |
+
hidden_states=hidden_states,
|
| 527 |
+
head_mask=head_mask[i],
|
| 528 |
+
output_attentions=output_attentions,
|
| 529 |
+
training=training,
|
| 530 |
+
)
|
| 531 |
+
hidden_states = layer_outputs[0]
|
| 532 |
+
|
| 533 |
+
if output_attentions:
|
| 534 |
+
all_attentions = all_attentions + (layer_outputs[1],)
|
| 535 |
+
|
| 536 |
+
# Add last layer
|
| 537 |
+
if output_hidden_states:
|
| 538 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 539 |
+
|
| 540 |
+
if not return_dict:
|
| 541 |
+
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
|
| 542 |
+
|
| 543 |
+
return TFBaseModelOutput(
|
| 544 |
+
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
|
| 545 |
+
)
|
| 546 |
+
|
| 547 |
+
def build(self, input_shape=None):
|
| 548 |
+
if self.built:
|
| 549 |
+
return
|
| 550 |
+
self.built = True
|
| 551 |
+
if getattr(self, "layer", None) is not None:
|
| 552 |
+
for layer in self.layer:
|
| 553 |
+
with tf.name_scope(layer.name):
|
| 554 |
+
layer.build(None)
|
| 555 |
+
|
| 556 |
+
|
| 557 |
+
@keras_serializable
|
| 558 |
+
class TFViTMainLayer(keras.layers.Layer):
|
| 559 |
+
config_class = ViTConfig
|
| 560 |
+
|
| 561 |
+
def __init__(self, config: ViTConfig, add_pooling_layer: bool = True, **kwargs):
|
| 562 |
+
super().__init__(**kwargs)
|
| 563 |
+
|
| 564 |
+
self.config = config
|
| 565 |
+
|
| 566 |
+
self.embeddings = TFViTEmbeddings(config, name="embeddings")
|
| 567 |
+
self.encoder = TFViTEncoder(config, name="encoder")
|
| 568 |
+
self.layernorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm")
|
| 569 |
+
self.pooler = TFViTPooler(config, name="pooler") if add_pooling_layer else None
|
| 570 |
+
|
| 571 |
+
def get_input_embeddings(self) -> keras.layers.Layer:
|
| 572 |
+
return self.embeddings.patch_embeddings
|
| 573 |
+
|
| 574 |
+
def _prune_heads(self, heads_to_prune):
|
| 575 |
+
"""
|
| 576 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
| 577 |
+
class PreTrainedModel
|
| 578 |
+
"""
|
| 579 |
+
raise NotImplementedError
|
| 580 |
+
|
| 581 |
+
@unpack_inputs
|
| 582 |
+
def call(
|
| 583 |
+
self,
|
| 584 |
+
pixel_values: TFModelInputType | None = None,
|
| 585 |
+
head_mask: np.ndarray | tf.Tensor | None = None,
|
| 586 |
+
output_attentions: Optional[bool] = None,
|
| 587 |
+
output_hidden_states: Optional[bool] = None,
|
| 588 |
+
interpolate_pos_encoding: Optional[bool] = None,
|
| 589 |
+
return_dict: Optional[bool] = None,
|
| 590 |
+
training: bool = False,
|
| 591 |
+
) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
|
| 592 |
+
if pixel_values is None:
|
| 593 |
+
raise ValueError("You have to specify pixel_values")
|
| 594 |
+
|
| 595 |
+
embedding_output = self.embeddings(
|
| 596 |
+
pixel_values=pixel_values,
|
| 597 |
+
interpolate_pos_encoding=interpolate_pos_encoding,
|
| 598 |
+
training=training,
|
| 599 |
+
)
|
| 600 |
+
|
| 601 |
+
# Prepare head mask if needed
|
| 602 |
+
# 1.0 in head_mask indicate we keep the head
|
| 603 |
+
# attention_probs has shape bsz x n_heads x N x N
|
| 604 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
| 605 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
| 606 |
+
if head_mask is not None:
|
| 607 |
+
raise NotImplementedError
|
| 608 |
+
else:
|
| 609 |
+
head_mask = [None] * self.config.num_hidden_layers
|
| 610 |
+
|
| 611 |
+
encoder_outputs = self.encoder(
|
| 612 |
+
hidden_states=embedding_output,
|
| 613 |
+
head_mask=head_mask,
|
| 614 |
+
output_attentions=output_attentions,
|
| 615 |
+
output_hidden_states=output_hidden_states,
|
| 616 |
+
return_dict=return_dict,
|
| 617 |
+
training=training,
|
| 618 |
+
)
|
| 619 |
+
|
| 620 |
+
sequence_output = encoder_outputs[0]
|
| 621 |
+
sequence_output = self.layernorm(inputs=sequence_output)
|
| 622 |
+
pooled_output = self.pooler(hidden_states=sequence_output) if self.pooler is not None else None
|
| 623 |
+
|
| 624 |
+
if not return_dict:
|
| 625 |
+
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
| 626 |
+
|
| 627 |
+
return TFBaseModelOutputWithPooling(
|
| 628 |
+
last_hidden_state=sequence_output,
|
| 629 |
+
pooler_output=pooled_output,
|
| 630 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 631 |
+
attentions=encoder_outputs.attentions,
|
| 632 |
+
)
|
| 633 |
+
|
| 634 |
+
def build(self, input_shape=None):
|
| 635 |
+
if self.built:
|
| 636 |
+
return
|
| 637 |
+
self.built = True
|
| 638 |
+
if getattr(self, "embeddings", None) is not None:
|
| 639 |
+
with tf.name_scope(self.embeddings.name):
|
| 640 |
+
self.embeddings.build(None)
|
| 641 |
+
if getattr(self, "encoder", None) is not None:
|
| 642 |
+
with tf.name_scope(self.encoder.name):
|
| 643 |
+
self.encoder.build(None)
|
| 644 |
+
if getattr(self, "layernorm", None) is not None:
|
| 645 |
+
with tf.name_scope(self.layernorm.name):
|
| 646 |
+
self.layernorm.build([None, None, self.config.hidden_size])
|
| 647 |
+
if getattr(self, "pooler", None) is not None:
|
| 648 |
+
with tf.name_scope(self.pooler.name):
|
| 649 |
+
self.pooler.build(None)
|
| 650 |
+
|
| 651 |
+
|
| 652 |
+
class TFViTPreTrainedModel(TFPreTrainedModel):
|
| 653 |
+
"""
|
| 654 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 655 |
+
models.
|
| 656 |
+
"""
|
| 657 |
+
|
| 658 |
+
config_class = ViTConfig
|
| 659 |
+
base_model_prefix = "vit"
|
| 660 |
+
main_input_name = "pixel_values"
|
| 661 |
+
|
| 662 |
+
|
| 663 |
+
VIT_START_DOCSTRING = r"""
|
| 664 |
+
|
| 665 |
+
This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 666 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| 667 |
+
etc.)
|
| 668 |
+
|
| 669 |
+
This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
|
| 670 |
+
as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
|
| 671 |
+
behavior.
|
| 672 |
+
|
| 673 |
+
<Tip>
|
| 674 |
+
|
| 675 |
+
TensorFlow models and layers in `transformers` accept two formats as input:
|
| 676 |
+
|
| 677 |
+
- having all inputs as keyword arguments (like PyTorch models), or
|
| 678 |
+
- having all inputs as a list, tuple or dict in the first positional argument.
|
| 679 |
+
|
| 680 |
+
The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
|
| 681 |
+
and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
|
| 682 |
+
pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
|
| 683 |
+
format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
|
| 684 |
+
the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
|
| 685 |
+
positional argument:
|
| 686 |
+
|
| 687 |
+
- a single Tensor with `pixel_values` only and nothing else: `model(pixel_values)`
|
| 688 |
+
- a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
|
| 689 |
+
`model([pixel_values, attention_mask])` or `model([pixel_values, attention_mask, token_type_ids])`
|
| 690 |
+
- a dictionary with one or several input Tensors associated to the input names given in the docstring:
|
| 691 |
+
`model({"pixel_values": pixel_values, "token_type_ids": token_type_ids})`
|
| 692 |
+
|
| 693 |
+
Note that when creating models and layers with
|
| 694 |
+
[subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
|
| 695 |
+
about any of this, as you can just pass inputs like you would to any other Python function!
|
| 696 |
+
|
| 697 |
+
</Tip>
|
| 698 |
+
|
| 699 |
+
Args:
|
| 700 |
+
config ([`ViTConfig`]): Model configuration class with all the parameters of the model.
|
| 701 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 702 |
+
configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
|
| 703 |
+
"""
|
| 704 |
+
|
| 705 |
+
VIT_INPUTS_DOCSTRING = r"""
|
| 706 |
+
Args:
|
| 707 |
+
pixel_values (`np.ndarray`, `tf.Tensor`, `List[tf.Tensor]` ``Dict[str, tf.Tensor]` or `Dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`):
|
| 708 |
+
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`]
|
| 709 |
+
for details.
|
| 710 |
+
|
| 711 |
+
head_mask (`np.ndarray` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
| 712 |
+
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
| 713 |
+
|
| 714 |
+
- 1 indicates the head is **not masked**,
|
| 715 |
+
- 0 indicates the head is **masked**.
|
| 716 |
+
|
| 717 |
+
output_attentions (`bool`, *optional*):
|
| 718 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 719 |
+
tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
|
| 720 |
+
config will be used instead.
|
| 721 |
+
output_hidden_states (`bool`, *optional*):
|
| 722 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 723 |
+
more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
|
| 724 |
+
used instead.
|
| 725 |
+
interpolate_pos_encoding (`bool`, *optional*):
|
| 726 |
+
Whether to interpolate the pre-trained position encodings.
|
| 727 |
+
return_dict (`bool`, *optional*):
|
| 728 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
|
| 729 |
+
eager mode, in graph mode the value will always be set to True.
|
| 730 |
+
training (`bool`, *optional*, defaults to `False``):
|
| 731 |
+
Whether or not to use the model in training mode (some modules like dropout modules have different
|
| 732 |
+
behaviors between training and evaluation).
|
| 733 |
+
"""
|
| 734 |
+
|
| 735 |
+
|
| 736 |
+
@add_start_docstrings(
|
| 737 |
+
"The bare ViT Model transformer outputting raw hidden-states without any specific head on top.",
|
| 738 |
+
VIT_START_DOCSTRING,
|
| 739 |
+
)
|
| 740 |
+
class TFViTModel(TFViTPreTrainedModel):
|
| 741 |
+
def __init__(self, config: ViTConfig, *inputs, add_pooling_layer=True, **kwargs):
|
| 742 |
+
super().__init__(config, *inputs, **kwargs)
|
| 743 |
+
|
| 744 |
+
self.vit = TFViTMainLayer(config, add_pooling_layer=add_pooling_layer, name="vit")
|
| 745 |
+
|
| 746 |
+
@unpack_inputs
|
| 747 |
+
@add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
|
| 748 |
+
@add_code_sample_docstrings(
|
| 749 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 750 |
+
output_type=TFBaseModelOutputWithPooling,
|
| 751 |
+
config_class=_CONFIG_FOR_DOC,
|
| 752 |
+
modality="vision",
|
| 753 |
+
expected_output=_EXPECTED_OUTPUT_SHAPE,
|
| 754 |
+
)
|
| 755 |
+
def call(
|
| 756 |
+
self,
|
| 757 |
+
pixel_values: TFModelInputType | None = None,
|
| 758 |
+
head_mask: np.ndarray | tf.Tensor | None = None,
|
| 759 |
+
output_attentions: Optional[bool] = None,
|
| 760 |
+
output_hidden_states: Optional[bool] = None,
|
| 761 |
+
interpolate_pos_encoding: Optional[bool] = None,
|
| 762 |
+
return_dict: Optional[bool] = None,
|
| 763 |
+
training: bool = False,
|
| 764 |
+
) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
|
| 765 |
+
outputs = self.vit(
|
| 766 |
+
pixel_values=pixel_values,
|
| 767 |
+
head_mask=head_mask,
|
| 768 |
+
output_attentions=output_attentions,
|
| 769 |
+
output_hidden_states=output_hidden_states,
|
| 770 |
+
interpolate_pos_encoding=interpolate_pos_encoding,
|
| 771 |
+
return_dict=return_dict,
|
| 772 |
+
training=training,
|
| 773 |
+
)
|
| 774 |
+
|
| 775 |
+
return outputs
|
| 776 |
+
|
| 777 |
+
def build(self, input_shape=None):
|
| 778 |
+
if self.built:
|
| 779 |
+
return
|
| 780 |
+
self.built = True
|
| 781 |
+
if getattr(self, "vit", None) is not None:
|
| 782 |
+
with tf.name_scope(self.vit.name):
|
| 783 |
+
self.vit.build(None)
|
| 784 |
+
|
| 785 |
+
|
| 786 |
+
class TFViTPooler(keras.layers.Layer):
|
| 787 |
+
def __init__(self, config: ViTConfig, **kwargs):
|
| 788 |
+
super().__init__(**kwargs)
|
| 789 |
+
|
| 790 |
+
self.dense = keras.layers.Dense(
|
| 791 |
+
units=config.pooler_output_size,
|
| 792 |
+
kernel_initializer=get_initializer(config.initializer_range),
|
| 793 |
+
activation=config.pooler_act,
|
| 794 |
+
name="dense",
|
| 795 |
+
)
|
| 796 |
+
self.config = config
|
| 797 |
+
|
| 798 |
+
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
|
| 799 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
| 800 |
+
# to the first token.
|
| 801 |
+
first_token_tensor = hidden_states[:, 0]
|
| 802 |
+
pooled_output = self.dense(inputs=first_token_tensor)
|
| 803 |
+
|
| 804 |
+
return pooled_output
|
| 805 |
+
|
| 806 |
+
def build(self, input_shape=None):
|
| 807 |
+
if self.built:
|
| 808 |
+
return
|
| 809 |
+
self.built = True
|
| 810 |
+
if getattr(self, "dense", None) is not None:
|
| 811 |
+
with tf.name_scope(self.dense.name):
|
| 812 |
+
self.dense.build([None, None, self.config.hidden_size])
|
| 813 |
+
|
| 814 |
+
|
| 815 |
+
@add_start_docstrings(
|
| 816 |
+
"""
|
| 817 |
+
ViT Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
|
| 818 |
+
the [CLS] token) e.g. for ImageNet.
|
| 819 |
+
|
| 820 |
+
<Tip>
|
| 821 |
+
|
| 822 |
+
Note that it's possible to fine-tune ViT on higher resolution images than the ones it has been trained on, by
|
| 823 |
+
setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
|
| 824 |
+
position embeddings to the higher resolution.
|
| 825 |
+
|
| 826 |
+
</Tip>
|
| 827 |
+
""",
|
| 828 |
+
VIT_START_DOCSTRING,
|
| 829 |
+
)
|
| 830 |
+
class TFViTForImageClassification(TFViTPreTrainedModel, TFSequenceClassificationLoss):
|
| 831 |
+
def __init__(self, config: ViTConfig, *inputs, **kwargs):
|
| 832 |
+
super().__init__(config, *inputs, **kwargs)
|
| 833 |
+
|
| 834 |
+
self.num_labels = config.num_labels
|
| 835 |
+
self.vit = TFViTMainLayer(config, add_pooling_layer=False, name="vit")
|
| 836 |
+
|
| 837 |
+
# Classifier head
|
| 838 |
+
self.classifier = keras.layers.Dense(
|
| 839 |
+
units=config.num_labels,
|
| 840 |
+
kernel_initializer=get_initializer(config.initializer_range),
|
| 841 |
+
name="classifier",
|
| 842 |
+
)
|
| 843 |
+
self.config = config
|
| 844 |
+
|
| 845 |
+
@unpack_inputs
|
| 846 |
+
@add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
|
| 847 |
+
@add_code_sample_docstrings(
|
| 848 |
+
checkpoint=_IMAGE_CLASS_CHECKPOINT,
|
| 849 |
+
output_type=TFSequenceClassifierOutput,
|
| 850 |
+
config_class=_CONFIG_FOR_DOC,
|
| 851 |
+
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
|
| 852 |
+
)
|
| 853 |
+
def call(
|
| 854 |
+
self,
|
| 855 |
+
pixel_values: TFModelInputType | None = None,
|
| 856 |
+
head_mask: np.ndarray | tf.Tensor | None = None,
|
| 857 |
+
output_attentions: Optional[bool] = None,
|
| 858 |
+
output_hidden_states: Optional[bool] = None,
|
| 859 |
+
interpolate_pos_encoding: Optional[bool] = None,
|
| 860 |
+
return_dict: Optional[bool] = None,
|
| 861 |
+
labels: np.ndarray | tf.Tensor | None = None,
|
| 862 |
+
training: Optional[bool] = False,
|
| 863 |
+
) -> Union[TFSequenceClassifierOutput, Tuple[tf.Tensor]]:
|
| 864 |
+
r"""
|
| 865 |
+
labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*):
|
| 866 |
+
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
|
| 867 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 868 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 869 |
+
"""
|
| 870 |
+
|
| 871 |
+
outputs = self.vit(
|
| 872 |
+
pixel_values=pixel_values,
|
| 873 |
+
head_mask=head_mask,
|
| 874 |
+
output_attentions=output_attentions,
|
| 875 |
+
output_hidden_states=output_hidden_states,
|
| 876 |
+
interpolate_pos_encoding=interpolate_pos_encoding,
|
| 877 |
+
return_dict=return_dict,
|
| 878 |
+
training=training,
|
| 879 |
+
)
|
| 880 |
+
sequence_output = outputs[0]
|
| 881 |
+
logits = self.classifier(inputs=sequence_output[:, 0, :])
|
| 882 |
+
loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
|
| 883 |
+
|
| 884 |
+
if not return_dict:
|
| 885 |
+
output = (logits,) + outputs[2:]
|
| 886 |
+
return ((loss,) + output) if loss is not None else output
|
| 887 |
+
|
| 888 |
+
return TFSequenceClassifierOutput(
|
| 889 |
+
loss=loss,
|
| 890 |
+
logits=logits,
|
| 891 |
+
hidden_states=outputs.hidden_states,
|
| 892 |
+
attentions=outputs.attentions,
|
| 893 |
+
)
|
| 894 |
+
|
| 895 |
+
def build(self, input_shape=None):
|
| 896 |
+
if self.built:
|
| 897 |
+
return
|
| 898 |
+
self.built = True
|
| 899 |
+
if getattr(self, "vit", None) is not None:
|
| 900 |
+
with tf.name_scope(self.vit.name):
|
| 901 |
+
self.vit.build(None)
|
| 902 |
+
if getattr(self, "classifier", None) is not None:
|
| 903 |
+
with tf.name_scope(self.classifier.name):
|
| 904 |
+
self.classifier.build([None, None, self.config.hidden_size])
|
| 905 |
+
|
| 906 |
+
|
| 907 |
+
__all__ = ["TFViTForImageClassification", "TFViTModel", "TFViTPreTrainedModel"]
|
docs/transformers/build/lib/transformers/models/vit/modeling_vit.py
ADDED
|
@@ -0,0 +1,883 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2021 Google AI, Ross Wightman, The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""PyTorch ViT model."""
|
| 16 |
+
|
| 17 |
+
import collections.abc
|
| 18 |
+
import math
|
| 19 |
+
from typing import Callable, Dict, List, Optional, Set, Tuple, Union
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import torch.utils.checkpoint
|
| 23 |
+
from torch import nn
|
| 24 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 25 |
+
|
| 26 |
+
from ...activations import ACT2FN
|
| 27 |
+
from ...modeling_outputs import (
|
| 28 |
+
BaseModelOutput,
|
| 29 |
+
BaseModelOutputWithPooling,
|
| 30 |
+
ImageClassifierOutput,
|
| 31 |
+
MaskedImageModelingOutput,
|
| 32 |
+
)
|
| 33 |
+
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 34 |
+
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
| 35 |
+
from ...utils import (
|
| 36 |
+
add_code_sample_docstrings,
|
| 37 |
+
add_start_docstrings,
|
| 38 |
+
add_start_docstrings_to_model_forward,
|
| 39 |
+
logging,
|
| 40 |
+
replace_return_docstrings,
|
| 41 |
+
torch_int,
|
| 42 |
+
)
|
| 43 |
+
from .configuration_vit import ViTConfig
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
logger = logging.get_logger(__name__)
|
| 47 |
+
|
| 48 |
+
# General docstring
|
| 49 |
+
_CONFIG_FOR_DOC = "ViTConfig"
|
| 50 |
+
|
| 51 |
+
# Base docstring
|
| 52 |
+
_CHECKPOINT_FOR_DOC = "google/vit-base-patch16-224-in21k"
|
| 53 |
+
_EXPECTED_OUTPUT_SHAPE = [1, 197, 768]
|
| 54 |
+
|
| 55 |
+
# Image classification docstring
|
| 56 |
+
_IMAGE_CLASS_CHECKPOINT = "google/vit-base-patch16-224"
|
| 57 |
+
_IMAGE_CLASS_EXPECTED_OUTPUT = "Egyptian cat"
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class ViTEmbeddings(nn.Module):
|
| 61 |
+
"""
|
| 62 |
+
Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
def __init__(self, config: ViTConfig, use_mask_token: bool = False) -> None:
|
| 66 |
+
super().__init__()
|
| 67 |
+
|
| 68 |
+
self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
|
| 69 |
+
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
|
| 70 |
+
self.patch_embeddings = ViTPatchEmbeddings(config)
|
| 71 |
+
num_patches = self.patch_embeddings.num_patches
|
| 72 |
+
self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size))
|
| 73 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 74 |
+
self.patch_size = config.patch_size
|
| 75 |
+
self.config = config
|
| 76 |
+
|
| 77 |
+
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
| 78 |
+
"""
|
| 79 |
+
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
|
| 80 |
+
images. This method is also adapted to support torch.jit tracing.
|
| 81 |
+
|
| 82 |
+
Adapted from:
|
| 83 |
+
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
|
| 84 |
+
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
num_patches = embeddings.shape[1] - 1
|
| 88 |
+
num_positions = self.position_embeddings.shape[1] - 1
|
| 89 |
+
|
| 90 |
+
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
|
| 91 |
+
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
|
| 92 |
+
return self.position_embeddings
|
| 93 |
+
|
| 94 |
+
class_pos_embed = self.position_embeddings[:, :1]
|
| 95 |
+
patch_pos_embed = self.position_embeddings[:, 1:]
|
| 96 |
+
|
| 97 |
+
dim = embeddings.shape[-1]
|
| 98 |
+
|
| 99 |
+
new_height = height // self.patch_size
|
| 100 |
+
new_width = width // self.patch_size
|
| 101 |
+
|
| 102 |
+
sqrt_num_positions = torch_int(num_positions**0.5)
|
| 103 |
+
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
|
| 104 |
+
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
| 105 |
+
|
| 106 |
+
patch_pos_embed = nn.functional.interpolate(
|
| 107 |
+
patch_pos_embed,
|
| 108 |
+
size=(new_height, new_width),
|
| 109 |
+
mode="bicubic",
|
| 110 |
+
align_corners=False,
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
| 114 |
+
|
| 115 |
+
return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
|
| 116 |
+
|
| 117 |
+
def forward(
|
| 118 |
+
self,
|
| 119 |
+
pixel_values: torch.Tensor,
|
| 120 |
+
bool_masked_pos: Optional[torch.BoolTensor] = None,
|
| 121 |
+
interpolate_pos_encoding: bool = False,
|
| 122 |
+
) -> torch.Tensor:
|
| 123 |
+
batch_size, num_channels, height, width = pixel_values.shape
|
| 124 |
+
embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
|
| 125 |
+
|
| 126 |
+
if bool_masked_pos is not None:
|
| 127 |
+
seq_length = embeddings.shape[1]
|
| 128 |
+
mask_tokens = self.mask_token.expand(batch_size, seq_length, -1)
|
| 129 |
+
# replace the masked visual tokens by mask_tokens
|
| 130 |
+
mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
|
| 131 |
+
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
|
| 132 |
+
|
| 133 |
+
# add the [CLS] token to the embedded patch tokens
|
| 134 |
+
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
| 135 |
+
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
|
| 136 |
+
|
| 137 |
+
# add positional encoding to each token
|
| 138 |
+
if interpolate_pos_encoding:
|
| 139 |
+
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
|
| 140 |
+
else:
|
| 141 |
+
embeddings = embeddings + self.position_embeddings
|
| 142 |
+
|
| 143 |
+
embeddings = self.dropout(embeddings)
|
| 144 |
+
|
| 145 |
+
return embeddings
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
class ViTPatchEmbeddings(nn.Module):
|
| 149 |
+
"""
|
| 150 |
+
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
|
| 151 |
+
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
|
| 152 |
+
Transformer.
|
| 153 |
+
"""
|
| 154 |
+
|
| 155 |
+
def __init__(self, config):
|
| 156 |
+
super().__init__()
|
| 157 |
+
image_size, patch_size = config.image_size, config.patch_size
|
| 158 |
+
num_channels, hidden_size = config.num_channels, config.hidden_size
|
| 159 |
+
|
| 160 |
+
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
|
| 161 |
+
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
|
| 162 |
+
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
| 163 |
+
self.image_size = image_size
|
| 164 |
+
self.patch_size = patch_size
|
| 165 |
+
self.num_channels = num_channels
|
| 166 |
+
self.num_patches = num_patches
|
| 167 |
+
|
| 168 |
+
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
|
| 169 |
+
|
| 170 |
+
def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
|
| 171 |
+
batch_size, num_channels, height, width = pixel_values.shape
|
| 172 |
+
if num_channels != self.num_channels:
|
| 173 |
+
raise ValueError(
|
| 174 |
+
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
| 175 |
+
f" Expected {self.num_channels} but got {num_channels}."
|
| 176 |
+
)
|
| 177 |
+
if not interpolate_pos_encoding:
|
| 178 |
+
if height != self.image_size[0] or width != self.image_size[1]:
|
| 179 |
+
raise ValueError(
|
| 180 |
+
f"Input image size ({height}*{width}) doesn't match model"
|
| 181 |
+
f" ({self.image_size[0]}*{self.image_size[1]})."
|
| 182 |
+
)
|
| 183 |
+
embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
|
| 184 |
+
return embeddings
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def eager_attention_forward(
|
| 188 |
+
module: nn.Module,
|
| 189 |
+
query: torch.Tensor,
|
| 190 |
+
key: torch.Tensor,
|
| 191 |
+
value: torch.Tensor,
|
| 192 |
+
attention_mask: Optional[torch.Tensor],
|
| 193 |
+
scaling: float,
|
| 194 |
+
dropout: float = 0.0,
|
| 195 |
+
**kwargs,
|
| 196 |
+
):
|
| 197 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
| 198 |
+
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
|
| 199 |
+
|
| 200 |
+
# Normalize the attention scores to probabilities.
|
| 201 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
| 202 |
+
|
| 203 |
+
# This is actually dropping out entire tokens to attend to, which might
|
| 204 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
| 205 |
+
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
| 206 |
+
|
| 207 |
+
# Mask heads if we want to
|
| 208 |
+
if attention_mask is not None:
|
| 209 |
+
attn_weights = attn_weights * attention_mask
|
| 210 |
+
|
| 211 |
+
attn_output = torch.matmul(attn_weights, value)
|
| 212 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 213 |
+
|
| 214 |
+
return attn_output, attn_weights
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
class ViTSelfAttention(nn.Module):
|
| 218 |
+
def __init__(self, config: ViTConfig) -> None:
|
| 219 |
+
super().__init__()
|
| 220 |
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
| 221 |
+
raise ValueError(
|
| 222 |
+
f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
|
| 223 |
+
f"heads {config.num_attention_heads}."
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
self.config = config
|
| 227 |
+
self.num_attention_heads = config.num_attention_heads
|
| 228 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
| 229 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
| 230 |
+
self.dropout_prob = config.attention_probs_dropout_prob
|
| 231 |
+
self.scaling = self.attention_head_size**-0.5
|
| 232 |
+
self.is_causal = False
|
| 233 |
+
|
| 234 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
| 235 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
| 236 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
| 237 |
+
|
| 238 |
+
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
| 239 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
| 240 |
+
x = x.view(new_x_shape)
|
| 241 |
+
return x.permute(0, 2, 1, 3)
|
| 242 |
+
|
| 243 |
+
def forward(
|
| 244 |
+
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
|
| 245 |
+
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
| 246 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
| 247 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
| 248 |
+
query_layer = self.transpose_for_scores(self.query(hidden_states))
|
| 249 |
+
|
| 250 |
+
attention_interface: Callable = eager_attention_forward
|
| 251 |
+
if self.config._attn_implementation != "eager":
|
| 252 |
+
if self.config._attn_implementation == "sdpa" and output_attentions:
|
| 253 |
+
logger.warning_once(
|
| 254 |
+
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
| 255 |
+
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
| 256 |
+
)
|
| 257 |
+
else:
|
| 258 |
+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
| 259 |
+
|
| 260 |
+
context_layer, attention_probs = attention_interface(
|
| 261 |
+
self,
|
| 262 |
+
query_layer,
|
| 263 |
+
key_layer,
|
| 264 |
+
value_layer,
|
| 265 |
+
head_mask,
|
| 266 |
+
is_causal=self.is_causal,
|
| 267 |
+
scaling=self.scaling,
|
| 268 |
+
dropout=0.0 if not self.training else self.dropout_prob,
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
| 272 |
+
context_layer = context_layer.reshape(new_context_layer_shape)
|
| 273 |
+
|
| 274 |
+
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
| 275 |
+
|
| 276 |
+
return outputs
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
class ViTSelfOutput(nn.Module):
|
| 280 |
+
"""
|
| 281 |
+
The residual connection is defined in ViTLayer instead of here (as is the case with other models), due to the
|
| 282 |
+
layernorm applied before each block.
|
| 283 |
+
"""
|
| 284 |
+
|
| 285 |
+
def __init__(self, config: ViTConfig) -> None:
|
| 286 |
+
super().__init__()
|
| 287 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 288 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 289 |
+
|
| 290 |
+
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
| 291 |
+
hidden_states = self.dense(hidden_states)
|
| 292 |
+
hidden_states = self.dropout(hidden_states)
|
| 293 |
+
|
| 294 |
+
return hidden_states
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
class ViTAttention(nn.Module):
|
| 298 |
+
def __init__(self, config: ViTConfig) -> None:
|
| 299 |
+
super().__init__()
|
| 300 |
+
self.attention = ViTSelfAttention(config)
|
| 301 |
+
self.output = ViTSelfOutput(config)
|
| 302 |
+
self.pruned_heads = set()
|
| 303 |
+
|
| 304 |
+
def prune_heads(self, heads: Set[int]) -> None:
|
| 305 |
+
if len(heads) == 0:
|
| 306 |
+
return
|
| 307 |
+
heads, index = find_pruneable_heads_and_indices(
|
| 308 |
+
heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
# Prune linear layers
|
| 312 |
+
self.attention.query = prune_linear_layer(self.attention.query, index)
|
| 313 |
+
self.attention.key = prune_linear_layer(self.attention.key, index)
|
| 314 |
+
self.attention.value = prune_linear_layer(self.attention.value, index)
|
| 315 |
+
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
| 316 |
+
|
| 317 |
+
# Update hyper params and store pruned heads
|
| 318 |
+
self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
|
| 319 |
+
self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
|
| 320 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
| 321 |
+
|
| 322 |
+
def forward(
|
| 323 |
+
self,
|
| 324 |
+
hidden_states: torch.Tensor,
|
| 325 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 326 |
+
output_attentions: bool = False,
|
| 327 |
+
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
| 328 |
+
self_outputs = self.attention(hidden_states, head_mask, output_attentions)
|
| 329 |
+
|
| 330 |
+
attention_output = self.output(self_outputs[0], hidden_states)
|
| 331 |
+
|
| 332 |
+
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
| 333 |
+
return outputs
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
class ViTIntermediate(nn.Module):
|
| 337 |
+
def __init__(self, config: ViTConfig) -> None:
|
| 338 |
+
super().__init__()
|
| 339 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
| 340 |
+
if isinstance(config.hidden_act, str):
|
| 341 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
| 342 |
+
else:
|
| 343 |
+
self.intermediate_act_fn = config.hidden_act
|
| 344 |
+
|
| 345 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 346 |
+
hidden_states = self.dense(hidden_states)
|
| 347 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
| 348 |
+
|
| 349 |
+
return hidden_states
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
class ViTOutput(nn.Module):
|
| 353 |
+
def __init__(self, config: ViTConfig) -> None:
|
| 354 |
+
super().__init__()
|
| 355 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
| 356 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 357 |
+
|
| 358 |
+
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
| 359 |
+
hidden_states = self.dense(hidden_states)
|
| 360 |
+
hidden_states = self.dropout(hidden_states)
|
| 361 |
+
|
| 362 |
+
hidden_states = hidden_states + input_tensor
|
| 363 |
+
|
| 364 |
+
return hidden_states
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
class ViTLayer(nn.Module):
|
| 368 |
+
"""This corresponds to the Block class in the timm implementation."""
|
| 369 |
+
|
| 370 |
+
def __init__(self, config: ViTConfig) -> None:
|
| 371 |
+
super().__init__()
|
| 372 |
+
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
| 373 |
+
self.seq_len_dim = 1
|
| 374 |
+
self.attention = ViTAttention(config)
|
| 375 |
+
self.intermediate = ViTIntermediate(config)
|
| 376 |
+
self.output = ViTOutput(config)
|
| 377 |
+
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 378 |
+
self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 379 |
+
|
| 380 |
+
def forward(
|
| 381 |
+
self,
|
| 382 |
+
hidden_states: torch.Tensor,
|
| 383 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 384 |
+
output_attentions: bool = False,
|
| 385 |
+
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
| 386 |
+
self_attention_outputs = self.attention(
|
| 387 |
+
self.layernorm_before(hidden_states), # in ViT, layernorm is applied before self-attention
|
| 388 |
+
head_mask,
|
| 389 |
+
output_attentions=output_attentions,
|
| 390 |
+
)
|
| 391 |
+
attention_output = self_attention_outputs[0]
|
| 392 |
+
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
| 393 |
+
|
| 394 |
+
# first residual connection
|
| 395 |
+
hidden_states = attention_output + hidden_states
|
| 396 |
+
|
| 397 |
+
# in ViT, layernorm is also applied after self-attention
|
| 398 |
+
layer_output = self.layernorm_after(hidden_states)
|
| 399 |
+
layer_output = self.intermediate(layer_output)
|
| 400 |
+
|
| 401 |
+
# second residual connection is done here
|
| 402 |
+
layer_output = self.output(layer_output, hidden_states)
|
| 403 |
+
|
| 404 |
+
outputs = (layer_output,) + outputs
|
| 405 |
+
|
| 406 |
+
return outputs
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
class ViTEncoder(nn.Module):
|
| 410 |
+
def __init__(self, config: ViTConfig) -> None:
|
| 411 |
+
super().__init__()
|
| 412 |
+
self.config = config
|
| 413 |
+
self.layer = nn.ModuleList([ViTLayer(config) for _ in range(config.num_hidden_layers)])
|
| 414 |
+
self.gradient_checkpointing = False
|
| 415 |
+
|
| 416 |
+
def forward(
|
| 417 |
+
self,
|
| 418 |
+
hidden_states: torch.Tensor,
|
| 419 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 420 |
+
output_attentions: bool = False,
|
| 421 |
+
output_hidden_states: bool = False,
|
| 422 |
+
return_dict: bool = True,
|
| 423 |
+
) -> Union[tuple, BaseModelOutput]:
|
| 424 |
+
all_hidden_states = () if output_hidden_states else None
|
| 425 |
+
all_self_attentions = () if output_attentions else None
|
| 426 |
+
|
| 427 |
+
for i, layer_module in enumerate(self.layer):
|
| 428 |
+
if output_hidden_states:
|
| 429 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 430 |
+
|
| 431 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
| 432 |
+
|
| 433 |
+
if self.gradient_checkpointing and self.training:
|
| 434 |
+
layer_outputs = self._gradient_checkpointing_func(
|
| 435 |
+
layer_module.__call__,
|
| 436 |
+
hidden_states,
|
| 437 |
+
layer_head_mask,
|
| 438 |
+
output_attentions,
|
| 439 |
+
)
|
| 440 |
+
else:
|
| 441 |
+
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
|
| 442 |
+
|
| 443 |
+
hidden_states = layer_outputs[0]
|
| 444 |
+
|
| 445 |
+
if output_attentions:
|
| 446 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
| 447 |
+
|
| 448 |
+
if output_hidden_states:
|
| 449 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 450 |
+
|
| 451 |
+
if not return_dict:
|
| 452 |
+
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
|
| 453 |
+
return BaseModelOutput(
|
| 454 |
+
last_hidden_state=hidden_states,
|
| 455 |
+
hidden_states=all_hidden_states,
|
| 456 |
+
attentions=all_self_attentions,
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
class ViTPreTrainedModel(PreTrainedModel):
|
| 461 |
+
"""
|
| 462 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 463 |
+
models.
|
| 464 |
+
"""
|
| 465 |
+
|
| 466 |
+
config_class = ViTConfig
|
| 467 |
+
base_model_prefix = "vit"
|
| 468 |
+
main_input_name = "pixel_values"
|
| 469 |
+
supports_gradient_checkpointing = True
|
| 470 |
+
_no_split_modules = ["ViTEmbeddings", "ViTLayer"]
|
| 471 |
+
_supports_sdpa = True
|
| 472 |
+
_supports_flash_attn_2 = True
|
| 473 |
+
|
| 474 |
+
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
|
| 475 |
+
"""Initialize the weights"""
|
| 476 |
+
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
| 477 |
+
# Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
|
| 478 |
+
# `trunc_normal_cpu` not implemented in `half` issues
|
| 479 |
+
module.weight.data = nn.init.trunc_normal_(
|
| 480 |
+
module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
|
| 481 |
+
).to(module.weight.dtype)
|
| 482 |
+
if module.bias is not None:
|
| 483 |
+
module.bias.data.zero_()
|
| 484 |
+
elif isinstance(module, nn.LayerNorm):
|
| 485 |
+
module.bias.data.zero_()
|
| 486 |
+
module.weight.data.fill_(1.0)
|
| 487 |
+
elif isinstance(module, ViTEmbeddings):
|
| 488 |
+
module.position_embeddings.data = nn.init.trunc_normal_(
|
| 489 |
+
module.position_embeddings.data.to(torch.float32),
|
| 490 |
+
mean=0.0,
|
| 491 |
+
std=self.config.initializer_range,
|
| 492 |
+
).to(module.position_embeddings.dtype)
|
| 493 |
+
|
| 494 |
+
module.cls_token.data = nn.init.trunc_normal_(
|
| 495 |
+
module.cls_token.data.to(torch.float32),
|
| 496 |
+
mean=0.0,
|
| 497 |
+
std=self.config.initializer_range,
|
| 498 |
+
).to(module.cls_token.dtype)
|
| 499 |
+
|
| 500 |
+
if module.mask_token is not None:
|
| 501 |
+
module.mask_token.data.zero_()
|
| 502 |
+
|
| 503 |
+
|
| 504 |
+
VIT_START_DOCSTRING = r"""
|
| 505 |
+
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
|
| 506 |
+
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
| 507 |
+
behavior.
|
| 508 |
+
|
| 509 |
+
Parameters:
|
| 510 |
+
config ([`ViTConfig`]): Model configuration class with all the parameters of the model.
|
| 511 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 512 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 513 |
+
"""
|
| 514 |
+
|
| 515 |
+
VIT_INPUTS_DOCSTRING = r"""
|
| 516 |
+
Args:
|
| 517 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
| 518 |
+
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`]
|
| 519 |
+
for details.
|
| 520 |
+
|
| 521 |
+
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
| 522 |
+
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
| 523 |
+
|
| 524 |
+
- 1 indicates the head is **not masked**,
|
| 525 |
+
- 0 indicates the head is **masked**.
|
| 526 |
+
|
| 527 |
+
output_attentions (`bool`, *optional*):
|
| 528 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 529 |
+
tensors for more detail.
|
| 530 |
+
output_hidden_states (`bool`, *optional*):
|
| 531 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 532 |
+
more detail.
|
| 533 |
+
interpolate_pos_encoding (`bool`, *optional*):
|
| 534 |
+
Whether to interpolate the pre-trained position encodings.
|
| 535 |
+
return_dict (`bool`, *optional*):
|
| 536 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 537 |
+
"""
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
@add_start_docstrings(
|
| 541 |
+
"The bare ViT Model transformer outputting raw hidden-states without any specific head on top.",
|
| 542 |
+
VIT_START_DOCSTRING,
|
| 543 |
+
)
|
| 544 |
+
class ViTModel(ViTPreTrainedModel):
|
| 545 |
+
def __init__(self, config: ViTConfig, add_pooling_layer: bool = True, use_mask_token: bool = False):
|
| 546 |
+
super().__init__(config)
|
| 547 |
+
self.config = config
|
| 548 |
+
|
| 549 |
+
self.embeddings = ViTEmbeddings(config, use_mask_token=use_mask_token)
|
| 550 |
+
self.encoder = ViTEncoder(config)
|
| 551 |
+
|
| 552 |
+
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 553 |
+
self.pooler = ViTPooler(config) if add_pooling_layer else None
|
| 554 |
+
|
| 555 |
+
# Initialize weights and apply final processing
|
| 556 |
+
self.post_init()
|
| 557 |
+
|
| 558 |
+
def get_input_embeddings(self) -> ViTPatchEmbeddings:
|
| 559 |
+
return self.embeddings.patch_embeddings
|
| 560 |
+
|
| 561 |
+
def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
|
| 562 |
+
"""
|
| 563 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
| 564 |
+
class PreTrainedModel
|
| 565 |
+
"""
|
| 566 |
+
for layer, heads in heads_to_prune.items():
|
| 567 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
| 568 |
+
|
| 569 |
+
@add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
|
| 570 |
+
@add_code_sample_docstrings(
|
| 571 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 572 |
+
output_type=BaseModelOutputWithPooling,
|
| 573 |
+
config_class=_CONFIG_FOR_DOC,
|
| 574 |
+
modality="vision",
|
| 575 |
+
expected_output=_EXPECTED_OUTPUT_SHAPE,
|
| 576 |
+
)
|
| 577 |
+
def forward(
|
| 578 |
+
self,
|
| 579 |
+
pixel_values: Optional[torch.Tensor] = None,
|
| 580 |
+
bool_masked_pos: Optional[torch.BoolTensor] = None,
|
| 581 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 582 |
+
output_attentions: Optional[bool] = None,
|
| 583 |
+
output_hidden_states: Optional[bool] = None,
|
| 584 |
+
interpolate_pos_encoding: Optional[bool] = None,
|
| 585 |
+
return_dict: Optional[bool] = None,
|
| 586 |
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
| 587 |
+
r"""
|
| 588 |
+
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
|
| 589 |
+
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
|
| 590 |
+
"""
|
| 591 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 592 |
+
output_hidden_states = (
|
| 593 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 594 |
+
)
|
| 595 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 596 |
+
|
| 597 |
+
if pixel_values is None:
|
| 598 |
+
raise ValueError("You have to specify pixel_values")
|
| 599 |
+
|
| 600 |
+
# Prepare head mask if needed
|
| 601 |
+
# 1.0 in head_mask indicate we keep the head
|
| 602 |
+
# attention_probs has shape bsz x n_heads x N x N
|
| 603 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
| 604 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
| 605 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
| 606 |
+
|
| 607 |
+
# TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?)
|
| 608 |
+
expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype
|
| 609 |
+
if pixel_values.dtype != expected_dtype:
|
| 610 |
+
pixel_values = pixel_values.to(expected_dtype)
|
| 611 |
+
|
| 612 |
+
embedding_output = self.embeddings(
|
| 613 |
+
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
|
| 614 |
+
)
|
| 615 |
+
|
| 616 |
+
encoder_outputs = self.encoder(
|
| 617 |
+
embedding_output,
|
| 618 |
+
head_mask=head_mask,
|
| 619 |
+
output_attentions=output_attentions,
|
| 620 |
+
output_hidden_states=output_hidden_states,
|
| 621 |
+
return_dict=return_dict,
|
| 622 |
+
)
|
| 623 |
+
sequence_output = encoder_outputs[0]
|
| 624 |
+
sequence_output = self.layernorm(sequence_output)
|
| 625 |
+
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
| 626 |
+
|
| 627 |
+
if not return_dict:
|
| 628 |
+
head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
|
| 629 |
+
return head_outputs + encoder_outputs[1:]
|
| 630 |
+
|
| 631 |
+
return BaseModelOutputWithPooling(
|
| 632 |
+
last_hidden_state=sequence_output,
|
| 633 |
+
pooler_output=pooled_output,
|
| 634 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 635 |
+
attentions=encoder_outputs.attentions,
|
| 636 |
+
)
|
| 637 |
+
|
| 638 |
+
|
| 639 |
+
class ViTPooler(nn.Module):
|
| 640 |
+
def __init__(self, config: ViTConfig):
|
| 641 |
+
super().__init__()
|
| 642 |
+
self.dense = nn.Linear(config.hidden_size, config.pooler_output_size)
|
| 643 |
+
self.activation = ACT2FN[config.pooler_act]
|
| 644 |
+
|
| 645 |
+
def forward(self, hidden_states):
|
| 646 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
| 647 |
+
# to the first token.
|
| 648 |
+
first_token_tensor = hidden_states[:, 0]
|
| 649 |
+
pooled_output = self.dense(first_token_tensor)
|
| 650 |
+
pooled_output = self.activation(pooled_output)
|
| 651 |
+
return pooled_output
|
| 652 |
+
|
| 653 |
+
|
| 654 |
+
@add_start_docstrings(
|
| 655 |
+
"""ViT Model with a decoder on top for masked image modeling, as proposed in [SimMIM](https://arxiv.org/abs/2111.09886).
|
| 656 |
+
|
| 657 |
+
<Tip>
|
| 658 |
+
|
| 659 |
+
Note that we provide a script to pre-train this model on custom data in our [examples
|
| 660 |
+
directory](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-pretraining).
|
| 661 |
+
|
| 662 |
+
</Tip>
|
| 663 |
+
""",
|
| 664 |
+
VIT_START_DOCSTRING,
|
| 665 |
+
)
|
| 666 |
+
class ViTForMaskedImageModeling(ViTPreTrainedModel):
|
| 667 |
+
def __init__(self, config: ViTConfig) -> None:
|
| 668 |
+
super().__init__(config)
|
| 669 |
+
|
| 670 |
+
self.vit = ViTModel(config, add_pooling_layer=False, use_mask_token=True)
|
| 671 |
+
|
| 672 |
+
self.decoder = nn.Sequential(
|
| 673 |
+
nn.Conv2d(
|
| 674 |
+
in_channels=config.hidden_size,
|
| 675 |
+
out_channels=config.encoder_stride**2 * config.num_channels,
|
| 676 |
+
kernel_size=1,
|
| 677 |
+
),
|
| 678 |
+
nn.PixelShuffle(config.encoder_stride),
|
| 679 |
+
)
|
| 680 |
+
|
| 681 |
+
# Initialize weights and apply final processing
|
| 682 |
+
self.post_init()
|
| 683 |
+
|
| 684 |
+
@add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
|
| 685 |
+
@replace_return_docstrings(output_type=MaskedImageModelingOutput, config_class=_CONFIG_FOR_DOC)
|
| 686 |
+
def forward(
|
| 687 |
+
self,
|
| 688 |
+
pixel_values: Optional[torch.Tensor] = None,
|
| 689 |
+
bool_masked_pos: Optional[torch.BoolTensor] = None,
|
| 690 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 691 |
+
output_attentions: Optional[bool] = None,
|
| 692 |
+
output_hidden_states: Optional[bool] = None,
|
| 693 |
+
interpolate_pos_encoding: Optional[bool] = None,
|
| 694 |
+
return_dict: Optional[bool] = None,
|
| 695 |
+
) -> Union[tuple, MaskedImageModelingOutput]:
|
| 696 |
+
r"""
|
| 697 |
+
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
|
| 698 |
+
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
|
| 699 |
+
|
| 700 |
+
Returns:
|
| 701 |
+
|
| 702 |
+
Examples:
|
| 703 |
+
```python
|
| 704 |
+
>>> from transformers import AutoImageProcessor, ViTForMaskedImageModeling
|
| 705 |
+
>>> import torch
|
| 706 |
+
>>> from PIL import Image
|
| 707 |
+
>>> import requests
|
| 708 |
+
|
| 709 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 710 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 711 |
+
|
| 712 |
+
>>> image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
|
| 713 |
+
>>> model = ViTForMaskedImageModeling.from_pretrained("google/vit-base-patch16-224-in21k")
|
| 714 |
+
|
| 715 |
+
>>> num_patches = (model.config.image_size // model.config.patch_size) ** 2
|
| 716 |
+
>>> pixel_values = image_processor(images=image, return_tensors="pt").pixel_values
|
| 717 |
+
>>> # create random boolean mask of shape (batch_size, num_patches)
|
| 718 |
+
>>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()
|
| 719 |
+
|
| 720 |
+
>>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
|
| 721 |
+
>>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction
|
| 722 |
+
>>> list(reconstructed_pixel_values.shape)
|
| 723 |
+
[1, 3, 224, 224]
|
| 724 |
+
```"""
|
| 725 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 726 |
+
|
| 727 |
+
if bool_masked_pos is not None and (self.config.patch_size != self.config.encoder_stride):
|
| 728 |
+
raise ValueError(
|
| 729 |
+
"When `bool_masked_pos` is provided, `patch_size` must be equal to `encoder_stride` to ensure that "
|
| 730 |
+
"the reconstructed image has the same dimensions as the input. "
|
| 731 |
+
f"Got `patch_size` = {self.config.patch_size} and `encoder_stride` = {self.config.encoder_stride}."
|
| 732 |
+
)
|
| 733 |
+
|
| 734 |
+
outputs = self.vit(
|
| 735 |
+
pixel_values,
|
| 736 |
+
bool_masked_pos=bool_masked_pos,
|
| 737 |
+
head_mask=head_mask,
|
| 738 |
+
output_attentions=output_attentions,
|
| 739 |
+
output_hidden_states=output_hidden_states,
|
| 740 |
+
interpolate_pos_encoding=interpolate_pos_encoding,
|
| 741 |
+
return_dict=return_dict,
|
| 742 |
+
)
|
| 743 |
+
|
| 744 |
+
sequence_output = outputs[0]
|
| 745 |
+
|
| 746 |
+
# Reshape to (batch_size, num_channels, height, width)
|
| 747 |
+
sequence_output = sequence_output[:, 1:]
|
| 748 |
+
batch_size, sequence_length, num_channels = sequence_output.shape
|
| 749 |
+
height = width = math.floor(sequence_length**0.5)
|
| 750 |
+
sequence_output = sequence_output.permute(0, 2, 1).reshape(batch_size, num_channels, height, width)
|
| 751 |
+
|
| 752 |
+
# Reconstruct pixel values
|
| 753 |
+
reconstructed_pixel_values = self.decoder(sequence_output)
|
| 754 |
+
|
| 755 |
+
masked_im_loss = None
|
| 756 |
+
if bool_masked_pos is not None:
|
| 757 |
+
size = self.config.image_size // self.config.patch_size
|
| 758 |
+
bool_masked_pos = bool_masked_pos.reshape(-1, size, size)
|
| 759 |
+
mask = (
|
| 760 |
+
bool_masked_pos.repeat_interleave(self.config.patch_size, 1)
|
| 761 |
+
.repeat_interleave(self.config.patch_size, 2)
|
| 762 |
+
.unsqueeze(1)
|
| 763 |
+
.contiguous()
|
| 764 |
+
)
|
| 765 |
+
reconstruction_loss = nn.functional.l1_loss(pixel_values, reconstructed_pixel_values, reduction="none")
|
| 766 |
+
masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels
|
| 767 |
+
|
| 768 |
+
if not return_dict:
|
| 769 |
+
output = (reconstructed_pixel_values,) + outputs[1:]
|
| 770 |
+
return ((masked_im_loss,) + output) if masked_im_loss is not None else output
|
| 771 |
+
|
| 772 |
+
return MaskedImageModelingOutput(
|
| 773 |
+
loss=masked_im_loss,
|
| 774 |
+
reconstruction=reconstructed_pixel_values,
|
| 775 |
+
hidden_states=outputs.hidden_states,
|
| 776 |
+
attentions=outputs.attentions,
|
| 777 |
+
)
|
| 778 |
+
|
| 779 |
+
|
| 780 |
+
@add_start_docstrings(
|
| 781 |
+
"""
|
| 782 |
+
ViT Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
|
| 783 |
+
the [CLS] token) e.g. for ImageNet.
|
| 784 |
+
|
| 785 |
+
<Tip>
|
| 786 |
+
|
| 787 |
+
Note that it's possible to fine-tune ViT on higher resolution images than the ones it has been trained on, by
|
| 788 |
+
setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
|
| 789 |
+
position embeddings to the higher resolution.
|
| 790 |
+
|
| 791 |
+
</Tip>
|
| 792 |
+
""",
|
| 793 |
+
VIT_START_DOCSTRING,
|
| 794 |
+
)
|
| 795 |
+
class ViTForImageClassification(ViTPreTrainedModel):
|
| 796 |
+
def __init__(self, config: ViTConfig) -> None:
|
| 797 |
+
super().__init__(config)
|
| 798 |
+
|
| 799 |
+
self.num_labels = config.num_labels
|
| 800 |
+
self.vit = ViTModel(config, add_pooling_layer=False)
|
| 801 |
+
|
| 802 |
+
# Classifier head
|
| 803 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
|
| 804 |
+
|
| 805 |
+
# Initialize weights and apply final processing
|
| 806 |
+
self.post_init()
|
| 807 |
+
|
| 808 |
+
@add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
|
| 809 |
+
@add_code_sample_docstrings(
|
| 810 |
+
checkpoint=_IMAGE_CLASS_CHECKPOINT,
|
| 811 |
+
output_type=ImageClassifierOutput,
|
| 812 |
+
config_class=_CONFIG_FOR_DOC,
|
| 813 |
+
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
|
| 814 |
+
)
|
| 815 |
+
def forward(
|
| 816 |
+
self,
|
| 817 |
+
pixel_values: Optional[torch.Tensor] = None,
|
| 818 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 819 |
+
labels: Optional[torch.Tensor] = None,
|
| 820 |
+
output_attentions: Optional[bool] = None,
|
| 821 |
+
output_hidden_states: Optional[bool] = None,
|
| 822 |
+
interpolate_pos_encoding: Optional[bool] = None,
|
| 823 |
+
return_dict: Optional[bool] = None,
|
| 824 |
+
) -> Union[tuple, ImageClassifierOutput]:
|
| 825 |
+
r"""
|
| 826 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 827 |
+
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
|
| 828 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 829 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 830 |
+
"""
|
| 831 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 832 |
+
|
| 833 |
+
outputs = self.vit(
|
| 834 |
+
pixel_values,
|
| 835 |
+
head_mask=head_mask,
|
| 836 |
+
output_attentions=output_attentions,
|
| 837 |
+
output_hidden_states=output_hidden_states,
|
| 838 |
+
interpolate_pos_encoding=interpolate_pos_encoding,
|
| 839 |
+
return_dict=return_dict,
|
| 840 |
+
)
|
| 841 |
+
|
| 842 |
+
sequence_output = outputs[0]
|
| 843 |
+
|
| 844 |
+
logits = self.classifier(sequence_output[:, 0, :])
|
| 845 |
+
|
| 846 |
+
loss = None
|
| 847 |
+
if labels is not None:
|
| 848 |
+
# move labels to correct device to enable model parallelism
|
| 849 |
+
labels = labels.to(logits.device)
|
| 850 |
+
if self.config.problem_type is None:
|
| 851 |
+
if self.num_labels == 1:
|
| 852 |
+
self.config.problem_type = "regression"
|
| 853 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
| 854 |
+
self.config.problem_type = "single_label_classification"
|
| 855 |
+
else:
|
| 856 |
+
self.config.problem_type = "multi_label_classification"
|
| 857 |
+
|
| 858 |
+
if self.config.problem_type == "regression":
|
| 859 |
+
loss_fct = MSELoss()
|
| 860 |
+
if self.num_labels == 1:
|
| 861 |
+
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
| 862 |
+
else:
|
| 863 |
+
loss = loss_fct(logits, labels)
|
| 864 |
+
elif self.config.problem_type == "single_label_classification":
|
| 865 |
+
loss_fct = CrossEntropyLoss()
|
| 866 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 867 |
+
elif self.config.problem_type == "multi_label_classification":
|
| 868 |
+
loss_fct = BCEWithLogitsLoss()
|
| 869 |
+
loss = loss_fct(logits, labels)
|
| 870 |
+
|
| 871 |
+
if not return_dict:
|
| 872 |
+
output = (logits,) + outputs[1:]
|
| 873 |
+
return ((loss,) + output) if loss is not None else output
|
| 874 |
+
|
| 875 |
+
return ImageClassifierOutput(
|
| 876 |
+
loss=loss,
|
| 877 |
+
logits=logits,
|
| 878 |
+
hidden_states=outputs.hidden_states,
|
| 879 |
+
attentions=outputs.attentions,
|
| 880 |
+
)
|
| 881 |
+
|
| 882 |
+
|
| 883 |
+
__all__ = ["ViTForImageClassification", "ViTForMaskedImageModeling", "ViTModel", "ViTPreTrainedModel"]
|
docs/transformers/build/lib/transformers/models/vit_mae/__init__.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_vit_mae import *
|
| 22 |
+
from .modeling_tf_vit_mae import *
|
| 23 |
+
from .modeling_vit_mae import *
|
| 24 |
+
else:
|
| 25 |
+
import sys
|
| 26 |
+
|
| 27 |
+
_file = globals()["__file__"]
|
| 28 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
docs/transformers/build/lib/transformers/models/vit_mae/configuration_vit_mae.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 Facebook AI and The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""ViT MAE model configuration"""
|
| 16 |
+
|
| 17 |
+
from ...configuration_utils import PretrainedConfig
|
| 18 |
+
from ...utils import logging
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
logger = logging.get_logger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class ViTMAEConfig(PretrainedConfig):
|
| 25 |
+
r"""
|
| 26 |
+
This is the configuration class to store the configuration of a [`ViTMAEModel`]. It is used to instantiate an ViT
|
| 27 |
+
MAE model according to the specified arguments, defining the model architecture. Instantiating a configuration with
|
| 28 |
+
the defaults will yield a similar configuration to that of the ViT
|
| 29 |
+
[facebook/vit-mae-base](https://huggingface.co/facebook/vit-mae-base) architecture.
|
| 30 |
+
|
| 31 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 32 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
hidden_size (`int`, *optional*, defaults to 768):
|
| 37 |
+
Dimensionality of the encoder layers and the pooler layer.
|
| 38 |
+
num_hidden_layers (`int`, *optional*, defaults to 12):
|
| 39 |
+
Number of hidden layers in the Transformer encoder.
|
| 40 |
+
num_attention_heads (`int`, *optional*, defaults to 12):
|
| 41 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 42 |
+
intermediate_size (`int`, *optional*, defaults to 3072):
|
| 43 |
+
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
| 44 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
| 45 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 46 |
+
`"relu"`, `"selu"` and `"gelu_new"` are supported.
|
| 47 |
+
hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
|
| 48 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
| 49 |
+
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
|
| 50 |
+
The dropout ratio for the attention probabilities.
|
| 51 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 52 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 53 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
| 54 |
+
The epsilon used by the layer normalization layers.
|
| 55 |
+
image_size (`int`, *optional*, defaults to 224):
|
| 56 |
+
The size (resolution) of each image.
|
| 57 |
+
patch_size (`int`, *optional*, defaults to 16):
|
| 58 |
+
The size (resolution) of each patch.
|
| 59 |
+
num_channels (`int`, *optional*, defaults to 3):
|
| 60 |
+
The number of input channels.
|
| 61 |
+
qkv_bias (`bool`, *optional*, defaults to `True`):
|
| 62 |
+
Whether to add a bias to the queries, keys and values.
|
| 63 |
+
decoder_num_attention_heads (`int`, *optional*, defaults to 16):
|
| 64 |
+
Number of attention heads for each attention layer in the decoder.
|
| 65 |
+
decoder_hidden_size (`int`, *optional*, defaults to 512):
|
| 66 |
+
Dimensionality of the decoder.
|
| 67 |
+
decoder_num_hidden_layers (`int`, *optional*, defaults to 8):
|
| 68 |
+
Number of hidden layers in the decoder.
|
| 69 |
+
decoder_intermediate_size (`int`, *optional*, defaults to 2048):
|
| 70 |
+
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the decoder.
|
| 71 |
+
mask_ratio (`float`, *optional*, defaults to 0.75):
|
| 72 |
+
The ratio of the number of masked tokens in the input sequence.
|
| 73 |
+
norm_pix_loss (`bool`, *optional*, defaults to `False`):
|
| 74 |
+
Whether or not to train with normalized pixels (see Table 3 in the paper). Using normalized pixels improved
|
| 75 |
+
representation quality in the experiments of the authors.
|
| 76 |
+
|
| 77 |
+
Example:
|
| 78 |
+
|
| 79 |
+
```python
|
| 80 |
+
>>> from transformers import ViTMAEConfig, ViTMAEModel
|
| 81 |
+
|
| 82 |
+
>>> # Initializing a ViT MAE vit-mae-base style configuration
|
| 83 |
+
>>> configuration = ViTMAEConfig()
|
| 84 |
+
|
| 85 |
+
>>> # Initializing a model (with random weights) from the vit-mae-base style configuration
|
| 86 |
+
>>> model = ViTMAEModel(configuration)
|
| 87 |
+
|
| 88 |
+
>>> # Accessing the model configuration
|
| 89 |
+
>>> configuration = model.config
|
| 90 |
+
```"""
|
| 91 |
+
|
| 92 |
+
model_type = "vit_mae"
|
| 93 |
+
|
| 94 |
+
def __init__(
|
| 95 |
+
self,
|
| 96 |
+
hidden_size=768,
|
| 97 |
+
num_hidden_layers=12,
|
| 98 |
+
num_attention_heads=12,
|
| 99 |
+
intermediate_size=3072,
|
| 100 |
+
hidden_act="gelu",
|
| 101 |
+
hidden_dropout_prob=0.0,
|
| 102 |
+
attention_probs_dropout_prob=0.0,
|
| 103 |
+
initializer_range=0.02,
|
| 104 |
+
layer_norm_eps=1e-12,
|
| 105 |
+
image_size=224,
|
| 106 |
+
patch_size=16,
|
| 107 |
+
num_channels=3,
|
| 108 |
+
qkv_bias=True,
|
| 109 |
+
decoder_num_attention_heads=16,
|
| 110 |
+
decoder_hidden_size=512,
|
| 111 |
+
decoder_num_hidden_layers=8,
|
| 112 |
+
decoder_intermediate_size=2048,
|
| 113 |
+
mask_ratio=0.75,
|
| 114 |
+
norm_pix_loss=False,
|
| 115 |
+
**kwargs,
|
| 116 |
+
):
|
| 117 |
+
super().__init__(**kwargs)
|
| 118 |
+
|
| 119 |
+
self.hidden_size = hidden_size
|
| 120 |
+
self.num_hidden_layers = num_hidden_layers
|
| 121 |
+
self.num_attention_heads = num_attention_heads
|
| 122 |
+
self.intermediate_size = intermediate_size
|
| 123 |
+
self.hidden_act = hidden_act
|
| 124 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
| 125 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
| 126 |
+
self.initializer_range = initializer_range
|
| 127 |
+
self.layer_norm_eps = layer_norm_eps
|
| 128 |
+
self.image_size = image_size
|
| 129 |
+
self.patch_size = patch_size
|
| 130 |
+
self.num_channels = num_channels
|
| 131 |
+
self.qkv_bias = qkv_bias
|
| 132 |
+
self.decoder_num_attention_heads = decoder_num_attention_heads
|
| 133 |
+
self.decoder_hidden_size = decoder_hidden_size
|
| 134 |
+
self.decoder_num_hidden_layers = decoder_num_hidden_layers
|
| 135 |
+
self.decoder_intermediate_size = decoder_intermediate_size
|
| 136 |
+
self.mask_ratio = mask_ratio
|
| 137 |
+
self.norm_pix_loss = norm_pix_loss
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
__all__ = ["ViTMAEConfig"]
|
docs/transformers/build/lib/transformers/models/vit_mae/convert_vit_mae_to_pytorch.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Convert ViT MAE checkpoints from the original repository: https://github.com/facebookresearch/mae"""
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
|
| 19 |
+
import requests
|
| 20 |
+
import torch
|
| 21 |
+
from PIL import Image
|
| 22 |
+
|
| 23 |
+
from transformers import ViTMAEConfig, ViTMAEForPreTraining, ViTMAEImageProcessor
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def rename_key(name):
|
| 27 |
+
if "cls_token" in name:
|
| 28 |
+
name = name.replace("cls_token", "vit.embeddings.cls_token")
|
| 29 |
+
if "mask_token" in name:
|
| 30 |
+
name = name.replace("mask_token", "decoder.mask_token")
|
| 31 |
+
if "decoder_pos_embed" in name:
|
| 32 |
+
name = name.replace("decoder_pos_embed", "decoder.decoder_pos_embed")
|
| 33 |
+
if "pos_embed" in name and "decoder" not in name:
|
| 34 |
+
name = name.replace("pos_embed", "vit.embeddings.position_embeddings")
|
| 35 |
+
if "patch_embed.proj" in name:
|
| 36 |
+
name = name.replace("patch_embed.proj", "vit.embeddings.patch_embeddings.projection")
|
| 37 |
+
if "patch_embed.norm" in name:
|
| 38 |
+
name = name.replace("patch_embed.norm", "vit.embeddings.norm")
|
| 39 |
+
if "decoder_blocks" in name:
|
| 40 |
+
name = name.replace("decoder_blocks", "decoder.decoder_layers")
|
| 41 |
+
if "blocks" in name:
|
| 42 |
+
name = name.replace("blocks", "vit.encoder.layer")
|
| 43 |
+
if "attn.proj" in name:
|
| 44 |
+
name = name.replace("attn.proj", "attention.output.dense")
|
| 45 |
+
if "attn" in name:
|
| 46 |
+
name = name.replace("attn", "attention.self")
|
| 47 |
+
if "norm1" in name:
|
| 48 |
+
name = name.replace("norm1", "layernorm_before")
|
| 49 |
+
if "norm2" in name:
|
| 50 |
+
name = name.replace("norm2", "layernorm_after")
|
| 51 |
+
if "mlp.fc1" in name:
|
| 52 |
+
name = name.replace("mlp.fc1", "intermediate.dense")
|
| 53 |
+
if "mlp.fc2" in name:
|
| 54 |
+
name = name.replace("mlp.fc2", "output.dense")
|
| 55 |
+
if "decoder_embed" in name:
|
| 56 |
+
name = name.replace("decoder_embed", "decoder.decoder_embed")
|
| 57 |
+
if "decoder_norm" in name:
|
| 58 |
+
name = name.replace("decoder_norm", "decoder.decoder_norm")
|
| 59 |
+
if "decoder_pred" in name:
|
| 60 |
+
name = name.replace("decoder_pred", "decoder.decoder_pred")
|
| 61 |
+
if "norm.weight" in name and "decoder" not in name:
|
| 62 |
+
name = name.replace("norm.weight", "vit.layernorm.weight")
|
| 63 |
+
if "norm.bias" in name and "decoder" not in name:
|
| 64 |
+
name = name.replace("norm.bias", "vit.layernorm.bias")
|
| 65 |
+
|
| 66 |
+
return name
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def convert_state_dict(orig_state_dict, config):
|
| 70 |
+
for key in orig_state_dict.copy().keys():
|
| 71 |
+
val = orig_state_dict.pop(key)
|
| 72 |
+
|
| 73 |
+
if "qkv" in key:
|
| 74 |
+
key_split = key.split(".")
|
| 75 |
+
layer_num = int(key_split[1])
|
| 76 |
+
if "decoder_blocks" in key:
|
| 77 |
+
dim = config.decoder_hidden_size
|
| 78 |
+
prefix = "decoder.decoder_layers."
|
| 79 |
+
if "weight" in key:
|
| 80 |
+
orig_state_dict[f"{prefix}{layer_num}.attention.attention.query.weight"] = val[:dim, :]
|
| 81 |
+
orig_state_dict[f"{prefix}{layer_num}.attention.attention.key.weight"] = val[dim : dim * 2, :]
|
| 82 |
+
orig_state_dict[f"{prefix}{layer_num}.attention.attention.value.weight"] = val[-dim:, :]
|
| 83 |
+
elif "bias" in key:
|
| 84 |
+
orig_state_dict[f"{prefix}{layer_num}.attention.attention.query.bias"] = val[:dim]
|
| 85 |
+
orig_state_dict[f"{prefix}{layer_num}.attention.attention.key.bias"] = val[dim : dim * 2]
|
| 86 |
+
orig_state_dict[f"{prefix}{layer_num}.attention.attention.value.bias"] = val[-dim:]
|
| 87 |
+
else:
|
| 88 |
+
dim = config.hidden_size
|
| 89 |
+
prefix = "vit.encoder.layer."
|
| 90 |
+
if "weight" in key:
|
| 91 |
+
orig_state_dict[f"{prefix}{layer_num}.attention.attention.query.weight"] = val[:dim, :]
|
| 92 |
+
orig_state_dict[f"{prefix}{layer_num}.attention.attention.key.weight"] = val[dim : dim * 2, :]
|
| 93 |
+
orig_state_dict[f"{prefix}{layer_num}.attention.attention.value.weight"] = val[-dim:, :]
|
| 94 |
+
elif "bias" in key:
|
| 95 |
+
orig_state_dict[f"{prefix}{layer_num}.attention.attention.query.bias"] = val[:dim]
|
| 96 |
+
orig_state_dict[f"{prefix}{layer_num}.attention.attention.key.bias"] = val[dim : dim * 2]
|
| 97 |
+
orig_state_dict[f"{prefix}{layer_num}.attention.attention.value.bias"] = val[-dim:]
|
| 98 |
+
|
| 99 |
+
else:
|
| 100 |
+
orig_state_dict[rename_key(key)] = val
|
| 101 |
+
|
| 102 |
+
return orig_state_dict
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def convert_vit_mae_checkpoint(checkpoint_url, pytorch_dump_folder_path):
|
| 106 |
+
config = ViTMAEConfig()
|
| 107 |
+
if "large" in checkpoint_url:
|
| 108 |
+
config.hidden_size = 1024
|
| 109 |
+
config.intermediate_size = 4096
|
| 110 |
+
config.num_hidden_layers = 24
|
| 111 |
+
config.num_attention_heads = 16
|
| 112 |
+
elif "huge" in checkpoint_url:
|
| 113 |
+
config.patch_size = 14
|
| 114 |
+
config.hidden_size = 1280
|
| 115 |
+
config.intermediate_size = 5120
|
| 116 |
+
config.num_hidden_layers = 32
|
| 117 |
+
config.num_attention_heads = 16
|
| 118 |
+
|
| 119 |
+
model = ViTMAEForPreTraining(config)
|
| 120 |
+
|
| 121 |
+
state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu")["model"]
|
| 122 |
+
|
| 123 |
+
image_processor = ViTMAEImageProcessor(size=config.image_size)
|
| 124 |
+
|
| 125 |
+
new_state_dict = convert_state_dict(state_dict, config)
|
| 126 |
+
|
| 127 |
+
model.load_state_dict(new_state_dict)
|
| 128 |
+
model.eval()
|
| 129 |
+
|
| 130 |
+
url = "https://user-images.githubusercontent.com/11435359/147738734-196fd92f-9260-48d5-ba7e-bf103d29364d.jpg"
|
| 131 |
+
|
| 132 |
+
image = Image.open(requests.get(url, stream=True).raw)
|
| 133 |
+
image_processor = ViTMAEImageProcessor(size=config.image_size)
|
| 134 |
+
inputs = image_processor(images=image, return_tensors="pt")
|
| 135 |
+
|
| 136 |
+
# forward pass
|
| 137 |
+
torch.manual_seed(2)
|
| 138 |
+
outputs = model(**inputs)
|
| 139 |
+
logits = outputs.logits
|
| 140 |
+
|
| 141 |
+
if "large" in checkpoint_url:
|
| 142 |
+
expected_slice = torch.tensor(
|
| 143 |
+
[[-0.7309, -0.7128, -1.0169], [-1.0161, -0.9058, -1.1878], [-1.0478, -0.9411, -1.1911]]
|
| 144 |
+
)
|
| 145 |
+
elif "huge" in checkpoint_url:
|
| 146 |
+
expected_slice = torch.tensor(
|
| 147 |
+
[[-1.1599, -0.9199, -1.2221], [-1.1952, -0.9269, -1.2307], [-1.2143, -0.9337, -1.2262]]
|
| 148 |
+
)
|
| 149 |
+
else:
|
| 150 |
+
expected_slice = torch.tensor(
|
| 151 |
+
[[-0.9192, -0.8481, -1.1259], [-1.1349, -1.0034, -1.2599], [-1.1757, -1.0429, -1.2726]]
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
# verify logits
|
| 155 |
+
assert torch.allclose(logits[0, :3, :3], expected_slice, atol=1e-4)
|
| 156 |
+
|
| 157 |
+
print(f"Saving model to {pytorch_dump_folder_path}")
|
| 158 |
+
model.save_pretrained(pytorch_dump_folder_path)
|
| 159 |
+
|
| 160 |
+
print(f"Saving image processor to {pytorch_dump_folder_path}")
|
| 161 |
+
image_processor.save_pretrained(pytorch_dump_folder_path)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
if __name__ == "__main__":
|
| 165 |
+
parser = argparse.ArgumentParser()
|
| 166 |
+
# Required parameters
|
| 167 |
+
parser.add_argument(
|
| 168 |
+
"--checkpoint_url",
|
| 169 |
+
default="https://dl.fbaipublicfiles.com/mae/visualize/mae_visualize_vit_base.pth",
|
| 170 |
+
type=str,
|
| 171 |
+
help="URL of the checkpoint you'd like to convert.",
|
| 172 |
+
)
|
| 173 |
+
parser.add_argument(
|
| 174 |
+
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
args = parser.parse_args()
|
| 178 |
+
convert_vit_mae_checkpoint(args.checkpoint_url, args.pytorch_dump_folder_path)
|
docs/transformers/build/lib/transformers/models/vit_mae/modeling_tf_vit_mae.py
ADDED
|
@@ -0,0 +1,1375 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 Facebook AI and The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""TF 2.0 ViT MAE (masked autoencoder) model."""
|
| 16 |
+
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
import collections.abc
|
| 20 |
+
import math
|
| 21 |
+
from copy import deepcopy
|
| 22 |
+
from dataclasses import dataclass
|
| 23 |
+
from typing import Optional, Tuple, Union
|
| 24 |
+
|
| 25 |
+
import numpy as np
|
| 26 |
+
import tensorflow as tf
|
| 27 |
+
|
| 28 |
+
from ...activations_tf import get_tf_activation
|
| 29 |
+
from ...file_utils import (
|
| 30 |
+
ModelOutput,
|
| 31 |
+
add_start_docstrings,
|
| 32 |
+
add_start_docstrings_to_model_forward,
|
| 33 |
+
replace_return_docstrings,
|
| 34 |
+
)
|
| 35 |
+
from ...modeling_tf_outputs import TFBaseModelOutput
|
| 36 |
+
from ...modeling_tf_utils import (
|
| 37 |
+
TFModelInputType,
|
| 38 |
+
TFPreTrainedModel,
|
| 39 |
+
get_initializer,
|
| 40 |
+
keras,
|
| 41 |
+
keras_serializable,
|
| 42 |
+
unpack_inputs,
|
| 43 |
+
)
|
| 44 |
+
from ...tf_utils import shape_list, stable_softmax
|
| 45 |
+
from ...utils import logging
|
| 46 |
+
from .configuration_vit_mae import ViTMAEConfig
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
logger = logging.get_logger(__name__)
|
| 50 |
+
|
| 51 |
+
_CONFIG_FOR_DOC = "ViTMAEConfig"
|
| 52 |
+
_CHECKPOINT_FOR_DOC = "facebook/vit-mae-base"
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
@dataclass
|
| 56 |
+
class TFViTMAEModelOutput(ModelOutput):
|
| 57 |
+
"""
|
| 58 |
+
Class for TFViTMAEModel's outputs, with potential hidden states and attentions.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
| 62 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
| 63 |
+
mask (`tf.Tensor` of shape `(batch_size, sequence_length)`):
|
| 64 |
+
Tensor indicating which patches are masked (1) and which are not (0).
|
| 65 |
+
ids_restore (`tf.Tensor` of shape `(batch_size, sequence_length)`):
|
| 66 |
+
Tensor containing the original index of the (shuffled) masked patches.
|
| 67 |
+
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 68 |
+
Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
|
| 69 |
+
`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus
|
| 70 |
+
the initial embedding outputs.
|
| 71 |
+
attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 72 |
+
Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 73 |
+
sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
|
| 74 |
+
the self-attention heads.
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
last_hidden_state: Optional[tf.Tensor] = None
|
| 78 |
+
mask: Optional[tf.Tensor] = None
|
| 79 |
+
ids_restore: Optional[tf.Tensor] = None
|
| 80 |
+
hidden_states: Tuple[tf.Tensor] | None = None
|
| 81 |
+
attentions: Tuple[tf.Tensor] | None = None
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
@dataclass
|
| 85 |
+
class TFViTMAEDecoderOutput(ModelOutput):
|
| 86 |
+
"""
|
| 87 |
+
Class for TFViTMAEDecoder's outputs, with potential hidden states and attentions.
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
logits (`tf.Tensor` of shape `(batch_size, sequence_length, patch_size ** 2 * num_channels)`):
|
| 91 |
+
Pixel reconstruction logits.
|
| 92 |
+
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 93 |
+
Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
|
| 94 |
+
`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus
|
| 95 |
+
the initial embedding outputs.
|
| 96 |
+
attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 97 |
+
Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 98 |
+
sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
|
| 99 |
+
the self-attention heads.
|
| 100 |
+
"""
|
| 101 |
+
|
| 102 |
+
logits: Optional[tf.Tensor] = None
|
| 103 |
+
hidden_states: Tuple[tf.Tensor] | None = None
|
| 104 |
+
attentions: Tuple[tf.Tensor] | None = None
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
@dataclass
|
| 108 |
+
class TFViTMAEForPreTrainingOutput(ModelOutput):
|
| 109 |
+
"""
|
| 110 |
+
Class for TFViTMAEForPreTraining's outputs, with potential hidden states and attentions.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
loss (`tf.Tensor` of shape `(1,)`):
|
| 114 |
+
Pixel reconstruction loss.
|
| 115 |
+
logits (`tf.Tensor` of shape `(batch_size, sequence_length, patch_size ** 2 * num_channels)`):
|
| 116 |
+
Pixel reconstruction logits.
|
| 117 |
+
mask (`tf.Tensor` of shape `(batch_size, sequence_length)`):
|
| 118 |
+
Tensor indicating which patches are masked (1) and which are not (0).
|
| 119 |
+
ids_restore (`tf.Tensor` of shape `(batch_size, sequence_length)`):
|
| 120 |
+
Tensor containing the original index of the (shuffled) masked patches.
|
| 121 |
+
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 122 |
+
Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
|
| 123 |
+
`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus
|
| 124 |
+
the initial embedding outputs.
|
| 125 |
+
attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 126 |
+
Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 127 |
+
sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
|
| 128 |
+
the self-attention heads.
|
| 129 |
+
"""
|
| 130 |
+
|
| 131 |
+
loss: tf.Tensor | None = None
|
| 132 |
+
logits: Optional[tf.Tensor] = None
|
| 133 |
+
mask: Optional[tf.Tensor] = None
|
| 134 |
+
ids_restore: Optional[tf.Tensor] = None
|
| 135 |
+
hidden_states: Tuple[tf.Tensor] | None = None
|
| 136 |
+
attentions: Tuple[tf.Tensor] | None = None
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False):
|
| 140 |
+
"""
|
| 141 |
+
Create 2D sin/cos positional embeddings.
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
embed_dim (`int`):
|
| 145 |
+
Embedding dimension.
|
| 146 |
+
grid_size (`int`):
|
| 147 |
+
The grid height and width.
|
| 148 |
+
add_cls_token (`bool`, *optional*, defaults to `False`):
|
| 149 |
+
Whether or not to add a classification (CLS) token.
|
| 150 |
+
|
| 151 |
+
Returns:
|
| 152 |
+
(`tf.Tensor` of shape (grid_size*grid_size, embed_dim) or (1+grid_size*grid_size, embed_dim): the position
|
| 153 |
+
embeddings (with or without classification token)
|
| 154 |
+
"""
|
| 155 |
+
grid_h = tf.range(grid_size, dtype=tf.float32)
|
| 156 |
+
grid_w = tf.range(grid_size, dtype=tf.float32)
|
| 157 |
+
grid = tf.meshgrid(grid_w, grid_h) # here w goes first
|
| 158 |
+
grid = tf.stack(grid, axis=0)
|
| 159 |
+
|
| 160 |
+
grid = tf.reshape(grid, [2, 1, grid_size, grid_size])
|
| 161 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
| 162 |
+
if add_cls_token:
|
| 163 |
+
pos_embed = tf.concat([tf.zeros((1, embed_dim)), pos_embed], axis=0)
|
| 164 |
+
return pos_embed
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
| 168 |
+
if embed_dim % 2 != 0:
|
| 169 |
+
raise ValueError("embed_dim must be even")
|
| 170 |
+
|
| 171 |
+
# use half of dimensions to encode grid_h
|
| 172 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
| 173 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
| 174 |
+
|
| 175 |
+
emb = tf.concat([emb_h, emb_w], axis=1) # (H*W, D)
|
| 176 |
+
return emb
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
| 180 |
+
"""
|
| 181 |
+
embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
|
| 182 |
+
"""
|
| 183 |
+
if embed_dim % 2 != 0:
|
| 184 |
+
raise ValueError("embed_dim must be even")
|
| 185 |
+
|
| 186 |
+
omega = tf.range(embed_dim // 2, dtype="float32")
|
| 187 |
+
omega /= embed_dim / 2.0
|
| 188 |
+
omega = 1.0 / 10000**omega # (D/2,)
|
| 189 |
+
|
| 190 |
+
pos = tf.reshape(pos, [-1]) # (M,)
|
| 191 |
+
out = tf.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
| 192 |
+
|
| 193 |
+
# half of the positions get sinusoidal pattern and the rest gets
|
| 194 |
+
# cosine pattern and then they are concatenated
|
| 195 |
+
emb_sin = tf.sin(out) # (M, D/2)
|
| 196 |
+
emb_cos = tf.cos(out) # (M, D/2)
|
| 197 |
+
|
| 198 |
+
emb = tf.concat([emb_sin, emb_cos], axis=1) # (M, D)
|
| 199 |
+
return emb
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
class TFViTMAEEmbeddings(keras.layers.Layer):
|
| 203 |
+
"""
|
| 204 |
+
Construct the CLS token, position and patch embeddings.
|
| 205 |
+
|
| 206 |
+
"""
|
| 207 |
+
|
| 208 |
+
def __init__(self, config: ViTMAEConfig, **kwargs):
|
| 209 |
+
super().__init__(**kwargs)
|
| 210 |
+
|
| 211 |
+
self.patch_embeddings = TFViTMAEPatchEmbeddings(config, name="patch_embeddings")
|
| 212 |
+
self.num_patches = self.patch_embeddings.num_patches
|
| 213 |
+
|
| 214 |
+
self.config = config
|
| 215 |
+
|
| 216 |
+
def build(self, input_shape=None):
|
| 217 |
+
self.cls_token = self.add_weight(
|
| 218 |
+
shape=(1, 1, self.config.hidden_size),
|
| 219 |
+
initializer=tf.random_normal_initializer(stddev=self.config.initializer_range),
|
| 220 |
+
trainable=True,
|
| 221 |
+
name="cls_token",
|
| 222 |
+
)
|
| 223 |
+
self.position_embeddings = self.add_weight(
|
| 224 |
+
shape=(1, self.num_patches + 1, self.config.hidden_size),
|
| 225 |
+
initializer="zeros",
|
| 226 |
+
trainable=False, # fixed sin-cos embedding
|
| 227 |
+
name="position_embeddings",
|
| 228 |
+
)
|
| 229 |
+
pos_embed = get_2d_sincos_pos_embed(
|
| 230 |
+
self.position_embeddings.shape[-1],
|
| 231 |
+
int(self.patch_embeddings.num_patches**0.5),
|
| 232 |
+
add_cls_token=True,
|
| 233 |
+
)[None, ...]
|
| 234 |
+
self.position_embeddings.assign(pos_embed)
|
| 235 |
+
|
| 236 |
+
if self.built:
|
| 237 |
+
return
|
| 238 |
+
self.built = True
|
| 239 |
+
if getattr(self, "patch_embeddings", None) is not None:
|
| 240 |
+
with tf.name_scope(self.patch_embeddings.name):
|
| 241 |
+
self.patch_embeddings.build(None)
|
| 242 |
+
|
| 243 |
+
def interpolate_pos_encoding(self, embeddings, height, width) -> tf.Tensor:
|
| 244 |
+
"""
|
| 245 |
+
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
|
| 246 |
+
resolution images.
|
| 247 |
+
|
| 248 |
+
Source:
|
| 249 |
+
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
| 250 |
+
"""
|
| 251 |
+
|
| 252 |
+
batch_size, seq_len, dim = shape_list(embeddings)
|
| 253 |
+
num_patches = seq_len - 1
|
| 254 |
+
|
| 255 |
+
_, num_positions, _ = shape_list(self.position_embeddings)
|
| 256 |
+
num_positions -= 1
|
| 257 |
+
|
| 258 |
+
if num_patches == num_positions and height == width:
|
| 259 |
+
return self.position_embeddings
|
| 260 |
+
class_pos_embed = self.position_embeddings[:, :1]
|
| 261 |
+
patch_pos_embed = self.position_embeddings[:, 1:]
|
| 262 |
+
h0 = height // self.config.patch_size
|
| 263 |
+
w0 = width // self.config.patch_size
|
| 264 |
+
patch_pos_embed = tf.image.resize(
|
| 265 |
+
images=tf.reshape(
|
| 266 |
+
patch_pos_embed, shape=(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
|
| 267 |
+
),
|
| 268 |
+
size=(h0, w0),
|
| 269 |
+
method="bicubic",
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
patch_pos_embed = tf.reshape(tensor=patch_pos_embed, shape=(1, -1, dim))
|
| 273 |
+
return tf.concat(values=(class_pos_embed, patch_pos_embed), axis=1)
|
| 274 |
+
|
| 275 |
+
def random_masking(self, sequence: tf.Tensor, noise: tf.Tensor | None = None):
|
| 276 |
+
"""
|
| 277 |
+
Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random
|
| 278 |
+
noise.
|
| 279 |
+
|
| 280 |
+
Args:
|
| 281 |
+
sequence (`tf.Tensor` of shape `(batch_size, sequence_length, dim)`)
|
| 282 |
+
noise (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*) which is
|
| 283 |
+
mainly used for testing purposes to control randomness and maintain the reproducibility
|
| 284 |
+
"""
|
| 285 |
+
batch_size, seq_length, dim = shape_list(sequence)
|
| 286 |
+
len_keep = int(seq_length * (1 - self.config.mask_ratio))
|
| 287 |
+
|
| 288 |
+
if noise is None:
|
| 289 |
+
noise = tf.random.uniform(shape=(batch_size, seq_length), minval=0.0, maxval=1.0) # noise in [0, 1)
|
| 290 |
+
|
| 291 |
+
# sort noise for each sample
|
| 292 |
+
ids_shuffle = tf.argsort(noise, axis=1) # ascend: small is keep, large is remove
|
| 293 |
+
ids_restore = tf.argsort(ids_shuffle, axis=1)
|
| 294 |
+
|
| 295 |
+
# keep the first subset
|
| 296 |
+
ids_keep = ids_shuffle[:, :len_keep]
|
| 297 |
+
sequence_unmasked = tf.gather(
|
| 298 |
+
sequence,
|
| 299 |
+
axis=1,
|
| 300 |
+
batch_dims=1,
|
| 301 |
+
indices=ids_keep,
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
# generate the binary mask: 0 is keep, 1 is remove
|
| 305 |
+
# this hack is needed because TF's EagerTensors don't support
|
| 306 |
+
# assignment
|
| 307 |
+
mask_keep = tf.zeros((batch_size, len_keep))
|
| 308 |
+
mask_remove = tf.ones((batch_size, seq_length - len_keep))
|
| 309 |
+
mask = tf.concat([mask_keep, mask_remove], axis=-1)
|
| 310 |
+
|
| 311 |
+
# unshuffle to get the binary mask
|
| 312 |
+
mask = tf.gather(mask, axis=1, batch_dims=1, indices=ids_restore)
|
| 313 |
+
|
| 314 |
+
return sequence_unmasked, mask, ids_restore
|
| 315 |
+
|
| 316 |
+
def call(
|
| 317 |
+
self, pixel_values: tf.Tensor, noise: Optional[tf.Tensor] = None, interpolate_pos_encoding: bool = False
|
| 318 |
+
) -> tf.Tensor:
|
| 319 |
+
batch_size, num_channels, height, width = shape_list(pixel_values)
|
| 320 |
+
embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
|
| 321 |
+
if interpolate_pos_encoding:
|
| 322 |
+
position_embeddings = self.interpolate_pos_encoding(embeddings, height, width)
|
| 323 |
+
else:
|
| 324 |
+
position_embeddings = self.position_embeddings
|
| 325 |
+
# add position embeddings w/o cls token
|
| 326 |
+
embeddings = embeddings + position_embeddings[:, 1:, :]
|
| 327 |
+
|
| 328 |
+
# masking: length -> length * config.mask_ratio
|
| 329 |
+
embeddings, mask, ids_restore = self.random_masking(embeddings, noise)
|
| 330 |
+
|
| 331 |
+
# append cls token
|
| 332 |
+
cls_token = self.cls_token + position_embeddings[:, :1, :]
|
| 333 |
+
cls_tokens = tf.tile(cls_token, (shape_list(embeddings)[0], 1, 1))
|
| 334 |
+
embeddings = tf.concat([cls_tokens, embeddings], axis=1)
|
| 335 |
+
|
| 336 |
+
return embeddings, mask, ids_restore
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
class TFViTMAEPatchEmbeddings(keras.layers.Layer):
|
| 340 |
+
"""
|
| 341 |
+
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
|
| 342 |
+
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
|
| 343 |
+
Transformer.
|
| 344 |
+
"""
|
| 345 |
+
|
| 346 |
+
def __init__(self, config: ViTMAEConfig, **kwargs):
|
| 347 |
+
super().__init__(**kwargs)
|
| 348 |
+
image_size, patch_size = config.image_size, config.patch_size
|
| 349 |
+
num_channels, hidden_size = config.num_channels, config.hidden_size
|
| 350 |
+
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
|
| 351 |
+
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
|
| 352 |
+
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
| 353 |
+
self.image_size = image_size
|
| 354 |
+
self.patch_size = patch_size
|
| 355 |
+
self.num_patches = num_patches
|
| 356 |
+
self.num_channels = num_channels
|
| 357 |
+
self.config = config
|
| 358 |
+
|
| 359 |
+
self.projection = keras.layers.Conv2D(
|
| 360 |
+
filters=hidden_size,
|
| 361 |
+
kernel_size=patch_size,
|
| 362 |
+
strides=patch_size,
|
| 363 |
+
padding="valid",
|
| 364 |
+
data_format="channels_last",
|
| 365 |
+
kernel_initializer="glorot_uniform", # following torch.nn.Linear
|
| 366 |
+
bias_initializer="zeros",
|
| 367 |
+
name="projection",
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
def call(
|
| 371 |
+
self, pixel_values: tf.Tensor, training: bool = False, interpolate_pos_encoding: bool = False
|
| 372 |
+
) -> tf.Tensor:
|
| 373 |
+
batch_size, num_channels, height, width = shape_list(pixel_values)
|
| 374 |
+
if tf.executing_eagerly():
|
| 375 |
+
if num_channels != self.num_channels:
|
| 376 |
+
raise ValueError(
|
| 377 |
+
"Make sure that the channel dimension of the pixel values match with the one set in the"
|
| 378 |
+
" configuration."
|
| 379 |
+
)
|
| 380 |
+
if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]):
|
| 381 |
+
raise ValueError(
|
| 382 |
+
f"Input image size ({height}*{width}) doesn't match model"
|
| 383 |
+
f" ({self.image_size[0]}*{self.image_size[1]})."
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
# When running on CPU, `keras.layers.Conv2D` doesn't support `NCHW` format.
|
| 387 |
+
# So change the input format from `NCHW` to `NHWC`.
|
| 388 |
+
# shape = (batch_size, in_height, in_width, in_channels=num_channels)
|
| 389 |
+
pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1))
|
| 390 |
+
|
| 391 |
+
projection = self.projection(pixel_values)
|
| 392 |
+
|
| 393 |
+
# Change the 2D spatial dimensions to a single temporal dimension.
|
| 394 |
+
# shape = (batch_size, num_patches, out_channels=embed_dim)
|
| 395 |
+
num_patches = (width // self.patch_size[1]) * (height // self.patch_size[0])
|
| 396 |
+
x = tf.reshape(tensor=projection, shape=(batch_size, num_patches, -1))
|
| 397 |
+
|
| 398 |
+
return x
|
| 399 |
+
|
| 400 |
+
def build(self, input_shape=None):
|
| 401 |
+
if self.built:
|
| 402 |
+
return
|
| 403 |
+
self.built = True
|
| 404 |
+
if getattr(self, "projection", None) is not None:
|
| 405 |
+
with tf.name_scope(self.projection.name):
|
| 406 |
+
self.projection.build([None, None, None, self.num_channels])
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
# Copied from transformers.models.vit.modeling_tf_vit.TFViTSelfAttention with ViT->ViTMAE
|
| 410 |
+
class TFViTMAESelfAttention(keras.layers.Layer):
|
| 411 |
+
def __init__(self, config: ViTMAEConfig, **kwargs):
|
| 412 |
+
super().__init__(**kwargs)
|
| 413 |
+
|
| 414 |
+
if config.hidden_size % config.num_attention_heads != 0:
|
| 415 |
+
raise ValueError(
|
| 416 |
+
f"The hidden size ({config.hidden_size}) is not a multiple of the number "
|
| 417 |
+
f"of attention heads ({config.num_attention_heads})"
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
self.num_attention_heads = config.num_attention_heads
|
| 421 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
| 422 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
| 423 |
+
self.sqrt_att_head_size = math.sqrt(self.attention_head_size)
|
| 424 |
+
|
| 425 |
+
self.query = keras.layers.Dense(
|
| 426 |
+
units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
|
| 427 |
+
)
|
| 428 |
+
self.key = keras.layers.Dense(
|
| 429 |
+
units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
|
| 430 |
+
)
|
| 431 |
+
self.value = keras.layers.Dense(
|
| 432 |
+
units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
|
| 433 |
+
)
|
| 434 |
+
self.dropout = keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
|
| 435 |
+
self.config = config
|
| 436 |
+
|
| 437 |
+
def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
|
| 438 |
+
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
|
| 439 |
+
tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
|
| 440 |
+
|
| 441 |
+
# Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size]
|
| 442 |
+
return tf.transpose(tensor, perm=[0, 2, 1, 3])
|
| 443 |
+
|
| 444 |
+
def call(
|
| 445 |
+
self,
|
| 446 |
+
hidden_states: tf.Tensor,
|
| 447 |
+
head_mask: tf.Tensor,
|
| 448 |
+
output_attentions: bool,
|
| 449 |
+
training: bool = False,
|
| 450 |
+
) -> Tuple[tf.Tensor]:
|
| 451 |
+
batch_size = shape_list(hidden_states)[0]
|
| 452 |
+
mixed_query_layer = self.query(inputs=hidden_states)
|
| 453 |
+
mixed_key_layer = self.key(inputs=hidden_states)
|
| 454 |
+
mixed_value_layer = self.value(inputs=hidden_states)
|
| 455 |
+
query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
|
| 456 |
+
key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
|
| 457 |
+
value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
|
| 458 |
+
|
| 459 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
| 460 |
+
# (batch size, num_heads, seq_len_q, seq_len_k)
|
| 461 |
+
attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
|
| 462 |
+
dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype)
|
| 463 |
+
attention_scores = tf.divide(attention_scores, dk)
|
| 464 |
+
|
| 465 |
+
# Normalize the attention scores to probabilities.
|
| 466 |
+
attention_probs = stable_softmax(logits=attention_scores, axis=-1)
|
| 467 |
+
|
| 468 |
+
# This is actually dropping out entire tokens to attend to, which might
|
| 469 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
| 470 |
+
attention_probs = self.dropout(inputs=attention_probs, training=training)
|
| 471 |
+
|
| 472 |
+
# Mask heads if we want to
|
| 473 |
+
if head_mask is not None:
|
| 474 |
+
attention_probs = tf.multiply(attention_probs, head_mask)
|
| 475 |
+
|
| 476 |
+
attention_output = tf.matmul(attention_probs, value_layer)
|
| 477 |
+
attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3])
|
| 478 |
+
|
| 479 |
+
# (batch_size, seq_len_q, all_head_size)
|
| 480 |
+
attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size))
|
| 481 |
+
outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
|
| 482 |
+
|
| 483 |
+
return outputs
|
| 484 |
+
|
| 485 |
+
def build(self, input_shape=None):
|
| 486 |
+
if self.built:
|
| 487 |
+
return
|
| 488 |
+
self.built = True
|
| 489 |
+
if getattr(self, "query", None) is not None:
|
| 490 |
+
with tf.name_scope(self.query.name):
|
| 491 |
+
self.query.build([None, None, self.config.hidden_size])
|
| 492 |
+
if getattr(self, "key", None) is not None:
|
| 493 |
+
with tf.name_scope(self.key.name):
|
| 494 |
+
self.key.build([None, None, self.config.hidden_size])
|
| 495 |
+
if getattr(self, "value", None) is not None:
|
| 496 |
+
with tf.name_scope(self.value.name):
|
| 497 |
+
self.value.build([None, None, self.config.hidden_size])
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
# Copied from transformers.models.vit.modeling_tf_vit.TFViTSelfOutput with ViT->ViTMAE
|
| 501 |
+
class TFViTMAESelfOutput(keras.layers.Layer):
|
| 502 |
+
"""
|
| 503 |
+
The residual connection is defined in TFViTMAELayer instead of here (as is the case with other models), due to the
|
| 504 |
+
layernorm applied before each block.
|
| 505 |
+
"""
|
| 506 |
+
|
| 507 |
+
def __init__(self, config: ViTMAEConfig, **kwargs):
|
| 508 |
+
super().__init__(**kwargs)
|
| 509 |
+
|
| 510 |
+
self.dense = keras.layers.Dense(
|
| 511 |
+
units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
| 512 |
+
)
|
| 513 |
+
self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
|
| 514 |
+
self.config = config
|
| 515 |
+
|
| 516 |
+
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
|
| 517 |
+
hidden_states = self.dense(inputs=hidden_states)
|
| 518 |
+
hidden_states = self.dropout(inputs=hidden_states, training=training)
|
| 519 |
+
|
| 520 |
+
return hidden_states
|
| 521 |
+
|
| 522 |
+
def build(self, input_shape=None):
|
| 523 |
+
if self.built:
|
| 524 |
+
return
|
| 525 |
+
self.built = True
|
| 526 |
+
if getattr(self, "dense", None) is not None:
|
| 527 |
+
with tf.name_scope(self.dense.name):
|
| 528 |
+
self.dense.build([None, None, self.config.hidden_size])
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
# Copied from transformers.models.vit.modeling_tf_vit.TFViTAttention with ViT->ViTMAE
|
| 532 |
+
class TFViTMAEAttention(keras.layers.Layer):
|
| 533 |
+
def __init__(self, config: ViTMAEConfig, **kwargs):
|
| 534 |
+
super().__init__(**kwargs)
|
| 535 |
+
|
| 536 |
+
self.self_attention = TFViTMAESelfAttention(config, name="attention")
|
| 537 |
+
self.dense_output = TFViTMAESelfOutput(config, name="output")
|
| 538 |
+
|
| 539 |
+
def prune_heads(self, heads):
|
| 540 |
+
raise NotImplementedError
|
| 541 |
+
|
| 542 |
+
def call(
|
| 543 |
+
self,
|
| 544 |
+
input_tensor: tf.Tensor,
|
| 545 |
+
head_mask: tf.Tensor,
|
| 546 |
+
output_attentions: bool,
|
| 547 |
+
training: bool = False,
|
| 548 |
+
) -> Tuple[tf.Tensor]:
|
| 549 |
+
self_outputs = self.self_attention(
|
| 550 |
+
hidden_states=input_tensor, head_mask=head_mask, output_attentions=output_attentions, training=training
|
| 551 |
+
)
|
| 552 |
+
attention_output = self.dense_output(
|
| 553 |
+
hidden_states=self_outputs[0], input_tensor=input_tensor, training=training
|
| 554 |
+
)
|
| 555 |
+
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
| 556 |
+
|
| 557 |
+
return outputs
|
| 558 |
+
|
| 559 |
+
def build(self, input_shape=None):
|
| 560 |
+
if self.built:
|
| 561 |
+
return
|
| 562 |
+
self.built = True
|
| 563 |
+
if getattr(self, "self_attention", None) is not None:
|
| 564 |
+
with tf.name_scope(self.self_attention.name):
|
| 565 |
+
self.self_attention.build(None)
|
| 566 |
+
if getattr(self, "dense_output", None) is not None:
|
| 567 |
+
with tf.name_scope(self.dense_output.name):
|
| 568 |
+
self.dense_output.build(None)
|
| 569 |
+
|
| 570 |
+
|
| 571 |
+
# Copied from transformers.models.vit.modeling_tf_vit.TFViTIntermediate with ViT->ViTMAE
|
| 572 |
+
class TFViTMAEIntermediate(keras.layers.Layer):
|
| 573 |
+
def __init__(self, config: ViTMAEConfig, **kwargs):
|
| 574 |
+
super().__init__(**kwargs)
|
| 575 |
+
|
| 576 |
+
self.dense = keras.layers.Dense(
|
| 577 |
+
units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
| 578 |
+
)
|
| 579 |
+
|
| 580 |
+
if isinstance(config.hidden_act, str):
|
| 581 |
+
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
|
| 582 |
+
else:
|
| 583 |
+
self.intermediate_act_fn = config.hidden_act
|
| 584 |
+
self.config = config
|
| 585 |
+
|
| 586 |
+
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
|
| 587 |
+
hidden_states = self.dense(inputs=hidden_states)
|
| 588 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
| 589 |
+
|
| 590 |
+
return hidden_states
|
| 591 |
+
|
| 592 |
+
def build(self, input_shape=None):
|
| 593 |
+
if self.built:
|
| 594 |
+
return
|
| 595 |
+
self.built = True
|
| 596 |
+
if getattr(self, "dense", None) is not None:
|
| 597 |
+
with tf.name_scope(self.dense.name):
|
| 598 |
+
self.dense.build([None, None, self.config.hidden_size])
|
| 599 |
+
|
| 600 |
+
|
| 601 |
+
# Copied from transformers.models.vit.modeling_tf_vit.TFViTOutput with ViT->ViTMAE
|
| 602 |
+
class TFViTMAEOutput(keras.layers.Layer):
|
| 603 |
+
def __init__(self, config: ViTMAEConfig, **kwargs):
|
| 604 |
+
super().__init__(**kwargs)
|
| 605 |
+
|
| 606 |
+
self.dense = keras.layers.Dense(
|
| 607 |
+
units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
| 608 |
+
)
|
| 609 |
+
self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
|
| 610 |
+
self.config = config
|
| 611 |
+
|
| 612 |
+
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
|
| 613 |
+
hidden_states = self.dense(inputs=hidden_states)
|
| 614 |
+
hidden_states = self.dropout(inputs=hidden_states, training=training)
|
| 615 |
+
hidden_states = hidden_states + input_tensor
|
| 616 |
+
|
| 617 |
+
return hidden_states
|
| 618 |
+
|
| 619 |
+
def build(self, input_shape=None):
|
| 620 |
+
if self.built:
|
| 621 |
+
return
|
| 622 |
+
self.built = True
|
| 623 |
+
if getattr(self, "dense", None) is not None:
|
| 624 |
+
with tf.name_scope(self.dense.name):
|
| 625 |
+
self.dense.build([None, None, self.config.intermediate_size])
|
| 626 |
+
|
| 627 |
+
|
| 628 |
+
# Copied from transformers.models.vit.modeling_tf_vit.TFViTLayer with ViT->ViTMAE
|
| 629 |
+
class TFViTMAELayer(keras.layers.Layer):
|
| 630 |
+
"""This corresponds to the Block class in the timm implementation."""
|
| 631 |
+
|
| 632 |
+
def __init__(self, config: ViTMAEConfig, **kwargs):
|
| 633 |
+
super().__init__(**kwargs)
|
| 634 |
+
|
| 635 |
+
self.attention = TFViTMAEAttention(config, name="attention")
|
| 636 |
+
self.intermediate = TFViTMAEIntermediate(config, name="intermediate")
|
| 637 |
+
self.vit_output = TFViTMAEOutput(config, name="output")
|
| 638 |
+
|
| 639 |
+
self.layernorm_before = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_before")
|
| 640 |
+
self.layernorm_after = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_after")
|
| 641 |
+
self.config = config
|
| 642 |
+
|
| 643 |
+
def call(
|
| 644 |
+
self,
|
| 645 |
+
hidden_states: tf.Tensor,
|
| 646 |
+
head_mask: tf.Tensor,
|
| 647 |
+
output_attentions: bool,
|
| 648 |
+
training: bool = False,
|
| 649 |
+
) -> Tuple[tf.Tensor]:
|
| 650 |
+
attention_outputs = self.attention(
|
| 651 |
+
# in ViTMAE, layernorm is applied before self-attention
|
| 652 |
+
input_tensor=self.layernorm_before(inputs=hidden_states),
|
| 653 |
+
head_mask=head_mask,
|
| 654 |
+
output_attentions=output_attentions,
|
| 655 |
+
training=training,
|
| 656 |
+
)
|
| 657 |
+
attention_output = attention_outputs[0]
|
| 658 |
+
|
| 659 |
+
# first residual connection
|
| 660 |
+
hidden_states = attention_output + hidden_states
|
| 661 |
+
|
| 662 |
+
# in ViTMAE, layernorm is also applied after self-attention
|
| 663 |
+
layer_output = self.layernorm_after(inputs=hidden_states)
|
| 664 |
+
|
| 665 |
+
intermediate_output = self.intermediate(hidden_states=layer_output)
|
| 666 |
+
|
| 667 |
+
# second residual connection is done here
|
| 668 |
+
layer_output = self.vit_output(
|
| 669 |
+
hidden_states=intermediate_output, input_tensor=hidden_states, training=training
|
| 670 |
+
)
|
| 671 |
+
outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
|
| 672 |
+
|
| 673 |
+
return outputs
|
| 674 |
+
|
| 675 |
+
def build(self, input_shape=None):
|
| 676 |
+
if self.built:
|
| 677 |
+
return
|
| 678 |
+
self.built = True
|
| 679 |
+
if getattr(self, "attention", None) is not None:
|
| 680 |
+
with tf.name_scope(self.attention.name):
|
| 681 |
+
self.attention.build(None)
|
| 682 |
+
if getattr(self, "intermediate", None) is not None:
|
| 683 |
+
with tf.name_scope(self.intermediate.name):
|
| 684 |
+
self.intermediate.build(None)
|
| 685 |
+
if getattr(self, "vit_output", None) is not None:
|
| 686 |
+
with tf.name_scope(self.vit_output.name):
|
| 687 |
+
self.vit_output.build(None)
|
| 688 |
+
if getattr(self, "layernorm_before", None) is not None:
|
| 689 |
+
with tf.name_scope(self.layernorm_before.name):
|
| 690 |
+
self.layernorm_before.build([None, None, self.config.hidden_size])
|
| 691 |
+
if getattr(self, "layernorm_after", None) is not None:
|
| 692 |
+
with tf.name_scope(self.layernorm_after.name):
|
| 693 |
+
self.layernorm_after.build([None, None, self.config.hidden_size])
|
| 694 |
+
|
| 695 |
+
|
| 696 |
+
# Copied from transformers.models.vit.modeling_tf_vit.TFViTEncoder with ViT->ViTMAE
|
| 697 |
+
class TFViTMAEEncoder(keras.layers.Layer):
|
| 698 |
+
def __init__(self, config: ViTMAEConfig, **kwargs):
|
| 699 |
+
super().__init__(**kwargs)
|
| 700 |
+
|
| 701 |
+
self.layer = [TFViTMAELayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)]
|
| 702 |
+
|
| 703 |
+
def call(
|
| 704 |
+
self,
|
| 705 |
+
hidden_states: tf.Tensor,
|
| 706 |
+
head_mask: tf.Tensor,
|
| 707 |
+
output_attentions: bool,
|
| 708 |
+
output_hidden_states: bool,
|
| 709 |
+
return_dict: bool,
|
| 710 |
+
training: bool = False,
|
| 711 |
+
) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
|
| 712 |
+
all_hidden_states = () if output_hidden_states else None
|
| 713 |
+
all_attentions = () if output_attentions else None
|
| 714 |
+
|
| 715 |
+
for i, layer_module in enumerate(self.layer):
|
| 716 |
+
if output_hidden_states:
|
| 717 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 718 |
+
|
| 719 |
+
layer_outputs = layer_module(
|
| 720 |
+
hidden_states=hidden_states,
|
| 721 |
+
head_mask=head_mask[i],
|
| 722 |
+
output_attentions=output_attentions,
|
| 723 |
+
training=training,
|
| 724 |
+
)
|
| 725 |
+
hidden_states = layer_outputs[0]
|
| 726 |
+
|
| 727 |
+
if output_attentions:
|
| 728 |
+
all_attentions = all_attentions + (layer_outputs[1],)
|
| 729 |
+
|
| 730 |
+
# Add last layer
|
| 731 |
+
if output_hidden_states:
|
| 732 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 733 |
+
|
| 734 |
+
if not return_dict:
|
| 735 |
+
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
|
| 736 |
+
|
| 737 |
+
return TFBaseModelOutput(
|
| 738 |
+
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
|
| 739 |
+
)
|
| 740 |
+
|
| 741 |
+
def build(self, input_shape=None):
|
| 742 |
+
if self.built:
|
| 743 |
+
return
|
| 744 |
+
self.built = True
|
| 745 |
+
if getattr(self, "layer", None) is not None:
|
| 746 |
+
for layer in self.layer:
|
| 747 |
+
with tf.name_scope(layer.name):
|
| 748 |
+
layer.build(None)
|
| 749 |
+
|
| 750 |
+
|
| 751 |
+
@keras_serializable
|
| 752 |
+
class TFViTMAEMainLayer(keras.layers.Layer):
|
| 753 |
+
config_class = ViTMAEConfig
|
| 754 |
+
|
| 755 |
+
def __init__(self, config: ViTMAEConfig, **kwargs):
|
| 756 |
+
super().__init__(**kwargs)
|
| 757 |
+
|
| 758 |
+
self.config = config
|
| 759 |
+
|
| 760 |
+
self.embeddings = TFViTMAEEmbeddings(config, name="embeddings")
|
| 761 |
+
self.encoder = TFViTMAEEncoder(config, name="encoder")
|
| 762 |
+
self.layernorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm")
|
| 763 |
+
|
| 764 |
+
def get_input_embeddings(self) -> keras.layers.Layer:
|
| 765 |
+
return self.embeddings.patch_embeddings
|
| 766 |
+
|
| 767 |
+
def _prune_heads(self, heads_to_prune):
|
| 768 |
+
"""
|
| 769 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
| 770 |
+
class PreTrainedModel
|
| 771 |
+
"""
|
| 772 |
+
raise NotImplementedError
|
| 773 |
+
|
| 774 |
+
@unpack_inputs
|
| 775 |
+
def call(
|
| 776 |
+
self,
|
| 777 |
+
pixel_values: TFModelInputType | None = None,
|
| 778 |
+
noise: Optional[tf.Tensor] = None,
|
| 779 |
+
head_mask: np.ndarray | tf.Tensor | None = None,
|
| 780 |
+
output_attentions: Optional[bool] = None,
|
| 781 |
+
output_hidden_states: Optional[bool] = None,
|
| 782 |
+
return_dict: Optional[bool] = None,
|
| 783 |
+
training: bool = False,
|
| 784 |
+
interpolate_pos_encoding: bool = False,
|
| 785 |
+
) -> Union[TFViTMAEModelOutput, Tuple[tf.Tensor]]:
|
| 786 |
+
embedding_output, mask, ids_restore = self.embeddings(
|
| 787 |
+
pixel_values=pixel_values,
|
| 788 |
+
training=training,
|
| 789 |
+
noise=noise,
|
| 790 |
+
interpolate_pos_encoding=interpolate_pos_encoding,
|
| 791 |
+
)
|
| 792 |
+
|
| 793 |
+
# Prepare head mask if needed
|
| 794 |
+
# 1.0 in head_mask indicate we keep the head
|
| 795 |
+
# attention_probs has shape bsz x n_heads x N x N
|
| 796 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
| 797 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
| 798 |
+
if head_mask is not None:
|
| 799 |
+
raise NotImplementedError
|
| 800 |
+
else:
|
| 801 |
+
head_mask = [None] * self.config.num_hidden_layers
|
| 802 |
+
|
| 803 |
+
encoder_outputs = self.encoder(
|
| 804 |
+
embedding_output,
|
| 805 |
+
head_mask=head_mask,
|
| 806 |
+
output_attentions=output_attentions,
|
| 807 |
+
output_hidden_states=output_hidden_states,
|
| 808 |
+
return_dict=return_dict,
|
| 809 |
+
training=training,
|
| 810 |
+
)
|
| 811 |
+
|
| 812 |
+
sequence_output = encoder_outputs[0]
|
| 813 |
+
sequence_output = self.layernorm(inputs=sequence_output)
|
| 814 |
+
|
| 815 |
+
if not return_dict:
|
| 816 |
+
return (sequence_output, mask, ids_restore) + encoder_outputs[1:]
|
| 817 |
+
|
| 818 |
+
return TFViTMAEModelOutput(
|
| 819 |
+
last_hidden_state=sequence_output,
|
| 820 |
+
mask=mask,
|
| 821 |
+
ids_restore=ids_restore,
|
| 822 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 823 |
+
attentions=encoder_outputs.attentions,
|
| 824 |
+
)
|
| 825 |
+
|
| 826 |
+
def build(self, input_shape=None):
|
| 827 |
+
if self.built:
|
| 828 |
+
return
|
| 829 |
+
self.built = True
|
| 830 |
+
if getattr(self, "embeddings", None) is not None:
|
| 831 |
+
with tf.name_scope(self.embeddings.name):
|
| 832 |
+
self.embeddings.build(None)
|
| 833 |
+
if getattr(self, "encoder", None) is not None:
|
| 834 |
+
with tf.name_scope(self.encoder.name):
|
| 835 |
+
self.encoder.build(None)
|
| 836 |
+
if getattr(self, "layernorm", None) is not None:
|
| 837 |
+
with tf.name_scope(self.layernorm.name):
|
| 838 |
+
self.layernorm.build([None, None, self.config.hidden_size])
|
| 839 |
+
|
| 840 |
+
|
| 841 |
+
class TFViTMAEPreTrainedModel(TFPreTrainedModel):
|
| 842 |
+
"""
|
| 843 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 844 |
+
models.
|
| 845 |
+
"""
|
| 846 |
+
|
| 847 |
+
config_class = ViTMAEConfig
|
| 848 |
+
base_model_prefix = "vit"
|
| 849 |
+
main_input_name = "pixel_values"
|
| 850 |
+
|
| 851 |
+
|
| 852 |
+
VIT_MAE_START_DOCSTRING = r"""
|
| 853 |
+
This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 854 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| 855 |
+
etc.)
|
| 856 |
+
|
| 857 |
+
This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
|
| 858 |
+
as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
|
| 859 |
+
behavior.
|
| 860 |
+
|
| 861 |
+
<Tip>
|
| 862 |
+
|
| 863 |
+
TensorFlow models and layers in `transformers` accept two formats as input:
|
| 864 |
+
|
| 865 |
+
- having all inputs as keyword arguments (like PyTorch models), or
|
| 866 |
+
- having all inputs as a list, tuple or dict in the first positional argument.
|
| 867 |
+
|
| 868 |
+
The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
|
| 869 |
+
and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
|
| 870 |
+
pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
|
| 871 |
+
format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
|
| 872 |
+
the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
|
| 873 |
+
positional argument:
|
| 874 |
+
|
| 875 |
+
- a single Tensor with `pixel_values` only and nothing else: `model(pixel_values)`
|
| 876 |
+
- a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
|
| 877 |
+
`model([pixel_values, attention_mask])` or `model([pixel_values, attention_mask, token_type_ids])`
|
| 878 |
+
- a dictionary with one or several input Tensors associated to the input names given in the docstring:
|
| 879 |
+
`model({"pixel_values": pixel_values, "token_type_ids": token_type_ids})`
|
| 880 |
+
|
| 881 |
+
Note that when creating models and layers with
|
| 882 |
+
[subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
|
| 883 |
+
about any of this, as you can just pass inputs like you would to any other Python function!
|
| 884 |
+
|
| 885 |
+
</Tip>
|
| 886 |
+
|
| 887 |
+
Args:
|
| 888 |
+
config ([`ViTMAEConfig`]): Model configuration class with all the parameters of the model.
|
| 889 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 890 |
+
configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
|
| 891 |
+
"""
|
| 892 |
+
|
| 893 |
+
VIT_MAE_INPUTS_DOCSTRING = r"""
|
| 894 |
+
Args:
|
| 895 |
+
pixel_values (`np.ndarray`, `tf.Tensor`, `List[tf.Tensor]` ``Dict[str, tf.Tensor]` or `Dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`):
|
| 896 |
+
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`]
|
| 897 |
+
for details.
|
| 898 |
+
|
| 899 |
+
head_mask (`np.ndarray` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
| 900 |
+
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
| 901 |
+
- 1 indicates the head is **not masked**,
|
| 902 |
+
- 0 indicates the head is **masked**.
|
| 903 |
+
|
| 904 |
+
output_attentions (`bool`, *optional*):
|
| 905 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 906 |
+
tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
|
| 907 |
+
config will be used instead.
|
| 908 |
+
|
| 909 |
+
output_hidden_states (`bool`, *optional*):
|
| 910 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 911 |
+
more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
|
| 912 |
+
used instead.
|
| 913 |
+
|
| 914 |
+
return_dict (`bool`, *optional*):
|
| 915 |
+
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. This argument can be used
|
| 916 |
+
in eager mode, in graph mode the value will always be set to True.
|
| 917 |
+
|
| 918 |
+
training (`bool`, *optional*, defaults to `False``):
|
| 919 |
+
Whether or not to use the model in training mode (some modules like dropout modules have different
|
| 920 |
+
behaviors between training and evaluation).
|
| 921 |
+
|
| 922 |
+
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
|
| 923 |
+
Whether to interpolate the position encodings at the encoder and decoder.
|
| 924 |
+
"""
|
| 925 |
+
|
| 926 |
+
|
| 927 |
+
@add_start_docstrings(
|
| 928 |
+
"The bare ViTMAE Model transformer outputting raw hidden-states without any specific head on top.",
|
| 929 |
+
VIT_MAE_START_DOCSTRING,
|
| 930 |
+
)
|
| 931 |
+
class TFViTMAEModel(TFViTMAEPreTrainedModel):
|
| 932 |
+
def __init__(self, config: ViTMAEConfig, *inputs, **kwargs):
|
| 933 |
+
super().__init__(config, *inputs, **kwargs)
|
| 934 |
+
|
| 935 |
+
self.vit = TFViTMAEMainLayer(config, name="vit")
|
| 936 |
+
|
| 937 |
+
def get_input_embeddings(self):
|
| 938 |
+
return self.vit.get_input_embeddings()
|
| 939 |
+
|
| 940 |
+
@unpack_inputs
|
| 941 |
+
@add_start_docstrings_to_model_forward(VIT_MAE_INPUTS_DOCSTRING)
|
| 942 |
+
@replace_return_docstrings(output_type=TFViTMAEModelOutput, config_class=_CONFIG_FOR_DOC)
|
| 943 |
+
def call(
|
| 944 |
+
self,
|
| 945 |
+
pixel_values: TFModelInputType | None = None,
|
| 946 |
+
noise: Optional[tf.Tensor] = None,
|
| 947 |
+
head_mask: np.ndarray | tf.Tensor | None = None,
|
| 948 |
+
output_attentions: Optional[bool] = None,
|
| 949 |
+
output_hidden_states: Optional[bool] = None,
|
| 950 |
+
return_dict: Optional[bool] = None,
|
| 951 |
+
training: bool = False,
|
| 952 |
+
interpolate_pos_encoding: bool = False,
|
| 953 |
+
) -> Union[TFViTMAEModelOutput, Tuple[tf.Tensor]]:
|
| 954 |
+
r"""
|
| 955 |
+
Returns:
|
| 956 |
+
|
| 957 |
+
Examples:
|
| 958 |
+
|
| 959 |
+
```python
|
| 960 |
+
>>> from transformers import AutoImageProcessor, TFViTMAEModel
|
| 961 |
+
>>> from PIL import Image
|
| 962 |
+
>>> import requests
|
| 963 |
+
|
| 964 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 965 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 966 |
+
|
| 967 |
+
>>> image_processor = AutoImageProcessor.from_pretrained("facebook/vit-mae-base")
|
| 968 |
+
>>> model = TFViTMAEModel.from_pretrained("facebook/vit-mae-base")
|
| 969 |
+
|
| 970 |
+
>>> inputs = image_processor(images=image, return_tensors="tf")
|
| 971 |
+
>>> outputs = model(**inputs)
|
| 972 |
+
>>> last_hidden_states = outputs.last_hidden_state
|
| 973 |
+
```"""
|
| 974 |
+
outputs = self.vit(
|
| 975 |
+
pixel_values=pixel_values,
|
| 976 |
+
noise=noise,
|
| 977 |
+
head_mask=head_mask,
|
| 978 |
+
output_attentions=output_attentions,
|
| 979 |
+
output_hidden_states=output_hidden_states,
|
| 980 |
+
return_dict=return_dict,
|
| 981 |
+
training=training,
|
| 982 |
+
interpolate_pos_encoding=interpolate_pos_encoding,
|
| 983 |
+
)
|
| 984 |
+
|
| 985 |
+
return outputs
|
| 986 |
+
|
| 987 |
+
def build(self, input_shape=None):
|
| 988 |
+
if self.built:
|
| 989 |
+
return
|
| 990 |
+
self.built = True
|
| 991 |
+
if getattr(self, "vit", None) is not None:
|
| 992 |
+
with tf.name_scope(self.vit.name):
|
| 993 |
+
self.vit.build(None)
|
| 994 |
+
|
| 995 |
+
|
| 996 |
+
class TFViTMAEDecoder(keras.layers.Layer):
|
| 997 |
+
def __init__(self, config, num_patches, **kwargs):
|
| 998 |
+
super().__init__(**kwargs)
|
| 999 |
+
self.decoder_embed = keras.layers.Dense(config.decoder_hidden_size, name="decoder_embed")
|
| 1000 |
+
|
| 1001 |
+
decoder_config = deepcopy(config)
|
| 1002 |
+
decoder_config.hidden_size = config.decoder_hidden_size
|
| 1003 |
+
decoder_config.num_hidden_layers = config.decoder_num_hidden_layers
|
| 1004 |
+
decoder_config.num_attention_heads = config.decoder_num_attention_heads
|
| 1005 |
+
decoder_config.intermediate_size = config.decoder_intermediate_size
|
| 1006 |
+
self.decoder_layers = [
|
| 1007 |
+
TFViTMAELayer(decoder_config, name=f"decoder_layers.{j}") for j in range(config.decoder_num_hidden_layers)
|
| 1008 |
+
]
|
| 1009 |
+
|
| 1010 |
+
self.decoder_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="decoder_norm")
|
| 1011 |
+
self.decoder_pred = keras.layers.Dense(
|
| 1012 |
+
config.patch_size**2 * config.num_channels,
|
| 1013 |
+
kernel_initializer=get_initializer(config.initializer_range),
|
| 1014 |
+
name="decoder_pred",
|
| 1015 |
+
) # encoder to decoder
|
| 1016 |
+
self.config = config
|
| 1017 |
+
self.num_patches = num_patches
|
| 1018 |
+
|
| 1019 |
+
def build(self, input_shape=None):
|
| 1020 |
+
self.mask_token = self.add_weight(
|
| 1021 |
+
shape=(1, 1, self.config.decoder_hidden_size),
|
| 1022 |
+
initializer=tf.random_normal_initializer(stddev=self.config.initializer_range),
|
| 1023 |
+
trainable=True,
|
| 1024 |
+
name="mask_token",
|
| 1025 |
+
)
|
| 1026 |
+
self.decoder_pos_embed = self.add_weight(
|
| 1027 |
+
shape=(1, self.num_patches + 1, self.config.decoder_hidden_size),
|
| 1028 |
+
initializer="zeros",
|
| 1029 |
+
trainable=False,
|
| 1030 |
+
name="decoder_pos_embed",
|
| 1031 |
+
)
|
| 1032 |
+
decoder_pos_embed = get_2d_sincos_pos_embed(
|
| 1033 |
+
self.decoder_pos_embed.shape[-1],
|
| 1034 |
+
int(self.num_patches**0.5),
|
| 1035 |
+
add_cls_token=True,
|
| 1036 |
+
)[None, ...]
|
| 1037 |
+
self.decoder_pos_embed.assign(decoder_pos_embed)
|
| 1038 |
+
|
| 1039 |
+
if self.built:
|
| 1040 |
+
return
|
| 1041 |
+
self.built = True
|
| 1042 |
+
if getattr(self, "decoder_embed", None) is not None:
|
| 1043 |
+
with tf.name_scope(self.decoder_embed.name):
|
| 1044 |
+
self.decoder_embed.build([None, None, self.config.hidden_size])
|
| 1045 |
+
if getattr(self, "decoder_norm", None) is not None:
|
| 1046 |
+
with tf.name_scope(self.decoder_norm.name):
|
| 1047 |
+
self.decoder_norm.build([None, None, self.config.decoder_hidden_size])
|
| 1048 |
+
if getattr(self, "decoder_pred", None) is not None:
|
| 1049 |
+
with tf.name_scope(self.decoder_pred.name):
|
| 1050 |
+
self.decoder_pred.build([None, None, self.config.decoder_hidden_size])
|
| 1051 |
+
if getattr(self, "decoder_layers", None) is not None:
|
| 1052 |
+
for layer in self.decoder_layers:
|
| 1053 |
+
with tf.name_scope(layer.name):
|
| 1054 |
+
layer.build(None)
|
| 1055 |
+
|
| 1056 |
+
def interpolate_pos_encoding(self, embeddings) -> tf.Tensor:
|
| 1057 |
+
"""
|
| 1058 |
+
This method is a modified version of the interpolation function for ViT-mae model at the deocder, that
|
| 1059 |
+
allows to interpolate the pre-trained decoder position encodings, to be able to use the model on higher
|
| 1060 |
+
resolution images.
|
| 1061 |
+
|
| 1062 |
+
Source:
|
| 1063 |
+
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
| 1064 |
+
"""
|
| 1065 |
+
|
| 1066 |
+
# [batch_size, num_patches + 1, hidden_size]
|
| 1067 |
+
_, num_positions, dim = shape_list(self.decoder_pos_embed)
|
| 1068 |
+
|
| 1069 |
+
# -1 removes the class dimension since we later append it without interpolation
|
| 1070 |
+
seq_len = shape_list(embeddings)[1] - 1
|
| 1071 |
+
num_positions = num_positions - 1
|
| 1072 |
+
|
| 1073 |
+
# Separation of class token and patch tokens
|
| 1074 |
+
class_pos_embed = self.decoder_pos_embed[:, :1, :]
|
| 1075 |
+
patch_pos_embed = self.decoder_pos_embed[:, 1:, :]
|
| 1076 |
+
|
| 1077 |
+
# interpolate the position embeddings
|
| 1078 |
+
patch_pos_embed = tf.image.resize(
|
| 1079 |
+
images=tf.reshape(patch_pos_embed, shape=(1, 1, -1, dim)),
|
| 1080 |
+
size=(1, seq_len),
|
| 1081 |
+
method="bicubic",
|
| 1082 |
+
)
|
| 1083 |
+
|
| 1084 |
+
# [1, seq_len, hidden_size]
|
| 1085 |
+
patch_pos_embed = tf.reshape(tensor=patch_pos_embed, shape=(1, -1, dim))
|
| 1086 |
+
# Adding the class token back
|
| 1087 |
+
return tf.concat(values=(class_pos_embed, patch_pos_embed), axis=1)
|
| 1088 |
+
|
| 1089 |
+
def call(
|
| 1090 |
+
self,
|
| 1091 |
+
hidden_states,
|
| 1092 |
+
ids_restore,
|
| 1093 |
+
output_attentions=False,
|
| 1094 |
+
output_hidden_states=False,
|
| 1095 |
+
return_dict=True,
|
| 1096 |
+
interpolate_pos_encoding=False,
|
| 1097 |
+
):
|
| 1098 |
+
# embed tokens
|
| 1099 |
+
x = self.decoder_embed(hidden_states)
|
| 1100 |
+
# append mask tokens to sequence
|
| 1101 |
+
mask_tokens = tf.tile(
|
| 1102 |
+
self.mask_token,
|
| 1103 |
+
(shape_list(x)[0], shape_list(ids_restore)[1] + 1 - shape_list(x)[1], 1),
|
| 1104 |
+
)
|
| 1105 |
+
x_ = tf.concat([x[:, 1:, :], mask_tokens], axis=1) # no cls token
|
| 1106 |
+
x_ = tf.gather(x_, axis=1, batch_dims=1, indices=ids_restore) # unshuffle
|
| 1107 |
+
x = tf.concat([x[:, :1, :], x_], axis=1) # append cls token
|
| 1108 |
+
if interpolate_pos_encoding:
|
| 1109 |
+
decoder_pos_embed = self.interpolate_pos_encoding(x)
|
| 1110 |
+
else:
|
| 1111 |
+
decoder_pos_embed = self.decoder_pos_embed
|
| 1112 |
+
# add pos embed
|
| 1113 |
+
hidden_states = x + decoder_pos_embed
|
| 1114 |
+
# apply Transformer layers (blocks)
|
| 1115 |
+
all_hidden_states = () if output_hidden_states else None
|
| 1116 |
+
all_self_attentions = () if output_attentions else None
|
| 1117 |
+
for i, layer_module in enumerate(self.decoder_layers):
|
| 1118 |
+
if output_hidden_states:
|
| 1119 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 1120 |
+
|
| 1121 |
+
layer_outputs = layer_module(
|
| 1122 |
+
hidden_states,
|
| 1123 |
+
head_mask=None,
|
| 1124 |
+
output_attentions=output_attentions,
|
| 1125 |
+
)
|
| 1126 |
+
|
| 1127 |
+
hidden_states = layer_outputs[0]
|
| 1128 |
+
|
| 1129 |
+
if output_attentions:
|
| 1130 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
| 1131 |
+
|
| 1132 |
+
if output_hidden_states:
|
| 1133 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 1134 |
+
|
| 1135 |
+
hidden_states = self.decoder_norm(hidden_states)
|
| 1136 |
+
|
| 1137 |
+
# predictor projection
|
| 1138 |
+
logits = self.decoder_pred(hidden_states)
|
| 1139 |
+
|
| 1140 |
+
# remove cls token
|
| 1141 |
+
logits = logits[:, 1:, :]
|
| 1142 |
+
|
| 1143 |
+
if not return_dict:
|
| 1144 |
+
return tuple(v for v in [logits, all_hidden_states, all_self_attentions] if v is not None)
|
| 1145 |
+
return TFViTMAEDecoderOutput(logits=logits, hidden_states=all_hidden_states, attentions=all_self_attentions)
|
| 1146 |
+
|
| 1147 |
+
|
| 1148 |
+
@add_start_docstrings(
|
| 1149 |
+
"The ViTMAE Model transformer with the decoder on top for self-supervised pre-training.",
|
| 1150 |
+
VIT_MAE_START_DOCSTRING,
|
| 1151 |
+
)
|
| 1152 |
+
class TFViTMAEForPreTraining(TFViTMAEPreTrainedModel):
|
| 1153 |
+
def __init__(self, config):
|
| 1154 |
+
super().__init__(config)
|
| 1155 |
+
self.config = config
|
| 1156 |
+
|
| 1157 |
+
self.vit = TFViTMAEMainLayer(config, name="vit")
|
| 1158 |
+
self.decoder = TFViTMAEDecoder(
|
| 1159 |
+
config,
|
| 1160 |
+
num_patches=self.vit.embeddings.num_patches,
|
| 1161 |
+
name="decoder",
|
| 1162 |
+
)
|
| 1163 |
+
|
| 1164 |
+
def get_input_embeddings(self):
|
| 1165 |
+
return self.vit.get_input_embeddings()
|
| 1166 |
+
|
| 1167 |
+
def _prune_heads(self, heads_to_prune):
|
| 1168 |
+
raise NotImplementedError
|
| 1169 |
+
|
| 1170 |
+
def patchify(self, pixel_values, interpolate_pos_encoding: bool = False):
|
| 1171 |
+
"""
|
| 1172 |
+
Args:
|
| 1173 |
+
pixel_values (`tf.Tensor` of shape `(batch_size, height, width, num_channels)` or `(batch_size, num_channels, height, width)`):
|
| 1174 |
+
Pixel values.
|
| 1175 |
+
interpolate_pos_encoding (`bool`, default `False`):
|
| 1176 |
+
interpolation flag passed during the forward pass.
|
| 1177 |
+
|
| 1178 |
+
Returns:
|
| 1179 |
+
`tf.Tensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
|
| 1180 |
+
Patchified pixel values.
|
| 1181 |
+
"""
|
| 1182 |
+
patch_size, num_channels = self.config.patch_size, self.config.num_channels
|
| 1183 |
+
# make sure channels are last
|
| 1184 |
+
if shape_list(pixel_values)[1] == num_channels:
|
| 1185 |
+
pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1))
|
| 1186 |
+
|
| 1187 |
+
# sanity checks
|
| 1188 |
+
if not interpolate_pos_encoding:
|
| 1189 |
+
tf.debugging.assert_equal(
|
| 1190 |
+
shape_list(pixel_values)[1],
|
| 1191 |
+
shape_list(pixel_values)[2],
|
| 1192 |
+
message="Make sure the pixel values have a squared size",
|
| 1193 |
+
)
|
| 1194 |
+
tf.debugging.assert_equal(
|
| 1195 |
+
shape_list(pixel_values)[1] % patch_size,
|
| 1196 |
+
0,
|
| 1197 |
+
message="Make sure the pixel values have a size that is divisible by the patch size",
|
| 1198 |
+
)
|
| 1199 |
+
tf.debugging.assert_equal(
|
| 1200 |
+
shape_list(pixel_values)[3],
|
| 1201 |
+
num_channels,
|
| 1202 |
+
message=(
|
| 1203 |
+
"Make sure the number of channels of the pixel values is equal to the one set in the configuration"
|
| 1204 |
+
),
|
| 1205 |
+
)
|
| 1206 |
+
|
| 1207 |
+
# patchify
|
| 1208 |
+
batch_size = shape_list(pixel_values)[0]
|
| 1209 |
+
num_patches_h = shape_list(pixel_values)[1] // patch_size
|
| 1210 |
+
num_patches_w = shape_list(pixel_values)[2] // patch_size
|
| 1211 |
+
patchified_pixel_values = tf.reshape(
|
| 1212 |
+
pixel_values,
|
| 1213 |
+
(batch_size, num_patches_h, patch_size, num_patches_w, patch_size, num_channels),
|
| 1214 |
+
)
|
| 1215 |
+
patchified_pixel_values = tf.einsum("nhpwqc->nhwpqc", patchified_pixel_values)
|
| 1216 |
+
patchified_pixel_values = tf.reshape(
|
| 1217 |
+
patchified_pixel_values,
|
| 1218 |
+
(batch_size, num_patches_h * num_patches_w, patch_size**2 * num_channels),
|
| 1219 |
+
)
|
| 1220 |
+
return patchified_pixel_values
|
| 1221 |
+
|
| 1222 |
+
def unpatchify(self, patchified_pixel_values, original_image_size: Optional[Tuple[int, int]] = None):
|
| 1223 |
+
"""
|
| 1224 |
+
Args:
|
| 1225 |
+
patchified_pixel_values (`tf.Tensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
|
| 1226 |
+
Patchified pixel values.
|
| 1227 |
+
original_image_size (`Tuple[int, int]`, *optional*):
|
| 1228 |
+
Original image size.
|
| 1229 |
+
|
| 1230 |
+
Returns:
|
| 1231 |
+
`tf.Tensor` of shape `(batch_size, height, width, num_channels)`:
|
| 1232 |
+
Pixel values.
|
| 1233 |
+
"""
|
| 1234 |
+
patch_size, num_channels = self.config.patch_size, self.config.num_channels
|
| 1235 |
+
original_image_size = (
|
| 1236 |
+
original_image_size
|
| 1237 |
+
if original_image_size is not None
|
| 1238 |
+
else (self.config.image_size, self.config.image_size)
|
| 1239 |
+
)
|
| 1240 |
+
original_height, original_width = original_image_size
|
| 1241 |
+
num_patches_h = original_height // patch_size
|
| 1242 |
+
num_patches_w = original_width // patch_size
|
| 1243 |
+
# sanity check
|
| 1244 |
+
tf.debugging.assert_equal(
|
| 1245 |
+
num_patches_h * num_patches_w,
|
| 1246 |
+
shape_list(patchified_pixel_values)[1],
|
| 1247 |
+
message=f"The number of patches in the patchified pixel values is {shape_list(patchified_pixel_values)[1]} does not match the patches of original image {num_patches_w}*{num_patches_h}",
|
| 1248 |
+
)
|
| 1249 |
+
|
| 1250 |
+
# unpatchify
|
| 1251 |
+
batch_size = shape_list(patchified_pixel_values)[0]
|
| 1252 |
+
patchified_pixel_values = tf.reshape(
|
| 1253 |
+
patchified_pixel_values,
|
| 1254 |
+
(batch_size, num_patches_h, num_patches_w, patch_size, patch_size, num_channels),
|
| 1255 |
+
)
|
| 1256 |
+
patchified_pixel_values = tf.einsum("nhwpqc->nhpwqc", patchified_pixel_values)
|
| 1257 |
+
pixel_values = tf.reshape(
|
| 1258 |
+
patchified_pixel_values,
|
| 1259 |
+
(batch_size, num_patches_h * patch_size, num_patches_w * patch_size, num_channels),
|
| 1260 |
+
)
|
| 1261 |
+
return pixel_values
|
| 1262 |
+
|
| 1263 |
+
def forward_loss(self, pixel_values, pred, mask, interpolate_pos_encoding: bool = False):
|
| 1264 |
+
"""
|
| 1265 |
+
Args:
|
| 1266 |
+
pixel_values (`tf.Tensor` of shape `(batch_size, height, width, num_channels)`):
|
| 1267 |
+
Pixel values.
|
| 1268 |
+
pred (`tf.Tensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
|
| 1269 |
+
Predicted pixel values.
|
| 1270 |
+
mask (`tf.Tensor` of shape `(batch_size, sequence_length)`):
|
| 1271 |
+
Tensor indicating which patches are masked (1) and which are not (0).
|
| 1272 |
+
interpolate_pos_encoding (`bool`, *optional*, default `False`):
|
| 1273 |
+
interpolation flag passed during the forward pass.
|
| 1274 |
+
|
| 1275 |
+
Returns:
|
| 1276 |
+
`tf.Tensor`: Pixel reconstruction loss.
|
| 1277 |
+
"""
|
| 1278 |
+
target = self.patchify(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
|
| 1279 |
+
if self.config.norm_pix_loss:
|
| 1280 |
+
mean = tf.reduce_mean(target, axis=-1, keepdims=True)
|
| 1281 |
+
var = tf.math.reduce_variance(target, axis=-1, keepdims=True)
|
| 1282 |
+
target = (target - mean) / (var + 1.0e-6) ** 0.5
|
| 1283 |
+
|
| 1284 |
+
loss = (pred - target) ** 2
|
| 1285 |
+
loss = tf.reduce_mean(loss, axis=-1) # [batch_size, num_patches], mean loss per patch
|
| 1286 |
+
|
| 1287 |
+
loss = tf.reduce_sum(loss * mask) / tf.reduce_sum(mask) # mean loss on removed patches
|
| 1288 |
+
loss = tf.reshape(loss, (1,))
|
| 1289 |
+
return loss
|
| 1290 |
+
|
| 1291 |
+
@unpack_inputs
|
| 1292 |
+
@add_start_docstrings_to_model_forward(VIT_MAE_INPUTS_DOCSTRING)
|
| 1293 |
+
@replace_return_docstrings(output_type=TFViTMAEForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
|
| 1294 |
+
def call(
|
| 1295 |
+
self,
|
| 1296 |
+
pixel_values: TFModelInputType | None = None,
|
| 1297 |
+
noise: Optional[tf.Tensor] = None,
|
| 1298 |
+
head_mask: np.ndarray | tf.Tensor | None = None,
|
| 1299 |
+
output_attentions: Optional[bool] = None,
|
| 1300 |
+
output_hidden_states: Optional[bool] = None,
|
| 1301 |
+
return_dict: Optional[bool] = None,
|
| 1302 |
+
training: bool = False,
|
| 1303 |
+
interpolate_pos_encoding: bool = False,
|
| 1304 |
+
) -> Union[TFViTMAEForPreTrainingOutput, Tuple[tf.Tensor]]:
|
| 1305 |
+
r"""
|
| 1306 |
+
Returns:
|
| 1307 |
+
|
| 1308 |
+
Examples:
|
| 1309 |
+
|
| 1310 |
+
```python
|
| 1311 |
+
>>> from transformers import AutoImageProcessor, TFViTMAEForPreTraining
|
| 1312 |
+
>>> from PIL import Image
|
| 1313 |
+
>>> import requests
|
| 1314 |
+
|
| 1315 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 1316 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 1317 |
+
|
| 1318 |
+
>>> image_processor = AutoImageProcessor.from_pretrained("facebook/vit-mae-base")
|
| 1319 |
+
>>> model = TFViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base")
|
| 1320 |
+
|
| 1321 |
+
>>> inputs = image_processor(images=image, return_tensors="pt")
|
| 1322 |
+
>>> outputs = model(**inputs)
|
| 1323 |
+
>>> loss = outputs.loss
|
| 1324 |
+
>>> mask = outputs.mask
|
| 1325 |
+
>>> ids_restore = outputs.ids_restore
|
| 1326 |
+
```"""
|
| 1327 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1328 |
+
|
| 1329 |
+
outputs = self.vit(
|
| 1330 |
+
pixel_values=pixel_values,
|
| 1331 |
+
noise=noise,
|
| 1332 |
+
head_mask=head_mask,
|
| 1333 |
+
output_attentions=output_attentions,
|
| 1334 |
+
output_hidden_states=output_hidden_states,
|
| 1335 |
+
return_dict=return_dict,
|
| 1336 |
+
training=training,
|
| 1337 |
+
interpolate_pos_encoding=interpolate_pos_encoding,
|
| 1338 |
+
)
|
| 1339 |
+
|
| 1340 |
+
latent = outputs.last_hidden_state
|
| 1341 |
+
ids_restore = outputs.ids_restore
|
| 1342 |
+
mask = outputs.mask
|
| 1343 |
+
|
| 1344 |
+
# [batch_size, num_patches, patch_size**2*3]
|
| 1345 |
+
decoder_outputs = self.decoder(latent, ids_restore, interpolate_pos_encoding=interpolate_pos_encoding)
|
| 1346 |
+
logits = decoder_outputs.logits
|
| 1347 |
+
|
| 1348 |
+
loss = self.forward_loss(pixel_values, logits, mask, interpolate_pos_encoding=interpolate_pos_encoding)
|
| 1349 |
+
|
| 1350 |
+
if not return_dict:
|
| 1351 |
+
output = (logits, mask, ids_restore) + outputs[2:]
|
| 1352 |
+
return ((loss,) + output) if loss is not None else output
|
| 1353 |
+
|
| 1354 |
+
return TFViTMAEForPreTrainingOutput(
|
| 1355 |
+
loss=loss,
|
| 1356 |
+
logits=logits,
|
| 1357 |
+
mask=mask,
|
| 1358 |
+
ids_restore=ids_restore,
|
| 1359 |
+
hidden_states=outputs.hidden_states,
|
| 1360 |
+
attentions=outputs.attentions,
|
| 1361 |
+
)
|
| 1362 |
+
|
| 1363 |
+
def build(self, input_shape=None):
|
| 1364 |
+
if self.built:
|
| 1365 |
+
return
|
| 1366 |
+
self.built = True
|
| 1367 |
+
if getattr(self, "vit", None) is not None:
|
| 1368 |
+
with tf.name_scope(self.vit.name):
|
| 1369 |
+
self.vit.build(None)
|
| 1370 |
+
if getattr(self, "decoder", None) is not None:
|
| 1371 |
+
with tf.name_scope(self.decoder.name):
|
| 1372 |
+
self.decoder.build(None)
|
| 1373 |
+
|
| 1374 |
+
|
| 1375 |
+
__all__ = ["TFViTMAEForPreTraining", "TFViTMAEModel", "TFViTMAEPreTrainedModel"]
|
docs/transformers/build/lib/transformers/models/vit_mae/modeling_vit_mae.py
ADDED
|
@@ -0,0 +1,1163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 Facebook AI and The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""PyTorch ViT MAE (masked autoencoder) model."""
|
| 16 |
+
|
| 17 |
+
import collections.abc
|
| 18 |
+
from copy import deepcopy
|
| 19 |
+
from dataclasses import dataclass
|
| 20 |
+
from typing import Callable, Optional, Set, Tuple, Union
|
| 21 |
+
|
| 22 |
+
import numpy as np
|
| 23 |
+
import torch
|
| 24 |
+
import torch.utils.checkpoint
|
| 25 |
+
from torch import nn
|
| 26 |
+
|
| 27 |
+
from ...activations import ACT2FN
|
| 28 |
+
from ...modeling_outputs import BaseModelOutput
|
| 29 |
+
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 30 |
+
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
| 31 |
+
from ...utils import (
|
| 32 |
+
ModelOutput,
|
| 33 |
+
add_start_docstrings,
|
| 34 |
+
add_start_docstrings_to_model_forward,
|
| 35 |
+
logging,
|
| 36 |
+
replace_return_docstrings,
|
| 37 |
+
torch_int,
|
| 38 |
+
)
|
| 39 |
+
from .configuration_vit_mae import ViTMAEConfig
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
logger = logging.get_logger(__name__)
|
| 43 |
+
|
| 44 |
+
_CONFIG_FOR_DOC = "ViTMAEConfig"
|
| 45 |
+
_CHECKPOINT_FOR_DOC = "facebook/vit-mae-base"
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@dataclass
|
| 49 |
+
class ViTMAEModelOutput(ModelOutput):
|
| 50 |
+
"""
|
| 51 |
+
Class for ViTMAEModel's outputs, with potential hidden states and attentions.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
| 55 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
| 56 |
+
mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
|
| 57 |
+
Tensor indicating which patches are masked (1) and which are not (0).
|
| 58 |
+
ids_restore (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 59 |
+
Tensor containing the original index of the (shuffled) masked patches.
|
| 60 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 61 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
|
| 62 |
+
shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
|
| 63 |
+
plus the initial embedding outputs.
|
| 64 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 65 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 66 |
+
sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
|
| 67 |
+
the self-attention heads.
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
last_hidden_state: Optional[torch.FloatTensor] = None
|
| 71 |
+
mask: Optional[torch.LongTensor] = None
|
| 72 |
+
ids_restore: Optional[torch.LongTensor] = None
|
| 73 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 74 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
@dataclass
|
| 78 |
+
class ViTMAEDecoderOutput(ModelOutput):
|
| 79 |
+
"""
|
| 80 |
+
Class for ViTMAEDecoder's outputs, with potential hidden states and attentions.
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, patch_size ** 2 * num_channels)`):
|
| 84 |
+
Pixel reconstruction logits.
|
| 85 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 86 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
|
| 87 |
+
shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
|
| 88 |
+
plus the initial embedding outputs.
|
| 89 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 90 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 91 |
+
sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
|
| 92 |
+
the self-attention heads.
|
| 93 |
+
"""
|
| 94 |
+
|
| 95 |
+
logits: Optional[torch.FloatTensor] = None
|
| 96 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 97 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
@dataclass
|
| 101 |
+
class ViTMAEForPreTrainingOutput(ModelOutput):
|
| 102 |
+
"""
|
| 103 |
+
Class for ViTMAEForPreTraining's outputs, with potential hidden states and attentions.
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
loss (`torch.FloatTensor` of shape `(1,)`):
|
| 107 |
+
Pixel reconstruction loss.
|
| 108 |
+
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, patch_size ** 2 * num_channels)`):
|
| 109 |
+
Pixel reconstruction logits.
|
| 110 |
+
mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
|
| 111 |
+
Tensor indicating which patches are masked (1) and which are not (0).
|
| 112 |
+
ids_restore (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 113 |
+
Tensor containing the original index of the (shuffled) masked patches.
|
| 114 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 115 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
|
| 116 |
+
shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
|
| 117 |
+
plus the initial embedding outputs.
|
| 118 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 119 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 120 |
+
sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
|
| 121 |
+
the self-attention heads.
|
| 122 |
+
"""
|
| 123 |
+
|
| 124 |
+
loss: Optional[torch.FloatTensor] = None
|
| 125 |
+
logits: Optional[torch.FloatTensor] = None
|
| 126 |
+
mask: Optional[torch.LongTensor] = None
|
| 127 |
+
ids_restore: Optional[torch.LongTensor] = None
|
| 128 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 129 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False):
|
| 133 |
+
"""
|
| 134 |
+
Create 2D sin/cos positional embeddings.
|
| 135 |
+
|
| 136 |
+
Args:
|
| 137 |
+
embed_dim (`int`):
|
| 138 |
+
Embedding dimension.
|
| 139 |
+
grid_size (`int`):
|
| 140 |
+
The grid height and width.
|
| 141 |
+
add_cls_token (`bool`, *optional*, defaults to `False`):
|
| 142 |
+
Whether or not to add a classification (CLS) token.
|
| 143 |
+
|
| 144 |
+
Returns:
|
| 145 |
+
(`torch.FloatTensor` of shape (grid_size*grid_size, embed_dim) or (1+grid_size*grid_size, embed_dim): the
|
| 146 |
+
position embeddings (with or without classification token)
|
| 147 |
+
"""
|
| 148 |
+
grid_h = np.arange(grid_size, dtype=np.float32)
|
| 149 |
+
grid_w = np.arange(grid_size, dtype=np.float32)
|
| 150 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
| 151 |
+
grid = np.stack(grid, axis=0)
|
| 152 |
+
|
| 153 |
+
grid = grid.reshape([2, 1, grid_size, grid_size])
|
| 154 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
| 155 |
+
if add_cls_token:
|
| 156 |
+
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
|
| 157 |
+
return pos_embed
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
| 161 |
+
if embed_dim % 2 != 0:
|
| 162 |
+
raise ValueError("embed_dim must be even")
|
| 163 |
+
|
| 164 |
+
# use half of dimensions to encode grid_h
|
| 165 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
| 166 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
| 167 |
+
|
| 168 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
| 169 |
+
return emb
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
| 173 |
+
"""
|
| 174 |
+
embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
|
| 175 |
+
"""
|
| 176 |
+
if embed_dim % 2 != 0:
|
| 177 |
+
raise ValueError("embed_dim must be even")
|
| 178 |
+
|
| 179 |
+
omega = np.arange(embed_dim // 2, dtype=float)
|
| 180 |
+
omega /= embed_dim / 2.0
|
| 181 |
+
omega = 1.0 / 10000**omega # (D/2,)
|
| 182 |
+
|
| 183 |
+
pos = pos.reshape(-1) # (M,)
|
| 184 |
+
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
| 185 |
+
|
| 186 |
+
emb_sin = np.sin(out) # (M, D/2)
|
| 187 |
+
emb_cos = np.cos(out) # (M, D/2)
|
| 188 |
+
|
| 189 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
| 190 |
+
return emb
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
class ViTMAEEmbeddings(nn.Module):
|
| 194 |
+
"""
|
| 195 |
+
Construct the CLS token, position and patch embeddings.
|
| 196 |
+
|
| 197 |
+
"""
|
| 198 |
+
|
| 199 |
+
def __init__(self, config):
|
| 200 |
+
super().__init__()
|
| 201 |
+
|
| 202 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
| 203 |
+
self.patch_embeddings = ViTMAEPatchEmbeddings(config)
|
| 204 |
+
self.num_patches = self.patch_embeddings.num_patches
|
| 205 |
+
# fixed sin-cos embedding
|
| 206 |
+
self.position_embeddings = nn.Parameter(
|
| 207 |
+
torch.zeros(1, self.num_patches + 1, config.hidden_size), requires_grad=False
|
| 208 |
+
)
|
| 209 |
+
self.patch_size = config.patch_size
|
| 210 |
+
self.config = config
|
| 211 |
+
|
| 212 |
+
def initialize_weights(self):
|
| 213 |
+
# initialize (and freeze) position embeddings by sin-cos embedding
|
| 214 |
+
pos_embed = get_2d_sincos_pos_embed(
|
| 215 |
+
self.position_embeddings.shape[-1], int(self.patch_embeddings.num_patches**0.5), add_cls_token=True
|
| 216 |
+
)
|
| 217 |
+
self.position_embeddings.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
| 218 |
+
|
| 219 |
+
# initialize patch_embeddings like nn.Linear (instead of nn.Conv2d)
|
| 220 |
+
w = self.patch_embeddings.projection.weight.data
|
| 221 |
+
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
| 222 |
+
|
| 223 |
+
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
|
| 224 |
+
torch.nn.init.normal_(self.cls_token, std=self.config.initializer_range)
|
| 225 |
+
|
| 226 |
+
# Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
|
| 227 |
+
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
| 228 |
+
"""
|
| 229 |
+
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
|
| 230 |
+
images. This method is also adapted to support torch.jit tracing.
|
| 231 |
+
|
| 232 |
+
Adapted from:
|
| 233 |
+
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
|
| 234 |
+
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
|
| 235 |
+
"""
|
| 236 |
+
|
| 237 |
+
num_patches = embeddings.shape[1] - 1
|
| 238 |
+
num_positions = self.position_embeddings.shape[1] - 1
|
| 239 |
+
|
| 240 |
+
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
|
| 241 |
+
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
|
| 242 |
+
return self.position_embeddings
|
| 243 |
+
|
| 244 |
+
class_pos_embed = self.position_embeddings[:, :1]
|
| 245 |
+
patch_pos_embed = self.position_embeddings[:, 1:]
|
| 246 |
+
|
| 247 |
+
dim = embeddings.shape[-1]
|
| 248 |
+
|
| 249 |
+
new_height = height // self.patch_size
|
| 250 |
+
new_width = width // self.patch_size
|
| 251 |
+
|
| 252 |
+
sqrt_num_positions = torch_int(num_positions**0.5)
|
| 253 |
+
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
|
| 254 |
+
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
| 255 |
+
|
| 256 |
+
patch_pos_embed = nn.functional.interpolate(
|
| 257 |
+
patch_pos_embed,
|
| 258 |
+
size=(new_height, new_width),
|
| 259 |
+
mode="bicubic",
|
| 260 |
+
align_corners=False,
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
| 264 |
+
|
| 265 |
+
return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
|
| 266 |
+
|
| 267 |
+
def random_masking(self, sequence, noise=None):
|
| 268 |
+
"""
|
| 269 |
+
Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random
|
| 270 |
+
noise.
|
| 271 |
+
|
| 272 |
+
Args:
|
| 273 |
+
sequence (`torch.LongTensor` of shape `(batch_size, sequence_length, dim)`)
|
| 274 |
+
noise (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*) which is
|
| 275 |
+
mainly used for testing purposes to control randomness and maintain the reproducibility
|
| 276 |
+
"""
|
| 277 |
+
batch_size, seq_length, dim = sequence.shape
|
| 278 |
+
len_keep = int(seq_length * (1 - self.config.mask_ratio))
|
| 279 |
+
|
| 280 |
+
if noise is None:
|
| 281 |
+
noise = torch.rand(batch_size, seq_length, device=sequence.device) # noise in [0, 1]
|
| 282 |
+
|
| 283 |
+
# sort noise for each sample
|
| 284 |
+
ids_shuffle = torch.argsort(noise, dim=1).to(sequence.device) # ascend: small is keep, large is remove
|
| 285 |
+
ids_restore = torch.argsort(ids_shuffle, dim=1).to(sequence.device)
|
| 286 |
+
|
| 287 |
+
# keep the first subset
|
| 288 |
+
ids_keep = ids_shuffle[:, :len_keep]
|
| 289 |
+
sequence_unmasked = torch.gather(sequence, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, dim))
|
| 290 |
+
|
| 291 |
+
# generate the binary mask: 0 is keep, 1 is remove
|
| 292 |
+
mask = torch.ones([batch_size, seq_length], device=sequence.device)
|
| 293 |
+
mask[:, :len_keep] = 0
|
| 294 |
+
# unshuffle to get the binary mask
|
| 295 |
+
mask = torch.gather(mask, dim=1, index=ids_restore)
|
| 296 |
+
|
| 297 |
+
return sequence_unmasked, mask, ids_restore
|
| 298 |
+
|
| 299 |
+
def forward(self, pixel_values, noise=None, interpolate_pos_encoding: bool = False):
|
| 300 |
+
batch_size, num_channels, height, width = pixel_values.shape
|
| 301 |
+
embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
|
| 302 |
+
if interpolate_pos_encoding:
|
| 303 |
+
position_embeddings = self.interpolate_pos_encoding(embeddings, height, width)
|
| 304 |
+
else:
|
| 305 |
+
position_embeddings = self.position_embeddings
|
| 306 |
+
|
| 307 |
+
# add position embeddings w/o cls token
|
| 308 |
+
embeddings = embeddings + position_embeddings[:, 1:, :]
|
| 309 |
+
|
| 310 |
+
# masking: length -> length * config.mask_ratio
|
| 311 |
+
embeddings, mask, ids_restore = self.random_masking(embeddings, noise)
|
| 312 |
+
|
| 313 |
+
# append cls token
|
| 314 |
+
cls_token = self.cls_token + position_embeddings[:, :1, :]
|
| 315 |
+
cls_tokens = cls_token.expand(embeddings.shape[0], -1, -1)
|
| 316 |
+
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
|
| 317 |
+
|
| 318 |
+
return embeddings, mask, ids_restore
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
class ViTMAEPatchEmbeddings(nn.Module):
|
| 322 |
+
"""
|
| 323 |
+
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
|
| 324 |
+
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
|
| 325 |
+
Transformer.
|
| 326 |
+
"""
|
| 327 |
+
|
| 328 |
+
def __init__(self, config):
|
| 329 |
+
super().__init__()
|
| 330 |
+
image_size, patch_size = config.image_size, config.patch_size
|
| 331 |
+
num_channels, hidden_size = config.num_channels, config.hidden_size
|
| 332 |
+
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
|
| 333 |
+
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
|
| 334 |
+
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
| 335 |
+
self.image_size = image_size
|
| 336 |
+
self.patch_size = patch_size
|
| 337 |
+
self.num_channels = num_channels
|
| 338 |
+
self.num_patches = num_patches
|
| 339 |
+
|
| 340 |
+
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
|
| 341 |
+
|
| 342 |
+
def forward(self, pixel_values, interpolate_pos_encoding: bool = False):
|
| 343 |
+
batch_size, num_channels, height, width = pixel_values.shape
|
| 344 |
+
if num_channels != self.num_channels:
|
| 345 |
+
raise ValueError(
|
| 346 |
+
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]):
|
| 350 |
+
raise ValueError(
|
| 351 |
+
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
|
| 352 |
+
)
|
| 353 |
+
x = self.projection(pixel_values).flatten(2).transpose(1, 2)
|
| 354 |
+
return x
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
# Copied from transformers.models.vit.modeling_vit.eager_attention_forward
|
| 358 |
+
def eager_attention_forward(
|
| 359 |
+
module: nn.Module,
|
| 360 |
+
query: torch.Tensor,
|
| 361 |
+
key: torch.Tensor,
|
| 362 |
+
value: torch.Tensor,
|
| 363 |
+
attention_mask: Optional[torch.Tensor],
|
| 364 |
+
scaling: float,
|
| 365 |
+
dropout: float = 0.0,
|
| 366 |
+
**kwargs,
|
| 367 |
+
):
|
| 368 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
| 369 |
+
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
|
| 370 |
+
|
| 371 |
+
# Normalize the attention scores to probabilities.
|
| 372 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
| 373 |
+
|
| 374 |
+
# This is actually dropping out entire tokens to attend to, which might
|
| 375 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
| 376 |
+
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
| 377 |
+
|
| 378 |
+
# Mask heads if we want to
|
| 379 |
+
if attention_mask is not None:
|
| 380 |
+
attn_weights = attn_weights * attention_mask
|
| 381 |
+
|
| 382 |
+
attn_output = torch.matmul(attn_weights, value)
|
| 383 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 384 |
+
|
| 385 |
+
return attn_output, attn_weights
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention ViT->ViTMAE
|
| 389 |
+
class ViTMAESelfAttention(nn.Module):
|
| 390 |
+
def __init__(self, config: ViTMAEConfig) -> None:
|
| 391 |
+
super().__init__()
|
| 392 |
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
| 393 |
+
raise ValueError(
|
| 394 |
+
f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
|
| 395 |
+
f"heads {config.num_attention_heads}."
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
self.config = config
|
| 399 |
+
self.num_attention_heads = config.num_attention_heads
|
| 400 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
| 401 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
| 402 |
+
self.dropout_prob = config.attention_probs_dropout_prob
|
| 403 |
+
self.scaling = self.attention_head_size**-0.5
|
| 404 |
+
self.is_causal = False
|
| 405 |
+
|
| 406 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
| 407 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
| 408 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
| 409 |
+
|
| 410 |
+
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
| 411 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
| 412 |
+
x = x.view(new_x_shape)
|
| 413 |
+
return x.permute(0, 2, 1, 3)
|
| 414 |
+
|
| 415 |
+
def forward(
|
| 416 |
+
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
|
| 417 |
+
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
| 418 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
| 419 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
| 420 |
+
query_layer = self.transpose_for_scores(self.query(hidden_states))
|
| 421 |
+
|
| 422 |
+
attention_interface: Callable = eager_attention_forward
|
| 423 |
+
if self.config._attn_implementation != "eager":
|
| 424 |
+
if self.config._attn_implementation == "sdpa" and output_attentions:
|
| 425 |
+
logger.warning_once(
|
| 426 |
+
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
| 427 |
+
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
| 428 |
+
)
|
| 429 |
+
else:
|
| 430 |
+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
| 431 |
+
|
| 432 |
+
context_layer, attention_probs = attention_interface(
|
| 433 |
+
self,
|
| 434 |
+
query_layer,
|
| 435 |
+
key_layer,
|
| 436 |
+
value_layer,
|
| 437 |
+
head_mask,
|
| 438 |
+
is_causal=self.is_causal,
|
| 439 |
+
scaling=self.scaling,
|
| 440 |
+
dropout=0.0 if not self.training else self.dropout_prob,
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
| 444 |
+
context_layer = context_layer.reshape(new_context_layer_shape)
|
| 445 |
+
|
| 446 |
+
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
| 447 |
+
|
| 448 |
+
return outputs
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->ViTMAE
|
| 452 |
+
class ViTMAESelfOutput(nn.Module):
|
| 453 |
+
"""
|
| 454 |
+
The residual connection is defined in ViTMAELayer instead of here (as is the case with other models), due to the
|
| 455 |
+
layernorm applied before each block.
|
| 456 |
+
"""
|
| 457 |
+
|
| 458 |
+
def __init__(self, config: ViTMAEConfig) -> None:
|
| 459 |
+
super().__init__()
|
| 460 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 461 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 462 |
+
|
| 463 |
+
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
| 464 |
+
hidden_states = self.dense(hidden_states)
|
| 465 |
+
hidden_states = self.dropout(hidden_states)
|
| 466 |
+
|
| 467 |
+
return hidden_states
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
# Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->ViTMAE
|
| 471 |
+
class ViTMAEAttention(nn.Module):
|
| 472 |
+
def __init__(self, config: ViTMAEConfig) -> None:
|
| 473 |
+
super().__init__()
|
| 474 |
+
self.attention = ViTMAESelfAttention(config)
|
| 475 |
+
self.output = ViTMAESelfOutput(config)
|
| 476 |
+
self.pruned_heads = set()
|
| 477 |
+
|
| 478 |
+
def prune_heads(self, heads: Set[int]) -> None:
|
| 479 |
+
if len(heads) == 0:
|
| 480 |
+
return
|
| 481 |
+
heads, index = find_pruneable_heads_and_indices(
|
| 482 |
+
heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
|
| 483 |
+
)
|
| 484 |
+
|
| 485 |
+
# Prune linear layers
|
| 486 |
+
self.attention.query = prune_linear_layer(self.attention.query, index)
|
| 487 |
+
self.attention.key = prune_linear_layer(self.attention.key, index)
|
| 488 |
+
self.attention.value = prune_linear_layer(self.attention.value, index)
|
| 489 |
+
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
| 490 |
+
|
| 491 |
+
# Update hyper params and store pruned heads
|
| 492 |
+
self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
|
| 493 |
+
self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
|
| 494 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
| 495 |
+
|
| 496 |
+
def forward(
|
| 497 |
+
self,
|
| 498 |
+
hidden_states: torch.Tensor,
|
| 499 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 500 |
+
output_attentions: bool = False,
|
| 501 |
+
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
| 502 |
+
self_outputs = self.attention(hidden_states, head_mask, output_attentions)
|
| 503 |
+
|
| 504 |
+
attention_output = self.output(self_outputs[0], hidden_states)
|
| 505 |
+
|
| 506 |
+
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
| 507 |
+
return outputs
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
# Copied from transformers.models.vit.modeling_vit.ViTIntermediate ViT->ViTMAE
|
| 511 |
+
class ViTMAEIntermediate(nn.Module):
|
| 512 |
+
def __init__(self, config: ViTMAEConfig) -> None:
|
| 513 |
+
super().__init__()
|
| 514 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
| 515 |
+
if isinstance(config.hidden_act, str):
|
| 516 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
| 517 |
+
else:
|
| 518 |
+
self.intermediate_act_fn = config.hidden_act
|
| 519 |
+
|
| 520 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 521 |
+
hidden_states = self.dense(hidden_states)
|
| 522 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
| 523 |
+
|
| 524 |
+
return hidden_states
|
| 525 |
+
|
| 526 |
+
|
| 527 |
+
# Copied from transformers.models.vit.modeling_vit.ViTOutput ViT->ViTMAE
|
| 528 |
+
class ViTMAEOutput(nn.Module):
|
| 529 |
+
def __init__(self, config: ViTMAEConfig) -> None:
|
| 530 |
+
super().__init__()
|
| 531 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
| 532 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 533 |
+
|
| 534 |
+
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
| 535 |
+
hidden_states = self.dense(hidden_states)
|
| 536 |
+
hidden_states = self.dropout(hidden_states)
|
| 537 |
+
|
| 538 |
+
hidden_states = hidden_states + input_tensor
|
| 539 |
+
|
| 540 |
+
return hidden_states
|
| 541 |
+
|
| 542 |
+
|
| 543 |
+
# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->ViTMAE,VIT->VITMAE
|
| 544 |
+
class ViTMAELayer(nn.Module):
|
| 545 |
+
"""This corresponds to the Block class in the timm implementation."""
|
| 546 |
+
|
| 547 |
+
def __init__(self, config: ViTMAEConfig) -> None:
|
| 548 |
+
super().__init__()
|
| 549 |
+
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
| 550 |
+
self.seq_len_dim = 1
|
| 551 |
+
self.attention = ViTMAEAttention(config)
|
| 552 |
+
self.intermediate = ViTMAEIntermediate(config)
|
| 553 |
+
self.output = ViTMAEOutput(config)
|
| 554 |
+
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 555 |
+
self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 556 |
+
|
| 557 |
+
def forward(
|
| 558 |
+
self,
|
| 559 |
+
hidden_states: torch.Tensor,
|
| 560 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 561 |
+
output_attentions: bool = False,
|
| 562 |
+
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
| 563 |
+
self_attention_outputs = self.attention(
|
| 564 |
+
self.layernorm_before(hidden_states), # in ViTMAE, layernorm is applied before self-attention
|
| 565 |
+
head_mask,
|
| 566 |
+
output_attentions=output_attentions,
|
| 567 |
+
)
|
| 568 |
+
attention_output = self_attention_outputs[0]
|
| 569 |
+
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
| 570 |
+
|
| 571 |
+
# first residual connection
|
| 572 |
+
hidden_states = attention_output + hidden_states
|
| 573 |
+
|
| 574 |
+
# in ViTMAE, layernorm is also applied after self-attention
|
| 575 |
+
layer_output = self.layernorm_after(hidden_states)
|
| 576 |
+
layer_output = self.intermediate(layer_output)
|
| 577 |
+
|
| 578 |
+
# second residual connection is done here
|
| 579 |
+
layer_output = self.output(layer_output, hidden_states)
|
| 580 |
+
|
| 581 |
+
outputs = (layer_output,) + outputs
|
| 582 |
+
|
| 583 |
+
return outputs
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
# Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->ViTMAE
|
| 587 |
+
class ViTMAEEncoder(nn.Module):
|
| 588 |
+
def __init__(self, config: ViTMAEConfig) -> None:
|
| 589 |
+
super().__init__()
|
| 590 |
+
self.config = config
|
| 591 |
+
self.layer = nn.ModuleList([ViTMAELayer(config) for _ in range(config.num_hidden_layers)])
|
| 592 |
+
self.gradient_checkpointing = False
|
| 593 |
+
|
| 594 |
+
def forward(
|
| 595 |
+
self,
|
| 596 |
+
hidden_states: torch.Tensor,
|
| 597 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 598 |
+
output_attentions: bool = False,
|
| 599 |
+
output_hidden_states: bool = False,
|
| 600 |
+
return_dict: bool = True,
|
| 601 |
+
) -> Union[tuple, BaseModelOutput]:
|
| 602 |
+
all_hidden_states = () if output_hidden_states else None
|
| 603 |
+
all_self_attentions = () if output_attentions else None
|
| 604 |
+
|
| 605 |
+
for i, layer_module in enumerate(self.layer):
|
| 606 |
+
if output_hidden_states:
|
| 607 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 608 |
+
|
| 609 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
| 610 |
+
|
| 611 |
+
if self.gradient_checkpointing and self.training:
|
| 612 |
+
layer_outputs = self._gradient_checkpointing_func(
|
| 613 |
+
layer_module.__call__,
|
| 614 |
+
hidden_states,
|
| 615 |
+
layer_head_mask,
|
| 616 |
+
output_attentions,
|
| 617 |
+
)
|
| 618 |
+
else:
|
| 619 |
+
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
|
| 620 |
+
|
| 621 |
+
hidden_states = layer_outputs[0]
|
| 622 |
+
|
| 623 |
+
if output_attentions:
|
| 624 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
| 625 |
+
|
| 626 |
+
if output_hidden_states:
|
| 627 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 628 |
+
|
| 629 |
+
if not return_dict:
|
| 630 |
+
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
|
| 631 |
+
return BaseModelOutput(
|
| 632 |
+
last_hidden_state=hidden_states,
|
| 633 |
+
hidden_states=all_hidden_states,
|
| 634 |
+
attentions=all_self_attentions,
|
| 635 |
+
)
|
| 636 |
+
|
| 637 |
+
|
| 638 |
+
class ViTMAEPreTrainedModel(PreTrainedModel):
|
| 639 |
+
"""
|
| 640 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 641 |
+
models.
|
| 642 |
+
"""
|
| 643 |
+
|
| 644 |
+
config_class = ViTMAEConfig
|
| 645 |
+
base_model_prefix = "vit"
|
| 646 |
+
main_input_name = "pixel_values"
|
| 647 |
+
supports_gradient_checkpointing = True
|
| 648 |
+
_supports_sdpa = True
|
| 649 |
+
_supports_flash_attn_2 = True
|
| 650 |
+
|
| 651 |
+
def _init_weights(self, module):
|
| 652 |
+
"""Initialize the weights"""
|
| 653 |
+
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
| 654 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
| 655 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
| 656 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 657 |
+
if module.bias is not None:
|
| 658 |
+
module.bias.data.zero_()
|
| 659 |
+
elif isinstance(module, nn.LayerNorm):
|
| 660 |
+
module.bias.data.zero_()
|
| 661 |
+
module.weight.data.fill_(1.0)
|
| 662 |
+
elif isinstance(module, ViTMAEEmbeddings):
|
| 663 |
+
module.initialize_weights()
|
| 664 |
+
elif isinstance(module, ViTMAEDecoder):
|
| 665 |
+
module.mask_token.data.zero_()
|
| 666 |
+
module.decoder_pos_embed.data.zero_()
|
| 667 |
+
|
| 668 |
+
|
| 669 |
+
VIT_MAE_START_DOCSTRING = r"""
|
| 670 |
+
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
|
| 671 |
+
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
| 672 |
+
behavior.
|
| 673 |
+
|
| 674 |
+
Parameters:
|
| 675 |
+
config ([`ViTMAEConfig`]): Model configuration class with all the parameters of the model.
|
| 676 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 677 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 678 |
+
"""
|
| 679 |
+
|
| 680 |
+
VIT_MAE_INPUTS_DOCSTRING = r"""
|
| 681 |
+
Args:
|
| 682 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
| 683 |
+
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`]
|
| 684 |
+
for details.
|
| 685 |
+
|
| 686 |
+
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
| 687 |
+
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
| 688 |
+
|
| 689 |
+
- 1 indicates the head is **not masked**,
|
| 690 |
+
- 0 indicates the head is **masked**.
|
| 691 |
+
|
| 692 |
+
output_attentions (`bool`, *optional*):
|
| 693 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 694 |
+
tensors for more detail.
|
| 695 |
+
output_hidden_states (`bool`, *optional*):
|
| 696 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 697 |
+
more detail.
|
| 698 |
+
return_dict (`bool`, *optional*):
|
| 699 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 700 |
+
interpolate_pos_encoding (`bool`, *optional*, default `False`):
|
| 701 |
+
Whether to interpolate the pre-trained position encodings. This is mainly used to use the model on higher
|
| 702 |
+
resolution images.
|
| 703 |
+
"""
|
| 704 |
+
|
| 705 |
+
|
| 706 |
+
@add_start_docstrings(
|
| 707 |
+
"The bare ViTMAE Model transformer outputting raw hidden-states without any specific head on top.",
|
| 708 |
+
VIT_MAE_START_DOCSTRING,
|
| 709 |
+
)
|
| 710 |
+
class ViTMAEModel(ViTMAEPreTrainedModel):
|
| 711 |
+
def __init__(self, config):
|
| 712 |
+
super().__init__(config)
|
| 713 |
+
self.config = config
|
| 714 |
+
|
| 715 |
+
self.embeddings = ViTMAEEmbeddings(config)
|
| 716 |
+
self.encoder = ViTMAEEncoder(config)
|
| 717 |
+
|
| 718 |
+
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 719 |
+
|
| 720 |
+
# Initialize weights and apply final processing
|
| 721 |
+
self.post_init()
|
| 722 |
+
|
| 723 |
+
def get_input_embeddings(self):
|
| 724 |
+
return self.embeddings.patch_embeddings
|
| 725 |
+
|
| 726 |
+
def _prune_heads(self, heads_to_prune):
|
| 727 |
+
"""
|
| 728 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
| 729 |
+
class PreTrainedModel
|
| 730 |
+
"""
|
| 731 |
+
for layer, heads in heads_to_prune.items():
|
| 732 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
| 733 |
+
|
| 734 |
+
@add_start_docstrings_to_model_forward(VIT_MAE_INPUTS_DOCSTRING)
|
| 735 |
+
@replace_return_docstrings(output_type=ViTMAEModelOutput, config_class=_CONFIG_FOR_DOC)
|
| 736 |
+
def forward(
|
| 737 |
+
self,
|
| 738 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 739 |
+
noise: Optional[torch.FloatTensor] = None,
|
| 740 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 741 |
+
output_attentions: Optional[bool] = None,
|
| 742 |
+
output_hidden_states: Optional[bool] = None,
|
| 743 |
+
return_dict: Optional[bool] = None,
|
| 744 |
+
interpolate_pos_encoding: bool = False,
|
| 745 |
+
) -> Union[Tuple, ViTMAEModelOutput]:
|
| 746 |
+
r"""
|
| 747 |
+
Returns:
|
| 748 |
+
|
| 749 |
+
Examples:
|
| 750 |
+
|
| 751 |
+
```python
|
| 752 |
+
>>> from transformers import AutoImageProcessor, ViTMAEModel
|
| 753 |
+
>>> from PIL import Image
|
| 754 |
+
>>> import requests
|
| 755 |
+
|
| 756 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 757 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 758 |
+
|
| 759 |
+
>>> image_processor = AutoImageProcessor.from_pretrained("facebook/vit-mae-base")
|
| 760 |
+
>>> model = ViTMAEModel.from_pretrained("facebook/vit-mae-base")
|
| 761 |
+
|
| 762 |
+
>>> inputs = image_processor(images=image, return_tensors="pt")
|
| 763 |
+
>>> outputs = model(**inputs)
|
| 764 |
+
>>> last_hidden_states = outputs.last_hidden_state
|
| 765 |
+
```"""
|
| 766 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 767 |
+
output_hidden_states = (
|
| 768 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 769 |
+
)
|
| 770 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 771 |
+
|
| 772 |
+
if pixel_values is None:
|
| 773 |
+
raise ValueError("You have to specify pixel_values")
|
| 774 |
+
|
| 775 |
+
# Prepare head mask if needed
|
| 776 |
+
# 1.0 in head_mask indicate we keep the head
|
| 777 |
+
# attention_probs has shape bsz x n_heads x N x N
|
| 778 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
| 779 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
| 780 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
| 781 |
+
|
| 782 |
+
embedding_output, mask, ids_restore = self.embeddings(
|
| 783 |
+
pixel_values, noise=noise, interpolate_pos_encoding=interpolate_pos_encoding
|
| 784 |
+
)
|
| 785 |
+
|
| 786 |
+
encoder_outputs = self.encoder(
|
| 787 |
+
embedding_output,
|
| 788 |
+
head_mask=head_mask,
|
| 789 |
+
output_attentions=output_attentions,
|
| 790 |
+
output_hidden_states=output_hidden_states,
|
| 791 |
+
return_dict=return_dict,
|
| 792 |
+
)
|
| 793 |
+
sequence_output = encoder_outputs[0]
|
| 794 |
+
sequence_output = self.layernorm(sequence_output)
|
| 795 |
+
|
| 796 |
+
if not return_dict:
|
| 797 |
+
return (sequence_output, mask, ids_restore) + encoder_outputs[1:]
|
| 798 |
+
|
| 799 |
+
return ViTMAEModelOutput(
|
| 800 |
+
last_hidden_state=sequence_output,
|
| 801 |
+
mask=mask,
|
| 802 |
+
ids_restore=ids_restore,
|
| 803 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 804 |
+
attentions=encoder_outputs.attentions,
|
| 805 |
+
)
|
| 806 |
+
|
| 807 |
+
|
| 808 |
+
class ViTMAEDecoder(nn.Module):
|
| 809 |
+
def __init__(self, config, num_patches):
|
| 810 |
+
super().__init__()
|
| 811 |
+
self.decoder_embed = nn.Linear(config.hidden_size, config.decoder_hidden_size, bias=True)
|
| 812 |
+
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.decoder_hidden_size))
|
| 813 |
+
self.decoder_pos_embed = nn.Parameter(
|
| 814 |
+
torch.zeros(1, num_patches + 1, config.decoder_hidden_size), requires_grad=False
|
| 815 |
+
) # fixed sin-cos embedding
|
| 816 |
+
|
| 817 |
+
decoder_config = deepcopy(config)
|
| 818 |
+
decoder_config.hidden_size = config.decoder_hidden_size
|
| 819 |
+
decoder_config.num_hidden_layers = config.decoder_num_hidden_layers
|
| 820 |
+
decoder_config.num_attention_heads = config.decoder_num_attention_heads
|
| 821 |
+
decoder_config.intermediate_size = config.decoder_intermediate_size
|
| 822 |
+
self.decoder_layers = nn.ModuleList(
|
| 823 |
+
[ViTMAELayer(decoder_config) for _ in range(config.decoder_num_hidden_layers)]
|
| 824 |
+
)
|
| 825 |
+
|
| 826 |
+
self.decoder_norm = nn.LayerNorm(config.decoder_hidden_size, eps=config.layer_norm_eps)
|
| 827 |
+
self.decoder_pred = nn.Linear(
|
| 828 |
+
config.decoder_hidden_size, config.patch_size**2 * config.num_channels, bias=True
|
| 829 |
+
) # encoder to decoder
|
| 830 |
+
self.gradient_checkpointing = False
|
| 831 |
+
self.config = config
|
| 832 |
+
self.initialize_weights(num_patches)
|
| 833 |
+
|
| 834 |
+
def interpolate_pos_encoding(self, embeddings: torch.Tensor) -> torch.Tensor:
|
| 835 |
+
"""
|
| 836 |
+
This method is a modified version of the interpolation function for ViT-mae model at the decoder, that
|
| 837 |
+
allows to interpolate the pre-trained decoder position encodings, to be able to use the model on higher
|
| 838 |
+
resolution images.
|
| 839 |
+
|
| 840 |
+
Adapted from:
|
| 841 |
+
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
| 842 |
+
"""
|
| 843 |
+
|
| 844 |
+
# -1 removes the class dimension since we later append it without interpolation
|
| 845 |
+
embeddings_positions = embeddings.shape[1] - 1
|
| 846 |
+
|
| 847 |
+
# Separation of class token and patch tokens
|
| 848 |
+
class_pos_embed = self.decoder_pos_embed[:, :1]
|
| 849 |
+
patch_pos_embed = self.decoder_pos_embed[:, 1:]
|
| 850 |
+
|
| 851 |
+
# To retain the final 3d tensor with the required dimensions
|
| 852 |
+
dim = self.decoder_pos_embed.shape[-1]
|
| 853 |
+
|
| 854 |
+
# Increasing a dimension to enable bicubic interpolation
|
| 855 |
+
patch_pos_embed = patch_pos_embed.reshape(1, 1, -1, dim)
|
| 856 |
+
|
| 857 |
+
# permute to bring the dimension to be interpolated, to the last
|
| 858 |
+
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
| 859 |
+
|
| 860 |
+
# Interpolating the decoder position embeddings shape wrt embeddings shape i.e (x).
|
| 861 |
+
# we keep the second last dimension constant
|
| 862 |
+
patch_pos_embed = nn.functional.interpolate(
|
| 863 |
+
patch_pos_embed,
|
| 864 |
+
size=(patch_pos_embed.shape[-2], embeddings_positions),
|
| 865 |
+
mode="bicubic",
|
| 866 |
+
align_corners=False,
|
| 867 |
+
)
|
| 868 |
+
|
| 869 |
+
# Converting back to the original shape
|
| 870 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
| 871 |
+
# Adding the class token back
|
| 872 |
+
return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
|
| 873 |
+
|
| 874 |
+
def initialize_weights(self, num_patches):
|
| 875 |
+
# initialize (and freeze) position embeddings by sin-cos embedding
|
| 876 |
+
decoder_pos_embed = get_2d_sincos_pos_embed(
|
| 877 |
+
self.decoder_pos_embed.shape[-1], int(num_patches**0.5), add_cls_token=True
|
| 878 |
+
)
|
| 879 |
+
self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
|
| 880 |
+
|
| 881 |
+
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
|
| 882 |
+
torch.nn.init.normal_(self.mask_token, std=self.config.initializer_range)
|
| 883 |
+
|
| 884 |
+
def forward(
|
| 885 |
+
self,
|
| 886 |
+
hidden_states,
|
| 887 |
+
ids_restore,
|
| 888 |
+
output_attentions=False,
|
| 889 |
+
output_hidden_states=False,
|
| 890 |
+
return_dict=True,
|
| 891 |
+
interpolate_pos_encoding: bool = False,
|
| 892 |
+
):
|
| 893 |
+
# embed tokens
|
| 894 |
+
x = self.decoder_embed(hidden_states)
|
| 895 |
+
|
| 896 |
+
# append mask tokens to sequence
|
| 897 |
+
mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
|
| 898 |
+
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
|
| 899 |
+
# unshuffle
|
| 900 |
+
x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]).to(x_.device))
|
| 901 |
+
x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
|
| 902 |
+
# add pos embed
|
| 903 |
+
if interpolate_pos_encoding:
|
| 904 |
+
decoder_pos_embed = self.interpolate_pos_encoding(x)
|
| 905 |
+
else:
|
| 906 |
+
decoder_pos_embed = self.decoder_pos_embed
|
| 907 |
+
hidden_states = x + decoder_pos_embed
|
| 908 |
+
|
| 909 |
+
# apply Transformer layers (blocks)
|
| 910 |
+
all_hidden_states = () if output_hidden_states else None
|
| 911 |
+
all_self_attentions = () if output_attentions else None
|
| 912 |
+
for i, layer_module in enumerate(self.decoder_layers):
|
| 913 |
+
if output_hidden_states:
|
| 914 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 915 |
+
|
| 916 |
+
if self.gradient_checkpointing and self.training:
|
| 917 |
+
layer_outputs = self._gradient_checkpointing_func(
|
| 918 |
+
layer_module.__call__,
|
| 919 |
+
hidden_states,
|
| 920 |
+
None,
|
| 921 |
+
output_attentions,
|
| 922 |
+
)
|
| 923 |
+
else:
|
| 924 |
+
layer_outputs = layer_module(hidden_states, head_mask=None, output_attentions=output_attentions)
|
| 925 |
+
|
| 926 |
+
hidden_states = layer_outputs[0]
|
| 927 |
+
|
| 928 |
+
if output_attentions:
|
| 929 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
| 930 |
+
|
| 931 |
+
if output_hidden_states:
|
| 932 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 933 |
+
|
| 934 |
+
hidden_states = self.decoder_norm(hidden_states)
|
| 935 |
+
|
| 936 |
+
# predictor projection
|
| 937 |
+
logits = self.decoder_pred(hidden_states)
|
| 938 |
+
|
| 939 |
+
# remove cls token
|
| 940 |
+
logits = logits[:, 1:, :]
|
| 941 |
+
|
| 942 |
+
if not return_dict:
|
| 943 |
+
return tuple(v for v in [logits, all_hidden_states, all_self_attentions] if v is not None)
|
| 944 |
+
return ViTMAEDecoderOutput(
|
| 945 |
+
logits=logits,
|
| 946 |
+
hidden_states=all_hidden_states,
|
| 947 |
+
attentions=all_self_attentions,
|
| 948 |
+
)
|
| 949 |
+
|
| 950 |
+
|
| 951 |
+
@add_start_docstrings(
|
| 952 |
+
"""The ViTMAE Model transformer with the decoder on top for self-supervised pre-training.
|
| 953 |
+
|
| 954 |
+
<Tip>
|
| 955 |
+
|
| 956 |
+
Note that we provide a script to pre-train this model on custom data in our [examples
|
| 957 |
+
directory](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-pretraining).
|
| 958 |
+
|
| 959 |
+
</Tip>
|
| 960 |
+
|
| 961 |
+
""",
|
| 962 |
+
VIT_MAE_START_DOCSTRING,
|
| 963 |
+
)
|
| 964 |
+
class ViTMAEForPreTraining(ViTMAEPreTrainedModel):
|
| 965 |
+
def __init__(self, config):
|
| 966 |
+
super().__init__(config)
|
| 967 |
+
self.config = config
|
| 968 |
+
|
| 969 |
+
self.vit = ViTMAEModel(config)
|
| 970 |
+
self.decoder = ViTMAEDecoder(config, num_patches=self.vit.embeddings.num_patches)
|
| 971 |
+
|
| 972 |
+
# Initialize weights and apply final processing
|
| 973 |
+
self.post_init()
|
| 974 |
+
|
| 975 |
+
def get_input_embeddings(self):
|
| 976 |
+
return self.vit.embeddings.patch_embeddings
|
| 977 |
+
|
| 978 |
+
def _prune_heads(self, heads_to_prune):
|
| 979 |
+
"""
|
| 980 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
| 981 |
+
class PreTrainedModel
|
| 982 |
+
"""
|
| 983 |
+
for layer, heads in heads_to_prune.items():
|
| 984 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
| 985 |
+
|
| 986 |
+
def patchify(self, pixel_values, interpolate_pos_encoding: bool = False):
|
| 987 |
+
"""
|
| 988 |
+
Args:
|
| 989 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
| 990 |
+
Pixel values.
|
| 991 |
+
interpolate_pos_encoding (`bool`, *optional*, default `False`):
|
| 992 |
+
interpolation flag passed during the forward pass.
|
| 993 |
+
|
| 994 |
+
Returns:
|
| 995 |
+
`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
|
| 996 |
+
Patchified pixel values.
|
| 997 |
+
"""
|
| 998 |
+
patch_size, num_channels = self.config.patch_size, self.config.num_channels
|
| 999 |
+
# sanity checks
|
| 1000 |
+
if not interpolate_pos_encoding and (
|
| 1001 |
+
pixel_values.shape[2] != pixel_values.shape[3] or pixel_values.shape[2] % patch_size != 0
|
| 1002 |
+
):
|
| 1003 |
+
raise ValueError("Make sure the pixel values have a squared size that is divisible by the patch size")
|
| 1004 |
+
if pixel_values.shape[1] != num_channels:
|
| 1005 |
+
raise ValueError(
|
| 1006 |
+
"Make sure the number of channels of the pixel values is equal to the one set in the configuration"
|
| 1007 |
+
)
|
| 1008 |
+
|
| 1009 |
+
# patchify
|
| 1010 |
+
batch_size = pixel_values.shape[0]
|
| 1011 |
+
num_patches_h = pixel_values.shape[2] // patch_size
|
| 1012 |
+
num_patches_w = pixel_values.shape[3] // patch_size
|
| 1013 |
+
patchified_pixel_values = pixel_values.reshape(
|
| 1014 |
+
batch_size, num_channels, num_patches_h, patch_size, num_patches_w, patch_size
|
| 1015 |
+
)
|
| 1016 |
+
patchified_pixel_values = torch.einsum("nchpwq->nhwpqc", patchified_pixel_values)
|
| 1017 |
+
patchified_pixel_values = patchified_pixel_values.reshape(
|
| 1018 |
+
batch_size, num_patches_h * num_patches_w, patch_size**2 * num_channels
|
| 1019 |
+
)
|
| 1020 |
+
return patchified_pixel_values
|
| 1021 |
+
|
| 1022 |
+
def unpatchify(self, patchified_pixel_values, original_image_size: Optional[Tuple[int, int]] = None):
|
| 1023 |
+
"""
|
| 1024 |
+
Args:
|
| 1025 |
+
patchified_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
|
| 1026 |
+
Patchified pixel values.
|
| 1027 |
+
original_image_size (`Tuple[int, int]`, *optional*):
|
| 1028 |
+
Original image size.
|
| 1029 |
+
|
| 1030 |
+
Returns:
|
| 1031 |
+
`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`:
|
| 1032 |
+
Pixel values.
|
| 1033 |
+
"""
|
| 1034 |
+
patch_size, num_channels = self.config.patch_size, self.config.num_channels
|
| 1035 |
+
original_image_size = (
|
| 1036 |
+
original_image_size
|
| 1037 |
+
if original_image_size is not None
|
| 1038 |
+
else (self.config.image_size, self.config.image_size)
|
| 1039 |
+
)
|
| 1040 |
+
original_height, original_width = original_image_size
|
| 1041 |
+
num_patches_h = original_height // patch_size
|
| 1042 |
+
num_patches_w = original_width // patch_size
|
| 1043 |
+
# sanity check
|
| 1044 |
+
if num_patches_h * num_patches_w != patchified_pixel_values.shape[1]:
|
| 1045 |
+
raise ValueError(
|
| 1046 |
+
f"The number of patches in the patchified pixel values {patchified_pixel_values.shape[1]}, does not match the number of patches on original image {num_patches_h}*{num_patches_w}"
|
| 1047 |
+
)
|
| 1048 |
+
|
| 1049 |
+
# unpatchify
|
| 1050 |
+
batch_size = patchified_pixel_values.shape[0]
|
| 1051 |
+
patchified_pixel_values = patchified_pixel_values.reshape(
|
| 1052 |
+
batch_size,
|
| 1053 |
+
num_patches_h,
|
| 1054 |
+
num_patches_w,
|
| 1055 |
+
patch_size,
|
| 1056 |
+
patch_size,
|
| 1057 |
+
num_channels,
|
| 1058 |
+
)
|
| 1059 |
+
patchified_pixel_values = torch.einsum("nhwpqc->nchpwq", patchified_pixel_values)
|
| 1060 |
+
pixel_values = patchified_pixel_values.reshape(
|
| 1061 |
+
batch_size,
|
| 1062 |
+
num_channels,
|
| 1063 |
+
num_patches_h * patch_size,
|
| 1064 |
+
num_patches_w * patch_size,
|
| 1065 |
+
)
|
| 1066 |
+
return pixel_values
|
| 1067 |
+
|
| 1068 |
+
def forward_loss(self, pixel_values, pred, mask, interpolate_pos_encoding: bool = False):
|
| 1069 |
+
"""
|
| 1070 |
+
Args:
|
| 1071 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
| 1072 |
+
Pixel values.
|
| 1073 |
+
pred (`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
|
| 1074 |
+
Predicted pixel values.
|
| 1075 |
+
mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
|
| 1076 |
+
Tensor indicating which patches are masked (1) and which are not (0).
|
| 1077 |
+
interpolate_pos_encoding (`bool`, *optional*, default `False`):
|
| 1078 |
+
interpolation flag passed during the forward pass.
|
| 1079 |
+
|
| 1080 |
+
Returns:
|
| 1081 |
+
`torch.FloatTensor`: Pixel reconstruction loss.
|
| 1082 |
+
"""
|
| 1083 |
+
target = self.patchify(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
|
| 1084 |
+
if self.config.norm_pix_loss:
|
| 1085 |
+
mean = target.mean(dim=-1, keepdim=True)
|
| 1086 |
+
var = target.var(dim=-1, keepdim=True)
|
| 1087 |
+
target = (target - mean) / (var + 1.0e-6) ** 0.5
|
| 1088 |
+
|
| 1089 |
+
loss = (pred - target) ** 2
|
| 1090 |
+
loss = loss.mean(dim=-1) # [N, L], mean loss per patch
|
| 1091 |
+
loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
|
| 1092 |
+
return loss
|
| 1093 |
+
|
| 1094 |
+
@add_start_docstrings_to_model_forward(VIT_MAE_INPUTS_DOCSTRING)
|
| 1095 |
+
@replace_return_docstrings(output_type=ViTMAEForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
|
| 1096 |
+
def forward(
|
| 1097 |
+
self,
|
| 1098 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 1099 |
+
noise: Optional[torch.FloatTensor] = None,
|
| 1100 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 1101 |
+
output_attentions: Optional[bool] = None,
|
| 1102 |
+
output_hidden_states: Optional[bool] = None,
|
| 1103 |
+
return_dict: Optional[bool] = None,
|
| 1104 |
+
interpolate_pos_encoding: bool = False,
|
| 1105 |
+
) -> Union[Tuple, ViTMAEForPreTrainingOutput]:
|
| 1106 |
+
r"""
|
| 1107 |
+
Returns:
|
| 1108 |
+
|
| 1109 |
+
Examples:
|
| 1110 |
+
|
| 1111 |
+
```python
|
| 1112 |
+
>>> from transformers import AutoImageProcessor, ViTMAEForPreTraining
|
| 1113 |
+
>>> from PIL import Image
|
| 1114 |
+
>>> import requests
|
| 1115 |
+
|
| 1116 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 1117 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 1118 |
+
|
| 1119 |
+
>>> image_processor = AutoImageProcessor.from_pretrained("facebook/vit-mae-base")
|
| 1120 |
+
>>> model = ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base")
|
| 1121 |
+
|
| 1122 |
+
>>> inputs = image_processor(images=image, return_tensors="pt")
|
| 1123 |
+
>>> outputs = model(**inputs)
|
| 1124 |
+
>>> loss = outputs.loss
|
| 1125 |
+
>>> mask = outputs.mask
|
| 1126 |
+
>>> ids_restore = outputs.ids_restore
|
| 1127 |
+
```"""
|
| 1128 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1129 |
+
|
| 1130 |
+
outputs = self.vit(
|
| 1131 |
+
pixel_values,
|
| 1132 |
+
noise=noise,
|
| 1133 |
+
head_mask=head_mask,
|
| 1134 |
+
output_attentions=output_attentions,
|
| 1135 |
+
output_hidden_states=output_hidden_states,
|
| 1136 |
+
return_dict=return_dict,
|
| 1137 |
+
interpolate_pos_encoding=interpolate_pos_encoding,
|
| 1138 |
+
)
|
| 1139 |
+
|
| 1140 |
+
latent = outputs.last_hidden_state
|
| 1141 |
+
ids_restore = outputs.ids_restore
|
| 1142 |
+
mask = outputs.mask
|
| 1143 |
+
|
| 1144 |
+
decoder_outputs = self.decoder(latent, ids_restore, interpolate_pos_encoding=interpolate_pos_encoding)
|
| 1145 |
+
logits = decoder_outputs.logits # shape (batch_size, num_patches, patch_size*patch_size*num_channels)
|
| 1146 |
+
|
| 1147 |
+
loss = self.forward_loss(pixel_values, logits, mask, interpolate_pos_encoding=interpolate_pos_encoding)
|
| 1148 |
+
|
| 1149 |
+
if not return_dict:
|
| 1150 |
+
output = (logits, mask, ids_restore) + outputs[2:]
|
| 1151 |
+
return ((loss,) + output) if loss is not None else output
|
| 1152 |
+
|
| 1153 |
+
return ViTMAEForPreTrainingOutput(
|
| 1154 |
+
loss=loss,
|
| 1155 |
+
logits=logits,
|
| 1156 |
+
mask=mask,
|
| 1157 |
+
ids_restore=ids_restore,
|
| 1158 |
+
hidden_states=outputs.hidden_states,
|
| 1159 |
+
attentions=outputs.attentions,
|
| 1160 |
+
)
|
| 1161 |
+
|
| 1162 |
+
|
| 1163 |
+
__all__ = ["ViTMAEForPreTraining", "ViTMAELayer", "ViTMAEModel", "ViTMAEPreTrainedModel"]
|
docs/transformers/build/lib/transformers/models/vit_msn/__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_vit_msn import *
|
| 22 |
+
from .modeling_vit_msn import *
|
| 23 |
+
else:
|
| 24 |
+
import sys
|
| 25 |
+
|
| 26 |
+
_file = globals()["__file__"]
|
| 27 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
docs/transformers/build/lib/transformers/models/vit_msn/configuration_vit_msn.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 Facebook AI and The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""ViT MSN model configuration"""
|
| 16 |
+
|
| 17 |
+
from ...configuration_utils import PretrainedConfig
|
| 18 |
+
from ...utils import logging
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
logger = logging.get_logger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class ViTMSNConfig(PretrainedConfig):
|
| 25 |
+
r"""
|
| 26 |
+
This is the configuration class to store the configuration of a [`ViTMSNModel`]. It is used to instantiate an ViT
|
| 27 |
+
MSN model according to the specified arguments, defining the model architecture. Instantiating a configuration with
|
| 28 |
+
the defaults will yield a similar configuration to that of the ViT
|
| 29 |
+
[facebook/vit_msn_base](https://huggingface.co/facebook/vit_msn_base) architecture.
|
| 30 |
+
|
| 31 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 32 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
hidden_size (`int`, *optional*, defaults to 768):
|
| 37 |
+
Dimensionality of the encoder layers and the pooler layer.
|
| 38 |
+
num_hidden_layers (`int`, *optional*, defaults to 12):
|
| 39 |
+
Number of hidden layers in the Transformer encoder.
|
| 40 |
+
num_attention_heads (`int`, *optional*, defaults to 12):
|
| 41 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 42 |
+
intermediate_size (`int`, *optional*, defaults to 3072):
|
| 43 |
+
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
| 44 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
| 45 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 46 |
+
`"relu"`, `"selu"` and `"gelu_new"` are supported.
|
| 47 |
+
hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
|
| 48 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
| 49 |
+
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
|
| 50 |
+
The dropout ratio for the attention probabilities.
|
| 51 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 52 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 53 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
|
| 54 |
+
The epsilon used by the layer normalization layers.
|
| 55 |
+
image_size (`int`, *optional*, defaults to 224):
|
| 56 |
+
The size (resolution) of each image.
|
| 57 |
+
patch_size (`int`, *optional*, defaults to 16):
|
| 58 |
+
The size (resolution) of each patch.
|
| 59 |
+
num_channels (`int`, *optional*, defaults to 3):
|
| 60 |
+
The number of input channels.
|
| 61 |
+
qkv_bias (`bool`, *optional*, defaults to `True`):
|
| 62 |
+
Whether to add a bias to the queries, keys and values.
|
| 63 |
+
|
| 64 |
+
Example:
|
| 65 |
+
|
| 66 |
+
```python
|
| 67 |
+
>>> from transformers import ViTMSNModel, ViTMSNConfig
|
| 68 |
+
|
| 69 |
+
>>> # Initializing a ViT MSN vit-msn-base style configuration
|
| 70 |
+
>>> configuration = ViTConfig()
|
| 71 |
+
|
| 72 |
+
>>> # Initializing a model from the vit-msn-base style configuration
|
| 73 |
+
>>> model = ViTMSNModel(configuration)
|
| 74 |
+
|
| 75 |
+
>>> # Accessing the model configuration
|
| 76 |
+
>>> configuration = model.config
|
| 77 |
+
```"""
|
| 78 |
+
|
| 79 |
+
model_type = "vit_msn"
|
| 80 |
+
|
| 81 |
+
def __init__(
|
| 82 |
+
self,
|
| 83 |
+
hidden_size=768,
|
| 84 |
+
num_hidden_layers=12,
|
| 85 |
+
num_attention_heads=12,
|
| 86 |
+
intermediate_size=3072,
|
| 87 |
+
hidden_act="gelu",
|
| 88 |
+
hidden_dropout_prob=0.0,
|
| 89 |
+
attention_probs_dropout_prob=0.0,
|
| 90 |
+
initializer_range=0.02,
|
| 91 |
+
layer_norm_eps=1e-06,
|
| 92 |
+
image_size=224,
|
| 93 |
+
patch_size=16,
|
| 94 |
+
num_channels=3,
|
| 95 |
+
qkv_bias=True,
|
| 96 |
+
**kwargs,
|
| 97 |
+
):
|
| 98 |
+
super().__init__(**kwargs)
|
| 99 |
+
|
| 100 |
+
self.hidden_size = hidden_size
|
| 101 |
+
self.num_hidden_layers = num_hidden_layers
|
| 102 |
+
self.num_attention_heads = num_attention_heads
|
| 103 |
+
self.intermediate_size = intermediate_size
|
| 104 |
+
self.hidden_act = hidden_act
|
| 105 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
| 106 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
| 107 |
+
self.initializer_range = initializer_range
|
| 108 |
+
self.layer_norm_eps = layer_norm_eps
|
| 109 |
+
self.image_size = image_size
|
| 110 |
+
self.patch_size = patch_size
|
| 111 |
+
self.num_channels = num_channels
|
| 112 |
+
self.qkv_bias = qkv_bias
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
__all__ = ["ViTMSNConfig"]
|
docs/transformers/build/lib/transformers/models/vit_msn/convert_msn_to_pytorch.py
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Convert ViT MSN checkpoints from the original repository: https://github.com/facebookresearch/msn"""
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import json
|
| 19 |
+
|
| 20 |
+
import requests
|
| 21 |
+
import torch
|
| 22 |
+
from huggingface_hub import hf_hub_download
|
| 23 |
+
from PIL import Image
|
| 24 |
+
|
| 25 |
+
from transformers import ViTImageProcessor, ViTMSNConfig, ViTMSNModel
|
| 26 |
+
from transformers.image_utils import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
torch.set_grad_enabled(False)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# here we list all keys to be renamed (original name on the left, our name on the right)
|
| 33 |
+
def create_rename_keys(config, base_model=False):
|
| 34 |
+
rename_keys = []
|
| 35 |
+
for i in range(config.num_hidden_layers):
|
| 36 |
+
# encoder layers: output projection, 2 feedforward neural networks and 2 layernorms
|
| 37 |
+
rename_keys.append((f"module.blocks.{i}.norm1.weight", f"vit.encoder.layer.{i}.layernorm_before.weight"))
|
| 38 |
+
rename_keys.append((f"module.blocks.{i}.norm1.bias", f"vit.encoder.layer.{i}.layernorm_before.bias"))
|
| 39 |
+
rename_keys.append(
|
| 40 |
+
(f"module.blocks.{i}.attn.proj.weight", f"vit.encoder.layer.{i}.attention.output.dense.weight")
|
| 41 |
+
)
|
| 42 |
+
rename_keys.append((f"module.blocks.{i}.attn.proj.bias", f"vit.encoder.layer.{i}.attention.output.dense.bias"))
|
| 43 |
+
rename_keys.append((f"module.blocks.{i}.norm2.weight", f"vit.encoder.layer.{i}.layernorm_after.weight"))
|
| 44 |
+
rename_keys.append((f"module.blocks.{i}.norm2.bias", f"vit.encoder.layer.{i}.layernorm_after.bias"))
|
| 45 |
+
rename_keys.append((f"module.blocks.{i}.mlp.fc1.weight", f"vit.encoder.layer.{i}.intermediate.dense.weight"))
|
| 46 |
+
rename_keys.append((f"module.blocks.{i}.mlp.fc1.bias", f"vit.encoder.layer.{i}.intermediate.dense.bias"))
|
| 47 |
+
rename_keys.append((f"module.blocks.{i}.mlp.fc2.weight", f"vit.encoder.layer.{i}.output.dense.weight"))
|
| 48 |
+
rename_keys.append((f"module.blocks.{i}.mlp.fc2.bias", f"vit.encoder.layer.{i}.output.dense.bias"))
|
| 49 |
+
|
| 50 |
+
# projection layer + position embeddings
|
| 51 |
+
rename_keys.extend(
|
| 52 |
+
[
|
| 53 |
+
("module.cls_token", "vit.embeddings.cls_token"),
|
| 54 |
+
("module.patch_embed.proj.weight", "vit.embeddings.patch_embeddings.projection.weight"),
|
| 55 |
+
("module.patch_embed.proj.bias", "vit.embeddings.patch_embeddings.projection.bias"),
|
| 56 |
+
("module.pos_embed", "vit.embeddings.position_embeddings"),
|
| 57 |
+
]
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
if base_model:
|
| 61 |
+
# layernorm + pooler
|
| 62 |
+
rename_keys.extend(
|
| 63 |
+
[
|
| 64 |
+
("module.norm.weight", "layernorm.weight"),
|
| 65 |
+
("module.norm.bias", "layernorm.bias"),
|
| 66 |
+
]
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
# if just the base model, we should remove "vit" from all keys that start with "vit"
|
| 70 |
+
rename_keys = [(pair[0], pair[1][4:]) if pair[1].startswith("vit") else pair for pair in rename_keys]
|
| 71 |
+
else:
|
| 72 |
+
# layernorm + classification head
|
| 73 |
+
rename_keys.extend(
|
| 74 |
+
[
|
| 75 |
+
("norm.weight", "vit.layernorm.weight"),
|
| 76 |
+
("norm.bias", "vit.layernorm.bias"),
|
| 77 |
+
("head.weight", "classifier.weight"),
|
| 78 |
+
("head.bias", "classifier.bias"),
|
| 79 |
+
]
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
return rename_keys
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
# we split up the matrix of each encoder layer into queries, keys and values
|
| 86 |
+
def read_in_q_k_v(state_dict, config, base_model=False):
|
| 87 |
+
for i in range(config.num_hidden_layers):
|
| 88 |
+
if base_model:
|
| 89 |
+
prefix = ""
|
| 90 |
+
else:
|
| 91 |
+
prefix = "vit."
|
| 92 |
+
# read in weights + bias of input projection layer (in timm, this is a single matrix + bias)
|
| 93 |
+
in_proj_weight = state_dict.pop(f"module.blocks.{i}.attn.qkv.weight")
|
| 94 |
+
in_proj_bias = state_dict.pop(f"module.blocks.{i}.attn.qkv.bias")
|
| 95 |
+
# next, add query, keys and values (in that order) to the state dict
|
| 96 |
+
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[
|
| 97 |
+
: config.hidden_size, :
|
| 98 |
+
]
|
| 99 |
+
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.bias"] = in_proj_bias[: config.hidden_size]
|
| 100 |
+
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[
|
| 101 |
+
config.hidden_size : config.hidden_size * 2, :
|
| 102 |
+
]
|
| 103 |
+
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.bias"] = in_proj_bias[
|
| 104 |
+
config.hidden_size : config.hidden_size * 2
|
| 105 |
+
]
|
| 106 |
+
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[
|
| 107 |
+
-config.hidden_size :, :
|
| 108 |
+
]
|
| 109 |
+
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.bias"] = in_proj_bias[-config.hidden_size :]
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def remove_classification_head_(state_dict):
|
| 113 |
+
ignore_keys = ["head.weight", "head.bias"]
|
| 114 |
+
for k in ignore_keys:
|
| 115 |
+
state_dict.pop(k, None)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def remove_projection_head(state_dict):
|
| 119 |
+
# projection head is used in the self-supervised pre-training in MSN,
|
| 120 |
+
# for downstream task it's not needed.
|
| 121 |
+
ignore_keys = [
|
| 122 |
+
"module.fc.fc1.weight",
|
| 123 |
+
"module.fc.fc1.bias",
|
| 124 |
+
"module.fc.bn1.weight",
|
| 125 |
+
"module.fc.bn1.bias",
|
| 126 |
+
"module.fc.bn1.running_mean",
|
| 127 |
+
"module.fc.bn1.running_var",
|
| 128 |
+
"module.fc.bn1.num_batches_tracked",
|
| 129 |
+
"module.fc.fc2.weight",
|
| 130 |
+
"module.fc.fc2.bias",
|
| 131 |
+
"module.fc.bn2.weight",
|
| 132 |
+
"module.fc.bn2.bias",
|
| 133 |
+
"module.fc.bn2.running_mean",
|
| 134 |
+
"module.fc.bn2.running_var",
|
| 135 |
+
"module.fc.bn2.num_batches_tracked",
|
| 136 |
+
"module.fc.fc3.weight",
|
| 137 |
+
"module.fc.fc3.bias",
|
| 138 |
+
]
|
| 139 |
+
for k in ignore_keys:
|
| 140 |
+
state_dict.pop(k, None)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def rename_key(dct, old, new):
|
| 144 |
+
val = dct.pop(old)
|
| 145 |
+
dct[new] = val
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def convert_vit_msn_checkpoint(checkpoint_url, pytorch_dump_folder_path):
|
| 149 |
+
config = ViTMSNConfig()
|
| 150 |
+
config.num_labels = 1000
|
| 151 |
+
|
| 152 |
+
repo_id = "datasets/huggingface/label-files"
|
| 153 |
+
filename = "imagenet-1k-id2label.json"
|
| 154 |
+
id2label = json.load(open(hf_hub_download(repo_id, filename), "r"))
|
| 155 |
+
id2label = {int(k): v for k, v in id2label.items()}
|
| 156 |
+
config.id2label = id2label
|
| 157 |
+
config.label2id = {v: k for k, v in id2label.items()}
|
| 158 |
+
|
| 159 |
+
if "s16" in checkpoint_url:
|
| 160 |
+
config.hidden_size = 384
|
| 161 |
+
config.intermediate_size = 1536
|
| 162 |
+
config.num_attention_heads = 6
|
| 163 |
+
elif "l16" in checkpoint_url:
|
| 164 |
+
config.hidden_size = 1024
|
| 165 |
+
config.intermediate_size = 4096
|
| 166 |
+
config.num_hidden_layers = 24
|
| 167 |
+
config.num_attention_heads = 16
|
| 168 |
+
config.hidden_dropout_prob = 0.1
|
| 169 |
+
elif "b4" in checkpoint_url:
|
| 170 |
+
config.patch_size = 4
|
| 171 |
+
elif "l7" in checkpoint_url:
|
| 172 |
+
config.patch_size = 7
|
| 173 |
+
config.hidden_size = 1024
|
| 174 |
+
config.intermediate_size = 4096
|
| 175 |
+
config.num_hidden_layers = 24
|
| 176 |
+
config.num_attention_heads = 16
|
| 177 |
+
config.hidden_dropout_prob = 0.1
|
| 178 |
+
|
| 179 |
+
model = ViTMSNModel(config)
|
| 180 |
+
|
| 181 |
+
state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu")["target_encoder"]
|
| 182 |
+
|
| 183 |
+
image_processor = ViTImageProcessor(size=config.image_size)
|
| 184 |
+
|
| 185 |
+
remove_projection_head(state_dict)
|
| 186 |
+
rename_keys = create_rename_keys(config, base_model=True)
|
| 187 |
+
|
| 188 |
+
for src, dest in rename_keys:
|
| 189 |
+
rename_key(state_dict, src, dest)
|
| 190 |
+
read_in_q_k_v(state_dict, config, base_model=True)
|
| 191 |
+
|
| 192 |
+
model.load_state_dict(state_dict)
|
| 193 |
+
model.eval()
|
| 194 |
+
|
| 195 |
+
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 196 |
+
|
| 197 |
+
image = Image.open(requests.get(url, stream=True).raw)
|
| 198 |
+
image_processor = ViTImageProcessor(
|
| 199 |
+
size=config.image_size, image_mean=IMAGENET_DEFAULT_MEAN, image_std=IMAGENET_DEFAULT_STD
|
| 200 |
+
)
|
| 201 |
+
inputs = image_processor(images=image, return_tensors="pt")
|
| 202 |
+
|
| 203 |
+
# forward pass
|
| 204 |
+
torch.manual_seed(2)
|
| 205 |
+
outputs = model(**inputs)
|
| 206 |
+
last_hidden_state = outputs.last_hidden_state
|
| 207 |
+
|
| 208 |
+
# The following Colab Notebook was used to generate these outputs:
|
| 209 |
+
# https://colab.research.google.com/gist/sayakpaul/3672419a04f5997827503fd84079bdd1/scratchpad.ipynb
|
| 210 |
+
if "s16" in checkpoint_url:
|
| 211 |
+
expected_slice = torch.tensor([[-1.0915, -1.4876, -1.1809]])
|
| 212 |
+
elif "b16" in checkpoint_url:
|
| 213 |
+
expected_slice = torch.tensor([[14.2889, -18.9045, 11.7281]])
|
| 214 |
+
elif "l16" in checkpoint_url:
|
| 215 |
+
expected_slice = torch.tensor([[41.5028, -22.8681, 45.6475]])
|
| 216 |
+
elif "b4" in checkpoint_url:
|
| 217 |
+
expected_slice = torch.tensor([[-4.3868, 5.2932, -0.4137]])
|
| 218 |
+
else:
|
| 219 |
+
expected_slice = torch.tensor([[-0.1792, -0.6465, 2.4263]])
|
| 220 |
+
|
| 221 |
+
# verify logits
|
| 222 |
+
assert torch.allclose(last_hidden_state[:, 0, :3], expected_slice, atol=1e-4)
|
| 223 |
+
|
| 224 |
+
print(f"Saving model to {pytorch_dump_folder_path}")
|
| 225 |
+
model.save_pretrained(pytorch_dump_folder_path)
|
| 226 |
+
|
| 227 |
+
print(f"Saving image processor to {pytorch_dump_folder_path}")
|
| 228 |
+
image_processor.save_pretrained(pytorch_dump_folder_path)
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
if __name__ == "__main__":
|
| 232 |
+
parser = argparse.ArgumentParser()
|
| 233 |
+
# Required parameters
|
| 234 |
+
parser.add_argument(
|
| 235 |
+
"--checkpoint_url",
|
| 236 |
+
default="https://dl.fbaipublicfiles.com/msn/vits16_800ep.pth.tar",
|
| 237 |
+
type=str,
|
| 238 |
+
help="URL of the checkpoint you'd like to convert.",
|
| 239 |
+
)
|
| 240 |
+
parser.add_argument(
|
| 241 |
+
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
args = parser.parse_args()
|
| 245 |
+
convert_vit_msn_checkpoint(args.checkpoint_url, args.pytorch_dump_folder_path)
|
docs/transformers/build/lib/transformers/models/vit_msn/modeling_vit_msn.py
ADDED
|
@@ -0,0 +1,741 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 Facebook AI and The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""PyTorch ViT MSN (masked siamese network) model."""
|
| 16 |
+
|
| 17 |
+
import collections.abc
|
| 18 |
+
from typing import Callable, Dict, List, Optional, Set, Tuple, Union
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import torch.utils.checkpoint
|
| 22 |
+
from torch import nn
|
| 23 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 24 |
+
|
| 25 |
+
from ...activations import ACT2FN
|
| 26 |
+
from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput
|
| 27 |
+
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 28 |
+
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
| 29 |
+
from ...utils import (
|
| 30 |
+
add_start_docstrings,
|
| 31 |
+
add_start_docstrings_to_model_forward,
|
| 32 |
+
logging,
|
| 33 |
+
replace_return_docstrings,
|
| 34 |
+
torch_int,
|
| 35 |
+
)
|
| 36 |
+
from .configuration_vit_msn import ViTMSNConfig
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
logger = logging.get_logger(__name__)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
_CONFIG_FOR_DOC = "ViTMSNConfig"
|
| 43 |
+
_CHECKPOINT_FOR_DOC = "facebook/vit-msn-small"
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class ViTMSNEmbeddings(nn.Module):
|
| 47 |
+
"""
|
| 48 |
+
Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
def __init__(self, config: ViTMSNConfig, use_mask_token: bool = False) -> None:
|
| 52 |
+
super().__init__()
|
| 53 |
+
|
| 54 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
| 55 |
+
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
|
| 56 |
+
self.patch_embeddings = ViTMSNPatchEmbeddings(config)
|
| 57 |
+
num_patches = self.patch_embeddings.num_patches
|
| 58 |
+
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
|
| 59 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 60 |
+
self.patch_size = config.patch_size
|
| 61 |
+
self.config = config
|
| 62 |
+
|
| 63 |
+
# Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
|
| 64 |
+
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
| 65 |
+
"""
|
| 66 |
+
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
|
| 67 |
+
images. This method is also adapted to support torch.jit tracing.
|
| 68 |
+
|
| 69 |
+
Adapted from:
|
| 70 |
+
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
|
| 71 |
+
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
num_patches = embeddings.shape[1] - 1
|
| 75 |
+
num_positions = self.position_embeddings.shape[1] - 1
|
| 76 |
+
|
| 77 |
+
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
|
| 78 |
+
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
|
| 79 |
+
return self.position_embeddings
|
| 80 |
+
|
| 81 |
+
class_pos_embed = self.position_embeddings[:, :1]
|
| 82 |
+
patch_pos_embed = self.position_embeddings[:, 1:]
|
| 83 |
+
|
| 84 |
+
dim = embeddings.shape[-1]
|
| 85 |
+
|
| 86 |
+
new_height = height // self.patch_size
|
| 87 |
+
new_width = width // self.patch_size
|
| 88 |
+
|
| 89 |
+
sqrt_num_positions = torch_int(num_positions**0.5)
|
| 90 |
+
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
|
| 91 |
+
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
| 92 |
+
|
| 93 |
+
patch_pos_embed = nn.functional.interpolate(
|
| 94 |
+
patch_pos_embed,
|
| 95 |
+
size=(new_height, new_width),
|
| 96 |
+
mode="bicubic",
|
| 97 |
+
align_corners=False,
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
| 101 |
+
|
| 102 |
+
return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
|
| 103 |
+
|
| 104 |
+
def forward(
|
| 105 |
+
self,
|
| 106 |
+
pixel_values: torch.Tensor,
|
| 107 |
+
bool_masked_pos: Optional[torch.BoolTensor] = None,
|
| 108 |
+
interpolate_pos_encoding: bool = False,
|
| 109 |
+
) -> torch.Tensor:
|
| 110 |
+
batch_size, num_channels, height, width = pixel_values.shape
|
| 111 |
+
embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
|
| 112 |
+
|
| 113 |
+
if bool_masked_pos is not None:
|
| 114 |
+
seq_length = embeddings.shape[1]
|
| 115 |
+
mask_tokens = self.mask_token.expand(batch_size, seq_length, -1)
|
| 116 |
+
# replace the masked visual tokens by mask_tokens
|
| 117 |
+
mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
|
| 118 |
+
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
|
| 119 |
+
|
| 120 |
+
# add the [CLS] token to the embedded patch tokens
|
| 121 |
+
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
| 122 |
+
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
|
| 123 |
+
|
| 124 |
+
# add positional encoding to each token
|
| 125 |
+
if interpolate_pos_encoding:
|
| 126 |
+
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
|
| 127 |
+
else:
|
| 128 |
+
embeddings = embeddings + self.position_embeddings
|
| 129 |
+
|
| 130 |
+
embeddings = self.dropout(embeddings)
|
| 131 |
+
|
| 132 |
+
return embeddings
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
# Copied from transformers.models.vit.modeling_vit.ViTPatchEmbeddings with ViT->ViTMSN
|
| 136 |
+
class ViTMSNPatchEmbeddings(nn.Module):
|
| 137 |
+
"""
|
| 138 |
+
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
|
| 139 |
+
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
|
| 140 |
+
Transformer.
|
| 141 |
+
"""
|
| 142 |
+
|
| 143 |
+
def __init__(self, config):
|
| 144 |
+
super().__init__()
|
| 145 |
+
image_size, patch_size = config.image_size, config.patch_size
|
| 146 |
+
num_channels, hidden_size = config.num_channels, config.hidden_size
|
| 147 |
+
|
| 148 |
+
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
|
| 149 |
+
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
|
| 150 |
+
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
| 151 |
+
self.image_size = image_size
|
| 152 |
+
self.patch_size = patch_size
|
| 153 |
+
self.num_channels = num_channels
|
| 154 |
+
self.num_patches = num_patches
|
| 155 |
+
|
| 156 |
+
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
|
| 157 |
+
|
| 158 |
+
def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
|
| 159 |
+
batch_size, num_channels, height, width = pixel_values.shape
|
| 160 |
+
if num_channels != self.num_channels:
|
| 161 |
+
raise ValueError(
|
| 162 |
+
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
| 163 |
+
f" Expected {self.num_channels} but got {num_channels}."
|
| 164 |
+
)
|
| 165 |
+
if not interpolate_pos_encoding:
|
| 166 |
+
if height != self.image_size[0] or width != self.image_size[1]:
|
| 167 |
+
raise ValueError(
|
| 168 |
+
f"Input image size ({height}*{width}) doesn't match model"
|
| 169 |
+
f" ({self.image_size[0]}*{self.image_size[1]})."
|
| 170 |
+
)
|
| 171 |
+
embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
|
| 172 |
+
return embeddings
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
# Copied from transformers.models.vit.modeling_vit.eager_attention_forward
|
| 176 |
+
def eager_attention_forward(
|
| 177 |
+
module: nn.Module,
|
| 178 |
+
query: torch.Tensor,
|
| 179 |
+
key: torch.Tensor,
|
| 180 |
+
value: torch.Tensor,
|
| 181 |
+
attention_mask: Optional[torch.Tensor],
|
| 182 |
+
scaling: float,
|
| 183 |
+
dropout: float = 0.0,
|
| 184 |
+
**kwargs,
|
| 185 |
+
):
|
| 186 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
| 187 |
+
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
|
| 188 |
+
|
| 189 |
+
# Normalize the attention scores to probabilities.
|
| 190 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
| 191 |
+
|
| 192 |
+
# This is actually dropping out entire tokens to attend to, which might
|
| 193 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
| 194 |
+
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
| 195 |
+
|
| 196 |
+
# Mask heads if we want to
|
| 197 |
+
if attention_mask is not None:
|
| 198 |
+
attn_weights = attn_weights * attention_mask
|
| 199 |
+
|
| 200 |
+
attn_output = torch.matmul(attn_weights, value)
|
| 201 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 202 |
+
|
| 203 |
+
return attn_output, attn_weights
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->ViTMSN
|
| 207 |
+
class ViTMSNSelfAttention(nn.Module):
|
| 208 |
+
def __init__(self, config: ViTMSNConfig) -> None:
|
| 209 |
+
super().__init__()
|
| 210 |
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
| 211 |
+
raise ValueError(
|
| 212 |
+
f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
|
| 213 |
+
f"heads {config.num_attention_heads}."
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
self.config = config
|
| 217 |
+
self.num_attention_heads = config.num_attention_heads
|
| 218 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
| 219 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
| 220 |
+
self.dropout_prob = config.attention_probs_dropout_prob
|
| 221 |
+
self.scaling = self.attention_head_size**-0.5
|
| 222 |
+
self.is_causal = False
|
| 223 |
+
|
| 224 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
| 225 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
| 226 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
| 227 |
+
|
| 228 |
+
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
| 229 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
| 230 |
+
x = x.view(new_x_shape)
|
| 231 |
+
return x.permute(0, 2, 1, 3)
|
| 232 |
+
|
| 233 |
+
def forward(
|
| 234 |
+
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
|
| 235 |
+
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
| 236 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
| 237 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
| 238 |
+
query_layer = self.transpose_for_scores(self.query(hidden_states))
|
| 239 |
+
|
| 240 |
+
attention_interface: Callable = eager_attention_forward
|
| 241 |
+
if self.config._attn_implementation != "eager":
|
| 242 |
+
if self.config._attn_implementation == "sdpa" and output_attentions:
|
| 243 |
+
logger.warning_once(
|
| 244 |
+
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
| 245 |
+
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
| 246 |
+
)
|
| 247 |
+
else:
|
| 248 |
+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
| 249 |
+
|
| 250 |
+
context_layer, attention_probs = attention_interface(
|
| 251 |
+
self,
|
| 252 |
+
query_layer,
|
| 253 |
+
key_layer,
|
| 254 |
+
value_layer,
|
| 255 |
+
head_mask,
|
| 256 |
+
is_causal=self.is_causal,
|
| 257 |
+
scaling=self.scaling,
|
| 258 |
+
dropout=0.0 if not self.training else self.dropout_prob,
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
| 262 |
+
context_layer = context_layer.reshape(new_context_layer_shape)
|
| 263 |
+
|
| 264 |
+
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
| 265 |
+
|
| 266 |
+
return outputs
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->ViTMSN
|
| 270 |
+
class ViTMSNSelfOutput(nn.Module):
|
| 271 |
+
"""
|
| 272 |
+
The residual connection is defined in ViTMSNLayer instead of here (as is the case with other models), due to the
|
| 273 |
+
layernorm applied before each block.
|
| 274 |
+
"""
|
| 275 |
+
|
| 276 |
+
def __init__(self, config: ViTMSNConfig) -> None:
|
| 277 |
+
super().__init__()
|
| 278 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 279 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 280 |
+
|
| 281 |
+
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
| 282 |
+
hidden_states = self.dense(hidden_states)
|
| 283 |
+
hidden_states = self.dropout(hidden_states)
|
| 284 |
+
|
| 285 |
+
return hidden_states
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
# Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->ViTMSN
|
| 289 |
+
class ViTMSNAttention(nn.Module):
|
| 290 |
+
def __init__(self, config: ViTMSNConfig) -> None:
|
| 291 |
+
super().__init__()
|
| 292 |
+
self.attention = ViTMSNSelfAttention(config)
|
| 293 |
+
self.output = ViTMSNSelfOutput(config)
|
| 294 |
+
self.pruned_heads = set()
|
| 295 |
+
|
| 296 |
+
def prune_heads(self, heads: Set[int]) -> None:
|
| 297 |
+
if len(heads) == 0:
|
| 298 |
+
return
|
| 299 |
+
heads, index = find_pruneable_heads_and_indices(
|
| 300 |
+
heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
# Prune linear layers
|
| 304 |
+
self.attention.query = prune_linear_layer(self.attention.query, index)
|
| 305 |
+
self.attention.key = prune_linear_layer(self.attention.key, index)
|
| 306 |
+
self.attention.value = prune_linear_layer(self.attention.value, index)
|
| 307 |
+
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
| 308 |
+
|
| 309 |
+
# Update hyper params and store pruned heads
|
| 310 |
+
self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
|
| 311 |
+
self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
|
| 312 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
| 313 |
+
|
| 314 |
+
def forward(
|
| 315 |
+
self,
|
| 316 |
+
hidden_states: torch.Tensor,
|
| 317 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 318 |
+
output_attentions: bool = False,
|
| 319 |
+
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
| 320 |
+
self_outputs = self.attention(hidden_states, head_mask, output_attentions)
|
| 321 |
+
|
| 322 |
+
attention_output = self.output(self_outputs[0], hidden_states)
|
| 323 |
+
|
| 324 |
+
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
| 325 |
+
return outputs
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
# Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->ViTMSN
|
| 329 |
+
class ViTMSNIntermediate(nn.Module):
|
| 330 |
+
def __init__(self, config: ViTMSNConfig) -> None:
|
| 331 |
+
super().__init__()
|
| 332 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
| 333 |
+
if isinstance(config.hidden_act, str):
|
| 334 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
| 335 |
+
else:
|
| 336 |
+
self.intermediate_act_fn = config.hidden_act
|
| 337 |
+
|
| 338 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 339 |
+
hidden_states = self.dense(hidden_states)
|
| 340 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
| 341 |
+
|
| 342 |
+
return hidden_states
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
# Copied from transformers.models.vit.modeling_vit.ViTOutput with ViT->ViTMSN
|
| 346 |
+
class ViTMSNOutput(nn.Module):
|
| 347 |
+
def __init__(self, config: ViTMSNConfig) -> None:
|
| 348 |
+
super().__init__()
|
| 349 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
| 350 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 351 |
+
|
| 352 |
+
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
| 353 |
+
hidden_states = self.dense(hidden_states)
|
| 354 |
+
hidden_states = self.dropout(hidden_states)
|
| 355 |
+
|
| 356 |
+
hidden_states = hidden_states + input_tensor
|
| 357 |
+
|
| 358 |
+
return hidden_states
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->ViTMSN, VIT->VITMSN
|
| 362 |
+
class ViTMSNLayer(nn.Module):
|
| 363 |
+
"""This corresponds to the Block class in the timm implementation."""
|
| 364 |
+
|
| 365 |
+
def __init__(self, config: ViTMSNConfig) -> None:
|
| 366 |
+
super().__init__()
|
| 367 |
+
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
| 368 |
+
self.seq_len_dim = 1
|
| 369 |
+
self.attention = ViTMSNAttention(config)
|
| 370 |
+
self.intermediate = ViTMSNIntermediate(config)
|
| 371 |
+
self.output = ViTMSNOutput(config)
|
| 372 |
+
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 373 |
+
self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 374 |
+
|
| 375 |
+
def forward(
|
| 376 |
+
self,
|
| 377 |
+
hidden_states: torch.Tensor,
|
| 378 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 379 |
+
output_attentions: bool = False,
|
| 380 |
+
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
| 381 |
+
self_attention_outputs = self.attention(
|
| 382 |
+
self.layernorm_before(hidden_states), # in ViTMSN, layernorm is applied before self-attention
|
| 383 |
+
head_mask,
|
| 384 |
+
output_attentions=output_attentions,
|
| 385 |
+
)
|
| 386 |
+
attention_output = self_attention_outputs[0]
|
| 387 |
+
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
| 388 |
+
|
| 389 |
+
# first residual connection
|
| 390 |
+
hidden_states = attention_output + hidden_states
|
| 391 |
+
|
| 392 |
+
# in ViTMSN, layernorm is also applied after self-attention
|
| 393 |
+
layer_output = self.layernorm_after(hidden_states)
|
| 394 |
+
layer_output = self.intermediate(layer_output)
|
| 395 |
+
|
| 396 |
+
# second residual connection is done here
|
| 397 |
+
layer_output = self.output(layer_output, hidden_states)
|
| 398 |
+
|
| 399 |
+
outputs = (layer_output,) + outputs
|
| 400 |
+
|
| 401 |
+
return outputs
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
# Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->ViTMSN
|
| 405 |
+
class ViTMSNEncoder(nn.Module):
|
| 406 |
+
def __init__(self, config: ViTMSNConfig) -> None:
|
| 407 |
+
super().__init__()
|
| 408 |
+
self.config = config
|
| 409 |
+
self.layer = nn.ModuleList([ViTMSNLayer(config) for _ in range(config.num_hidden_layers)])
|
| 410 |
+
self.gradient_checkpointing = False
|
| 411 |
+
|
| 412 |
+
def forward(
|
| 413 |
+
self,
|
| 414 |
+
hidden_states: torch.Tensor,
|
| 415 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 416 |
+
output_attentions: bool = False,
|
| 417 |
+
output_hidden_states: bool = False,
|
| 418 |
+
return_dict: bool = True,
|
| 419 |
+
) -> Union[tuple, BaseModelOutput]:
|
| 420 |
+
all_hidden_states = () if output_hidden_states else None
|
| 421 |
+
all_self_attentions = () if output_attentions else None
|
| 422 |
+
|
| 423 |
+
for i, layer_module in enumerate(self.layer):
|
| 424 |
+
if output_hidden_states:
|
| 425 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 426 |
+
|
| 427 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
| 428 |
+
|
| 429 |
+
if self.gradient_checkpointing and self.training:
|
| 430 |
+
layer_outputs = self._gradient_checkpointing_func(
|
| 431 |
+
layer_module.__call__,
|
| 432 |
+
hidden_states,
|
| 433 |
+
layer_head_mask,
|
| 434 |
+
output_attentions,
|
| 435 |
+
)
|
| 436 |
+
else:
|
| 437 |
+
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
|
| 438 |
+
|
| 439 |
+
hidden_states = layer_outputs[0]
|
| 440 |
+
|
| 441 |
+
if output_attentions:
|
| 442 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
| 443 |
+
|
| 444 |
+
if output_hidden_states:
|
| 445 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 446 |
+
|
| 447 |
+
if not return_dict:
|
| 448 |
+
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
|
| 449 |
+
return BaseModelOutput(
|
| 450 |
+
last_hidden_state=hidden_states,
|
| 451 |
+
hidden_states=all_hidden_states,
|
| 452 |
+
attentions=all_self_attentions,
|
| 453 |
+
)
|
| 454 |
+
|
| 455 |
+
|
| 456 |
+
class ViTMSNPreTrainedModel(PreTrainedModel):
|
| 457 |
+
"""
|
| 458 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 459 |
+
models.
|
| 460 |
+
"""
|
| 461 |
+
|
| 462 |
+
config_class = ViTMSNConfig
|
| 463 |
+
base_model_prefix = "vit"
|
| 464 |
+
main_input_name = "pixel_values"
|
| 465 |
+
supports_gradient_checkpointing = True
|
| 466 |
+
_no_split_modules = ["ViTMSNAttention", "ViTMSNSdpaAttention"]
|
| 467 |
+
_supports_sdpa = True
|
| 468 |
+
_supports_flash_attn_2 = True
|
| 469 |
+
|
| 470 |
+
# todo: Resort to https://github.com/facebookresearch/msn/blob/main/src/deit.py#L200-#L211
|
| 471 |
+
# when creating pre-training scripts.
|
| 472 |
+
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
|
| 473 |
+
"""Initialize the weights"""
|
| 474 |
+
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
| 475 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
| 476 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
| 477 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 478 |
+
if module.bias is not None:
|
| 479 |
+
module.bias.data.zero_()
|
| 480 |
+
elif isinstance(module, nn.LayerNorm):
|
| 481 |
+
module.bias.data.zero_()
|
| 482 |
+
module.weight.data.fill_(1.0)
|
| 483 |
+
elif isinstance(module, ViTMSNEmbeddings):
|
| 484 |
+
module.cls_token.data.zero_()
|
| 485 |
+
module.position_embeddings.data.zero_()
|
| 486 |
+
if module.mask_token is not None:
|
| 487 |
+
module.mask_token.data.zero_()
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
VIT_MSN_START_DOCSTRING = r"""
|
| 491 |
+
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
|
| 492 |
+
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
| 493 |
+
behavior.
|
| 494 |
+
|
| 495 |
+
Parameters:
|
| 496 |
+
config ([`ViTMSNConfig`]): Model configuration class with all the parameters of the model.
|
| 497 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 498 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 499 |
+
"""
|
| 500 |
+
|
| 501 |
+
VIT_MSN_INPUTS_DOCSTRING = r"""
|
| 502 |
+
Args:
|
| 503 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
| 504 |
+
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`]
|
| 505 |
+
for details.
|
| 506 |
+
|
| 507 |
+
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
| 508 |
+
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
| 509 |
+
|
| 510 |
+
- 1 indicates the head is **not masked**,
|
| 511 |
+
- 0 indicates the head is **masked**.
|
| 512 |
+
|
| 513 |
+
output_attentions (`bool`, *optional*):
|
| 514 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 515 |
+
tensors for more detail.
|
| 516 |
+
output_hidden_states (`bool`, *optional*):
|
| 517 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 518 |
+
more detail.
|
| 519 |
+
interpolate_pos_encoding (`bool`, *optional*):
|
| 520 |
+
Whether to interpolate the pre-trained position encodings.
|
| 521 |
+
return_dict (`bool`, *optional*):
|
| 522 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 523 |
+
"""
|
| 524 |
+
|
| 525 |
+
|
| 526 |
+
@add_start_docstrings(
|
| 527 |
+
"The bare ViTMSN Model outputting raw hidden-states without any specific head on top.",
|
| 528 |
+
VIT_MSN_START_DOCSTRING,
|
| 529 |
+
)
|
| 530 |
+
class ViTMSNModel(ViTMSNPreTrainedModel):
|
| 531 |
+
def __init__(self, config: ViTMSNConfig, use_mask_token: bool = False):
|
| 532 |
+
super().__init__(config)
|
| 533 |
+
self.config = config
|
| 534 |
+
|
| 535 |
+
self.embeddings = ViTMSNEmbeddings(config, use_mask_token=use_mask_token)
|
| 536 |
+
self.encoder = ViTMSNEncoder(config)
|
| 537 |
+
|
| 538 |
+
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 539 |
+
|
| 540 |
+
# Initialize weights and apply final processing
|
| 541 |
+
self.post_init()
|
| 542 |
+
|
| 543 |
+
def get_input_embeddings(self) -> ViTMSNPatchEmbeddings:
|
| 544 |
+
return self.embeddings.patch_embeddings
|
| 545 |
+
|
| 546 |
+
def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
|
| 547 |
+
"""
|
| 548 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
| 549 |
+
class PreTrainedModel
|
| 550 |
+
"""
|
| 551 |
+
for layer, heads in heads_to_prune.items():
|
| 552 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
| 553 |
+
|
| 554 |
+
@add_start_docstrings_to_model_forward(VIT_MSN_INPUTS_DOCSTRING)
|
| 555 |
+
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
|
| 556 |
+
def forward(
|
| 557 |
+
self,
|
| 558 |
+
pixel_values: Optional[torch.Tensor] = None,
|
| 559 |
+
bool_masked_pos: Optional[torch.BoolTensor] = None,
|
| 560 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 561 |
+
output_attentions: Optional[bool] = None,
|
| 562 |
+
output_hidden_states: Optional[bool] = None,
|
| 563 |
+
interpolate_pos_encoding: Optional[bool] = None,
|
| 564 |
+
return_dict: Optional[bool] = None,
|
| 565 |
+
) -> Union[tuple, BaseModelOutput]:
|
| 566 |
+
r"""
|
| 567 |
+
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
|
| 568 |
+
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
|
| 569 |
+
|
| 570 |
+
Returns:
|
| 571 |
+
|
| 572 |
+
Examples:
|
| 573 |
+
|
| 574 |
+
```python
|
| 575 |
+
>>> from transformers import AutoImageProcessor, ViTMSNModel
|
| 576 |
+
>>> import torch
|
| 577 |
+
>>> from PIL import Image
|
| 578 |
+
>>> import requests
|
| 579 |
+
|
| 580 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 581 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 582 |
+
|
| 583 |
+
>>> image_processor = AutoImageProcessor.from_pretrained("facebook/vit-msn-small")
|
| 584 |
+
>>> model = ViTMSNModel.from_pretrained("facebook/vit-msn-small")
|
| 585 |
+
>>> inputs = image_processor(images=image, return_tensors="pt")
|
| 586 |
+
>>> with torch.no_grad():
|
| 587 |
+
... outputs = model(**inputs)
|
| 588 |
+
>>> last_hidden_states = outputs.last_hidden_state
|
| 589 |
+
```"""
|
| 590 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 591 |
+
output_hidden_states = (
|
| 592 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 593 |
+
)
|
| 594 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 595 |
+
|
| 596 |
+
if pixel_values is None:
|
| 597 |
+
raise ValueError("You have to specify pixel_values")
|
| 598 |
+
|
| 599 |
+
# Prepare head mask if needed
|
| 600 |
+
# 1.0 in head_mask indicate we keep the head
|
| 601 |
+
# attention_probs has shape bsz x n_heads x N x N
|
| 602 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
| 603 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
| 604 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
| 605 |
+
|
| 606 |
+
embedding_output = self.embeddings(
|
| 607 |
+
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
|
| 608 |
+
)
|
| 609 |
+
|
| 610 |
+
encoder_outputs = self.encoder(
|
| 611 |
+
embedding_output,
|
| 612 |
+
head_mask=head_mask,
|
| 613 |
+
output_attentions=output_attentions,
|
| 614 |
+
output_hidden_states=output_hidden_states,
|
| 615 |
+
return_dict=return_dict,
|
| 616 |
+
)
|
| 617 |
+
sequence_output = encoder_outputs[0]
|
| 618 |
+
sequence_output = self.layernorm(sequence_output)
|
| 619 |
+
|
| 620 |
+
if not return_dict:
|
| 621 |
+
head_outputs = (sequence_output,)
|
| 622 |
+
return head_outputs + encoder_outputs[1:]
|
| 623 |
+
|
| 624 |
+
return BaseModelOutput(
|
| 625 |
+
last_hidden_state=sequence_output,
|
| 626 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 627 |
+
attentions=encoder_outputs.attentions,
|
| 628 |
+
)
|
| 629 |
+
|
| 630 |
+
|
| 631 |
+
# Caution: We don't have the weights for the classification head yet. This class
|
| 632 |
+
# is here for the users that are interested to fine-tune the base model (ViTMSNModel).
|
| 633 |
+
@add_start_docstrings(
|
| 634 |
+
"""
|
| 635 |
+
ViTMSN Model with an image classification head on top e.g. for ImageNet.
|
| 636 |
+
""",
|
| 637 |
+
VIT_MSN_START_DOCSTRING,
|
| 638 |
+
)
|
| 639 |
+
class ViTMSNForImageClassification(ViTMSNPreTrainedModel):
|
| 640 |
+
def __init__(self, config: ViTMSNConfig) -> None:
|
| 641 |
+
super().__init__(config)
|
| 642 |
+
|
| 643 |
+
self.num_labels = config.num_labels
|
| 644 |
+
self.vit = ViTMSNModel(config)
|
| 645 |
+
|
| 646 |
+
# Classifier head
|
| 647 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
|
| 648 |
+
|
| 649 |
+
# Initialize weights and apply final processing
|
| 650 |
+
self.post_init()
|
| 651 |
+
|
| 652 |
+
@add_start_docstrings_to_model_forward(VIT_MSN_INPUTS_DOCSTRING)
|
| 653 |
+
@replace_return_docstrings(output_type=ImageClassifierOutput, config_class=_CONFIG_FOR_DOC)
|
| 654 |
+
def forward(
|
| 655 |
+
self,
|
| 656 |
+
pixel_values: Optional[torch.Tensor] = None,
|
| 657 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 658 |
+
labels: Optional[torch.Tensor] = None,
|
| 659 |
+
output_attentions: Optional[bool] = None,
|
| 660 |
+
output_hidden_states: Optional[bool] = None,
|
| 661 |
+
interpolate_pos_encoding: Optional[bool] = None,
|
| 662 |
+
return_dict: Optional[bool] = None,
|
| 663 |
+
) -> Union[tuple, ImageClassifierOutput]:
|
| 664 |
+
r"""
|
| 665 |
+
Returns:
|
| 666 |
+
|
| 667 |
+
Examples:
|
| 668 |
+
|
| 669 |
+
```python
|
| 670 |
+
>>> from transformers import AutoImageProcessor, ViTMSNForImageClassification
|
| 671 |
+
>>> import torch
|
| 672 |
+
>>> from PIL import Image
|
| 673 |
+
>>> import requests
|
| 674 |
+
|
| 675 |
+
>>> torch.manual_seed(2) # doctest: +IGNORE_RESULT
|
| 676 |
+
|
| 677 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 678 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 679 |
+
|
| 680 |
+
>>> image_processor = AutoImageProcessor.from_pretrained("facebook/vit-msn-small")
|
| 681 |
+
>>> model = ViTMSNForImageClassification.from_pretrained("facebook/vit-msn-small")
|
| 682 |
+
|
| 683 |
+
>>> inputs = image_processor(images=image, return_tensors="pt")
|
| 684 |
+
>>> with torch.no_grad():
|
| 685 |
+
... logits = model(**inputs).logits
|
| 686 |
+
>>> # model predicts one of the 1000 ImageNet classes
|
| 687 |
+
>>> predicted_label = logits.argmax(-1).item()
|
| 688 |
+
>>> print(model.config.id2label[predicted_label])
|
| 689 |
+
tusker
|
| 690 |
+
```"""
|
| 691 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 692 |
+
|
| 693 |
+
outputs = self.vit(
|
| 694 |
+
pixel_values,
|
| 695 |
+
head_mask=head_mask,
|
| 696 |
+
output_attentions=output_attentions,
|
| 697 |
+
output_hidden_states=output_hidden_states,
|
| 698 |
+
interpolate_pos_encoding=interpolate_pos_encoding,
|
| 699 |
+
return_dict=return_dict,
|
| 700 |
+
)
|
| 701 |
+
|
| 702 |
+
sequence_output = outputs[0]
|
| 703 |
+
|
| 704 |
+
logits = self.classifier(sequence_output[:, 0, :])
|
| 705 |
+
|
| 706 |
+
loss = None
|
| 707 |
+
if labels is not None:
|
| 708 |
+
if self.config.problem_type is None:
|
| 709 |
+
if self.num_labels == 1:
|
| 710 |
+
self.config.problem_type = "regression"
|
| 711 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
| 712 |
+
self.config.problem_type = "single_label_classification"
|
| 713 |
+
else:
|
| 714 |
+
self.config.problem_type = "multi_label_classification"
|
| 715 |
+
|
| 716 |
+
if self.config.problem_type == "regression":
|
| 717 |
+
loss_fct = MSELoss()
|
| 718 |
+
if self.num_labels == 1:
|
| 719 |
+
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
| 720 |
+
else:
|
| 721 |
+
loss = loss_fct(logits, labels)
|
| 722 |
+
elif self.config.problem_type == "single_label_classification":
|
| 723 |
+
loss_fct = CrossEntropyLoss()
|
| 724 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 725 |
+
elif self.config.problem_type == "multi_label_classification":
|
| 726 |
+
loss_fct = BCEWithLogitsLoss()
|
| 727 |
+
loss = loss_fct(logits, labels)
|
| 728 |
+
|
| 729 |
+
if not return_dict:
|
| 730 |
+
output = (logits,) + outputs[1:]
|
| 731 |
+
return ((loss,) + output) if loss is not None else output
|
| 732 |
+
|
| 733 |
+
return ImageClassifierOutput(
|
| 734 |
+
loss=loss,
|
| 735 |
+
logits=logits,
|
| 736 |
+
hidden_states=outputs.hidden_states,
|
| 737 |
+
attentions=outputs.attentions,
|
| 738 |
+
)
|
| 739 |
+
|
| 740 |
+
|
| 741 |
+
__all__ = ["ViTMSNModel", "ViTMSNForImageClassification", "ViTMSNPreTrainedModel"]
|
docs/transformers/build/lib/transformers/models/vitdet/__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_vitdet import *
|
| 22 |
+
from .modeling_vitdet import *
|
| 23 |
+
else:
|
| 24 |
+
import sys
|
| 25 |
+
|
| 26 |
+
_file = globals()["__file__"]
|
| 27 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
docs/transformers/build/lib/transformers/models/vitdet/configuration_vitdet.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""VitDet model configuration"""
|
| 16 |
+
|
| 17 |
+
from ...configuration_utils import PretrainedConfig
|
| 18 |
+
from ...utils import logging
|
| 19 |
+
from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
logger = logging.get_logger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class VitDetConfig(BackboneConfigMixin, PretrainedConfig):
|
| 26 |
+
r"""
|
| 27 |
+
This is the configuration class to store the configuration of a [`VitDetModel`]. It is used to instantiate an
|
| 28 |
+
VitDet model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
| 29 |
+
with the defaults will yield a similar configuration to that of the VitDet
|
| 30 |
+
[google/vitdet-base-patch16-224](https://huggingface.co/google/vitdet-base-patch16-224) architecture.
|
| 31 |
+
|
| 32 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 33 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
hidden_size (`int`, *optional*, defaults to 768):
|
| 37 |
+
Dimensionality of the encoder layers and the pooler layer.
|
| 38 |
+
num_hidden_layers (`int`, *optional*, defaults to 12):
|
| 39 |
+
Number of hidden layers in the Transformer encoder.
|
| 40 |
+
num_attention_heads (`int`, *optional*, defaults to 12):
|
| 41 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 42 |
+
mlp_ratio (`int`, *optional*, defaults to 4):
|
| 43 |
+
Ratio of mlp hidden dim to embedding dim.
|
| 44 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
| 45 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 46 |
+
`"relu"`, `"selu"` and `"gelu_new"` are supported.
|
| 47 |
+
dropout_prob (`float`, *optional*, defaults to 0.0):
|
| 48 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
| 49 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 50 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 51 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
|
| 52 |
+
The epsilon used by the layer normalization layers.
|
| 53 |
+
image_size (`int`, *optional*, defaults to 224):
|
| 54 |
+
The size (resolution) of each image.
|
| 55 |
+
pretrain_image_size (`int`, *optional*, defaults to 224):
|
| 56 |
+
The size (resolution) of each image during pretraining.
|
| 57 |
+
patch_size (`int`, *optional*, defaults to 16):
|
| 58 |
+
The size (resolution) of each patch.
|
| 59 |
+
num_channels (`int`, *optional*, defaults to 3):
|
| 60 |
+
The number of input channels.
|
| 61 |
+
qkv_bias (`bool`, *optional*, defaults to `True`):
|
| 62 |
+
Whether to add a bias to the queries, keys and values.
|
| 63 |
+
drop_path_rate (`float`, *optional*, defaults to 0.0):
|
| 64 |
+
Stochastic depth rate.
|
| 65 |
+
window_block_indices (`List[int]`, *optional*, defaults to `[]`):
|
| 66 |
+
List of indices of blocks that should have window attention instead of regular global self-attention.
|
| 67 |
+
residual_block_indices (`List[int]`, *optional*, defaults to `[]`):
|
| 68 |
+
List of indices of blocks that should have an extra residual block after the MLP.
|
| 69 |
+
use_absolute_position_embeddings (`bool`, *optional*, defaults to `True`):
|
| 70 |
+
Whether to add absolute position embeddings to the patch embeddings.
|
| 71 |
+
use_relative_position_embeddings (`bool`, *optional*, defaults to `False`):
|
| 72 |
+
Whether to add relative position embeddings to the attention maps.
|
| 73 |
+
window_size (`int`, *optional*, defaults to 0):
|
| 74 |
+
The size of the attention window.
|
| 75 |
+
out_features (`List[str]`, *optional*):
|
| 76 |
+
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
|
| 77 |
+
(depending on how many stages the model has). If unset and `out_indices` is set, will default to the
|
| 78 |
+
corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the
|
| 79 |
+
same order as defined in the `stage_names` attribute.
|
| 80 |
+
out_indices (`List[int]`, *optional*):
|
| 81 |
+
If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
|
| 82 |
+
many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
|
| 83 |
+
If unset and `out_features` is unset, will default to the last stage. Must be in the
|
| 84 |
+
same order as defined in the `stage_names` attribute.
|
| 85 |
+
|
| 86 |
+
Example:
|
| 87 |
+
|
| 88 |
+
```python
|
| 89 |
+
>>> from transformers import VitDetConfig, VitDetModel
|
| 90 |
+
|
| 91 |
+
>>> # Initializing a VitDet configuration
|
| 92 |
+
>>> configuration = VitDetConfig()
|
| 93 |
+
|
| 94 |
+
>>> # Initializing a model (with random weights) from the configuration
|
| 95 |
+
>>> model = VitDetModel(configuration)
|
| 96 |
+
|
| 97 |
+
>>> # Accessing the model configuration
|
| 98 |
+
>>> configuration = model.config
|
| 99 |
+
```"""
|
| 100 |
+
|
| 101 |
+
model_type = "vitdet"
|
| 102 |
+
|
| 103 |
+
def __init__(
|
| 104 |
+
self,
|
| 105 |
+
hidden_size=768,
|
| 106 |
+
num_hidden_layers=12,
|
| 107 |
+
num_attention_heads=12,
|
| 108 |
+
mlp_ratio=4,
|
| 109 |
+
hidden_act="gelu",
|
| 110 |
+
dropout_prob=0.0,
|
| 111 |
+
initializer_range=0.02,
|
| 112 |
+
layer_norm_eps=1e-6,
|
| 113 |
+
image_size=224,
|
| 114 |
+
pretrain_image_size=224,
|
| 115 |
+
patch_size=16,
|
| 116 |
+
num_channels=3,
|
| 117 |
+
qkv_bias=True,
|
| 118 |
+
drop_path_rate=0.0,
|
| 119 |
+
window_block_indices=[],
|
| 120 |
+
residual_block_indices=[],
|
| 121 |
+
use_absolute_position_embeddings=True,
|
| 122 |
+
use_relative_position_embeddings=False,
|
| 123 |
+
window_size=0,
|
| 124 |
+
out_features=None,
|
| 125 |
+
out_indices=None,
|
| 126 |
+
**kwargs,
|
| 127 |
+
):
|
| 128 |
+
super().__init__(**kwargs)
|
| 129 |
+
|
| 130 |
+
self.hidden_size = hidden_size
|
| 131 |
+
self.num_hidden_layers = num_hidden_layers
|
| 132 |
+
self.num_attention_heads = num_attention_heads
|
| 133 |
+
self.mlp_ratio = mlp_ratio
|
| 134 |
+
self.hidden_act = hidden_act
|
| 135 |
+
self.dropout_prob = dropout_prob
|
| 136 |
+
self.initializer_range = initializer_range
|
| 137 |
+
self.layer_norm_eps = layer_norm_eps
|
| 138 |
+
self.image_size = image_size
|
| 139 |
+
self.pretrain_image_size = pretrain_image_size
|
| 140 |
+
self.patch_size = patch_size
|
| 141 |
+
self.num_channels = num_channels
|
| 142 |
+
self.qkv_bias = qkv_bias
|
| 143 |
+
self.drop_path_rate = drop_path_rate
|
| 144 |
+
self.window_block_indices = window_block_indices
|
| 145 |
+
self.residual_block_indices = residual_block_indices
|
| 146 |
+
self.use_absolute_position_embeddings = use_absolute_position_embeddings
|
| 147 |
+
self.use_relative_position_embeddings = use_relative_position_embeddings
|
| 148 |
+
self.window_size = window_size
|
| 149 |
+
|
| 150 |
+
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, self.num_hidden_layers + 1)]
|
| 151 |
+
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
|
| 152 |
+
out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
__all__ = ["VitDetConfig"]
|
docs/transformers/build/lib/transformers/models/vitdet/modeling_vitdet.py
ADDED
|
@@ -0,0 +1,883 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""PyTorch ViTDet backbone."""
|
| 16 |
+
|
| 17 |
+
import collections.abc
|
| 18 |
+
import math
|
| 19 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import torch.utils.checkpoint
|
| 23 |
+
from torch import nn
|
| 24 |
+
|
| 25 |
+
from ...activations import ACT2FN
|
| 26 |
+
from ...modeling_outputs import BackboneOutput, BaseModelOutput
|
| 27 |
+
from ...modeling_utils import PreTrainedModel
|
| 28 |
+
from ...utils import (
|
| 29 |
+
add_start_docstrings,
|
| 30 |
+
add_start_docstrings_to_model_forward,
|
| 31 |
+
logging,
|
| 32 |
+
replace_return_docstrings,
|
| 33 |
+
)
|
| 34 |
+
from ...utils.backbone_utils import BackboneMixin
|
| 35 |
+
from .configuration_vitdet import VitDetConfig
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
logger = logging.get_logger(__name__)
|
| 39 |
+
|
| 40 |
+
# General docstring
|
| 41 |
+
_CONFIG_FOR_DOC = "VitDetConfig"
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class VitDetEmbeddings(nn.Module):
|
| 45 |
+
"""
|
| 46 |
+
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
|
| 47 |
+
`hidden_states` (patch embeddings) to be consumed by a Transformer.
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
def __init__(self, config):
|
| 51 |
+
super().__init__()
|
| 52 |
+
image_size, patch_size = config.pretrain_image_size, config.patch_size
|
| 53 |
+
num_channels, hidden_size = config.num_channels, config.hidden_size
|
| 54 |
+
|
| 55 |
+
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
|
| 56 |
+
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
|
| 57 |
+
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
| 58 |
+
self.image_size = image_size
|
| 59 |
+
self.patch_size = patch_size
|
| 60 |
+
self.num_channels = num_channels
|
| 61 |
+
self.num_patches = num_patches
|
| 62 |
+
|
| 63 |
+
if config.use_absolute_position_embeddings:
|
| 64 |
+
# Initialize absolute positional embedding with pretrain image size.
|
| 65 |
+
num_positions = num_patches + 1
|
| 66 |
+
self.position_embeddings = nn.Parameter(torch.zeros(1, num_positions, config.hidden_size))
|
| 67 |
+
else:
|
| 68 |
+
self.position_embeddings = None
|
| 69 |
+
|
| 70 |
+
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
|
| 71 |
+
|
| 72 |
+
def get_absolute_positions(self, abs_pos_embeddings, has_cls_token, height, width):
|
| 73 |
+
"""
|
| 74 |
+
Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token dimension for the
|
| 75 |
+
original embeddings.
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
abs_pos_embeddings (`torch.Tensor`):
|
| 79 |
+
Absolute positional embeddings with (1, num_position, num_channels).
|
| 80 |
+
has_cls_token (`bool`):
|
| 81 |
+
If true, has 1 embedding in abs_pos_embeddings for cls token.
|
| 82 |
+
height (`int`):
|
| 83 |
+
Height of input image tokens.
|
| 84 |
+
width (`int`):
|
| 85 |
+
Width of input image tokens.
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
Absolute positional embeddings after processing with shape (1, height, width, num_channels)
|
| 89 |
+
"""
|
| 90 |
+
if has_cls_token:
|
| 91 |
+
abs_pos_embeddings = abs_pos_embeddings[:, 1:]
|
| 92 |
+
num_position = abs_pos_embeddings.shape[1]
|
| 93 |
+
size = int(math.sqrt(num_position)) # This is a constant and can be recorded as such in the ONNX export.
|
| 94 |
+
if size * size != num_position:
|
| 95 |
+
raise ValueError("Absolute position embeddings must be a square number.")
|
| 96 |
+
|
| 97 |
+
if torch.jit.is_tracing() or (size != height or size != width):
|
| 98 |
+
# nn.functional.interpolate is a noop in case size == height and size == width - we need to always capture this path with jit.trace.
|
| 99 |
+
new_abs_pos_embeddings = nn.functional.interpolate(
|
| 100 |
+
abs_pos_embeddings.reshape(1, size, size, -1).permute(0, 3, 1, 2),
|
| 101 |
+
size=(height, width),
|
| 102 |
+
mode="bicubic",
|
| 103 |
+
align_corners=False,
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
return new_abs_pos_embeddings.permute(0, 2, 3, 1)
|
| 107 |
+
else:
|
| 108 |
+
return abs_pos_embeddings.reshape(1, height, width, -1)
|
| 109 |
+
|
| 110 |
+
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
| 111 |
+
num_channels = pixel_values.shape[1]
|
| 112 |
+
if num_channels != self.num_channels:
|
| 113 |
+
raise ValueError(
|
| 114 |
+
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
| 115 |
+
f" Expected {self.num_channels} but got {num_channels}."
|
| 116 |
+
)
|
| 117 |
+
embeddings = self.projection(pixel_values)
|
| 118 |
+
|
| 119 |
+
if self.position_embeddings is not None:
|
| 120 |
+
# (batch_size, num_channels, height, width) -> (batch_size, height, width, num_channels)
|
| 121 |
+
embeddings = embeddings.permute(0, 2, 3, 1)
|
| 122 |
+
# add position embeddings
|
| 123 |
+
embeddings = embeddings + self.get_absolute_positions(
|
| 124 |
+
self.position_embeddings, True, embeddings.shape[1], embeddings.shape[2]
|
| 125 |
+
)
|
| 126 |
+
# (batch_size, height, width, num_channels) -> (batch_size, num_channels, height, width)
|
| 127 |
+
embeddings = embeddings.permute(0, 3, 1, 2)
|
| 128 |
+
|
| 129 |
+
return embeddings
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
@torch.jit.script_if_tracing # nn.functional.interpolate's `size` needs to be dynamic.
|
| 133 |
+
def get_rel_pos(q_size, k_size, rel_pos):
|
| 134 |
+
"""
|
| 135 |
+
Get relative positional embeddings according to the relative positions of query and key sizes.
|
| 136 |
+
|
| 137 |
+
Args:
|
| 138 |
+
q_size (`int`):
|
| 139 |
+
Size of query q.
|
| 140 |
+
k_size (`int`):
|
| 141 |
+
Size of key k.
|
| 142 |
+
rel_pos (`torch.Tensor`):
|
| 143 |
+
Relative position embeddings (num_embeddings, num_channels).
|
| 144 |
+
|
| 145 |
+
Returns:
|
| 146 |
+
Extracted positional embeddings according to relative positions.
|
| 147 |
+
"""
|
| 148 |
+
max_rel_dist = int(2 * max(q_size, k_size) - 1)
|
| 149 |
+
# Interpolate rel pos if needed.
|
| 150 |
+
if rel_pos.shape[0] != max_rel_dist:
|
| 151 |
+
# Interpolate rel position embeddings.
|
| 152 |
+
rel_pos_resized = nn.functional.interpolate(
|
| 153 |
+
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
|
| 154 |
+
size=max_rel_dist,
|
| 155 |
+
mode="linear",
|
| 156 |
+
)
|
| 157 |
+
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
|
| 158 |
+
else:
|
| 159 |
+
rel_pos_resized = rel_pos
|
| 160 |
+
|
| 161 |
+
# Scale the coords with short length if shapes for q and k are different.
|
| 162 |
+
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
|
| 163 |
+
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
|
| 164 |
+
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
|
| 165 |
+
|
| 166 |
+
return rel_pos_resized[relative_coords.long()]
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def add_decomposed_relative_positions(attn, queries, rel_pos_h, rel_pos_w, q_size, k_size):
|
| 170 |
+
"""
|
| 171 |
+
Calculate decomposed Relative Positional Embeddings as introduced in
|
| 172 |
+
[MViT2](https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py).
|
| 173 |
+
|
| 174 |
+
Args:
|
| 175 |
+
attn (`torch.Tensor`):
|
| 176 |
+
Attention map.
|
| 177 |
+
queries (`torch.Tensor`):
|
| 178 |
+
Query q in the attention layer with shape (batch_size, queries_height * queries_width, num_channels).
|
| 179 |
+
rel_pos_h (`torch.Tensor`):
|
| 180 |
+
Relative position embeddings (Lh, num_channels) for height axis.
|
| 181 |
+
rel_pos_w (`torch.Tensor`):
|
| 182 |
+
Relative position embeddings (Lw, num_channels) for width axis.
|
| 183 |
+
q_size (`Tuple[int]`):
|
| 184 |
+
Spatial sequence size of query q with (queries_height, queries_width).
|
| 185 |
+
k_size (`Tuple[int]`):
|
| 186 |
+
Spatial sequence size of key k with (keys_height, keys_width).
|
| 187 |
+
|
| 188 |
+
Returns:
|
| 189 |
+
attn (Tensor): attention map with added relative positional embeddings.
|
| 190 |
+
"""
|
| 191 |
+
queries_height, queries_width = q_size
|
| 192 |
+
keys_height, keys_width = k_size
|
| 193 |
+
relative_height = get_rel_pos(queries_height, keys_height, rel_pos_h)
|
| 194 |
+
relative_width = get_rel_pos(queries_width, keys_width, rel_pos_w)
|
| 195 |
+
|
| 196 |
+
batch_size, _, dim = queries.shape
|
| 197 |
+
r_q = queries.reshape(batch_size, queries_height, queries_width, dim)
|
| 198 |
+
relative_height = torch.einsum("bhwc,hkc->bhwk", r_q, relative_height)
|
| 199 |
+
relative_weight = torch.einsum("bhwc,wkc->bhwk", r_q, relative_width)
|
| 200 |
+
|
| 201 |
+
attn = (
|
| 202 |
+
attn.view(batch_size, queries_height, queries_width, keys_height, keys_width)
|
| 203 |
+
+ relative_height[:, :, :, :, None]
|
| 204 |
+
+ relative_weight[:, :, :, None, :]
|
| 205 |
+
).view(batch_size, queries_height * queries_width, keys_height * keys_width)
|
| 206 |
+
|
| 207 |
+
return attn
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
class VitDetAttention(nn.Module):
|
| 211 |
+
"""Multi-head Attention block with relative position embeddings."""
|
| 212 |
+
|
| 213 |
+
def __init__(self, config, input_size=None):
|
| 214 |
+
"""
|
| 215 |
+
Args:
|
| 216 |
+
config (`VitDetConfig`):
|
| 217 |
+
Model configuration.
|
| 218 |
+
input_size (`Tuple[int]`, *optional*):
|
| 219 |
+
Input resolution, only required in case relative position embeddings are added.
|
| 220 |
+
"""
|
| 221 |
+
super().__init__()
|
| 222 |
+
|
| 223 |
+
dim = config.hidden_size
|
| 224 |
+
num_heads = config.num_attention_heads
|
| 225 |
+
|
| 226 |
+
self.num_heads = num_heads
|
| 227 |
+
head_dim = dim // num_heads
|
| 228 |
+
self.scale = head_dim**-0.5
|
| 229 |
+
|
| 230 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=config.qkv_bias)
|
| 231 |
+
self.proj = nn.Linear(dim, dim)
|
| 232 |
+
|
| 233 |
+
self.use_relative_position_embeddings = config.use_relative_position_embeddings
|
| 234 |
+
if self.use_relative_position_embeddings:
|
| 235 |
+
# initialize relative positional embeddings
|
| 236 |
+
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
|
| 237 |
+
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
|
| 238 |
+
|
| 239 |
+
def forward(self, hidden_state, output_attentions=False):
|
| 240 |
+
batch_size, height, width, _ = hidden_state.shape
|
| 241 |
+
# qkv with shape (3, batch_size, num_heads, height * width, num_channels)
|
| 242 |
+
qkv = self.qkv(hidden_state).reshape(batch_size, height * width, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
| 243 |
+
# queries, keys and values have shape (batch_size * num_heads, height * width, num_channels)
|
| 244 |
+
queries, keys, values = qkv.reshape(3, batch_size * self.num_heads, height * width, -1).unbind(0)
|
| 245 |
+
|
| 246 |
+
attention_scores = (queries * self.scale) @ keys.transpose(-2, -1)
|
| 247 |
+
|
| 248 |
+
if self.use_relative_position_embeddings:
|
| 249 |
+
attention_scores = add_decomposed_relative_positions(
|
| 250 |
+
attention_scores, queries, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
attention_probs = attention_scores.softmax(dim=-1)
|
| 254 |
+
|
| 255 |
+
hidden_state = attention_probs @ values
|
| 256 |
+
hidden_state = hidden_state.view(batch_size, self.num_heads, height, width, -1)
|
| 257 |
+
hidden_state = hidden_state.permute(0, 2, 3, 1, 4)
|
| 258 |
+
hidden_state = hidden_state.reshape(batch_size, height, width, -1)
|
| 259 |
+
hidden_state = self.proj(hidden_state)
|
| 260 |
+
|
| 261 |
+
if output_attentions:
|
| 262 |
+
attention_probs = attention_probs.reshape(
|
| 263 |
+
batch_size, self.num_heads, attention_probs.shape[-2], attention_probs.shape[-1]
|
| 264 |
+
)
|
| 265 |
+
outputs = (hidden_state, attention_probs)
|
| 266 |
+
else:
|
| 267 |
+
outputs = (hidden_state,)
|
| 268 |
+
|
| 269 |
+
return outputs
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
# Copied from transformers.models.beit.modeling_beit.drop_path
|
| 273 |
+
def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
|
| 274 |
+
"""
|
| 275 |
+
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 276 |
+
|
| 277 |
+
Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
|
| 278 |
+
however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
| 279 |
+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
|
| 280 |
+
layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
|
| 281 |
+
argument.
|
| 282 |
+
"""
|
| 283 |
+
if drop_prob == 0.0 or not training:
|
| 284 |
+
return input
|
| 285 |
+
keep_prob = 1 - drop_prob
|
| 286 |
+
shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
| 287 |
+
random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
|
| 288 |
+
random_tensor.floor_() # binarize
|
| 289 |
+
output = input.div(keep_prob) * random_tensor
|
| 290 |
+
return output
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
# Copied from transformers.models.beit.modeling_beit.BeitDropPath
|
| 294 |
+
class VitDetDropPath(nn.Module):
|
| 295 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
| 296 |
+
|
| 297 |
+
def __init__(self, drop_prob: Optional[float] = None) -> None:
|
| 298 |
+
super().__init__()
|
| 299 |
+
self.drop_prob = drop_prob
|
| 300 |
+
|
| 301 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 302 |
+
return drop_path(hidden_states, self.drop_prob, self.training)
|
| 303 |
+
|
| 304 |
+
def extra_repr(self) -> str:
|
| 305 |
+
return "p={}".format(self.drop_prob)
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
class VitDetLayerNorm(nn.Module):
|
| 309 |
+
"""
|
| 310 |
+
A LayerNorm variant, popularized by Transformers, that performs point-wise mean and variance normalization over the
|
| 311 |
+
channel dimension for inputs that have shape (batch_size, channels, height, width).
|
| 312 |
+
https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119
|
| 313 |
+
"""
|
| 314 |
+
|
| 315 |
+
def __init__(self, normalized_shape, eps=1e-6):
|
| 316 |
+
super().__init__()
|
| 317 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
| 318 |
+
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
| 319 |
+
self.eps = eps
|
| 320 |
+
self.normalized_shape = (normalized_shape,)
|
| 321 |
+
|
| 322 |
+
def forward(self, x):
|
| 323 |
+
u = x.mean(1, keepdim=True)
|
| 324 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
| 325 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
| 326 |
+
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
| 327 |
+
return x
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
class VitDetResBottleneckBlock(nn.Module):
|
| 331 |
+
"""
|
| 332 |
+
The standard bottleneck residual block without the last activation layer. It contains 3 conv layers with kernels
|
| 333 |
+
1x1, 3x3, 1x1.
|
| 334 |
+
"""
|
| 335 |
+
|
| 336 |
+
def __init__(self, config, in_channels, out_channels, bottleneck_channels):
|
| 337 |
+
"""
|
| 338 |
+
Args:
|
| 339 |
+
config (`VitDetConfig`):
|
| 340 |
+
Model configuration.
|
| 341 |
+
in_channels (`int`):
|
| 342 |
+
Number of input channels.
|
| 343 |
+
out_channels (`int`):
|
| 344 |
+
Number of output channels.
|
| 345 |
+
bottleneck_channels (`int`):
|
| 346 |
+
Number of output channels for the 3x3 "bottleneck" conv layers.
|
| 347 |
+
"""
|
| 348 |
+
super().__init__()
|
| 349 |
+
self.conv1 = nn.Conv2d(in_channels, bottleneck_channels, 1, bias=False)
|
| 350 |
+
self.norm1 = VitDetLayerNorm(bottleneck_channels)
|
| 351 |
+
self.act1 = ACT2FN[config.hidden_act]
|
| 352 |
+
|
| 353 |
+
self.conv2 = nn.Conv2d(bottleneck_channels, bottleneck_channels, 3, padding=1, bias=False)
|
| 354 |
+
self.norm2 = VitDetLayerNorm(bottleneck_channels)
|
| 355 |
+
self.act2 = ACT2FN[config.hidden_act]
|
| 356 |
+
|
| 357 |
+
self.conv3 = nn.Conv2d(bottleneck_channels, out_channels, 1, bias=False)
|
| 358 |
+
self.norm3 = VitDetLayerNorm(out_channels)
|
| 359 |
+
|
| 360 |
+
def forward(self, x):
|
| 361 |
+
out = x
|
| 362 |
+
for layer in self.children():
|
| 363 |
+
out = layer(out)
|
| 364 |
+
|
| 365 |
+
out = x + out
|
| 366 |
+
return out
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
class VitDetMlp(nn.Module):
|
| 370 |
+
def __init__(self, config, in_features: int, hidden_features: int) -> None:
|
| 371 |
+
super().__init__()
|
| 372 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 373 |
+
self.act = ACT2FN[config.hidden_act]
|
| 374 |
+
self.fc2 = nn.Linear(hidden_features, in_features)
|
| 375 |
+
self.drop = nn.Dropout(config.dropout_prob)
|
| 376 |
+
|
| 377 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 378 |
+
x = self.fc1(x)
|
| 379 |
+
x = self.act(x)
|
| 380 |
+
x = self.drop(x)
|
| 381 |
+
x = self.fc2(x)
|
| 382 |
+
x = self.drop(x)
|
| 383 |
+
|
| 384 |
+
return x
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
def window_partition(hidden_state, window_size):
|
| 388 |
+
"""
|
| 389 |
+
Partition into non-overlapping windows with padding if needed.
|
| 390 |
+
|
| 391 |
+
Args:
|
| 392 |
+
hidden_state (`torch.Tensor`):
|
| 393 |
+
Input tokens with [batch_size, height, width, num_channels].
|
| 394 |
+
window_size (`int`):
|
| 395 |
+
Window size.
|
| 396 |
+
|
| 397 |
+
Returns:
|
| 398 |
+
`tuple(torch.FloatTensor)` comprising various elements:
|
| 399 |
+
- windows: windows after partition with [batch_size * num_windows, window_size, window_size, num_channels].
|
| 400 |
+
- (padded_height, padded_width): padded height and width before partition
|
| 401 |
+
"""
|
| 402 |
+
batch_size, height, width, num_channels = hidden_state.shape
|
| 403 |
+
|
| 404 |
+
pad_height = (window_size - height % window_size) % window_size
|
| 405 |
+
pad_width = (window_size - width % window_size) % window_size
|
| 406 |
+
|
| 407 |
+
# Noop in case pad_width == 0 and pad_height == 0.
|
| 408 |
+
hidden_state = nn.functional.pad(hidden_state, (0, 0, 0, pad_width, 0, pad_height))
|
| 409 |
+
|
| 410 |
+
padded_height, padded_width = height + pad_height, width + pad_width
|
| 411 |
+
|
| 412 |
+
hidden_state = hidden_state.view(
|
| 413 |
+
batch_size, padded_height // window_size, window_size, padded_width // window_size, window_size, num_channels
|
| 414 |
+
)
|
| 415 |
+
windows = hidden_state.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels)
|
| 416 |
+
return windows, (padded_height, padded_width)
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
def window_unpartition(windows, window_size, pad_height_width, height_width):
|
| 420 |
+
"""
|
| 421 |
+
Window unpartition into original sequences and removing padding.
|
| 422 |
+
|
| 423 |
+
Args:
|
| 424 |
+
windows (`torch.Tensor`):
|
| 425 |
+
Input tokens with [batch_size * num_windows, window_size, window_size, num_channels].
|
| 426 |
+
window_size (`int`):
|
| 427 |
+
Window size.
|
| 428 |
+
pad_height_width (`Tuple[int]`):
|
| 429 |
+
Padded height and width (padded_height, padded_width).
|
| 430 |
+
height_width (`Tuple[int]`):
|
| 431 |
+
Original height and width before padding.
|
| 432 |
+
|
| 433 |
+
Returns:
|
| 434 |
+
hidden_state: unpartitioned sequences with [batch_size, height, width, num_channels].
|
| 435 |
+
"""
|
| 436 |
+
padded_height, padded_width = pad_height_width
|
| 437 |
+
height, width = height_width
|
| 438 |
+
batch_size = windows.shape[0] // (padded_height * padded_width // window_size // window_size)
|
| 439 |
+
hidden_state = windows.view(
|
| 440 |
+
batch_size, padded_height // window_size, padded_width // window_size, window_size, window_size, -1
|
| 441 |
+
)
|
| 442 |
+
hidden_state = hidden_state.permute(0, 1, 3, 2, 4, 5).contiguous()
|
| 443 |
+
hidden_state = hidden_state.view(batch_size, padded_height, padded_width, -1)
|
| 444 |
+
|
| 445 |
+
# We always have height <= padded_height and width <= padded_width
|
| 446 |
+
hidden_state = hidden_state[:, :height, :width, :].contiguous()
|
| 447 |
+
return hidden_state
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
class VitDetLayer(nn.Module):
|
| 451 |
+
"""This corresponds to the Block class in the original implementation."""
|
| 452 |
+
|
| 453 |
+
def __init__(
|
| 454 |
+
self, config: VitDetConfig, drop_path_rate: float = 0, window_size: int = 0, use_residual_block: bool = False
|
| 455 |
+
) -> None:
|
| 456 |
+
super().__init__()
|
| 457 |
+
|
| 458 |
+
dim = config.hidden_size
|
| 459 |
+
|
| 460 |
+
image_size = config.image_size
|
| 461 |
+
image_size = image_size if isinstance(image_size, (list, tuple)) else (image_size, image_size)
|
| 462 |
+
|
| 463 |
+
patch_size = config.patch_size
|
| 464 |
+
patch_size = patch_size if isinstance(patch_size, (list, tuple)) else (patch_size, patch_size)
|
| 465 |
+
|
| 466 |
+
input_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
|
| 467 |
+
self.norm1 = nn.LayerNorm(dim, eps=config.layer_norm_eps)
|
| 468 |
+
self.attention = VitDetAttention(
|
| 469 |
+
config, input_size=input_size if window_size == 0 else (window_size, window_size)
|
| 470 |
+
)
|
| 471 |
+
|
| 472 |
+
self.drop_path = VitDetDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
|
| 473 |
+
self.norm2 = nn.LayerNorm(dim, eps=config.layer_norm_eps)
|
| 474 |
+
self.mlp = VitDetMlp(config=config, in_features=dim, hidden_features=int(dim * config.mlp_ratio))
|
| 475 |
+
|
| 476 |
+
self.window_size = window_size
|
| 477 |
+
|
| 478 |
+
self.use_residual_block = use_residual_block
|
| 479 |
+
if self.use_residual_block:
|
| 480 |
+
# Use a residual block with bottleneck channel as dim // 2
|
| 481 |
+
self.residual = VitDetResBottleneckBlock(
|
| 482 |
+
config=config,
|
| 483 |
+
in_channels=dim,
|
| 484 |
+
out_channels=dim,
|
| 485 |
+
bottleneck_channels=dim // 2,
|
| 486 |
+
)
|
| 487 |
+
|
| 488 |
+
def forward(
|
| 489 |
+
self,
|
| 490 |
+
hidden_states: torch.Tensor,
|
| 491 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 492 |
+
output_attentions: bool = False,
|
| 493 |
+
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
| 494 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1)
|
| 495 |
+
|
| 496 |
+
shortcut = hidden_states
|
| 497 |
+
|
| 498 |
+
hidden_states = self.norm1(hidden_states)
|
| 499 |
+
|
| 500 |
+
# Window partition
|
| 501 |
+
if self.window_size > 0:
|
| 502 |
+
height, width = hidden_states.shape[1], hidden_states.shape[2]
|
| 503 |
+
hidden_states, pad_height_width = window_partition(hidden_states, self.window_size)
|
| 504 |
+
|
| 505 |
+
self_attention_outputs = self.attention(
|
| 506 |
+
hidden_states,
|
| 507 |
+
output_attentions=output_attentions,
|
| 508 |
+
)
|
| 509 |
+
hidden_states = self_attention_outputs[0]
|
| 510 |
+
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
| 511 |
+
|
| 512 |
+
# Reverse window partition
|
| 513 |
+
if self.window_size > 0:
|
| 514 |
+
hidden_states = window_unpartition(hidden_states, self.window_size, pad_height_width, (height, width))
|
| 515 |
+
|
| 516 |
+
# first residual connection
|
| 517 |
+
hidden_states = shortcut + self.drop_path(hidden_states)
|
| 518 |
+
|
| 519 |
+
hidden_states = hidden_states + self.drop_path(self.mlp(self.norm2(hidden_states)))
|
| 520 |
+
|
| 521 |
+
hidden_states = hidden_states.permute(0, 3, 1, 2)
|
| 522 |
+
|
| 523 |
+
if self.use_residual_block:
|
| 524 |
+
hidden_states = self.residual(hidden_states)
|
| 525 |
+
|
| 526 |
+
outputs = (hidden_states,) + outputs
|
| 527 |
+
|
| 528 |
+
return outputs
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
class VitDetEncoder(nn.Module):
|
| 532 |
+
def __init__(self, config: VitDetConfig) -> None:
|
| 533 |
+
super().__init__()
|
| 534 |
+
self.config = config
|
| 535 |
+
depth = config.num_hidden_layers
|
| 536 |
+
|
| 537 |
+
# stochastic depth decay rule
|
| 538 |
+
drop_path_rate = [x.item() for x in torch.linspace(0, config.drop_path_rate, depth, device="cpu")]
|
| 539 |
+
|
| 540 |
+
layers = []
|
| 541 |
+
for i in range(depth):
|
| 542 |
+
layers.append(
|
| 543 |
+
VitDetLayer(
|
| 544 |
+
config,
|
| 545 |
+
drop_path_rate=drop_path_rate[i],
|
| 546 |
+
window_size=config.window_size if i in config.window_block_indices else 0,
|
| 547 |
+
use_residual_block=i in config.residual_block_indices,
|
| 548 |
+
)
|
| 549 |
+
)
|
| 550 |
+
|
| 551 |
+
self.layer = nn.ModuleList(layers)
|
| 552 |
+
self.gradient_checkpointing = False
|
| 553 |
+
|
| 554 |
+
def forward(
|
| 555 |
+
self,
|
| 556 |
+
hidden_states: torch.Tensor,
|
| 557 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 558 |
+
output_attentions: bool = False,
|
| 559 |
+
output_hidden_states: bool = False,
|
| 560 |
+
return_dict: bool = True,
|
| 561 |
+
) -> Union[tuple, BaseModelOutput]:
|
| 562 |
+
all_hidden_states = () if output_hidden_states else None
|
| 563 |
+
all_self_attentions = () if output_attentions else None
|
| 564 |
+
|
| 565 |
+
for i, layer_module in enumerate(self.layer):
|
| 566 |
+
if output_hidden_states:
|
| 567 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 568 |
+
|
| 569 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
| 570 |
+
|
| 571 |
+
if self.gradient_checkpointing and self.training:
|
| 572 |
+
layer_outputs = self._gradient_checkpointing_func(
|
| 573 |
+
layer_module.__call__,
|
| 574 |
+
hidden_states,
|
| 575 |
+
layer_head_mask,
|
| 576 |
+
output_attentions,
|
| 577 |
+
)
|
| 578 |
+
else:
|
| 579 |
+
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
|
| 580 |
+
|
| 581 |
+
hidden_states = layer_outputs[0]
|
| 582 |
+
|
| 583 |
+
if output_attentions:
|
| 584 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
| 585 |
+
|
| 586 |
+
if output_hidden_states:
|
| 587 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 588 |
+
|
| 589 |
+
if not return_dict:
|
| 590 |
+
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
|
| 591 |
+
return BaseModelOutput(
|
| 592 |
+
last_hidden_state=hidden_states,
|
| 593 |
+
hidden_states=all_hidden_states,
|
| 594 |
+
attentions=all_self_attentions,
|
| 595 |
+
)
|
| 596 |
+
|
| 597 |
+
|
| 598 |
+
def caffe2_msra_fill(module: nn.Module) -> None:
|
| 599 |
+
"""
|
| 600 |
+
Initialize `module.weight` using the "MSRAFill" implemented in Caffe2. Also initializes `module.bias` to 0.
|
| 601 |
+
|
| 602 |
+
Source: https://detectron2.readthedocs.io/en/latest/_modules/fvcore/nn/weight_init.html.
|
| 603 |
+
|
| 604 |
+
Args:
|
| 605 |
+
module (torch.nn.Module): module to initialize.
|
| 606 |
+
"""
|
| 607 |
+
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
|
| 608 |
+
if module.bias is not None:
|
| 609 |
+
nn.init.constant_(module.bias, 0)
|
| 610 |
+
|
| 611 |
+
|
| 612 |
+
class VitDetPreTrainedModel(PreTrainedModel):
|
| 613 |
+
"""
|
| 614 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 615 |
+
models.
|
| 616 |
+
"""
|
| 617 |
+
|
| 618 |
+
config_class = VitDetConfig
|
| 619 |
+
base_model_prefix = "vitdet"
|
| 620 |
+
main_input_name = "pixel_values"
|
| 621 |
+
supports_gradient_checkpointing = True
|
| 622 |
+
_no_split_modules = []
|
| 623 |
+
|
| 624 |
+
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
|
| 625 |
+
"""Initialize the weights"""
|
| 626 |
+
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
| 627 |
+
# Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
|
| 628 |
+
# `trunc_normal_cpu` not implemented in `half` issues
|
| 629 |
+
module.weight.data = nn.init.trunc_normal_(
|
| 630 |
+
module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
|
| 631 |
+
).to(module.weight.dtype)
|
| 632 |
+
if module.bias is not None:
|
| 633 |
+
module.bias.data.zero_()
|
| 634 |
+
elif isinstance(module, nn.LayerNorm):
|
| 635 |
+
module.bias.data.zero_()
|
| 636 |
+
module.weight.data.fill_(1.0)
|
| 637 |
+
|
| 638 |
+
elif isinstance(module, VitDetEmbeddings):
|
| 639 |
+
module.position_embeddings.data = nn.init.trunc_normal_(
|
| 640 |
+
module.position_embeddings.data.to(torch.float32),
|
| 641 |
+
mean=0.0,
|
| 642 |
+
std=self.config.initializer_range,
|
| 643 |
+
).to(module.position_embeddings.dtype)
|
| 644 |
+
|
| 645 |
+
elif isinstance(module, VitDetAttention) and self.config.use_relative_position_embeddings:
|
| 646 |
+
module.rel_pos_h.data = nn.init.trunc_normal_(
|
| 647 |
+
module.rel_pos_h.data.to(torch.float32),
|
| 648 |
+
mean=0.0,
|
| 649 |
+
std=self.config.initializer_range,
|
| 650 |
+
)
|
| 651 |
+
module.rel_pos_w.data = nn.init.trunc_normal_(
|
| 652 |
+
module.rel_pos_w.data.to(torch.float32),
|
| 653 |
+
mean=0.0,
|
| 654 |
+
std=self.config.initializer_range,
|
| 655 |
+
)
|
| 656 |
+
|
| 657 |
+
elif isinstance(module, VitDetResBottleneckBlock):
|
| 658 |
+
for layer in [module.conv1, module.conv2, module.conv3]:
|
| 659 |
+
caffe2_msra_fill(layer)
|
| 660 |
+
for layer in [module.norm1, module.norm2]:
|
| 661 |
+
layer.weight.data.fill_(1.0)
|
| 662 |
+
layer.bias.data.zero_()
|
| 663 |
+
# zero init last norm layer.
|
| 664 |
+
module.norm3.weight.data.zero_()
|
| 665 |
+
module.norm3.bias.data.zero_()
|
| 666 |
+
|
| 667 |
+
|
| 668 |
+
VITDET_START_DOCSTRING = r"""
|
| 669 |
+
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
|
| 670 |
+
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
| 671 |
+
behavior.
|
| 672 |
+
|
| 673 |
+
Parameters:
|
| 674 |
+
config ([`VitDetConfig`]): Model configuration class with all the parameters of the model.
|
| 675 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 676 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 677 |
+
"""
|
| 678 |
+
|
| 679 |
+
VITDET_INPUTS_DOCSTRING = r"""
|
| 680 |
+
Args:
|
| 681 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
| 682 |
+
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`]
|
| 683 |
+
for details.
|
| 684 |
+
|
| 685 |
+
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
| 686 |
+
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
| 687 |
+
|
| 688 |
+
- 1 indicates the head is **not masked**,
|
| 689 |
+
- 0 indicates the head is **masked**.
|
| 690 |
+
|
| 691 |
+
output_attentions (`bool`, *optional*):
|
| 692 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 693 |
+
tensors for more detail.
|
| 694 |
+
output_hidden_states (`bool`, *optional*):
|
| 695 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 696 |
+
more detail.
|
| 697 |
+
return_dict (`bool`, *optional*):
|
| 698 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 699 |
+
"""
|
| 700 |
+
|
| 701 |
+
|
| 702 |
+
@add_start_docstrings(
|
| 703 |
+
"The bare VitDet Transformer model outputting raw hidden-states without any specific head on top.",
|
| 704 |
+
VITDET_START_DOCSTRING,
|
| 705 |
+
)
|
| 706 |
+
class VitDetModel(VitDetPreTrainedModel):
|
| 707 |
+
def __init__(self, config: VitDetConfig):
|
| 708 |
+
super().__init__(config)
|
| 709 |
+
self.config = config
|
| 710 |
+
|
| 711 |
+
self.embeddings = VitDetEmbeddings(config)
|
| 712 |
+
self.encoder = VitDetEncoder(config)
|
| 713 |
+
|
| 714 |
+
# Initialize weights and apply final processing
|
| 715 |
+
self.post_init()
|
| 716 |
+
|
| 717 |
+
def get_input_embeddings(self) -> VitDetEmbeddings:
|
| 718 |
+
return self.embeddings.projection
|
| 719 |
+
|
| 720 |
+
def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
|
| 721 |
+
"""
|
| 722 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
| 723 |
+
class PreTrainedModel
|
| 724 |
+
"""
|
| 725 |
+
for layer, heads in heads_to_prune.items():
|
| 726 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
| 727 |
+
|
| 728 |
+
@add_start_docstrings_to_model_forward(VITDET_INPUTS_DOCSTRING)
|
| 729 |
+
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
|
| 730 |
+
def forward(
|
| 731 |
+
self,
|
| 732 |
+
pixel_values: Optional[torch.Tensor] = None,
|
| 733 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 734 |
+
output_attentions: Optional[bool] = None,
|
| 735 |
+
output_hidden_states: Optional[bool] = None,
|
| 736 |
+
return_dict: Optional[bool] = None,
|
| 737 |
+
) -> Union[Tuple, BaseModelOutput]:
|
| 738 |
+
"""
|
| 739 |
+
Returns:
|
| 740 |
+
|
| 741 |
+
Examples:
|
| 742 |
+
|
| 743 |
+
```python
|
| 744 |
+
>>> from transformers import VitDetConfig, VitDetModel
|
| 745 |
+
>>> import torch
|
| 746 |
+
|
| 747 |
+
>>> config = VitDetConfig()
|
| 748 |
+
>>> model = VitDetModel(config)
|
| 749 |
+
|
| 750 |
+
>>> pixel_values = torch.randn(1, 3, 224, 224)
|
| 751 |
+
|
| 752 |
+
>>> with torch.no_grad():
|
| 753 |
+
... outputs = model(pixel_values)
|
| 754 |
+
|
| 755 |
+
>>> last_hidden_states = outputs.last_hidden_state
|
| 756 |
+
>>> list(last_hidden_states.shape)
|
| 757 |
+
[1, 768, 14, 14]
|
| 758 |
+
```"""
|
| 759 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 760 |
+
output_hidden_states = (
|
| 761 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 762 |
+
)
|
| 763 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 764 |
+
|
| 765 |
+
if pixel_values is None:
|
| 766 |
+
raise ValueError("You have to specify pixel_values")
|
| 767 |
+
|
| 768 |
+
# Prepare head mask if needed
|
| 769 |
+
# 1.0 in head_mask indicate we keep the head
|
| 770 |
+
# attention_probs has shape bsz x n_heads x N x N
|
| 771 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
| 772 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
| 773 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
| 774 |
+
|
| 775 |
+
embedding_output = self.embeddings(pixel_values)
|
| 776 |
+
|
| 777 |
+
encoder_outputs = self.encoder(
|
| 778 |
+
embedding_output,
|
| 779 |
+
head_mask=head_mask,
|
| 780 |
+
output_attentions=output_attentions,
|
| 781 |
+
output_hidden_states=output_hidden_states,
|
| 782 |
+
return_dict=return_dict,
|
| 783 |
+
)
|
| 784 |
+
sequence_output = encoder_outputs[0]
|
| 785 |
+
|
| 786 |
+
if not return_dict:
|
| 787 |
+
return (sequence_output,) + encoder_outputs[1:]
|
| 788 |
+
|
| 789 |
+
return BaseModelOutput(
|
| 790 |
+
last_hidden_state=sequence_output,
|
| 791 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 792 |
+
attentions=encoder_outputs.attentions,
|
| 793 |
+
)
|
| 794 |
+
|
| 795 |
+
|
| 796 |
+
@add_start_docstrings(
|
| 797 |
+
"""
|
| 798 |
+
ViTDet backbone, to be used with frameworks like Mask R-CNN.
|
| 799 |
+
""",
|
| 800 |
+
VITDET_START_DOCSTRING,
|
| 801 |
+
)
|
| 802 |
+
class VitDetBackbone(VitDetPreTrainedModel, BackboneMixin):
|
| 803 |
+
def __init__(self, config):
|
| 804 |
+
super().__init__(config)
|
| 805 |
+
super()._init_backbone(config)
|
| 806 |
+
|
| 807 |
+
self.embeddings = VitDetEmbeddings(config)
|
| 808 |
+
self.encoder = VitDetEncoder(config)
|
| 809 |
+
self.num_features = [config.hidden_size for _ in range(config.num_hidden_layers + 1)]
|
| 810 |
+
|
| 811 |
+
# initialize weights and apply final processing
|
| 812 |
+
self.post_init()
|
| 813 |
+
|
| 814 |
+
def get_input_embeddings(self) -> VitDetEmbeddings:
|
| 815 |
+
return self.embeddings.projection
|
| 816 |
+
|
| 817 |
+
@add_start_docstrings_to_model_forward(VITDET_INPUTS_DOCSTRING)
|
| 818 |
+
@replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
|
| 819 |
+
def forward(
|
| 820 |
+
self,
|
| 821 |
+
pixel_values: torch.Tensor,
|
| 822 |
+
output_hidden_states: Optional[bool] = None,
|
| 823 |
+
output_attentions: Optional[bool] = None,
|
| 824 |
+
return_dict: Optional[bool] = None,
|
| 825 |
+
) -> BackboneOutput:
|
| 826 |
+
"""
|
| 827 |
+
Returns:
|
| 828 |
+
|
| 829 |
+
Examples:
|
| 830 |
+
|
| 831 |
+
```python
|
| 832 |
+
>>> from transformers import VitDetConfig, VitDetBackbone
|
| 833 |
+
>>> import torch
|
| 834 |
+
|
| 835 |
+
>>> config = VitDetConfig()
|
| 836 |
+
>>> model = VitDetBackbone(config)
|
| 837 |
+
|
| 838 |
+
>>> pixel_values = torch.randn(1, 3, 224, 224)
|
| 839 |
+
|
| 840 |
+
>>> with torch.no_grad():
|
| 841 |
+
... outputs = model(pixel_values)
|
| 842 |
+
|
| 843 |
+
>>> feature_maps = outputs.feature_maps
|
| 844 |
+
>>> list(feature_maps[-1].shape)
|
| 845 |
+
[1, 768, 14, 14]
|
| 846 |
+
```"""
|
| 847 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 848 |
+
output_hidden_states = (
|
| 849 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 850 |
+
)
|
| 851 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 852 |
+
|
| 853 |
+
embedding_output = self.embeddings(pixel_values)
|
| 854 |
+
|
| 855 |
+
outputs = self.encoder(
|
| 856 |
+
embedding_output,
|
| 857 |
+
output_hidden_states=True,
|
| 858 |
+
output_attentions=output_attentions,
|
| 859 |
+
return_dict=return_dict,
|
| 860 |
+
)
|
| 861 |
+
|
| 862 |
+
hidden_states = outputs.hidden_states if return_dict else outputs[1]
|
| 863 |
+
|
| 864 |
+
feature_maps = ()
|
| 865 |
+
for stage, hidden_state in zip(self.stage_names, hidden_states):
|
| 866 |
+
if stage in self.out_features:
|
| 867 |
+
feature_maps += (hidden_state,)
|
| 868 |
+
|
| 869 |
+
if not return_dict:
|
| 870 |
+
if output_hidden_states:
|
| 871 |
+
output = (feature_maps,) + outputs[1:]
|
| 872 |
+
else:
|
| 873 |
+
output = (feature_maps,) + outputs[2:]
|
| 874 |
+
return output
|
| 875 |
+
|
| 876 |
+
return BackboneOutput(
|
| 877 |
+
feature_maps=feature_maps,
|
| 878 |
+
hidden_states=outputs.hidden_states if output_hidden_states else None,
|
| 879 |
+
attentions=outputs.attentions,
|
| 880 |
+
)
|
| 881 |
+
|
| 882 |
+
|
| 883 |
+
__all__ = ["VitDetModel", "VitDetPreTrainedModel", "VitDetBackbone"]
|
docs/transformers/build/lib/transformers/models/vitmatte/__init__.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_vitmatte import *
|
| 22 |
+
from .image_processing_vitmatte import *
|
| 23 |
+
from .modeling_vitmatte import *
|
| 24 |
+
else:
|
| 25 |
+
import sys
|
| 26 |
+
|
| 27 |
+
_file = globals()["__file__"]
|
| 28 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
docs/transformers/build/lib/transformers/models/vitmatte/configuration_vitmatte.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""VitMatte model configuration"""
|
| 16 |
+
|
| 17 |
+
import copy
|
| 18 |
+
from typing import List
|
| 19 |
+
|
| 20 |
+
from ...configuration_utils import PretrainedConfig
|
| 21 |
+
from ...utils import logging
|
| 22 |
+
from ...utils.backbone_utils import verify_backbone_config_arguments
|
| 23 |
+
from ..auto.configuration_auto import CONFIG_MAPPING
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
logger = logging.get_logger(__name__)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class VitMatteConfig(PretrainedConfig):
|
| 30 |
+
r"""
|
| 31 |
+
This is the configuration class to store the configuration of [`VitMatteForImageMatting`]. It is used to
|
| 32 |
+
instantiate a ViTMatte model according to the specified arguments, defining the model architecture. Instantiating a
|
| 33 |
+
configuration with the defaults will yield a similar configuration to that of the ViTMatte
|
| 34 |
+
[hustvl/vitmatte-small-composition-1k](https://huggingface.co/hustvl/vitmatte-small-composition-1k) architecture.
|
| 35 |
+
|
| 36 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 37 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
backbone_config (`PretrainedConfig` or `dict`, *optional*, defaults to `VitDetConfig()`):
|
| 41 |
+
The configuration of the backbone model.
|
| 42 |
+
backbone (`str`, *optional*):
|
| 43 |
+
Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
|
| 44 |
+
will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
|
| 45 |
+
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
|
| 46 |
+
use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
|
| 47 |
+
Whether to use pretrained weights for the backbone.
|
| 48 |
+
use_timm_backbone (`bool`, *optional*, defaults to `False`):
|
| 49 |
+
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
|
| 50 |
+
library.
|
| 51 |
+
backbone_kwargs (`dict`, *optional*):
|
| 52 |
+
Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
|
| 53 |
+
e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
|
| 54 |
+
hidden_size (`int`, *optional*, defaults to 384):
|
| 55 |
+
The number of input channels of the decoder.
|
| 56 |
+
batch_norm_eps (`float`, *optional*, defaults to 1e-05):
|
| 57 |
+
The epsilon used by the batch norm layers.
|
| 58 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 59 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 60 |
+
convstream_hidden_sizes (`List[int]`, *optional*, defaults to `[48, 96, 192]`):
|
| 61 |
+
The output channels of the ConvStream module.
|
| 62 |
+
fusion_hidden_sizes (`List[int]`, *optional*, defaults to `[256, 128, 64, 32]`):
|
| 63 |
+
The output channels of the Fusion blocks.
|
| 64 |
+
|
| 65 |
+
Example:
|
| 66 |
+
|
| 67 |
+
```python
|
| 68 |
+
>>> from transformers import VitMatteConfig, VitMatteForImageMatting
|
| 69 |
+
|
| 70 |
+
>>> # Initializing a ViTMatte hustvl/vitmatte-small-composition-1k style configuration
|
| 71 |
+
>>> configuration = VitMatteConfig()
|
| 72 |
+
|
| 73 |
+
>>> # Initializing a model (with random weights) from the hustvl/vitmatte-small-composition-1k style configuration
|
| 74 |
+
>>> model = VitMatteForImageMatting(configuration)
|
| 75 |
+
|
| 76 |
+
>>> # Accessing the model configuration
|
| 77 |
+
>>> configuration = model.config
|
| 78 |
+
```"""
|
| 79 |
+
|
| 80 |
+
model_type = "vitmatte"
|
| 81 |
+
|
| 82 |
+
def __init__(
|
| 83 |
+
self,
|
| 84 |
+
backbone_config: PretrainedConfig = None,
|
| 85 |
+
backbone=None,
|
| 86 |
+
use_pretrained_backbone=False,
|
| 87 |
+
use_timm_backbone=False,
|
| 88 |
+
backbone_kwargs=None,
|
| 89 |
+
hidden_size: int = 384,
|
| 90 |
+
batch_norm_eps: float = 1e-5,
|
| 91 |
+
initializer_range: float = 0.02,
|
| 92 |
+
convstream_hidden_sizes: List[int] = [48, 96, 192],
|
| 93 |
+
fusion_hidden_sizes: List[int] = [256, 128, 64, 32],
|
| 94 |
+
**kwargs,
|
| 95 |
+
):
|
| 96 |
+
super().__init__(**kwargs)
|
| 97 |
+
|
| 98 |
+
if backbone_config is None and backbone is None:
|
| 99 |
+
logger.info("`backbone_config` is `None`. Initializing the config with the default `VitDet` backbone.")
|
| 100 |
+
backbone_config = CONFIG_MAPPING["vitdet"](out_features=["stage4"])
|
| 101 |
+
elif isinstance(backbone_config, dict):
|
| 102 |
+
backbone_model_type = backbone_config.get("model_type")
|
| 103 |
+
config_class = CONFIG_MAPPING[backbone_model_type]
|
| 104 |
+
backbone_config = config_class.from_dict(backbone_config)
|
| 105 |
+
|
| 106 |
+
verify_backbone_config_arguments(
|
| 107 |
+
use_timm_backbone=use_timm_backbone,
|
| 108 |
+
use_pretrained_backbone=use_pretrained_backbone,
|
| 109 |
+
backbone=backbone,
|
| 110 |
+
backbone_config=backbone_config,
|
| 111 |
+
backbone_kwargs=backbone_kwargs,
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
self.backbone_config = backbone_config
|
| 115 |
+
self.backbone = backbone
|
| 116 |
+
self.use_pretrained_backbone = use_pretrained_backbone
|
| 117 |
+
self.use_timm_backbone = use_timm_backbone
|
| 118 |
+
self.backbone_kwargs = backbone_kwargs
|
| 119 |
+
self.batch_norm_eps = batch_norm_eps
|
| 120 |
+
self.hidden_size = hidden_size
|
| 121 |
+
self.initializer_range = initializer_range
|
| 122 |
+
self.convstream_hidden_sizes = convstream_hidden_sizes
|
| 123 |
+
self.fusion_hidden_sizes = fusion_hidden_sizes
|
| 124 |
+
|
| 125 |
+
def to_dict(self):
|
| 126 |
+
"""
|
| 127 |
+
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. Returns:
|
| 128 |
+
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
|
| 129 |
+
"""
|
| 130 |
+
output = copy.deepcopy(self.__dict__)
|
| 131 |
+
output["backbone_config"] = self.backbone_config.to_dict()
|
| 132 |
+
output["model_type"] = self.__class__.model_type
|
| 133 |
+
return output
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
__all__ = ["VitMatteConfig"]
|
docs/transformers/build/lib/transformers/models/vitmatte/convert_vitmatte_to_hf.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Convert VitMatte checkpoints from the original repository.
|
| 16 |
+
|
| 17 |
+
URL: https://github.com/hustvl/ViTMatte
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import argparse
|
| 21 |
+
|
| 22 |
+
import requests
|
| 23 |
+
import torch
|
| 24 |
+
from huggingface_hub import hf_hub_download
|
| 25 |
+
from PIL import Image
|
| 26 |
+
|
| 27 |
+
from transformers import VitDetConfig, VitMatteConfig, VitMatteForImageMatting, VitMatteImageProcessor
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def get_config(model_name):
|
| 31 |
+
hidden_size = 384 if "small" in model_name else 768
|
| 32 |
+
num_attention_heads = 6 if "small" in model_name else 12
|
| 33 |
+
|
| 34 |
+
backbone_config = VitDetConfig(
|
| 35 |
+
num_channels=4,
|
| 36 |
+
image_size=512,
|
| 37 |
+
pretrain_image_size=224,
|
| 38 |
+
patch_size=16,
|
| 39 |
+
hidden_size=hidden_size,
|
| 40 |
+
num_attention_heads=num_attention_heads,
|
| 41 |
+
use_absolute_position_embeddings=True,
|
| 42 |
+
use_relative_position_embeddings=True,
|
| 43 |
+
window_size=14,
|
| 44 |
+
# 2, 5, 8, 11 for global attention
|
| 45 |
+
window_block_indices=[0, 1, 3, 4, 6, 7, 9, 10],
|
| 46 |
+
residual_block_indices=[2, 5, 8, 11],
|
| 47 |
+
out_features=["stage12"],
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
return VitMatteConfig(backbone_config=backbone_config, hidden_size=hidden_size)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# here we list all keys to be renamed (original name on the left, our name on the right)
|
| 54 |
+
def create_rename_keys(config):
|
| 55 |
+
rename_keys = []
|
| 56 |
+
|
| 57 |
+
# fmt: off
|
| 58 |
+
# stem
|
| 59 |
+
rename_keys.append(("backbone.pos_embed", "backbone.embeddings.position_embeddings"))
|
| 60 |
+
rename_keys.append(("backbone.patch_embed.proj.weight", "backbone.embeddings.projection.weight"))
|
| 61 |
+
rename_keys.append(("backbone.patch_embed.proj.bias", "backbone.embeddings.projection.bias"))
|
| 62 |
+
# fmt: on
|
| 63 |
+
|
| 64 |
+
return rename_keys
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def rename_key(dct, old, new):
|
| 68 |
+
val = dct.pop(old)
|
| 69 |
+
dct[new] = val
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def convert_vitmatte_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub):
|
| 73 |
+
config = get_config(model_name)
|
| 74 |
+
|
| 75 |
+
# load original state dict
|
| 76 |
+
model_name_to_filename = {
|
| 77 |
+
"vitmatte-small-composition-1k": "ViTMatte_S_Com.pth",
|
| 78 |
+
"vitmatte-base-composition-1k": "ViTMatte_B_Com.pth",
|
| 79 |
+
"vitmatte-small-distinctions-646": "ViTMatte_S_DIS.pth",
|
| 80 |
+
"vitmatte-base-distinctions-646": "ViTMatte_B_DIS.pth",
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
filename = model_name_to_filename[model_name]
|
| 84 |
+
filepath = hf_hub_download(repo_id="nielsr/vitmatte-checkpoints", filename=filename, repo_type="model")
|
| 85 |
+
state_dict = torch.load(filepath, map_location="cpu", weights_only=True)
|
| 86 |
+
|
| 87 |
+
# rename keys
|
| 88 |
+
for key in state_dict.copy().keys():
|
| 89 |
+
val = state_dict.pop(key)
|
| 90 |
+
if "backbone.blocks" in key:
|
| 91 |
+
key = key.replace("backbone.blocks", "backbone.encoder.layer")
|
| 92 |
+
if "attn" in key:
|
| 93 |
+
key = key.replace("attn", "attention")
|
| 94 |
+
if "fusion_blks" in key:
|
| 95 |
+
key = key.replace("fusion_blks", "fusion_blocks")
|
| 96 |
+
if "bn" in key:
|
| 97 |
+
key = key.replace("bn", "batch_norm")
|
| 98 |
+
state_dict[key] = val
|
| 99 |
+
|
| 100 |
+
# rename keys
|
| 101 |
+
rename_keys = create_rename_keys(config)
|
| 102 |
+
for src, dest in rename_keys:
|
| 103 |
+
rename_key(state_dict, src, dest)
|
| 104 |
+
|
| 105 |
+
# create model
|
| 106 |
+
processor = VitMatteImageProcessor()
|
| 107 |
+
model = VitMatteForImageMatting(config)
|
| 108 |
+
model.eval()
|
| 109 |
+
|
| 110 |
+
# load state dict
|
| 111 |
+
model.load_state_dict(state_dict)
|
| 112 |
+
|
| 113 |
+
# verify on dummy image + trimap
|
| 114 |
+
url = "https://github.com/hustvl/ViTMatte/blob/main/demo/bulb_rgb.png?raw=true"
|
| 115 |
+
image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
|
| 116 |
+
url = "https://github.com/hustvl/ViTMatte/blob/main/demo/bulb_trimap.png?raw=true"
|
| 117 |
+
trimap = Image.open(requests.get(url, stream=True).raw)
|
| 118 |
+
|
| 119 |
+
pixel_values = processor(images=image, trimaps=trimap.convert("L"), return_tensors="pt").pixel_values
|
| 120 |
+
|
| 121 |
+
with torch.no_grad():
|
| 122 |
+
alphas = model(pixel_values).alphas
|
| 123 |
+
|
| 124 |
+
if model_name == "vitmatte-small-composition-1k":
|
| 125 |
+
expected_slice = torch.tensor([[0.9977, 0.9987, 0.9990], [0.9980, 0.9998, 0.9998], [0.9983, 0.9998, 0.9998]])
|
| 126 |
+
elif model_name == "vitmatte-base-composition-1k":
|
| 127 |
+
expected_slice = torch.tensor([[0.9972, 0.9971, 0.9981], [0.9948, 0.9987, 0.9994], [0.9963, 0.9992, 0.9995]])
|
| 128 |
+
elif model_name == "vitmatte-small-distinctions-646":
|
| 129 |
+
expected_slice = torch.tensor([[0.9880, 0.9970, 0.9972], [0.9960, 0.9996, 0.9997], [0.9963, 0.9996, 0.9997]])
|
| 130 |
+
elif model_name == "vitmatte-base-distinctions-646":
|
| 131 |
+
expected_slice = torch.tensor([[0.9963, 0.9998, 0.9999], [0.9995, 1.0000, 1.0000], [0.9992, 0.9999, 1.0000]])
|
| 132 |
+
|
| 133 |
+
assert torch.allclose(alphas[0, 0, :3, :3], expected_slice, atol=1e-4)
|
| 134 |
+
print("Looks ok!")
|
| 135 |
+
|
| 136 |
+
if pytorch_dump_folder_path is not None:
|
| 137 |
+
print(f"Saving model and processor of {model_name} to {pytorch_dump_folder_path}")
|
| 138 |
+
model.save_pretrained(pytorch_dump_folder_path)
|
| 139 |
+
processor.save_pretrained(pytorch_dump_folder_path)
|
| 140 |
+
|
| 141 |
+
if push_to_hub:
|
| 142 |
+
print(f"Pushing model and processor for {model_name} to hub")
|
| 143 |
+
model.push_to_hub(f"hustvl/{model_name}")
|
| 144 |
+
processor.push_to_hub(f"hustvl/{model_name}")
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
if __name__ == "__main__":
|
| 148 |
+
parser = argparse.ArgumentParser()
|
| 149 |
+
# Required parameters
|
| 150 |
+
parser.add_argument(
|
| 151 |
+
"--model_name",
|
| 152 |
+
default="vitmatte-small-composition-1k",
|
| 153 |
+
type=str,
|
| 154 |
+
choices=[
|
| 155 |
+
"vitmatte-small-composition-1k",
|
| 156 |
+
"vitmatte-base-composition-1k",
|
| 157 |
+
"vitmatte-small-distinctions-646",
|
| 158 |
+
"vitmatte-base-distinctions-646",
|
| 159 |
+
],
|
| 160 |
+
help="Name of the VitMatte model you'd like to convert.",
|
| 161 |
+
)
|
| 162 |
+
parser.add_argument(
|
| 163 |
+
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
|
| 164 |
+
)
|
| 165 |
+
parser.add_argument(
|
| 166 |
+
"--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
args = parser.parse_args()
|
| 170 |
+
convert_vitmatte_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub)
|
docs/transformers/build/lib/transformers/models/vitmatte/image_processing_vitmatte.py
ADDED
|
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Image processor class for ViTMatte."""
|
| 16 |
+
|
| 17 |
+
from typing import List, Optional, Union
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
|
| 21 |
+
from ...image_processing_utils import BaseImageProcessor, BatchFeature
|
| 22 |
+
from ...image_transforms import pad, to_channel_dimension_format
|
| 23 |
+
from ...image_utils import (
|
| 24 |
+
IMAGENET_STANDARD_MEAN,
|
| 25 |
+
IMAGENET_STANDARD_STD,
|
| 26 |
+
ChannelDimension,
|
| 27 |
+
ImageInput,
|
| 28 |
+
get_image_size,
|
| 29 |
+
infer_channel_dimension_format,
|
| 30 |
+
is_scaled_image,
|
| 31 |
+
make_list_of_images,
|
| 32 |
+
to_numpy_array,
|
| 33 |
+
valid_images,
|
| 34 |
+
validate_preprocess_arguments,
|
| 35 |
+
)
|
| 36 |
+
from ...utils import TensorType, filter_out_non_signature_kwargs, logging
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
logger = logging.get_logger(__name__)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class VitMatteImageProcessor(BaseImageProcessor):
|
| 43 |
+
r"""
|
| 44 |
+
Constructs a ViTMatte image processor.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
do_rescale (`bool`, *optional*, defaults to `True`):
|
| 48 |
+
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
|
| 49 |
+
parameter in the `preprocess` method.
|
| 50 |
+
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
| 51 |
+
Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
|
| 52 |
+
`preprocess` method.
|
| 53 |
+
do_normalize (`bool`, *optional*, defaults to `True`):
|
| 54 |
+
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
|
| 55 |
+
method.
|
| 56 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
|
| 57 |
+
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
|
| 58 |
+
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
|
| 59 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
|
| 60 |
+
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
|
| 61 |
+
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
| 62 |
+
do_pad (`bool`, *optional*, defaults to `True`):
|
| 63 |
+
Whether to pad the image to make the width and height divisible by `size_divisibility`. Can be overridden
|
| 64 |
+
by the `do_pad` parameter in the `preprocess` method.
|
| 65 |
+
size_divisibility (`int`, *optional*, defaults to 32):
|
| 66 |
+
The width and height of the image will be padded to be divisible by this number.
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
model_input_names = ["pixel_values"]
|
| 70 |
+
|
| 71 |
+
def __init__(
|
| 72 |
+
self,
|
| 73 |
+
do_rescale: bool = True,
|
| 74 |
+
rescale_factor: Union[int, float] = 1 / 255,
|
| 75 |
+
do_normalize: bool = True,
|
| 76 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 77 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 78 |
+
do_pad: bool = True,
|
| 79 |
+
size_divisibility: int = 32,
|
| 80 |
+
**kwargs,
|
| 81 |
+
) -> None:
|
| 82 |
+
super().__init__(**kwargs)
|
| 83 |
+
self.do_rescale = do_rescale
|
| 84 |
+
self.do_normalize = do_normalize
|
| 85 |
+
self.do_pad = do_pad
|
| 86 |
+
self.rescale_factor = rescale_factor
|
| 87 |
+
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
|
| 88 |
+
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
|
| 89 |
+
self.size_divisibility = size_divisibility
|
| 90 |
+
|
| 91 |
+
def pad_image(
|
| 92 |
+
self,
|
| 93 |
+
image: np.ndarray,
|
| 94 |
+
size_divisibility: int = 32,
|
| 95 |
+
data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 96 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 97 |
+
) -> np.ndarray:
|
| 98 |
+
"""
|
| 99 |
+
Args:
|
| 100 |
+
image (`np.ndarray`):
|
| 101 |
+
Image to pad.
|
| 102 |
+
size_divisibility (`int`, *optional*, defaults to 32):
|
| 103 |
+
The width and height of the image will be padded to be divisible by this number.
|
| 104 |
+
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
| 105 |
+
The channel dimension format for the output image. Can be one of:
|
| 106 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 107 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 108 |
+
- Unset: Use the channel dimension format of the input image.
|
| 109 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 110 |
+
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
| 111 |
+
from the input image. Can be one of:
|
| 112 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 113 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 114 |
+
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
| 115 |
+
"""
|
| 116 |
+
if input_data_format is None:
|
| 117 |
+
input_data_format = infer_channel_dimension_format(image)
|
| 118 |
+
|
| 119 |
+
height, width = get_image_size(image, input_data_format)
|
| 120 |
+
|
| 121 |
+
pad_height = 0 if height % size_divisibility == 0 else size_divisibility - height % size_divisibility
|
| 122 |
+
pad_width = 0 if width % size_divisibility == 0 else size_divisibility - width % size_divisibility
|
| 123 |
+
if pad_width + pad_height > 0:
|
| 124 |
+
padding = ((0, pad_height), (0, pad_width))
|
| 125 |
+
image = pad(image, padding=padding, data_format=data_format, input_data_format=input_data_format)
|
| 126 |
+
|
| 127 |
+
if data_format is not None:
|
| 128 |
+
image = to_channel_dimension_format(image, data_format, input_data_format)
|
| 129 |
+
|
| 130 |
+
return image
|
| 131 |
+
|
| 132 |
+
@filter_out_non_signature_kwargs()
|
| 133 |
+
def preprocess(
|
| 134 |
+
self,
|
| 135 |
+
images: ImageInput,
|
| 136 |
+
trimaps: ImageInput,
|
| 137 |
+
do_rescale: Optional[bool] = None,
|
| 138 |
+
rescale_factor: Optional[float] = None,
|
| 139 |
+
do_normalize: Optional[bool] = None,
|
| 140 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 141 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 142 |
+
do_pad: Optional[bool] = None,
|
| 143 |
+
size_divisibility: Optional[int] = None,
|
| 144 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 145 |
+
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
|
| 146 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 147 |
+
):
|
| 148 |
+
"""
|
| 149 |
+
Preprocess an image or batch of images.
|
| 150 |
+
|
| 151 |
+
Args:
|
| 152 |
+
images (`ImageInput`):
|
| 153 |
+
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
|
| 154 |
+
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
| 155 |
+
trimaps (`ImageInput`):
|
| 156 |
+
Trimap to preprocess.
|
| 157 |
+
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
| 158 |
+
Whether to rescale the image values between [0 - 1].
|
| 159 |
+
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
| 160 |
+
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
| 161 |
+
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
| 162 |
+
Whether to normalize the image.
|
| 163 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
| 164 |
+
Image mean to use if `do_normalize` is set to `True`.
|
| 165 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
| 166 |
+
Image standard deviation to use if `do_normalize` is set to `True`.
|
| 167 |
+
do_pad (`bool`, *optional*, defaults to `self.do_pad`):
|
| 168 |
+
Whether to pad the image.
|
| 169 |
+
size_divisibility (`int`, *optional*, defaults to `self.size_divisibility`):
|
| 170 |
+
The size divisibility to pad the image to if `do_pad` is set to `True`.
|
| 171 |
+
return_tensors (`str` or `TensorType`, *optional*):
|
| 172 |
+
The type of tensors to return. Can be one of:
|
| 173 |
+
- Unset: Return a list of `np.ndarray`.
|
| 174 |
+
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
| 175 |
+
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
| 176 |
+
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
| 177 |
+
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
| 178 |
+
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
| 179 |
+
The channel dimension format for the output image. Can be one of:
|
| 180 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 181 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 182 |
+
- Unset: Use the channel dimension format of the input image.
|
| 183 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 184 |
+
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
| 185 |
+
from the input image. Can be one of:
|
| 186 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 187 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 188 |
+
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
| 189 |
+
"""
|
| 190 |
+
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
| 191 |
+
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
| 192 |
+
do_pad = do_pad if do_pad is not None else self.do_pad
|
| 193 |
+
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
| 194 |
+
image_mean = image_mean if image_mean is not None else self.image_mean
|
| 195 |
+
image_std = image_std if image_std is not None else self.image_std
|
| 196 |
+
size_divisibility = size_divisibility if size_divisibility is not None else self.size_divisibility
|
| 197 |
+
|
| 198 |
+
images = make_list_of_images(images)
|
| 199 |
+
trimaps = make_list_of_images(trimaps, expected_ndims=2)
|
| 200 |
+
|
| 201 |
+
if not valid_images(trimaps):
|
| 202 |
+
raise ValueError(
|
| 203 |
+
"Invalid trimap type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
| 204 |
+
"torch.Tensor, tf.Tensor or jax.ndarray."
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
if not valid_images(images):
|
| 208 |
+
raise ValueError(
|
| 209 |
+
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
| 210 |
+
"torch.Tensor, tf.Tensor or jax.ndarray."
|
| 211 |
+
)
|
| 212 |
+
validate_preprocess_arguments(
|
| 213 |
+
do_rescale=do_rescale,
|
| 214 |
+
rescale_factor=rescale_factor,
|
| 215 |
+
do_normalize=do_normalize,
|
| 216 |
+
image_mean=image_mean,
|
| 217 |
+
image_std=image_std,
|
| 218 |
+
do_pad=do_pad,
|
| 219 |
+
size_divisibility=size_divisibility,
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
# All transformations expect numpy arrays.
|
| 223 |
+
images = [to_numpy_array(image) for image in images]
|
| 224 |
+
trimaps = [to_numpy_array(trimap) for trimap in trimaps]
|
| 225 |
+
|
| 226 |
+
if do_rescale and is_scaled_image(images[0]):
|
| 227 |
+
logger.warning_once(
|
| 228 |
+
"It looks like you are trying to rescale already rescaled images. If the input"
|
| 229 |
+
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
if input_data_format is None:
|
| 233 |
+
# We assume that all images have the same channel dimension format.
|
| 234 |
+
input_data_format = infer_channel_dimension_format(images[0])
|
| 235 |
+
|
| 236 |
+
if do_rescale:
|
| 237 |
+
images = [
|
| 238 |
+
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
| 239 |
+
for image in images
|
| 240 |
+
]
|
| 241 |
+
trimaps = [
|
| 242 |
+
self.rescale(image=trimap, scale=rescale_factor, input_data_format=input_data_format)
|
| 243 |
+
for trimap in trimaps
|
| 244 |
+
]
|
| 245 |
+
|
| 246 |
+
if do_normalize:
|
| 247 |
+
images = [
|
| 248 |
+
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
|
| 249 |
+
for image in images
|
| 250 |
+
]
|
| 251 |
+
|
| 252 |
+
# concatenate images and trimaps
|
| 253 |
+
images = [
|
| 254 |
+
np.concatenate([image, np.expand_dims(trimap, axis=-1)], axis=-1) for image, trimap in zip(images, trimaps)
|
| 255 |
+
]
|
| 256 |
+
|
| 257 |
+
if do_pad:
|
| 258 |
+
images = [
|
| 259 |
+
self.pad_image(image, size_divisibility=size_divisibility, input_data_format=input_data_format)
|
| 260 |
+
for image in images
|
| 261 |
+
]
|
| 262 |
+
|
| 263 |
+
images = [
|
| 264 |
+
to_channel_dimension_format(image=image, channel_dim=data_format, input_channel_dim=input_data_format)
|
| 265 |
+
for image in images
|
| 266 |
+
]
|
| 267 |
+
|
| 268 |
+
data = {"pixel_values": images}
|
| 269 |
+
return BatchFeature(data=data, tensor_type=return_tensors)
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
__all__ = ["VitMatteImageProcessor"]
|
docs/transformers/build/lib/transformers/models/vitmatte/modeling_vitmatte.py
ADDED
|
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 HUST-VL and The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""PyTorch ViTMatte model."""
|
| 16 |
+
|
| 17 |
+
from dataclasses import dataclass
|
| 18 |
+
from typing import Optional, Tuple
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
from torch import nn
|
| 22 |
+
|
| 23 |
+
from ...modeling_utils import PreTrainedModel
|
| 24 |
+
from ...utils import (
|
| 25 |
+
ModelOutput,
|
| 26 |
+
add_start_docstrings,
|
| 27 |
+
add_start_docstrings_to_model_forward,
|
| 28 |
+
replace_return_docstrings,
|
| 29 |
+
)
|
| 30 |
+
from ...utils.backbone_utils import load_backbone
|
| 31 |
+
from .configuration_vitmatte import VitMatteConfig
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# General docstring
|
| 35 |
+
_CONFIG_FOR_DOC = "VitMatteConfig"
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@dataclass
|
| 39 |
+
class ImageMattingOutput(ModelOutput):
|
| 40 |
+
"""
|
| 41 |
+
Class for outputs of image matting models.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
| 45 |
+
Loss.
|
| 46 |
+
alphas (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
| 47 |
+
Estimated alpha values.
|
| 48 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 49 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
| 50 |
+
one for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states
|
| 51 |
+
(also called feature maps) of the model at the output of each stage.
|
| 52 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 53 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size,
|
| 54 |
+
sequence_length)`.
|
| 55 |
+
|
| 56 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 57 |
+
heads.
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
loss: Optional[torch.FloatTensor] = None
|
| 61 |
+
alphas: Optional[torch.FloatTensor] = None
|
| 62 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 63 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class VitMattePreTrainedModel(PreTrainedModel):
|
| 67 |
+
"""
|
| 68 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 69 |
+
models.
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
config_class = VitMatteConfig
|
| 73 |
+
main_input_name = "pixel_values"
|
| 74 |
+
supports_gradient_checkpointing = True
|
| 75 |
+
_no_split_modules = []
|
| 76 |
+
|
| 77 |
+
def _init_weights(self, module):
|
| 78 |
+
if isinstance(module, nn.Conv2d):
|
| 79 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 80 |
+
if module.bias is not None:
|
| 81 |
+
module.bias.data.zero_()
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class VitMatteBasicConv3x3(nn.Module):
|
| 85 |
+
"""
|
| 86 |
+
Basic convolution layers including: Conv3x3, BatchNorm2d, ReLU layers.
|
| 87 |
+
"""
|
| 88 |
+
|
| 89 |
+
def __init__(self, config, in_channels, out_channels, stride=2, padding=1):
|
| 90 |
+
super().__init__()
|
| 91 |
+
self.conv = nn.Conv2d(
|
| 92 |
+
in_channels=in_channels,
|
| 93 |
+
out_channels=out_channels,
|
| 94 |
+
kernel_size=3,
|
| 95 |
+
stride=stride,
|
| 96 |
+
padding=padding,
|
| 97 |
+
bias=False,
|
| 98 |
+
)
|
| 99 |
+
self.batch_norm = nn.BatchNorm2d(out_channels, eps=config.batch_norm_eps)
|
| 100 |
+
self.relu = nn.ReLU()
|
| 101 |
+
|
| 102 |
+
def forward(self, hidden_state):
|
| 103 |
+
hidden_state = self.conv(hidden_state)
|
| 104 |
+
hidden_state = self.batch_norm(hidden_state)
|
| 105 |
+
hidden_state = self.relu(hidden_state)
|
| 106 |
+
|
| 107 |
+
return hidden_state
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class VitMatteConvStream(nn.Module):
|
| 111 |
+
"""
|
| 112 |
+
Simple ConvStream containing a series of basic conv3x3 layers to extract detail features.
|
| 113 |
+
"""
|
| 114 |
+
|
| 115 |
+
def __init__(self, config):
|
| 116 |
+
super().__init__()
|
| 117 |
+
|
| 118 |
+
# We use a default in-case there isn't a backbone config set. This is for backwards compatibility and
|
| 119 |
+
# to enable loading HF backbone models.
|
| 120 |
+
in_channels = 4
|
| 121 |
+
if config.backbone_config is not None:
|
| 122 |
+
in_channels = config.backbone_config.num_channels
|
| 123 |
+
|
| 124 |
+
out_channels = config.convstream_hidden_sizes
|
| 125 |
+
|
| 126 |
+
self.convs = nn.ModuleList()
|
| 127 |
+
self.conv_chans = [in_channels] + out_channels
|
| 128 |
+
|
| 129 |
+
for i in range(len(self.conv_chans) - 1):
|
| 130 |
+
in_chan_ = self.conv_chans[i]
|
| 131 |
+
out_chan_ = self.conv_chans[i + 1]
|
| 132 |
+
self.convs.append(VitMatteBasicConv3x3(config, in_chan_, out_chan_))
|
| 133 |
+
|
| 134 |
+
def forward(self, pixel_values):
|
| 135 |
+
out_dict = {"detailed_feature_map_0": pixel_values}
|
| 136 |
+
embeddings = pixel_values
|
| 137 |
+
for i in range(len(self.convs)):
|
| 138 |
+
embeddings = self.convs[i](embeddings)
|
| 139 |
+
name_ = "detailed_feature_map_" + str(i + 1)
|
| 140 |
+
out_dict[name_] = embeddings
|
| 141 |
+
|
| 142 |
+
return out_dict
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class VitMatteFusionBlock(nn.Module):
|
| 146 |
+
"""
|
| 147 |
+
Simple fusion block to fuse features from ConvStream and Plain Vision Transformer.
|
| 148 |
+
"""
|
| 149 |
+
|
| 150 |
+
def __init__(self, config, in_channels, out_channels):
|
| 151 |
+
super().__init__()
|
| 152 |
+
self.conv = VitMatteBasicConv3x3(config, in_channels, out_channels, stride=1, padding=1)
|
| 153 |
+
|
| 154 |
+
def forward(self, features, detailed_feature_map):
|
| 155 |
+
upscaled_features = nn.functional.interpolate(features, scale_factor=2, mode="bilinear", align_corners=False)
|
| 156 |
+
out = torch.cat([detailed_feature_map, upscaled_features], dim=1)
|
| 157 |
+
out = self.conv(out)
|
| 158 |
+
|
| 159 |
+
return out
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class VitMatteHead(nn.Module):
|
| 163 |
+
"""
|
| 164 |
+
Simple Matting Head, containing only conv3x3 and conv1x1 layers.
|
| 165 |
+
"""
|
| 166 |
+
|
| 167 |
+
def __init__(self, config):
|
| 168 |
+
super().__init__()
|
| 169 |
+
|
| 170 |
+
in_channels = config.fusion_hidden_sizes[-1]
|
| 171 |
+
mid_channels = 16
|
| 172 |
+
|
| 173 |
+
self.matting_convs = nn.Sequential(
|
| 174 |
+
nn.Conv2d(in_channels, mid_channels, kernel_size=3, stride=1, padding=1),
|
| 175 |
+
nn.BatchNorm2d(mid_channels),
|
| 176 |
+
nn.ReLU(True),
|
| 177 |
+
nn.Conv2d(mid_channels, 1, kernel_size=1, stride=1, padding=0),
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
def forward(self, hidden_state):
|
| 181 |
+
hidden_state = self.matting_convs(hidden_state)
|
| 182 |
+
|
| 183 |
+
return hidden_state
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
class VitMatteDetailCaptureModule(nn.Module):
|
| 187 |
+
"""
|
| 188 |
+
Simple and lightweight Detail Capture Module for ViT Matting.
|
| 189 |
+
"""
|
| 190 |
+
|
| 191 |
+
def __init__(self, config):
|
| 192 |
+
super().__init__()
|
| 193 |
+
if len(config.fusion_hidden_sizes) != len(config.convstream_hidden_sizes) + 1:
|
| 194 |
+
raise ValueError(
|
| 195 |
+
"The length of fusion_hidden_sizes should be equal to the length of convstream_hidden_sizes + 1."
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
self.config = config
|
| 199 |
+
self.convstream = VitMatteConvStream(config)
|
| 200 |
+
self.conv_chans = self.convstream.conv_chans
|
| 201 |
+
|
| 202 |
+
self.fusion_blocks = nn.ModuleList()
|
| 203 |
+
self.fusion_channels = [config.hidden_size] + config.fusion_hidden_sizes
|
| 204 |
+
|
| 205 |
+
for i in range(len(self.fusion_channels) - 1):
|
| 206 |
+
self.fusion_blocks.append(
|
| 207 |
+
VitMatteFusionBlock(
|
| 208 |
+
config=config,
|
| 209 |
+
in_channels=self.fusion_channels[i] + self.conv_chans[-(i + 1)],
|
| 210 |
+
out_channels=self.fusion_channels[i + 1],
|
| 211 |
+
)
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
self.matting_head = VitMatteHead(config)
|
| 215 |
+
|
| 216 |
+
def forward(self, features, pixel_values):
|
| 217 |
+
detail_features = self.convstream(pixel_values)
|
| 218 |
+
for i in range(len(self.fusion_blocks)):
|
| 219 |
+
detailed_feature_map_name = "detailed_feature_map_" + str(len(self.fusion_blocks) - i - 1)
|
| 220 |
+
features = self.fusion_blocks[i](features, detail_features[detailed_feature_map_name])
|
| 221 |
+
|
| 222 |
+
alphas = torch.sigmoid(self.matting_head(features))
|
| 223 |
+
|
| 224 |
+
return alphas
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
VITMATTE_START_DOCSTRING = r"""
|
| 228 |
+
Parameters:
|
| 229 |
+
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
|
| 230 |
+
it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
| 231 |
+
behavior.
|
| 232 |
+
config ([`UperNetConfig`]): Model configuration class with all the parameters of the model.
|
| 233 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 234 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 235 |
+
"""
|
| 236 |
+
|
| 237 |
+
VITMATTE_INPUTS_DOCSTRING = r"""
|
| 238 |
+
Args:
|
| 239 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
| 240 |
+
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
|
| 241 |
+
[`AutoImageProcessor`]. See [`VitMatteImageProcessor.__call__`] for details.
|
| 242 |
+
output_attentions (`bool`, *optional*):
|
| 243 |
+
Whether or not to return the attentions tensors of all attention layers in case the backbone has them. See
|
| 244 |
+
`attentions` under returned tensors for more detail.
|
| 245 |
+
output_hidden_states (`bool`, *optional*):
|
| 246 |
+
Whether or not to return the hidden states of all layers of the backbone. See `hidden_states` under
|
| 247 |
+
returned tensors for more detail.
|
| 248 |
+
return_dict (`bool`, *optional*):
|
| 249 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 250 |
+
"""
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
@add_start_docstrings(
|
| 254 |
+
"""ViTMatte framework leveraging any vision backbone e.g. for ADE20k, CityScapes.""",
|
| 255 |
+
VITMATTE_START_DOCSTRING,
|
| 256 |
+
)
|
| 257 |
+
class VitMatteForImageMatting(VitMattePreTrainedModel):
|
| 258 |
+
def __init__(self, config):
|
| 259 |
+
super().__init__(config)
|
| 260 |
+
self.config = config
|
| 261 |
+
|
| 262 |
+
self.backbone = load_backbone(config)
|
| 263 |
+
self.decoder = VitMatteDetailCaptureModule(config)
|
| 264 |
+
|
| 265 |
+
# Initialize weights and apply final processing
|
| 266 |
+
self.post_init()
|
| 267 |
+
|
| 268 |
+
@add_start_docstrings_to_model_forward(VITMATTE_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 269 |
+
@replace_return_docstrings(output_type=ImageMattingOutput, config_class=_CONFIG_FOR_DOC)
|
| 270 |
+
def forward(
|
| 271 |
+
self,
|
| 272 |
+
pixel_values: Optional[torch.Tensor] = None,
|
| 273 |
+
output_attentions: Optional[bool] = None,
|
| 274 |
+
output_hidden_states: Optional[bool] = None,
|
| 275 |
+
labels: Optional[torch.Tensor] = None,
|
| 276 |
+
return_dict: Optional[bool] = None,
|
| 277 |
+
):
|
| 278 |
+
"""
|
| 279 |
+
labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
|
| 280 |
+
Ground truth image matting for computing the loss.
|
| 281 |
+
|
| 282 |
+
Returns:
|
| 283 |
+
|
| 284 |
+
Examples:
|
| 285 |
+
|
| 286 |
+
```python
|
| 287 |
+
>>> from transformers import VitMatteImageProcessor, VitMatteForImageMatting
|
| 288 |
+
>>> import torch
|
| 289 |
+
>>> from PIL import Image
|
| 290 |
+
>>> from huggingface_hub import hf_hub_download
|
| 291 |
+
|
| 292 |
+
>>> processor = VitMatteImageProcessor.from_pretrained("hustvl/vitmatte-small-composition-1k")
|
| 293 |
+
>>> model = VitMatteForImageMatting.from_pretrained("hustvl/vitmatte-small-composition-1k")
|
| 294 |
+
|
| 295 |
+
>>> filepath = hf_hub_download(
|
| 296 |
+
... repo_id="hf-internal-testing/image-matting-fixtures", filename="image.png", repo_type="dataset"
|
| 297 |
+
... )
|
| 298 |
+
>>> image = Image.open(filepath).convert("RGB")
|
| 299 |
+
>>> filepath = hf_hub_download(
|
| 300 |
+
... repo_id="hf-internal-testing/image-matting-fixtures", filename="trimap.png", repo_type="dataset"
|
| 301 |
+
... )
|
| 302 |
+
>>> trimap = Image.open(filepath).convert("L")
|
| 303 |
+
|
| 304 |
+
>>> # prepare image + trimap for the model
|
| 305 |
+
>>> inputs = processor(images=image, trimaps=trimap, return_tensors="pt")
|
| 306 |
+
|
| 307 |
+
>>> with torch.no_grad():
|
| 308 |
+
... alphas = model(**inputs).alphas
|
| 309 |
+
>>> print(alphas.shape)
|
| 310 |
+
torch.Size([1, 1, 640, 960])
|
| 311 |
+
```"""
|
| 312 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 313 |
+
output_hidden_states = (
|
| 314 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 315 |
+
)
|
| 316 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 317 |
+
|
| 318 |
+
loss = None
|
| 319 |
+
if labels is not None:
|
| 320 |
+
raise NotImplementedError("Training is not yet supported")
|
| 321 |
+
|
| 322 |
+
outputs = self.backbone.forward_with_filtered_kwargs(
|
| 323 |
+
pixel_values, output_hidden_states=output_hidden_states, output_attentions=output_attentions
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
features = outputs.feature_maps[-1]
|
| 327 |
+
alphas = self.decoder(features, pixel_values)
|
| 328 |
+
|
| 329 |
+
if not return_dict:
|
| 330 |
+
output = (alphas,) + outputs[1:]
|
| 331 |
+
return ((loss,) + output) if loss is not None else output
|
| 332 |
+
|
| 333 |
+
return ImageMattingOutput(
|
| 334 |
+
loss=loss,
|
| 335 |
+
alphas=alphas,
|
| 336 |
+
hidden_states=outputs.hidden_states,
|
| 337 |
+
attentions=outputs.attentions,
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
__all__ = ["VitMattePreTrainedModel", "VitMatteForImageMatting"]
|
docs/transformers/build/lib/transformers/models/vitpose/__init__.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_vitpose import *
|
| 22 |
+
from .image_processing_vitpose import *
|
| 23 |
+
from .modeling_vitpose import *
|
| 24 |
+
else:
|
| 25 |
+
import sys
|
| 26 |
+
|
| 27 |
+
_file = globals()["__file__"]
|
| 28 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
docs/transformers/build/lib/transformers/models/vitpose/configuration_vitpose.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""VitPose model configuration"""
|
| 16 |
+
|
| 17 |
+
from typing import Optional
|
| 18 |
+
|
| 19 |
+
from ...configuration_utils import PretrainedConfig
|
| 20 |
+
from ...utils import logging
|
| 21 |
+
from ...utils.backbone_utils import verify_backbone_config_arguments
|
| 22 |
+
from ..auto.configuration_auto import CONFIG_MAPPING
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
logger = logging.get_logger(__name__)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class VitPoseConfig(PretrainedConfig):
|
| 29 |
+
r"""
|
| 30 |
+
This is the configuration class to store the configuration of a [`VitPoseForPoseEstimation`]. It is used to instantiate a
|
| 31 |
+
VitPose model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
| 32 |
+
with the defaults will yield a similar configuration to that of the VitPose
|
| 33 |
+
[usyd-community/vitpose-base-simple](https://huggingface.co/usyd-community/vitpose-base-simple) architecture.
|
| 34 |
+
|
| 35 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 36 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
backbone_config (`PretrainedConfig` or `dict`, *optional*, defaults to `VitPoseBackboneConfig()`):
|
| 40 |
+
The configuration of the backbone model. Currently, only `backbone_config` with `vitpose_backbone` as `model_type` is supported.
|
| 41 |
+
backbone (`str`, *optional*):
|
| 42 |
+
Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
|
| 43 |
+
will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
|
| 44 |
+
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
|
| 45 |
+
use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
|
| 46 |
+
Whether to use pretrained weights for the backbone.
|
| 47 |
+
use_timm_backbone (`bool`, *optional*, defaults to `False`):
|
| 48 |
+
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
|
| 49 |
+
library.
|
| 50 |
+
backbone_kwargs (`dict`, *optional*):
|
| 51 |
+
Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
|
| 52 |
+
e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
|
| 53 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 54 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 55 |
+
scale_factor (`int`, *optional*, defaults to 4):
|
| 56 |
+
Factor to upscale the feature maps coming from the ViT backbone.
|
| 57 |
+
use_simple_decoder (`bool`, *optional*, defaults to `True`):
|
| 58 |
+
Whether to use a `VitPoseSimpleDecoder` to decode the feature maps from the backbone into heatmaps. Otherwise it uses `VitPoseClassicDecoder`.
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
Example:
|
| 62 |
+
|
| 63 |
+
```python
|
| 64 |
+
>>> from transformers import VitPoseConfig, VitPoseForPoseEstimation
|
| 65 |
+
|
| 66 |
+
>>> # Initializing a VitPose configuration
|
| 67 |
+
>>> configuration = VitPoseConfig()
|
| 68 |
+
|
| 69 |
+
>>> # Initializing a model (with random weights) from the configuration
|
| 70 |
+
>>> model = VitPoseForPoseEstimation(configuration)
|
| 71 |
+
|
| 72 |
+
>>> # Accessing the model configuration
|
| 73 |
+
>>> configuration = model.config
|
| 74 |
+
```"""
|
| 75 |
+
|
| 76 |
+
model_type = "vitpose"
|
| 77 |
+
|
| 78 |
+
def __init__(
|
| 79 |
+
self,
|
| 80 |
+
backbone_config: Optional[PretrainedConfig] = None,
|
| 81 |
+
backbone: Optional[str] = None,
|
| 82 |
+
use_pretrained_backbone: bool = False,
|
| 83 |
+
use_timm_backbone: bool = False,
|
| 84 |
+
backbone_kwargs: Optional[dict] = None,
|
| 85 |
+
initializer_range: float = 0.02,
|
| 86 |
+
scale_factor: int = 4,
|
| 87 |
+
use_simple_decoder: bool = True,
|
| 88 |
+
**kwargs,
|
| 89 |
+
):
|
| 90 |
+
super().__init__(**kwargs)
|
| 91 |
+
|
| 92 |
+
if use_pretrained_backbone:
|
| 93 |
+
logger.info(
|
| 94 |
+
"`use_pretrained_backbone` is `True`. For the pure inference purpose of VitPose weight do not set this value."
|
| 95 |
+
)
|
| 96 |
+
if use_timm_backbone:
|
| 97 |
+
raise ValueError("use_timm_backbone set `True` is not supported at the moment.")
|
| 98 |
+
|
| 99 |
+
if backbone_config is None and backbone is None:
|
| 100 |
+
logger.info("`backbone_config` is `None`. Initializing the config with the default `VitPose` backbone.")
|
| 101 |
+
backbone_config = CONFIG_MAPPING["vitpose_backbone"](out_indices=[4])
|
| 102 |
+
elif isinstance(backbone_config, dict):
|
| 103 |
+
backbone_model_type = backbone_config.get("model_type")
|
| 104 |
+
config_class = CONFIG_MAPPING[backbone_model_type]
|
| 105 |
+
backbone_config = config_class.from_dict(backbone_config)
|
| 106 |
+
|
| 107 |
+
verify_backbone_config_arguments(
|
| 108 |
+
use_timm_backbone=use_timm_backbone,
|
| 109 |
+
use_pretrained_backbone=use_pretrained_backbone,
|
| 110 |
+
backbone=backbone,
|
| 111 |
+
backbone_config=backbone_config,
|
| 112 |
+
backbone_kwargs=backbone_kwargs,
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
self.backbone_config = backbone_config
|
| 116 |
+
self.backbone = backbone
|
| 117 |
+
self.use_pretrained_backbone = use_pretrained_backbone
|
| 118 |
+
self.use_timm_backbone = use_timm_backbone
|
| 119 |
+
self.backbone_kwargs = backbone_kwargs
|
| 120 |
+
|
| 121 |
+
self.initializer_range = initializer_range
|
| 122 |
+
self.scale_factor = scale_factor
|
| 123 |
+
self.use_simple_decoder = use_simple_decoder
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
__all__ = ["VitPoseConfig"]
|
docs/transformers/build/lib/transformers/models/vitpose/convert_vitpose_to_hf.py
ADDED
|
@@ -0,0 +1,428 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Convert VitPose checkpoints from the original repository.
|
| 16 |
+
|
| 17 |
+
URL: https://github.com/vitae-transformer/vitpose
|
| 18 |
+
|
| 19 |
+
Notebook to get the original logits: https://colab.research.google.com/drive/1QDX_2POTpl6JaZAV2WIFjuiqDsDwiqMZ?usp=sharing.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
import argparse
|
| 23 |
+
import os
|
| 24 |
+
import re
|
| 25 |
+
|
| 26 |
+
import requests
|
| 27 |
+
import torch
|
| 28 |
+
from huggingface_hub import hf_hub_download
|
| 29 |
+
from PIL import Image
|
| 30 |
+
|
| 31 |
+
from transformers import VitPoseBackboneConfig, VitPoseConfig, VitPoseForPoseEstimation, VitPoseImageProcessor
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
ORIGINAL_TO_CONVERTED_KEY_MAPPING = {
|
| 35 |
+
r"patch_embed.proj": "embeddings.patch_embeddings.projection",
|
| 36 |
+
r"pos_embed": "embeddings.position_embeddings",
|
| 37 |
+
r"blocks": "encoder.layer",
|
| 38 |
+
r"attn.proj": "attention.output.dense",
|
| 39 |
+
r"attn": "attention.self",
|
| 40 |
+
r"norm1": "layernorm_before",
|
| 41 |
+
r"norm2": "layernorm_after",
|
| 42 |
+
r"last_norm": "layernorm",
|
| 43 |
+
r"keypoint_head": "head",
|
| 44 |
+
r"final_layer": "conv",
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
MODEL_TO_FILE_NAME_MAPPING = {
|
| 48 |
+
# VitPose models, simple decoder
|
| 49 |
+
"vitpose-base-simple": "vitpose-b-simple.pth",
|
| 50 |
+
# VitPose models, classic decoder
|
| 51 |
+
"vitpose-base": "vitpose-b.pth",
|
| 52 |
+
# VitPose models, COCO-AIC-MPII
|
| 53 |
+
"vitpose-base-coco-aic-mpii": "vitpose_base_coco_aic_mpii.pth",
|
| 54 |
+
# VitPose+ models
|
| 55 |
+
"vitpose-plus-small": "vitpose+_small.pth",
|
| 56 |
+
"vitpose-plus-base": "vitpose+_base.pth",
|
| 57 |
+
"vitpose-plus-large": "vitpose+_large.pth",
|
| 58 |
+
"vitpose-plus-huge": "vitpose+_huge.pth",
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def get_config(model_name):
|
| 63 |
+
if "plus" in model_name:
|
| 64 |
+
num_experts = 6
|
| 65 |
+
if "small" in model_name:
|
| 66 |
+
part_features = 96
|
| 67 |
+
out_indices = [12]
|
| 68 |
+
elif "base" in model_name:
|
| 69 |
+
part_features = 192
|
| 70 |
+
out_indices = [12]
|
| 71 |
+
elif "large" in model_name:
|
| 72 |
+
part_features = 256
|
| 73 |
+
out_indices = [24]
|
| 74 |
+
elif "huge" in model_name:
|
| 75 |
+
part_features = 320
|
| 76 |
+
out_indices = [32]
|
| 77 |
+
else:
|
| 78 |
+
raise ValueError(f"Model {model_name} not supported")
|
| 79 |
+
else:
|
| 80 |
+
num_experts = 1
|
| 81 |
+
part_features = 0
|
| 82 |
+
|
| 83 |
+
# size of the architecture
|
| 84 |
+
if "small" in model_name:
|
| 85 |
+
hidden_size = 384
|
| 86 |
+
num_hidden_layers = 12
|
| 87 |
+
num_attention_heads = 12
|
| 88 |
+
elif "large" in model_name:
|
| 89 |
+
hidden_size = 1024
|
| 90 |
+
num_hidden_layers = 24
|
| 91 |
+
num_attention_heads = 16
|
| 92 |
+
elif "huge" in model_name:
|
| 93 |
+
hidden_size = 1280
|
| 94 |
+
num_hidden_layers = 32
|
| 95 |
+
num_attention_heads = 16
|
| 96 |
+
|
| 97 |
+
backbone_config = VitPoseBackboneConfig(
|
| 98 |
+
out_indices=out_indices,
|
| 99 |
+
hidden_size=hidden_size,
|
| 100 |
+
num_hidden_layers=num_hidden_layers,
|
| 101 |
+
num_attention_heads=num_attention_heads,
|
| 102 |
+
num_experts=num_experts,
|
| 103 |
+
part_features=part_features,
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
use_simple_decoder = "simple" in model_name
|
| 107 |
+
|
| 108 |
+
edges = [
|
| 109 |
+
[15, 13],
|
| 110 |
+
[13, 11],
|
| 111 |
+
[16, 14],
|
| 112 |
+
[14, 12],
|
| 113 |
+
[11, 12],
|
| 114 |
+
[5, 11],
|
| 115 |
+
[6, 12],
|
| 116 |
+
[5, 6],
|
| 117 |
+
[5, 7],
|
| 118 |
+
[6, 8],
|
| 119 |
+
[7, 9],
|
| 120 |
+
[8, 10],
|
| 121 |
+
[1, 2],
|
| 122 |
+
[0, 1],
|
| 123 |
+
[0, 2],
|
| 124 |
+
[1, 3],
|
| 125 |
+
[2, 4],
|
| 126 |
+
[3, 5],
|
| 127 |
+
[4, 6],
|
| 128 |
+
]
|
| 129 |
+
id2label = {
|
| 130 |
+
0: "Nose",
|
| 131 |
+
1: "L_Eye",
|
| 132 |
+
2: "R_Eye",
|
| 133 |
+
3: "L_Ear",
|
| 134 |
+
4: "R_Ear",
|
| 135 |
+
5: "L_Shoulder",
|
| 136 |
+
6: "R_Shoulder",
|
| 137 |
+
7: "L_Elbow",
|
| 138 |
+
8: "R_Elbow",
|
| 139 |
+
9: "L_Wrist",
|
| 140 |
+
10: "R_Wrist",
|
| 141 |
+
11: "L_Hip",
|
| 142 |
+
12: "R_Hip",
|
| 143 |
+
13: "L_Knee",
|
| 144 |
+
14: "R_Knee",
|
| 145 |
+
15: "L_Ankle",
|
| 146 |
+
16: "R_Ankle",
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
label2id = {v: k for k, v in id2label.items()}
|
| 150 |
+
|
| 151 |
+
config = VitPoseConfig(
|
| 152 |
+
backbone_config=backbone_config,
|
| 153 |
+
num_labels=17,
|
| 154 |
+
use_simple_decoder=use_simple_decoder,
|
| 155 |
+
edges=edges,
|
| 156 |
+
id2label=id2label,
|
| 157 |
+
label2id=label2id,
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
return config
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def convert_old_keys_to_new_keys(state_dict_keys: dict = None):
|
| 164 |
+
"""
|
| 165 |
+
This function should be applied only once, on the concatenated keys to efficiently rename using
|
| 166 |
+
the key mappings.
|
| 167 |
+
"""
|
| 168 |
+
output_dict = {}
|
| 169 |
+
if state_dict_keys is not None:
|
| 170 |
+
old_text = "\n".join(state_dict_keys)
|
| 171 |
+
new_text = old_text
|
| 172 |
+
for pattern, replacement in ORIGINAL_TO_CONVERTED_KEY_MAPPING.items():
|
| 173 |
+
if replacement is None:
|
| 174 |
+
new_text = re.sub(pattern, "", new_text) # an empty line
|
| 175 |
+
continue
|
| 176 |
+
new_text = re.sub(pattern, replacement, new_text)
|
| 177 |
+
output_dict = dict(zip(old_text.split("\n"), new_text.split("\n")))
|
| 178 |
+
return output_dict
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
# We will verify our results on a COCO image
|
| 182 |
+
def prepare_img():
|
| 183 |
+
url = "http://images.cocodataset.org/val2017/000000000139.jpg"
|
| 184 |
+
image = Image.open(requests.get(url, stream=True).raw)
|
| 185 |
+
return image
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
@torch.no_grad()
|
| 189 |
+
def write_model(model_name, model_path, push_to_hub, check_logits=True):
|
| 190 |
+
# ------------------------------------------------------------
|
| 191 |
+
# Vision model params and config
|
| 192 |
+
# ------------------------------------------------------------
|
| 193 |
+
|
| 194 |
+
# params from config
|
| 195 |
+
config = get_config(model_name)
|
| 196 |
+
|
| 197 |
+
# ------------------------------------------------------------
|
| 198 |
+
# Convert weights
|
| 199 |
+
# ------------------------------------------------------------
|
| 200 |
+
|
| 201 |
+
# load original state_dict
|
| 202 |
+
filename = MODEL_TO_FILE_NAME_MAPPING[model_name]
|
| 203 |
+
print(f"Fetching all parameters from the checkpoint at {filename}...")
|
| 204 |
+
|
| 205 |
+
checkpoint_path = hf_hub_download(
|
| 206 |
+
repo_id="nielsr/vitpose-original-checkpoints", filename=filename, repo_type="model"
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
print("Converting model...")
|
| 210 |
+
original_state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True)["state_dict"]
|
| 211 |
+
all_keys = list(original_state_dict.keys())
|
| 212 |
+
new_keys = convert_old_keys_to_new_keys(all_keys)
|
| 213 |
+
|
| 214 |
+
dim = config.backbone_config.hidden_size
|
| 215 |
+
|
| 216 |
+
state_dict = {}
|
| 217 |
+
for key in all_keys:
|
| 218 |
+
new_key = new_keys[key]
|
| 219 |
+
value = original_state_dict[key]
|
| 220 |
+
|
| 221 |
+
if re.search("associate_heads", new_key) or re.search("backbone.cls_token", new_key):
|
| 222 |
+
# This associated_heads is concept of auxiliary head so does not require in inference stage.
|
| 223 |
+
# backbone.cls_token is optional forward function for dynamically change of size, see detail in https://github.com/ViTAE-Transformer/ViTPose/issues/34
|
| 224 |
+
pass
|
| 225 |
+
elif re.search("qkv", new_key):
|
| 226 |
+
state_dict[new_key.replace("self.qkv", "attention.query")] = value[:dim]
|
| 227 |
+
state_dict[new_key.replace("self.qkv", "attention.key")] = value[dim : dim * 2]
|
| 228 |
+
state_dict[new_key.replace("self.qkv", "attention.value")] = value[-dim:]
|
| 229 |
+
elif re.search("head", new_key) and not config.use_simple_decoder:
|
| 230 |
+
# Pattern for deconvolution layers
|
| 231 |
+
deconv_pattern = r"deconv_layers\.(0|3)\.weight"
|
| 232 |
+
new_key = re.sub(deconv_pattern, lambda m: f"deconv{int(m.group(1)) // 3 + 1}.weight", new_key)
|
| 233 |
+
# Pattern for batch normalization layers
|
| 234 |
+
bn_patterns = [
|
| 235 |
+
(r"deconv_layers\.(\d+)\.weight", r"batchnorm\1.weight"),
|
| 236 |
+
(r"deconv_layers\.(\d+)\.bias", r"batchnorm\1.bias"),
|
| 237 |
+
(r"deconv_layers\.(\d+)\.running_mean", r"batchnorm\1.running_mean"),
|
| 238 |
+
(r"deconv_layers\.(\d+)\.running_var", r"batchnorm\1.running_var"),
|
| 239 |
+
(r"deconv_layers\.(\d+)\.num_batches_tracked", r"batchnorm\1.num_batches_tracked"),
|
| 240 |
+
]
|
| 241 |
+
|
| 242 |
+
for pattern, replacement in bn_patterns:
|
| 243 |
+
if re.search(pattern, new_key):
|
| 244 |
+
# Convert the layer number to the correct batch norm index
|
| 245 |
+
layer_num = int(re.search(pattern, key).group(1))
|
| 246 |
+
bn_num = layer_num // 3 + 1
|
| 247 |
+
new_key = re.sub(pattern, replacement.replace(r"\1", str(bn_num)), new_key)
|
| 248 |
+
state_dict[new_key] = value
|
| 249 |
+
else:
|
| 250 |
+
state_dict[new_key] = value
|
| 251 |
+
|
| 252 |
+
print("Loading the checkpoint in a Vitpose model.")
|
| 253 |
+
model = VitPoseForPoseEstimation(config)
|
| 254 |
+
model.eval()
|
| 255 |
+
model.load_state_dict(state_dict)
|
| 256 |
+
print("Checkpoint loaded successfully.")
|
| 257 |
+
|
| 258 |
+
# create image processor
|
| 259 |
+
image_processor = VitPoseImageProcessor()
|
| 260 |
+
|
| 261 |
+
# verify image processor
|
| 262 |
+
image = prepare_img()
|
| 263 |
+
boxes = [[[412.8, 157.61, 53.05, 138.01], [384.43, 172.21, 15.12, 35.74]]]
|
| 264 |
+
pixel_values = image_processor(images=image, boxes=boxes, return_tensors="pt").pixel_values
|
| 265 |
+
|
| 266 |
+
filepath = hf_hub_download(repo_id="nielsr/test-image", filename="vitpose_batch_data.pt", repo_type="dataset")
|
| 267 |
+
original_pixel_values = torch.load(filepath, map_location="cpu", weights_only=True)["img"]
|
| 268 |
+
# we allow for a small difference in the pixel values due to the original repository using cv2
|
| 269 |
+
assert torch.allclose(pixel_values, original_pixel_values, atol=1e-1)
|
| 270 |
+
|
| 271 |
+
dataset_index = torch.tensor([0])
|
| 272 |
+
|
| 273 |
+
with torch.no_grad():
|
| 274 |
+
print("Shape of original_pixel_values: ", original_pixel_values.shape)
|
| 275 |
+
print("First values of original_pixel_values: ", original_pixel_values[0, 0, :3, :3])
|
| 276 |
+
|
| 277 |
+
# first forward pass
|
| 278 |
+
outputs = model(original_pixel_values, dataset_index=dataset_index)
|
| 279 |
+
output_heatmap = outputs.heatmaps
|
| 280 |
+
|
| 281 |
+
print("Shape of output_heatmap: ", output_heatmap.shape)
|
| 282 |
+
print("First values: ", output_heatmap[0, 0, :3, :3])
|
| 283 |
+
|
| 284 |
+
# second forward pass (flipped)
|
| 285 |
+
# this is done since the model uses `flip_test=True` in its test config
|
| 286 |
+
original_pixel_values_flipped = torch.flip(original_pixel_values, [3])
|
| 287 |
+
outputs_flipped = model(
|
| 288 |
+
original_pixel_values_flipped,
|
| 289 |
+
dataset_index=dataset_index,
|
| 290 |
+
flip_pairs=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]]),
|
| 291 |
+
)
|
| 292 |
+
output_flipped_heatmap = outputs_flipped.heatmaps
|
| 293 |
+
|
| 294 |
+
outputs.heatmaps = (output_heatmap + output_flipped_heatmap) * 0.5
|
| 295 |
+
|
| 296 |
+
# Verify pose_results
|
| 297 |
+
pose_results = image_processor.post_process_pose_estimation(outputs, boxes=boxes)[0]
|
| 298 |
+
|
| 299 |
+
if check_logits:
|
| 300 |
+
# Simple decoder checkpoints
|
| 301 |
+
if model_name == "vitpose-base-simple":
|
| 302 |
+
assert torch.allclose(
|
| 303 |
+
pose_results[1]["keypoints"][0],
|
| 304 |
+
torch.tensor([3.98180511e02, 1.81808380e02]),
|
| 305 |
+
atol=5e-2,
|
| 306 |
+
)
|
| 307 |
+
assert torch.allclose(
|
| 308 |
+
pose_results[1]["scores"][0],
|
| 309 |
+
torch.tensor([8.66642594e-01]),
|
| 310 |
+
atol=5e-2,
|
| 311 |
+
)
|
| 312 |
+
# Classic decoder checkpoints
|
| 313 |
+
elif model_name == "vitpose-base":
|
| 314 |
+
assert torch.allclose(
|
| 315 |
+
pose_results[1]["keypoints"][0],
|
| 316 |
+
torch.tensor([3.9807913e02, 1.8182812e02]),
|
| 317 |
+
atol=5e-2,
|
| 318 |
+
)
|
| 319 |
+
assert torch.allclose(
|
| 320 |
+
pose_results[1]["scores"][0],
|
| 321 |
+
torch.tensor([8.8235235e-01]),
|
| 322 |
+
atol=5e-2,
|
| 323 |
+
)
|
| 324 |
+
# COCO-AIC-MPII checkpoints
|
| 325 |
+
elif model_name == "vitpose-base-coco-aic-mpii":
|
| 326 |
+
assert torch.allclose(
|
| 327 |
+
pose_results[1]["keypoints"][0],
|
| 328 |
+
torch.tensor([3.98305542e02, 1.81741592e02]),
|
| 329 |
+
atol=5e-2,
|
| 330 |
+
)
|
| 331 |
+
assert torch.allclose(
|
| 332 |
+
pose_results[1]["scores"][0],
|
| 333 |
+
torch.tensor([8.69966745e-01]),
|
| 334 |
+
atol=5e-2,
|
| 335 |
+
)
|
| 336 |
+
# VitPose+ models
|
| 337 |
+
elif model_name == "vitpose-plus-small":
|
| 338 |
+
assert torch.allclose(
|
| 339 |
+
pose_results[1]["keypoints"][0],
|
| 340 |
+
torch.tensor([398.1597, 181.6902]),
|
| 341 |
+
atol=5e-2,
|
| 342 |
+
)
|
| 343 |
+
assert torch.allclose(
|
| 344 |
+
pose_results[1]["scores"][0],
|
| 345 |
+
torch.tensor(0.9051),
|
| 346 |
+
atol=5e-2,
|
| 347 |
+
)
|
| 348 |
+
elif model_name == "vitpose-plus-base":
|
| 349 |
+
assert torch.allclose(
|
| 350 |
+
pose_results[1]["keypoints"][0],
|
| 351 |
+
torch.tensor([3.98201294e02, 1.81728302e02]),
|
| 352 |
+
atol=5e-2,
|
| 353 |
+
)
|
| 354 |
+
assert torch.allclose(
|
| 355 |
+
pose_results[1]["scores"][0],
|
| 356 |
+
torch.tensor([8.75046968e-01]),
|
| 357 |
+
atol=5e-2,
|
| 358 |
+
)
|
| 359 |
+
elif model_name == "vitpose-plus-large":
|
| 360 |
+
assert torch.allclose(
|
| 361 |
+
pose_results[1]["keypoints"][0],
|
| 362 |
+
torch.tensor([398.1409, 181.7412]),
|
| 363 |
+
atol=5e-2,
|
| 364 |
+
)
|
| 365 |
+
assert torch.allclose(
|
| 366 |
+
pose_results[1]["scores"][0],
|
| 367 |
+
torch.tensor(0.8746),
|
| 368 |
+
atol=5e-2,
|
| 369 |
+
)
|
| 370 |
+
elif model_name == "vitpose-plus-huge":
|
| 371 |
+
assert torch.allclose(
|
| 372 |
+
pose_results[1]["keypoints"][0],
|
| 373 |
+
torch.tensor([398.2079, 181.8026]),
|
| 374 |
+
atol=5e-2,
|
| 375 |
+
)
|
| 376 |
+
assert torch.allclose(
|
| 377 |
+
pose_results[1]["scores"][0],
|
| 378 |
+
torch.tensor(0.8693),
|
| 379 |
+
atol=5e-2,
|
| 380 |
+
)
|
| 381 |
+
else:
|
| 382 |
+
raise ValueError("Model not supported")
|
| 383 |
+
print("Conversion successfully done.")
|
| 384 |
+
|
| 385 |
+
if model_path is not None:
|
| 386 |
+
os.makedirs(model_path, exist_ok=True)
|
| 387 |
+
model.save_pretrained(model_path)
|
| 388 |
+
image_processor.save_pretrained(model_path)
|
| 389 |
+
|
| 390 |
+
if push_to_hub:
|
| 391 |
+
print(f"Pushing model and image processor for {model_name} to hub")
|
| 392 |
+
# we created a community organization on the hub for this model
|
| 393 |
+
# maintained by the Transformers team
|
| 394 |
+
model.push_to_hub(f"usyd-community/{model_name}")
|
| 395 |
+
image_processor.push_to_hub(f"usyd-community/{model_name}")
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
def main():
|
| 399 |
+
parser = argparse.ArgumentParser()
|
| 400 |
+
# Required parameters
|
| 401 |
+
parser.add_argument(
|
| 402 |
+
"--model_name",
|
| 403 |
+
default="vitpose-base-simple",
|
| 404 |
+
choices=MODEL_TO_FILE_NAME_MAPPING.keys(),
|
| 405 |
+
type=str,
|
| 406 |
+
help="Name of the VitPose model you'd like to convert.",
|
| 407 |
+
)
|
| 408 |
+
parser.add_argument(
|
| 409 |
+
"--pytorch_dump_folder_path", default=None, type=str, help="Path to store the converted model."
|
| 410 |
+
)
|
| 411 |
+
parser.add_argument(
|
| 412 |
+
"--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
|
| 413 |
+
)
|
| 414 |
+
parser.add_argument(
|
| 415 |
+
"--check_logits", action="store_false", help="Whether or not to verify the logits of the converted model."
|
| 416 |
+
)
|
| 417 |
+
|
| 418 |
+
args = parser.parse_args()
|
| 419 |
+
write_model(
|
| 420 |
+
model_path=args.pytorch_dump_folder_path,
|
| 421 |
+
model_name=args.model_name,
|
| 422 |
+
push_to_hub=args.push_to_hub,
|
| 423 |
+
check_logits=args.check_logits,
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
if __name__ == "__main__":
|
| 428 |
+
main()
|
docs/transformers/build/lib/transformers/models/vitpose/image_processing_vitpose.py
ADDED
|
@@ -0,0 +1,684 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Image processor class for VitPose."""
|
| 16 |
+
|
| 17 |
+
import itertools
|
| 18 |
+
import math
|
| 19 |
+
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
|
| 23 |
+
from ...image_processing_utils import BaseImageProcessor, BatchFeature
|
| 24 |
+
from ...image_transforms import to_channel_dimension_format
|
| 25 |
+
from ...image_utils import (
|
| 26 |
+
IMAGENET_DEFAULT_MEAN,
|
| 27 |
+
IMAGENET_DEFAULT_STD,
|
| 28 |
+
ChannelDimension,
|
| 29 |
+
ImageInput,
|
| 30 |
+
infer_channel_dimension_format,
|
| 31 |
+
is_scaled_image,
|
| 32 |
+
make_list_of_images,
|
| 33 |
+
to_numpy_array,
|
| 34 |
+
valid_images,
|
| 35 |
+
)
|
| 36 |
+
from ...utils import TensorType, is_scipy_available, is_torch_available, is_vision_available, logging
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
if is_torch_available():
|
| 40 |
+
import torch
|
| 41 |
+
|
| 42 |
+
if is_vision_available():
|
| 43 |
+
import PIL
|
| 44 |
+
|
| 45 |
+
if is_scipy_available():
|
| 46 |
+
from scipy.linalg import inv
|
| 47 |
+
from scipy.ndimage import affine_transform, gaussian_filter
|
| 48 |
+
|
| 49 |
+
if TYPE_CHECKING:
|
| 50 |
+
from .modeling_vitpose import VitPoseEstimatorOutput
|
| 51 |
+
|
| 52 |
+
logger = logging.get_logger(__name__)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# inspired by https://github.com/ViTAE-Transformer/ViTPose/blob/d5216452796c90c6bc29f5c5ec0bdba94366768a/mmpose/datasets/datasets/base/kpt_2d_sview_rgb_img_top_down_dataset.py#L132
|
| 56 |
+
def box_to_center_and_scale(
|
| 57 |
+
box: Union[Tuple, List, np.ndarray],
|
| 58 |
+
image_width: int,
|
| 59 |
+
image_height: int,
|
| 60 |
+
normalize_factor: float = 200.0,
|
| 61 |
+
padding_factor: float = 1.25,
|
| 62 |
+
):
|
| 63 |
+
"""
|
| 64 |
+
Encodes a bounding box in COCO format into (center, scale).
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
box (`Tuple`, `List`, or `np.ndarray`):
|
| 68 |
+
Bounding box in COCO format (top_left_x, top_left_y, width, height).
|
| 69 |
+
image_width (`int`):
|
| 70 |
+
Image width.
|
| 71 |
+
image_height (`int`):
|
| 72 |
+
Image height.
|
| 73 |
+
normalize_factor (`float`):
|
| 74 |
+
Width and height scale factor.
|
| 75 |
+
padding_factor (`float`):
|
| 76 |
+
Bounding box padding factor.
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
tuple: A tuple containing center and scale.
|
| 80 |
+
|
| 81 |
+
- `np.ndarray` [float32](2,): Center of the bbox (x, y).
|
| 82 |
+
- `np.ndarray` [float32](2,): Scale of the bbox width & height.
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
top_left_x, top_left_y, width, height = box[:4]
|
| 86 |
+
aspect_ratio = image_width / image_height
|
| 87 |
+
center = np.array([top_left_x + width * 0.5, top_left_y + height * 0.5], dtype=np.float32)
|
| 88 |
+
|
| 89 |
+
if width > aspect_ratio * height:
|
| 90 |
+
height = width * 1.0 / aspect_ratio
|
| 91 |
+
elif width < aspect_ratio * height:
|
| 92 |
+
width = height * aspect_ratio
|
| 93 |
+
|
| 94 |
+
scale = np.array([width / normalize_factor, height / normalize_factor], dtype=np.float32)
|
| 95 |
+
scale = scale * padding_factor
|
| 96 |
+
|
| 97 |
+
return center, scale
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def coco_to_pascal_voc(bboxes: np.ndarray) -> np.ndarray:
|
| 101 |
+
"""
|
| 102 |
+
Converts bounding boxes from the COCO format to the Pascal VOC format.
|
| 103 |
+
|
| 104 |
+
In other words, converts from (top_left_x, top_left_y, width, height) format
|
| 105 |
+
to (top_left_x, top_left_y, bottom_right_x, bottom_right_y).
|
| 106 |
+
|
| 107 |
+
Args:
|
| 108 |
+
bboxes (`np.ndarray` of shape `(batch_size, 4)):
|
| 109 |
+
Bounding boxes in COCO format.
|
| 110 |
+
|
| 111 |
+
Returns:
|
| 112 |
+
`np.ndarray` of shape `(batch_size, 4) in Pascal VOC format.
|
| 113 |
+
"""
|
| 114 |
+
bboxes[:, 2] = bboxes[:, 2] + bboxes[:, 0] - 1
|
| 115 |
+
bboxes[:, 3] = bboxes[:, 3] + bboxes[:, 1] - 1
|
| 116 |
+
|
| 117 |
+
return bboxes
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def get_keypoint_predictions(heatmaps: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
| 121 |
+
"""Get keypoint predictions from score maps.
|
| 122 |
+
|
| 123 |
+
Args:
|
| 124 |
+
heatmaps (`np.ndarray` of shape `(batch_size, num_keypoints, height, width)`):
|
| 125 |
+
Model predicted heatmaps.
|
| 126 |
+
|
| 127 |
+
Returns:
|
| 128 |
+
tuple: A tuple containing aggregated results.
|
| 129 |
+
|
| 130 |
+
- coords (`np.ndarray` of shape `(batch_size, num_keypoints, 2)`):
|
| 131 |
+
Predicted keypoint location.
|
| 132 |
+
- scores (`np.ndarray` of shape `(batch_size, num_keypoints, 1)`):
|
| 133 |
+
Scores (confidence) of the keypoints.
|
| 134 |
+
"""
|
| 135 |
+
if not isinstance(heatmaps, np.ndarray):
|
| 136 |
+
raise ValueError("Heatmaps should be np.ndarray")
|
| 137 |
+
if heatmaps.ndim != 4:
|
| 138 |
+
raise ValueError("Heatmaps should be 4-dimensional")
|
| 139 |
+
|
| 140 |
+
batch_size, num_keypoints, _, width = heatmaps.shape
|
| 141 |
+
heatmaps_reshaped = heatmaps.reshape((batch_size, num_keypoints, -1))
|
| 142 |
+
idx = np.argmax(heatmaps_reshaped, 2).reshape((batch_size, num_keypoints, 1))
|
| 143 |
+
scores = np.amax(heatmaps_reshaped, 2).reshape((batch_size, num_keypoints, 1))
|
| 144 |
+
|
| 145 |
+
preds = np.tile(idx, (1, 1, 2)).astype(np.float32)
|
| 146 |
+
preds[:, :, 0] = preds[:, :, 0] % width
|
| 147 |
+
preds[:, :, 1] = preds[:, :, 1] // width
|
| 148 |
+
|
| 149 |
+
preds = np.where(np.tile(scores, (1, 1, 2)) > 0.0, preds, -1)
|
| 150 |
+
return preds, scores
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def post_dark_unbiased_data_processing(coords: np.ndarray, batch_heatmaps: np.ndarray, kernel: int = 3) -> np.ndarray:
|
| 154 |
+
"""DARK post-pocessing. Implemented by unbiased_data_processing.
|
| 155 |
+
|
| 156 |
+
Paper references:
|
| 157 |
+
- Huang et al. The Devil is in the Details: Delving into Unbiased Data Processing for Human Pose Estimation (CVPR 2020).
|
| 158 |
+
- Zhang et al. Distribution-Aware Coordinate Representation for Human Pose Estimation (CVPR 2020).
|
| 159 |
+
|
| 160 |
+
Args:
|
| 161 |
+
coords (`np.ndarray` of shape `(num_persons, num_keypoints, 2)`):
|
| 162 |
+
Initial coordinates of human pose.
|
| 163 |
+
batch_heatmaps (`np.ndarray` of shape `(batch_size, num_keypoints, height, width)`):
|
| 164 |
+
Batched heatmaps as predicted by the model.
|
| 165 |
+
A batch_size of 1 is used for the bottom up paradigm where all persons share the same heatmap.
|
| 166 |
+
A batch_size of `num_persons` is used for the top down paradigm where each person has its own heatmaps.
|
| 167 |
+
kernel (`int`, *optional*, defaults to 3):
|
| 168 |
+
Gaussian kernel size (K) for modulation.
|
| 169 |
+
|
| 170 |
+
Returns:
|
| 171 |
+
`np.ndarray` of shape `(num_persons, num_keypoints, 2)` ):
|
| 172 |
+
Refined coordinates.
|
| 173 |
+
"""
|
| 174 |
+
batch_size, num_keypoints, height, width = batch_heatmaps.shape
|
| 175 |
+
num_coords = coords.shape[0]
|
| 176 |
+
if not (batch_size == 1 or batch_size == num_coords):
|
| 177 |
+
raise ValueError("The batch size of heatmaps should be 1 or equal to the batch size of coordinates.")
|
| 178 |
+
radius = int((kernel - 1) // 2)
|
| 179 |
+
batch_heatmaps = np.array(
|
| 180 |
+
[
|
| 181 |
+
[gaussian_filter(heatmap, sigma=0.8, radius=(radius, radius), axes=(0, 1)) for heatmap in heatmaps]
|
| 182 |
+
for heatmaps in batch_heatmaps
|
| 183 |
+
]
|
| 184 |
+
)
|
| 185 |
+
batch_heatmaps = np.clip(batch_heatmaps, 0.001, 50)
|
| 186 |
+
batch_heatmaps = np.log(batch_heatmaps)
|
| 187 |
+
|
| 188 |
+
batch_heatmaps_pad = np.pad(batch_heatmaps, ((0, 0), (0, 0), (1, 1), (1, 1)), mode="edge").flatten()
|
| 189 |
+
|
| 190 |
+
# calculate indices for coordinates
|
| 191 |
+
index = coords[..., 0] + 1 + (coords[..., 1] + 1) * (width + 2)
|
| 192 |
+
index += (width + 2) * (height + 2) * np.arange(0, batch_size * num_keypoints).reshape(-1, num_keypoints)
|
| 193 |
+
index = index.astype(int).reshape(-1, 1)
|
| 194 |
+
i_ = batch_heatmaps_pad[index]
|
| 195 |
+
ix1 = batch_heatmaps_pad[index + 1]
|
| 196 |
+
iy1 = batch_heatmaps_pad[index + width + 2]
|
| 197 |
+
ix1y1 = batch_heatmaps_pad[index + width + 3]
|
| 198 |
+
ix1_y1_ = batch_heatmaps_pad[index - width - 3]
|
| 199 |
+
ix1_ = batch_heatmaps_pad[index - 1]
|
| 200 |
+
iy1_ = batch_heatmaps_pad[index - 2 - width]
|
| 201 |
+
|
| 202 |
+
# calculate refined coordinates using Newton's method
|
| 203 |
+
dx = 0.5 * (ix1 - ix1_)
|
| 204 |
+
dy = 0.5 * (iy1 - iy1_)
|
| 205 |
+
derivative = np.concatenate([dx, dy], axis=1)
|
| 206 |
+
derivative = derivative.reshape(num_coords, num_keypoints, 2, 1)
|
| 207 |
+
dxx = ix1 - 2 * i_ + ix1_
|
| 208 |
+
dyy = iy1 - 2 * i_ + iy1_
|
| 209 |
+
dxy = 0.5 * (ix1y1 - ix1 - iy1 + i_ + i_ - ix1_ - iy1_ + ix1_y1_)
|
| 210 |
+
hessian = np.concatenate([dxx, dxy, dxy, dyy], axis=1)
|
| 211 |
+
hessian = hessian.reshape(num_coords, num_keypoints, 2, 2)
|
| 212 |
+
hessian = np.linalg.inv(hessian + np.finfo(np.float32).eps * np.eye(2))
|
| 213 |
+
coords -= np.einsum("ijmn,ijnk->ijmk", hessian, derivative).squeeze()
|
| 214 |
+
return coords
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def transform_preds(coords: np.ndarray, center: np.ndarray, scale: np.ndarray, output_size: np.ndarray) -> np.ndarray:
|
| 218 |
+
"""Get final keypoint predictions from heatmaps and apply scaling and
|
| 219 |
+
translation to map them back to the image.
|
| 220 |
+
|
| 221 |
+
Note:
|
| 222 |
+
num_keypoints: K
|
| 223 |
+
|
| 224 |
+
Args:
|
| 225 |
+
coords (`np.ndarray` of shape `(num_keypoints, ndims)`):
|
| 226 |
+
|
| 227 |
+
* If ndims=2, corrds are predicted keypoint location.
|
| 228 |
+
* If ndims=4, corrds are composed of (x, y, scores, tags)
|
| 229 |
+
* If ndims=5, corrds are composed of (x, y, scores, tags,
|
| 230 |
+
flipped_tags)
|
| 231 |
+
|
| 232 |
+
center (`np.ndarray` of shape `(2,)`):
|
| 233 |
+
Center of the bounding box (x, y).
|
| 234 |
+
scale (`np.ndarray` of shape `(2,)`):
|
| 235 |
+
Scale of the bounding box wrt original image of width and height.
|
| 236 |
+
output_size (`np.ndarray` of shape `(2,)`):
|
| 237 |
+
Size of the destination heatmaps in (height, width) format.
|
| 238 |
+
|
| 239 |
+
Returns:
|
| 240 |
+
np.ndarray: Predicted coordinates in the images.
|
| 241 |
+
"""
|
| 242 |
+
if coords.shape[1] not in (2, 4, 5):
|
| 243 |
+
raise ValueError("Coordinates need to have either 2, 4 or 5 dimensions.")
|
| 244 |
+
if len(center) != 2:
|
| 245 |
+
raise ValueError("Center needs to have 2 elements, one for x and one for y.")
|
| 246 |
+
if len(scale) != 2:
|
| 247 |
+
raise ValueError("Scale needs to consist of a width and height")
|
| 248 |
+
if len(output_size) != 2:
|
| 249 |
+
raise ValueError("Output size needs to consist of a height and width")
|
| 250 |
+
|
| 251 |
+
# Recover the scale which is normalized by a factor of 200.
|
| 252 |
+
scale = scale * 200.0
|
| 253 |
+
|
| 254 |
+
# We use unbiased data processing
|
| 255 |
+
scale_y = scale[1] / (output_size[0] - 1.0)
|
| 256 |
+
scale_x = scale[0] / (output_size[1] - 1.0)
|
| 257 |
+
|
| 258 |
+
target_coords = np.ones_like(coords)
|
| 259 |
+
target_coords[:, 0] = coords[:, 0] * scale_x + center[0] - scale[0] * 0.5
|
| 260 |
+
target_coords[:, 1] = coords[:, 1] * scale_y + center[1] - scale[1] * 0.5
|
| 261 |
+
|
| 262 |
+
return target_coords
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def get_warp_matrix(theta: float, size_input: np.ndarray, size_dst: np.ndarray, size_target: np.ndarray):
|
| 266 |
+
"""
|
| 267 |
+
Calculate the transformation matrix under the constraint of unbiased. Paper ref: Huang et al. The Devil is in the
|
| 268 |
+
Details: Delving into Unbiased Data Processing for Human Pose Estimation (CVPR 2020).
|
| 269 |
+
|
| 270 |
+
Source: https://github.com/open-mmlab/mmpose/blob/master/mmpose/core/post_processing/post_transforms.py
|
| 271 |
+
|
| 272 |
+
Args:
|
| 273 |
+
theta (`float`):
|
| 274 |
+
Rotation angle in degrees.
|
| 275 |
+
size_input (`np.ndarray`):
|
| 276 |
+
Size of input image [width, height].
|
| 277 |
+
size_dst (`np.ndarray`):
|
| 278 |
+
Size of output image [width, height].
|
| 279 |
+
size_target (`np.ndarray`):
|
| 280 |
+
Size of ROI in input plane [w, h].
|
| 281 |
+
|
| 282 |
+
Returns:
|
| 283 |
+
`np.ndarray`: A matrix for transformation.
|
| 284 |
+
"""
|
| 285 |
+
theta = np.deg2rad(theta)
|
| 286 |
+
matrix = np.zeros((2, 3), dtype=np.float32)
|
| 287 |
+
scale_x = size_dst[0] / size_target[0]
|
| 288 |
+
scale_y = size_dst[1] / size_target[1]
|
| 289 |
+
matrix[0, 0] = math.cos(theta) * scale_x
|
| 290 |
+
matrix[0, 1] = -math.sin(theta) * scale_x
|
| 291 |
+
matrix[0, 2] = scale_x * (
|
| 292 |
+
-0.5 * size_input[0] * math.cos(theta) + 0.5 * size_input[1] * math.sin(theta) + 0.5 * size_target[0]
|
| 293 |
+
)
|
| 294 |
+
matrix[1, 0] = math.sin(theta) * scale_y
|
| 295 |
+
matrix[1, 1] = math.cos(theta) * scale_y
|
| 296 |
+
matrix[1, 2] = scale_y * (
|
| 297 |
+
-0.5 * size_input[0] * math.sin(theta) - 0.5 * size_input[1] * math.cos(theta) + 0.5 * size_target[1]
|
| 298 |
+
)
|
| 299 |
+
return matrix
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
def scipy_warp_affine(src, M, size):
|
| 303 |
+
"""
|
| 304 |
+
This function implements cv2.warpAffine function using affine_transform in scipy. See https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.affine_transform.html and https://docs.opencv.org/4.x/d4/d61/tutorial_warp_affine.html for more details.
|
| 305 |
+
|
| 306 |
+
Note: the original implementation of cv2.warpAffine uses cv2.INTER_LINEAR.
|
| 307 |
+
"""
|
| 308 |
+
channels = [src[..., i] for i in range(src.shape[-1])]
|
| 309 |
+
|
| 310 |
+
# Convert to a 3x3 matrix used by SciPy
|
| 311 |
+
M_scipy = np.vstack([M, [0, 0, 1]])
|
| 312 |
+
# If you have a matrix for the ‘push’ transformation, use its inverse (numpy.linalg.inv) in this function.
|
| 313 |
+
M_inv = inv(M_scipy)
|
| 314 |
+
M_inv[0, 0], M_inv[0, 1], M_inv[1, 0], M_inv[1, 1], M_inv[0, 2], M_inv[1, 2] = (
|
| 315 |
+
M_inv[1, 1],
|
| 316 |
+
M_inv[1, 0],
|
| 317 |
+
M_inv[0, 1],
|
| 318 |
+
M_inv[0, 0],
|
| 319 |
+
M_inv[1, 2],
|
| 320 |
+
M_inv[0, 2],
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
new_src = [affine_transform(channel, M_inv, output_shape=size, order=1) for channel in channels]
|
| 324 |
+
new_src = np.stack(new_src, axis=-1)
|
| 325 |
+
return new_src
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
class VitPoseImageProcessor(BaseImageProcessor):
|
| 329 |
+
r"""
|
| 330 |
+
Constructs a VitPose image processor.
|
| 331 |
+
|
| 332 |
+
Args:
|
| 333 |
+
do_affine_transform (`bool`, *optional*, defaults to `True`):
|
| 334 |
+
Whether to apply an affine transformation to the input images.
|
| 335 |
+
size (`Dict[str, int]` *optional*, defaults to `{"height": 256, "width": 192}`):
|
| 336 |
+
Resolution of the image after `affine_transform` is applied. Only has an effect if `do_affine_transform` is set to `True`. Can
|
| 337 |
+
be overriden by `size` in the `preprocess` method.
|
| 338 |
+
do_rescale (`bool`, *optional*, defaults to `True`):
|
| 339 |
+
Whether or not to apply the scaling factor (to make pixel values floats between 0. and 1.).
|
| 340 |
+
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
| 341 |
+
Scale factor to use if rescaling the image. Can be overriden by `rescale_factor` in the `preprocess`
|
| 342 |
+
method.
|
| 343 |
+
do_normalize (`bool`, *optional*, defaults to `True`):
|
| 344 |
+
Whether or not to normalize the input with mean and standard deviation.
|
| 345 |
+
image_mean (`List[int]`, defaults to `[0.485, 0.456, 0.406]`, *optional*):
|
| 346 |
+
The sequence of means for each channel, to be used when normalizing images.
|
| 347 |
+
image_std (`List[int]`, defaults to `[0.229, 0.224, 0.225]`, *optional*):
|
| 348 |
+
The sequence of standard deviations for each channel, to be used when normalizing images.
|
| 349 |
+
"""
|
| 350 |
+
|
| 351 |
+
model_input_names = ["pixel_values"]
|
| 352 |
+
|
| 353 |
+
def __init__(
|
| 354 |
+
self,
|
| 355 |
+
do_affine_transform: bool = True,
|
| 356 |
+
size: Dict[str, int] = None,
|
| 357 |
+
do_rescale: bool = True,
|
| 358 |
+
rescale_factor: Union[int, float] = 1 / 255,
|
| 359 |
+
do_normalize: bool = True,
|
| 360 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 361 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 362 |
+
**kwargs,
|
| 363 |
+
):
|
| 364 |
+
super().__init__(**kwargs)
|
| 365 |
+
self.do_affine_transform = do_affine_transform
|
| 366 |
+
self.size = size if size is not None else {"height": 256, "width": 192}
|
| 367 |
+
self.do_rescale = do_rescale
|
| 368 |
+
self.rescale_factor = rescale_factor
|
| 369 |
+
self.do_normalize = do_normalize
|
| 370 |
+
self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
|
| 371 |
+
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
|
| 372 |
+
self.normalize_factor = 200.0
|
| 373 |
+
|
| 374 |
+
def affine_transform(
|
| 375 |
+
self,
|
| 376 |
+
image: np.array,
|
| 377 |
+
center: Tuple[float],
|
| 378 |
+
scale: Tuple[float],
|
| 379 |
+
rotation: float,
|
| 380 |
+
size: Dict[str, int],
|
| 381 |
+
data_format: Optional[ChannelDimension] = None,
|
| 382 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 383 |
+
) -> np.array:
|
| 384 |
+
"""
|
| 385 |
+
Apply an affine transformation to an image.
|
| 386 |
+
|
| 387 |
+
Args:
|
| 388 |
+
image (`np.array`):
|
| 389 |
+
Image to transform.
|
| 390 |
+
center (`Tuple[float]`):
|
| 391 |
+
Center of the bounding box (x, y).
|
| 392 |
+
scale (`Tuple[float]`):
|
| 393 |
+
Scale of the bounding box with respect to height/width.
|
| 394 |
+
rotation (`float`):
|
| 395 |
+
Rotation angle in degrees.
|
| 396 |
+
size (`Dict[str, int]`):
|
| 397 |
+
Size of the destination image.
|
| 398 |
+
data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`):
|
| 399 |
+
The channel dimension format of the output image.
|
| 400 |
+
input_data_format (`str` or `ChannelDimension`, *optional*):
|
| 401 |
+
The channel dimension format of the input image.
|
| 402 |
+
"""
|
| 403 |
+
|
| 404 |
+
data_format = input_data_format if data_format is None else data_format
|
| 405 |
+
|
| 406 |
+
size = (size["width"], size["height"])
|
| 407 |
+
|
| 408 |
+
# one uses a pixel standard deviation of 200 pixels
|
| 409 |
+
transformation = get_warp_matrix(rotation, center * 2.0, np.array(size) - 1.0, scale * 200.0)
|
| 410 |
+
|
| 411 |
+
# input image requires channels last format
|
| 412 |
+
image = (
|
| 413 |
+
image
|
| 414 |
+
if input_data_format == ChannelDimension.LAST
|
| 415 |
+
else to_channel_dimension_format(image, ChannelDimension.LAST, input_data_format)
|
| 416 |
+
)
|
| 417 |
+
image = scipy_warp_affine(src=image, M=transformation, size=(size[1], size[0]))
|
| 418 |
+
|
| 419 |
+
image = to_channel_dimension_format(image, data_format, ChannelDimension.LAST)
|
| 420 |
+
|
| 421 |
+
return image
|
| 422 |
+
|
| 423 |
+
def preprocess(
|
| 424 |
+
self,
|
| 425 |
+
images: ImageInput,
|
| 426 |
+
boxes: Union[List[List[float]], np.ndarray],
|
| 427 |
+
do_affine_transform: Optional[bool] = None,
|
| 428 |
+
size: Dict[str, int] = None,
|
| 429 |
+
do_rescale: Optional[bool] = None,
|
| 430 |
+
rescale_factor: Optional[float] = None,
|
| 431 |
+
do_normalize: Optional[bool] = None,
|
| 432 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 433 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 434 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 435 |
+
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
|
| 436 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 437 |
+
) -> PIL.Image.Image:
|
| 438 |
+
"""
|
| 439 |
+
Preprocess an image or batch of images.
|
| 440 |
+
|
| 441 |
+
Args:
|
| 442 |
+
images (`ImageInput`):
|
| 443 |
+
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
|
| 444 |
+
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
| 445 |
+
|
| 446 |
+
boxes (`List[List[List[float]]]` or `np.ndarray`):
|
| 447 |
+
List or array of bounding boxes for each image. Each box should be a list of 4 floats representing the bounding
|
| 448 |
+
box coordinates in COCO format (top_left_x, top_left_y, width, height).
|
| 449 |
+
|
| 450 |
+
do_affine_transform (`bool`, *optional*, defaults to `self.do_affine_transform`):
|
| 451 |
+
Whether to apply an affine transformation to the input images.
|
| 452 |
+
size (`Dict[str, int]` *optional*, defaults to `self.size`):
|
| 453 |
+
Dictionary in the format `{"height": h, "width": w}` specifying the size of the output image after
|
| 454 |
+
resizing.
|
| 455 |
+
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
| 456 |
+
Whether to rescale the image values between [0 - 1].
|
| 457 |
+
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
| 458 |
+
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
| 459 |
+
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
| 460 |
+
Whether to normalize the image.
|
| 461 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
| 462 |
+
Image mean to use if `do_normalize` is set to `True`.
|
| 463 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
| 464 |
+
Image standard deviation to use if `do_normalize` is set to `True`.
|
| 465 |
+
return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
|
| 466 |
+
If set, will return tensors of a particular framework. Acceptable values are:
|
| 467 |
+
|
| 468 |
+
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
| 469 |
+
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
| 470 |
+
- `'np'`: Return NumPy `np.ndarray` objects.
|
| 471 |
+
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
| 472 |
+
|
| 473 |
+
Returns:
|
| 474 |
+
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
| 475 |
+
|
| 476 |
+
- **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height,
|
| 477 |
+
width).
|
| 478 |
+
"""
|
| 479 |
+
do_affine_transform = do_affine_transform if do_affine_transform is not None else self.do_affine_transform
|
| 480 |
+
size = size if size is not None else self.size
|
| 481 |
+
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
| 482 |
+
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
| 483 |
+
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
| 484 |
+
image_mean = image_mean if image_mean is not None else self.image_mean
|
| 485 |
+
image_std = image_std if image_std is not None else self.image_std
|
| 486 |
+
|
| 487 |
+
images = make_list_of_images(images)
|
| 488 |
+
|
| 489 |
+
if not valid_images(images):
|
| 490 |
+
raise ValueError(
|
| 491 |
+
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
| 492 |
+
"torch.Tensor, tf.Tensor or jax.ndarray."
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
if isinstance(boxes, list) and len(images) != len(boxes):
|
| 496 |
+
raise ValueError(f"Batch of images and boxes mismatch : {len(images)} != {len(boxes)}")
|
| 497 |
+
elif isinstance(boxes, np.ndarray) and len(images) != boxes.shape[0]:
|
| 498 |
+
raise ValueError(f"Batch of images and boxes mismatch : {len(images)} != {boxes.shape[0]}")
|
| 499 |
+
|
| 500 |
+
# All transformations expect numpy arrays.
|
| 501 |
+
images = [to_numpy_array(image) for image in images]
|
| 502 |
+
|
| 503 |
+
if is_scaled_image(images[0]) and do_rescale:
|
| 504 |
+
logger.warning_once(
|
| 505 |
+
"It looks like you are trying to rescale already rescaled images. If the input"
|
| 506 |
+
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
| 507 |
+
)
|
| 508 |
+
|
| 509 |
+
if input_data_format is None:
|
| 510 |
+
# We assume that all images have the same channel dimension format.
|
| 511 |
+
input_data_format = infer_channel_dimension_format(images[0])
|
| 512 |
+
|
| 513 |
+
# transformations (affine transformation + rescaling + normalization)
|
| 514 |
+
if self.do_affine_transform:
|
| 515 |
+
new_images = []
|
| 516 |
+
for image, image_boxes in zip(images, boxes):
|
| 517 |
+
for box in image_boxes:
|
| 518 |
+
center, scale = box_to_center_and_scale(
|
| 519 |
+
box,
|
| 520 |
+
image_width=size["width"],
|
| 521 |
+
image_height=size["height"],
|
| 522 |
+
normalize_factor=self.normalize_factor,
|
| 523 |
+
)
|
| 524 |
+
transformed_image = self.affine_transform(
|
| 525 |
+
image, center, scale, rotation=0, size=size, input_data_format=input_data_format
|
| 526 |
+
)
|
| 527 |
+
new_images.append(transformed_image)
|
| 528 |
+
images = new_images
|
| 529 |
+
|
| 530 |
+
# For batch processing, the number of boxes must be consistent across all images in the batch.
|
| 531 |
+
# When using a list input, the number of boxes can vary dynamically per image.
|
| 532 |
+
# The image processor creates pixel_values of shape (batch_size*num_persons, num_channels, height, width)
|
| 533 |
+
|
| 534 |
+
all_images = []
|
| 535 |
+
for image in images:
|
| 536 |
+
if do_rescale:
|
| 537 |
+
image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
| 538 |
+
|
| 539 |
+
if do_normalize:
|
| 540 |
+
image = self.normalize(
|
| 541 |
+
image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
|
| 542 |
+
)
|
| 543 |
+
|
| 544 |
+
all_images.append(image)
|
| 545 |
+
images = [
|
| 546 |
+
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
|
| 547 |
+
for image in all_images
|
| 548 |
+
]
|
| 549 |
+
|
| 550 |
+
data = {"pixel_values": images}
|
| 551 |
+
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
|
| 552 |
+
|
| 553 |
+
return encoded_inputs
|
| 554 |
+
|
| 555 |
+
def keypoints_from_heatmaps(
|
| 556 |
+
self,
|
| 557 |
+
heatmaps: np.ndarray,
|
| 558 |
+
center: np.ndarray,
|
| 559 |
+
scale: np.ndarray,
|
| 560 |
+
kernel: int = 11,
|
| 561 |
+
):
|
| 562 |
+
"""
|
| 563 |
+
Get final keypoint predictions from heatmaps and transform them back to
|
| 564 |
+
the image.
|
| 565 |
+
|
| 566 |
+
Args:
|
| 567 |
+
heatmaps (`np.ndarray` of shape `(batch_size, num_keypoints, height, width])`):
|
| 568 |
+
Model predicted heatmaps.
|
| 569 |
+
center (`np.ndarray` of shape `(batch_size, 2)`):
|
| 570 |
+
Center of the bounding box (x, y).
|
| 571 |
+
scale (`np.ndarray` of shape `(batch_size, 2)`):
|
| 572 |
+
Scale of the bounding box wrt original images of width and height.
|
| 573 |
+
kernel (int, *optional*, defaults to 11):
|
| 574 |
+
Gaussian kernel size (K) for modulation, which should match the heatmap gaussian sigma when training.
|
| 575 |
+
K=17 for sigma=3 and k=11 for sigma=2.
|
| 576 |
+
|
| 577 |
+
Returns:
|
| 578 |
+
tuple: A tuple containing keypoint predictions and scores.
|
| 579 |
+
|
| 580 |
+
- preds (`np.ndarray` of shape `(batch_size, num_keypoints, 2)`):
|
| 581 |
+
Predicted keypoint location in images.
|
| 582 |
+
- scores (`np.ndarray` of shape `(batch_size, num_keypoints, 1)`):
|
| 583 |
+
Scores (confidence) of the keypoints.
|
| 584 |
+
"""
|
| 585 |
+
batch_size, _, height, width = heatmaps.shape
|
| 586 |
+
|
| 587 |
+
coords, scores = get_keypoint_predictions(heatmaps)
|
| 588 |
+
|
| 589 |
+
preds = post_dark_unbiased_data_processing(coords, heatmaps, kernel=kernel)
|
| 590 |
+
|
| 591 |
+
# Transform back to the image
|
| 592 |
+
for i in range(batch_size):
|
| 593 |
+
preds[i] = transform_preds(preds[i], center=center[i], scale=scale[i], output_size=[height, width])
|
| 594 |
+
|
| 595 |
+
return preds, scores
|
| 596 |
+
|
| 597 |
+
def post_process_pose_estimation(
|
| 598 |
+
self,
|
| 599 |
+
outputs: "VitPoseEstimatorOutput",
|
| 600 |
+
boxes: Union[List[List[List[float]]], np.ndarray],
|
| 601 |
+
kernel_size: int = 11,
|
| 602 |
+
threshold: Optional[float] = None,
|
| 603 |
+
target_sizes: Union[TensorType, List[Tuple]] = None,
|
| 604 |
+
):
|
| 605 |
+
"""
|
| 606 |
+
Transform the heatmaps into keypoint predictions and transform them back to the image.
|
| 607 |
+
|
| 608 |
+
Args:
|
| 609 |
+
outputs (`VitPoseEstimatorOutput`):
|
| 610 |
+
VitPoseForPoseEstimation model outputs.
|
| 611 |
+
boxes (`List[List[List[float]]]` or `np.ndarray`):
|
| 612 |
+
List or array of bounding boxes for each image. Each box should be a list of 4 floats representing the bounding
|
| 613 |
+
box coordinates in COCO format (top_left_x, top_left_y, width, height).
|
| 614 |
+
kernel_size (`int`, *optional*, defaults to 11):
|
| 615 |
+
Gaussian kernel size (K) for modulation.
|
| 616 |
+
threshold (`float`, *optional*, defaults to None):
|
| 617 |
+
Score threshold to keep object detection predictions.
|
| 618 |
+
target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*):
|
| 619 |
+
Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
|
| 620 |
+
`(height, width)` of each image in the batch. If unset, predictions will be resize with the default value.
|
| 621 |
+
Returns:
|
| 622 |
+
`List[List[Dict]]`: A list of dictionaries, each dictionary containing the keypoints and boxes for an image
|
| 623 |
+
in the batch as predicted by the model.
|
| 624 |
+
"""
|
| 625 |
+
|
| 626 |
+
# First compute centers and scales for each bounding box
|
| 627 |
+
batch_size, num_keypoints, _, _ = outputs.heatmaps.shape
|
| 628 |
+
|
| 629 |
+
if target_sizes is not None:
|
| 630 |
+
if batch_size != len(target_sizes):
|
| 631 |
+
raise ValueError(
|
| 632 |
+
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
|
| 633 |
+
)
|
| 634 |
+
|
| 635 |
+
centers = np.zeros((batch_size, 2), dtype=np.float32)
|
| 636 |
+
scales = np.zeros((batch_size, 2), dtype=np.float32)
|
| 637 |
+
flattened_boxes = list(itertools.chain(*boxes))
|
| 638 |
+
for i in range(batch_size):
|
| 639 |
+
if target_sizes is not None:
|
| 640 |
+
image_width, image_height = target_sizes[i][0], target_sizes[i][1]
|
| 641 |
+
scale_factor = np.array([image_width, image_height, image_width, image_height])
|
| 642 |
+
flattened_boxes[i] = flattened_boxes[i] * scale_factor
|
| 643 |
+
width, height = self.size["width"], self.size["height"]
|
| 644 |
+
center, scale = box_to_center_and_scale(flattened_boxes[i], image_width=width, image_height=height)
|
| 645 |
+
centers[i, :] = center
|
| 646 |
+
scales[i, :] = scale
|
| 647 |
+
|
| 648 |
+
preds, scores = self.keypoints_from_heatmaps(
|
| 649 |
+
outputs.heatmaps.cpu().numpy(), centers, scales, kernel=kernel_size
|
| 650 |
+
)
|
| 651 |
+
|
| 652 |
+
all_boxes = np.zeros((batch_size, 4), dtype=np.float32)
|
| 653 |
+
all_boxes[:, 0:2] = centers[:, 0:2]
|
| 654 |
+
all_boxes[:, 2:4] = scales[:, 0:2]
|
| 655 |
+
|
| 656 |
+
poses = torch.tensor(preds)
|
| 657 |
+
scores = torch.tensor(scores)
|
| 658 |
+
labels = torch.arange(0, num_keypoints)
|
| 659 |
+
bboxes_xyxy = torch.tensor(coco_to_pascal_voc(all_boxes))
|
| 660 |
+
|
| 661 |
+
results: List[List[Dict[str, torch.Tensor]]] = []
|
| 662 |
+
|
| 663 |
+
pose_bbox_pairs = zip(poses, scores, bboxes_xyxy)
|
| 664 |
+
|
| 665 |
+
for image_bboxes in boxes:
|
| 666 |
+
image_results: List[Dict[str, torch.Tensor]] = []
|
| 667 |
+
for _ in image_bboxes:
|
| 668 |
+
# Unpack the next pose and bbox_xyxy from the iterator
|
| 669 |
+
pose, score, bbox_xyxy = next(pose_bbox_pairs)
|
| 670 |
+
score = score.squeeze()
|
| 671 |
+
keypoints_labels = labels
|
| 672 |
+
if threshold is not None:
|
| 673 |
+
keep = score > threshold
|
| 674 |
+
pose = pose[keep]
|
| 675 |
+
score = score[keep]
|
| 676 |
+
keypoints_labels = keypoints_labels[keep]
|
| 677 |
+
pose_result = {"keypoints": pose, "scores": score, "labels": keypoints_labels, "bbox": bbox_xyxy}
|
| 678 |
+
image_results.append(pose_result)
|
| 679 |
+
results.append(image_results)
|
| 680 |
+
|
| 681 |
+
return results
|
| 682 |
+
|
| 683 |
+
|
| 684 |
+
__all__ = ["VitPoseImageProcessor"]
|
docs/transformers/build/lib/transformers/models/vitpose/modeling_vitpose.py
ADDED
|
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 University of Sydney and The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""PyTorch VitPose model."""
|
| 16 |
+
|
| 17 |
+
from dataclasses import dataclass
|
| 18 |
+
from typing import Optional, Tuple, Union
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import torch.utils.checkpoint
|
| 22 |
+
from torch import nn
|
| 23 |
+
|
| 24 |
+
from ...modeling_utils import PreTrainedModel
|
| 25 |
+
from ...utils import (
|
| 26 |
+
ModelOutput,
|
| 27 |
+
add_start_docstrings,
|
| 28 |
+
add_start_docstrings_to_model_forward,
|
| 29 |
+
logging,
|
| 30 |
+
replace_return_docstrings,
|
| 31 |
+
)
|
| 32 |
+
from ...utils.backbone_utils import load_backbone
|
| 33 |
+
from .configuration_vitpose import VitPoseConfig
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
logger = logging.get_logger(__name__)
|
| 37 |
+
|
| 38 |
+
# General docstring
|
| 39 |
+
_CONFIG_FOR_DOC = "VitPoseConfig"
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@dataclass
|
| 43 |
+
class VitPoseEstimatorOutput(ModelOutput):
|
| 44 |
+
"""
|
| 45 |
+
Class for outputs of pose estimation models.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
| 49 |
+
Loss is not supported at this moment. See https://github.com/ViTAE-Transformer/ViTPose/tree/main/mmpose/models/losses for further detail.
|
| 50 |
+
heatmaps (`torch.FloatTensor` of shape `(batch_size, num_keypoints, height, width)`):
|
| 51 |
+
Heatmaps as predicted by the model.
|
| 52 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 53 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
| 54 |
+
one for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states
|
| 55 |
+
(also called feature maps) of the model at the output of each stage.
|
| 56 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 57 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size,
|
| 58 |
+
sequence_length)`.
|
| 59 |
+
|
| 60 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 61 |
+
heads.
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
loss: Optional[torch.FloatTensor] = None
|
| 65 |
+
heatmaps: Optional[torch.FloatTensor] = None
|
| 66 |
+
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
| 67 |
+
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class VitPosePreTrainedModel(PreTrainedModel):
|
| 71 |
+
"""
|
| 72 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 73 |
+
models.
|
| 74 |
+
"""
|
| 75 |
+
|
| 76 |
+
config_class = VitPoseConfig
|
| 77 |
+
base_model_prefix = "vit"
|
| 78 |
+
main_input_name = "pixel_values"
|
| 79 |
+
supports_gradient_checkpointing = True
|
| 80 |
+
|
| 81 |
+
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
|
| 82 |
+
"""Initialize the weights"""
|
| 83 |
+
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
| 84 |
+
# Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
|
| 85 |
+
# `trunc_normal_cpu` not implemented in `half` issues
|
| 86 |
+
module.weight.data = nn.init.trunc_normal_(
|
| 87 |
+
module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
|
| 88 |
+
).to(module.weight.dtype)
|
| 89 |
+
if module.bias is not None:
|
| 90 |
+
module.bias.data.zero_()
|
| 91 |
+
elif isinstance(module, nn.LayerNorm):
|
| 92 |
+
module.bias.data.zero_()
|
| 93 |
+
module.weight.data.fill_(1.0)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
VITPOSE_START_DOCSTRING = r"""
|
| 97 |
+
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
|
| 98 |
+
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
| 99 |
+
behavior.
|
| 100 |
+
|
| 101 |
+
Parameters:
|
| 102 |
+
config ([`VitPoseConfig`]): Model configuration class with all the parameters of the model.
|
| 103 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 104 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 105 |
+
"""
|
| 106 |
+
|
| 107 |
+
VITPOSE_INPUTS_DOCSTRING = r"""
|
| 108 |
+
Args:
|
| 109 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
| 110 |
+
Pixel values. Pixel values can be obtained using [`VitPoseImageProcessor`]. See
|
| 111 |
+
[`VitPoseImageProcessor.__call__`] for details.
|
| 112 |
+
|
| 113 |
+
dataset_index (`torch.Tensor` of shape `(batch_size,)`):
|
| 114 |
+
Index to use in the Mixture-of-Experts (MoE) blocks of the backbone.
|
| 115 |
+
|
| 116 |
+
This corresponds to the dataset index used during training, e.g. For the single dataset index 0 refers to the corresponding dataset. For the multiple datasets index 0 refers to dataset A (e.g. MPII) and index 1 refers to dataset B (e.g. CrowdPose).
|
| 117 |
+
|
| 118 |
+
flip_pairs (`torch.tensor`, *optional*):
|
| 119 |
+
Whether to mirror pairs of keypoints (for example, left ear -- right ear).
|
| 120 |
+
|
| 121 |
+
output_attentions (`bool`, *optional*):
|
| 122 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 123 |
+
tensors for more detail.
|
| 124 |
+
output_hidden_states (`bool`, *optional*):
|
| 125 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 126 |
+
more detail.
|
| 127 |
+
return_dict (`bool`, *optional*):
|
| 128 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 129 |
+
"""
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def flip_back(output_flipped, flip_pairs, target_type="gaussian-heatmap"):
|
| 133 |
+
"""Flip the flipped heatmaps back to the original form.
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
output_flipped (`torch.tensor` of shape `(batch_size, num_keypoints, height, width)`):
|
| 137 |
+
The output heatmaps obtained from the flipped images.
|
| 138 |
+
flip_pairs (`torch.Tensor` of shape `(num_keypoints, 2)`):
|
| 139 |
+
Pairs of keypoints which are mirrored (for example, left ear -- right ear).
|
| 140 |
+
target_type (`str`, *optional*, defaults to `"gaussian-heatmap"`):
|
| 141 |
+
Target type to use. Can be gaussian-heatmap or combined-target.
|
| 142 |
+
gaussian-heatmap: Classification target with gaussian distribution.
|
| 143 |
+
combined-target: The combination of classification target (response map) and regression target (offset map).
|
| 144 |
+
Paper ref: Huang et al. The Devil is in the Details: Delving into Unbiased Data Processing for Human Pose Estimation (CVPR 2020).
|
| 145 |
+
|
| 146 |
+
Returns:
|
| 147 |
+
torch.Tensor: heatmaps that flipped back to the original image
|
| 148 |
+
"""
|
| 149 |
+
if target_type not in ["gaussian-heatmap", "combined-target"]:
|
| 150 |
+
raise ValueError("target_type should be gaussian-heatmap or combined-target")
|
| 151 |
+
|
| 152 |
+
if output_flipped.ndim != 4:
|
| 153 |
+
raise ValueError("output_flipped should be [batch_size, num_keypoints, height, width]")
|
| 154 |
+
batch_size, num_keypoints, height, width = output_flipped.shape
|
| 155 |
+
channels = 1
|
| 156 |
+
if target_type == "combined-target":
|
| 157 |
+
channels = 3
|
| 158 |
+
output_flipped[:, 1::3, ...] = -output_flipped[:, 1::3, ...]
|
| 159 |
+
output_flipped = output_flipped.reshape(batch_size, -1, channels, height, width)
|
| 160 |
+
output_flipped_back = output_flipped.clone()
|
| 161 |
+
|
| 162 |
+
# Swap left-right parts
|
| 163 |
+
for left, right in flip_pairs.tolist():
|
| 164 |
+
output_flipped_back[:, left, ...] = output_flipped[:, right, ...]
|
| 165 |
+
output_flipped_back[:, right, ...] = output_flipped[:, left, ...]
|
| 166 |
+
output_flipped_back = output_flipped_back.reshape((batch_size, num_keypoints, height, width))
|
| 167 |
+
# Flip horizontally
|
| 168 |
+
output_flipped_back = output_flipped_back.flip(-1)
|
| 169 |
+
return output_flipped_back
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
class VitPoseSimpleDecoder(nn.Module):
|
| 173 |
+
"""
|
| 174 |
+
Simple decoding head consisting of a ReLU activation, 4x upsampling and a 3x3 convolution, turning the
|
| 175 |
+
feature maps into heatmaps.
|
| 176 |
+
"""
|
| 177 |
+
|
| 178 |
+
def __init__(self, config) -> None:
|
| 179 |
+
super().__init__()
|
| 180 |
+
|
| 181 |
+
self.activation = nn.ReLU()
|
| 182 |
+
self.upsampling = nn.Upsample(scale_factor=config.scale_factor, mode="bilinear", align_corners=False)
|
| 183 |
+
self.conv = nn.Conv2d(
|
| 184 |
+
config.backbone_config.hidden_size, config.num_labels, kernel_size=3, stride=1, padding=1
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
def forward(self, hidden_state: torch.Tensor, flip_pairs: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 188 |
+
# Transform input: ReLU + upsample
|
| 189 |
+
hidden_state = self.activation(hidden_state)
|
| 190 |
+
hidden_state = self.upsampling(hidden_state)
|
| 191 |
+
heatmaps = self.conv(hidden_state)
|
| 192 |
+
|
| 193 |
+
if flip_pairs is not None:
|
| 194 |
+
heatmaps = flip_back(heatmaps, flip_pairs)
|
| 195 |
+
|
| 196 |
+
return heatmaps
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
class VitPoseClassicDecoder(nn.Module):
|
| 200 |
+
"""
|
| 201 |
+
Classic decoding head consisting of a 2 deconvolutional blocks, followed by a 1x1 convolution layer,
|
| 202 |
+
turning the feature maps into heatmaps.
|
| 203 |
+
"""
|
| 204 |
+
|
| 205 |
+
def __init__(self, config: VitPoseConfig):
|
| 206 |
+
super().__init__()
|
| 207 |
+
|
| 208 |
+
self.deconv1 = nn.ConvTranspose2d(
|
| 209 |
+
config.backbone_config.hidden_size, 256, kernel_size=4, stride=2, padding=1, bias=False
|
| 210 |
+
)
|
| 211 |
+
self.batchnorm1 = nn.BatchNorm2d(256)
|
| 212 |
+
self.relu1 = nn.ReLU()
|
| 213 |
+
|
| 214 |
+
self.deconv2 = nn.ConvTranspose2d(256, 256, kernel_size=4, stride=2, padding=1, bias=False)
|
| 215 |
+
self.batchnorm2 = nn.BatchNorm2d(256)
|
| 216 |
+
self.relu2 = nn.ReLU()
|
| 217 |
+
|
| 218 |
+
self.conv = nn.Conv2d(256, config.num_labels, kernel_size=1, stride=1, padding=0)
|
| 219 |
+
|
| 220 |
+
def forward(self, hidden_state: torch.Tensor, flip_pairs: Optional[torch.Tensor] = None):
|
| 221 |
+
hidden_state = self.deconv1(hidden_state)
|
| 222 |
+
hidden_state = self.batchnorm1(hidden_state)
|
| 223 |
+
hidden_state = self.relu1(hidden_state)
|
| 224 |
+
|
| 225 |
+
hidden_state = self.deconv2(hidden_state)
|
| 226 |
+
hidden_state = self.batchnorm2(hidden_state)
|
| 227 |
+
hidden_state = self.relu2(hidden_state)
|
| 228 |
+
|
| 229 |
+
heatmaps = self.conv(hidden_state)
|
| 230 |
+
|
| 231 |
+
if flip_pairs is not None:
|
| 232 |
+
heatmaps = flip_back(heatmaps, flip_pairs)
|
| 233 |
+
|
| 234 |
+
return heatmaps
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
@add_start_docstrings(
|
| 238 |
+
"The VitPose model with a pose estimation head on top.",
|
| 239 |
+
VITPOSE_START_DOCSTRING,
|
| 240 |
+
)
|
| 241 |
+
class VitPoseForPoseEstimation(VitPosePreTrainedModel):
|
| 242 |
+
def __init__(self, config: VitPoseConfig) -> None:
|
| 243 |
+
super().__init__(config)
|
| 244 |
+
|
| 245 |
+
self.backbone = load_backbone(config)
|
| 246 |
+
|
| 247 |
+
# add backbone attributes
|
| 248 |
+
if not hasattr(self.backbone.config, "hidden_size"):
|
| 249 |
+
raise ValueError("The backbone should have a hidden_size attribute")
|
| 250 |
+
if not hasattr(self.backbone.config, "image_size"):
|
| 251 |
+
raise ValueError("The backbone should have an image_size attribute")
|
| 252 |
+
if not hasattr(self.backbone.config, "patch_size"):
|
| 253 |
+
raise ValueError("The backbone should have a patch_size attribute")
|
| 254 |
+
|
| 255 |
+
self.head = VitPoseSimpleDecoder(config) if config.use_simple_decoder else VitPoseClassicDecoder(config)
|
| 256 |
+
|
| 257 |
+
# Initialize weights and apply final processing
|
| 258 |
+
self.post_init()
|
| 259 |
+
|
| 260 |
+
@add_start_docstrings_to_model_forward(VITPOSE_INPUTS_DOCSTRING)
|
| 261 |
+
@replace_return_docstrings(output_type=VitPoseEstimatorOutput, config_class=_CONFIG_FOR_DOC)
|
| 262 |
+
def forward(
|
| 263 |
+
self,
|
| 264 |
+
pixel_values: torch.Tensor,
|
| 265 |
+
dataset_index: Optional[torch.Tensor] = None,
|
| 266 |
+
flip_pairs: Optional[torch.Tensor] = None,
|
| 267 |
+
labels: Optional[torch.Tensor] = None,
|
| 268 |
+
output_attentions: Optional[bool] = None,
|
| 269 |
+
output_hidden_states: Optional[bool] = None,
|
| 270 |
+
return_dict: Optional[bool] = None,
|
| 271 |
+
) -> Union[tuple, VitPoseEstimatorOutput]:
|
| 272 |
+
"""
|
| 273 |
+
Returns:
|
| 274 |
+
|
| 275 |
+
Examples:
|
| 276 |
+
|
| 277 |
+
```python
|
| 278 |
+
>>> from transformers import AutoImageProcessor, VitPoseForPoseEstimation
|
| 279 |
+
>>> import torch
|
| 280 |
+
>>> from PIL import Image
|
| 281 |
+
>>> import requests
|
| 282 |
+
|
| 283 |
+
>>> processor = AutoImageProcessor.from_pretrained("usyd-community/vitpose-base-simple")
|
| 284 |
+
>>> model = VitPoseForPoseEstimation.from_pretrained("usyd-community/vitpose-base-simple")
|
| 285 |
+
|
| 286 |
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 287 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
| 288 |
+
>>> boxes = [[[412.8, 157.61, 53.05, 138.01], [384.43, 172.21, 15.12, 35.74]]]
|
| 289 |
+
>>> inputs = processor(image, boxes=boxes, return_tensors="pt")
|
| 290 |
+
|
| 291 |
+
>>> with torch.no_grad():
|
| 292 |
+
... outputs = model(**inputs)
|
| 293 |
+
>>> heatmaps = outputs.heatmaps
|
| 294 |
+
```"""
|
| 295 |
+
|
| 296 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 297 |
+
output_hidden_states = (
|
| 298 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 299 |
+
)
|
| 300 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 301 |
+
|
| 302 |
+
loss = None
|
| 303 |
+
if labels is not None:
|
| 304 |
+
raise NotImplementedError("Training is not yet supported")
|
| 305 |
+
|
| 306 |
+
outputs = self.backbone.forward_with_filtered_kwargs(
|
| 307 |
+
pixel_values,
|
| 308 |
+
dataset_index=dataset_index,
|
| 309 |
+
output_hidden_states=output_hidden_states,
|
| 310 |
+
output_attentions=output_attentions,
|
| 311 |
+
return_dict=return_dict,
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
# Turn output hidden states in tensor of shape (batch_size, num_channels, height, width)
|
| 315 |
+
sequence_output = outputs.feature_maps[-1] if return_dict else outputs[0][-1]
|
| 316 |
+
batch_size = sequence_output.shape[0]
|
| 317 |
+
patch_height = self.config.backbone_config.image_size[0] // self.config.backbone_config.patch_size[0]
|
| 318 |
+
patch_width = self.config.backbone_config.image_size[1] // self.config.backbone_config.patch_size[1]
|
| 319 |
+
sequence_output = (
|
| 320 |
+
sequence_output.permute(0, 2, 1).reshape(batch_size, -1, patch_height, patch_width).contiguous()
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
heatmaps = self.head(sequence_output, flip_pairs=flip_pairs)
|
| 324 |
+
|
| 325 |
+
if not return_dict:
|
| 326 |
+
if output_hidden_states:
|
| 327 |
+
output = (heatmaps,) + outputs[1:]
|
| 328 |
+
else:
|
| 329 |
+
output = (heatmaps,) + outputs[2:]
|
| 330 |
+
return ((loss,) + output) if loss is not None else output
|
| 331 |
+
|
| 332 |
+
return VitPoseEstimatorOutput(
|
| 333 |
+
loss=loss,
|
| 334 |
+
heatmaps=heatmaps,
|
| 335 |
+
hidden_states=outputs.hidden_states,
|
| 336 |
+
attentions=outputs.attentions,
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
__all__ = ["VitPosePreTrainedModel", "VitPoseForPoseEstimation"]
|
docs/transformers/build/lib/transformers/models/vitpose_backbone/__init__.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa
|
| 2 |
+
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
| 3 |
+
# module, but to preserve other warnings. So, don't check this module at all.
|
| 4 |
+
from typing import TYPE_CHECKING
|
| 5 |
+
|
| 6 |
+
from ...utils import _LazyModule
|
| 7 |
+
from ...utils.import_utils import define_import_structure
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
if TYPE_CHECKING:
|
| 11 |
+
from .configuration_vitpose_backbone import *
|
| 12 |
+
from .modeling_vitpose_backbone import *
|
| 13 |
+
else:
|
| 14 |
+
import sys
|
| 15 |
+
|
| 16 |
+
_file = globals()["__file__"]
|
| 17 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
docs/transformers/build/lib/transformers/models/vitpose_backbone/configuration_vitpose_backbone.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""VitPose backbone configuration"""
|
| 16 |
+
|
| 17 |
+
from ...configuration_utils import PretrainedConfig
|
| 18 |
+
from ...utils import logging
|
| 19 |
+
from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
logger = logging.get_logger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class VitPoseBackboneConfig(BackboneConfigMixin, PretrainedConfig):
|
| 26 |
+
r"""
|
| 27 |
+
This is the configuration class to store the configuration of a [`VitPoseBackbone`]. It is used to instantiate a
|
| 28 |
+
VitPose model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
| 29 |
+
with the defaults will yield a similar configuration to that of the VitPose
|
| 30 |
+
[usyd-community/vitpose-base-simple](https://huggingface.co/usyd-community/vitpose-base-simple) architecture.
|
| 31 |
+
|
| 32 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 33 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
image_size (`int`, *optional*, defaults to `[256, 192]`):
|
| 37 |
+
The size (resolution) of each image.
|
| 38 |
+
patch_size (`List[int]`, *optional*, defaults to `[16, 16]`):
|
| 39 |
+
The size (resolution) of each patch.
|
| 40 |
+
num_channels (`int`, *optional*, defaults to 3):
|
| 41 |
+
The number of input channels.
|
| 42 |
+
hidden_size (`int`, *optional*, defaults to 768):
|
| 43 |
+
Dimensionality of the encoder layers and the pooler layer.
|
| 44 |
+
num_hidden_layers (`int`, *optional*, defaults to 12):
|
| 45 |
+
Number of hidden layers in the Transformer encoder.
|
| 46 |
+
num_attention_heads (`int`, *optional*, defaults to 12):
|
| 47 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 48 |
+
mlp_ratio (`int`, *optional*, defaults to 4):
|
| 49 |
+
The ratio of the hidden size in the feedforward network to the hidden size in the attention layers.
|
| 50 |
+
num_experts (`int`, *optional*, defaults to 1):
|
| 51 |
+
The number of experts in the MoE layer.
|
| 52 |
+
part_features (`int`, *optional*):
|
| 53 |
+
The number of part features to output. Only used in case `num_experts` is greater than 1.
|
| 54 |
+
hidden_act (`str`, *optional*, defaults to `"gelu"`):
|
| 55 |
+
The non-linear activation function in the encoder and pooler. If string, `"gelu"`,
|
| 56 |
+
`"relu"`, `"selu"` and `"gelu_new"` are supported.
|
| 57 |
+
hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
|
| 58 |
+
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
|
| 59 |
+
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
|
| 60 |
+
The dropout ratio for the attention probabilities.
|
| 61 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 62 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 63 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
| 64 |
+
The epsilon used by the layer normalization layers.
|
| 65 |
+
qkv_bias (`bool`, *optional*, defaults to `True`):
|
| 66 |
+
Whether to add a bias to the queries, keys and values.
|
| 67 |
+
out_features (`List[str]`, *optional*):
|
| 68 |
+
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
|
| 69 |
+
(depending on how many stages the model has). If unset and `out_indices` is set, will default to the
|
| 70 |
+
corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the
|
| 71 |
+
same order as defined in the `stage_names` attribute.
|
| 72 |
+
out_indices (`List[int]`, *optional*):
|
| 73 |
+
If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
|
| 74 |
+
many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
|
| 75 |
+
If unset and `out_features` is unset, will default to the last stage. Must be in the
|
| 76 |
+
same order as defined in the `stage_names` attribute.
|
| 77 |
+
|
| 78 |
+
Example:
|
| 79 |
+
|
| 80 |
+
```python
|
| 81 |
+
>>> from transformers import VitPoseBackboneConfig, VitPoseBackbone
|
| 82 |
+
|
| 83 |
+
>>> # Initializing a VitPose configuration
|
| 84 |
+
>>> configuration = VitPoseBackboneConfig()
|
| 85 |
+
|
| 86 |
+
>>> # Initializing a model (with random weights) from the configuration
|
| 87 |
+
>>> model = VitPoseBackbone(configuration)
|
| 88 |
+
|
| 89 |
+
>>> # Accessing the model configuration
|
| 90 |
+
>>> configuration = model.config
|
| 91 |
+
```"""
|
| 92 |
+
|
| 93 |
+
model_type = "vitpose_backbone"
|
| 94 |
+
|
| 95 |
+
def __init__(
|
| 96 |
+
self,
|
| 97 |
+
image_size=[256, 192],
|
| 98 |
+
patch_size=[16, 16],
|
| 99 |
+
num_channels=3,
|
| 100 |
+
hidden_size=768,
|
| 101 |
+
num_hidden_layers=12,
|
| 102 |
+
num_attention_heads=12,
|
| 103 |
+
mlp_ratio=4,
|
| 104 |
+
num_experts=1,
|
| 105 |
+
part_features=256,
|
| 106 |
+
hidden_act="gelu",
|
| 107 |
+
hidden_dropout_prob=0.0,
|
| 108 |
+
attention_probs_dropout_prob=0.0,
|
| 109 |
+
initializer_range=0.02,
|
| 110 |
+
layer_norm_eps=1e-12,
|
| 111 |
+
qkv_bias=True,
|
| 112 |
+
out_features=None,
|
| 113 |
+
out_indices=None,
|
| 114 |
+
**kwargs,
|
| 115 |
+
):
|
| 116 |
+
super().__init__(**kwargs)
|
| 117 |
+
|
| 118 |
+
self.hidden_size = hidden_size
|
| 119 |
+
self.num_hidden_layers = num_hidden_layers
|
| 120 |
+
self.num_attention_heads = num_attention_heads
|
| 121 |
+
self.mlp_ratio = mlp_ratio
|
| 122 |
+
self.num_experts = num_experts
|
| 123 |
+
self.part_features = part_features
|
| 124 |
+
self.hidden_act = hidden_act
|
| 125 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
| 126 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
| 127 |
+
self.initializer_range = initializer_range
|
| 128 |
+
self.layer_norm_eps = layer_norm_eps
|
| 129 |
+
self.image_size = image_size
|
| 130 |
+
self.patch_size = patch_size
|
| 131 |
+
self.num_channels = num_channels
|
| 132 |
+
self.qkv_bias = qkv_bias
|
| 133 |
+
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, num_hidden_layers + 1)]
|
| 134 |
+
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
|
| 135 |
+
out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
__all__ = ["VitPoseBackboneConfig"]
|
docs/transformers/build/lib/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py
ADDED
|
@@ -0,0 +1,579 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 University of Sydney and The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""PyTorch VitPose backbone model.
|
| 16 |
+
|
| 17 |
+
This code is the same as the original Vision Transformer (ViT) with 2 modifications:
|
| 18 |
+
- use of padding=2 in the patch embedding layer
|
| 19 |
+
- addition of a mixture-of-experts MLP layer
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
import collections.abc
|
| 23 |
+
from typing import Callable, Optional, Set, Tuple, Union
|
| 24 |
+
|
| 25 |
+
import torch
|
| 26 |
+
import torch.utils.checkpoint
|
| 27 |
+
from torch import nn
|
| 28 |
+
|
| 29 |
+
from ...activations import ACT2FN
|
| 30 |
+
from ...modeling_outputs import BackboneOutput, BaseModelOutput
|
| 31 |
+
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 32 |
+
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
| 33 |
+
from ...utils import (
|
| 34 |
+
add_start_docstrings,
|
| 35 |
+
add_start_docstrings_to_model_forward,
|
| 36 |
+
logging,
|
| 37 |
+
replace_return_docstrings,
|
| 38 |
+
)
|
| 39 |
+
from ...utils.backbone_utils import BackboneMixin
|
| 40 |
+
from .configuration_vitpose_backbone import VitPoseBackboneConfig
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
logger = logging.get_logger(__name__)
|
| 44 |
+
|
| 45 |
+
# General docstring
|
| 46 |
+
_CONFIG_FOR_DOC = "VitPoseBackboneConfig"
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class VitPoseBackbonePatchEmbeddings(nn.Module):
|
| 50 |
+
"""Image to Patch Embedding."""
|
| 51 |
+
|
| 52 |
+
def __init__(self, config):
|
| 53 |
+
super().__init__()
|
| 54 |
+
|
| 55 |
+
image_size = config.image_size
|
| 56 |
+
patch_size = config.patch_size
|
| 57 |
+
num_channels = config.num_channels
|
| 58 |
+
embed_dim = config.hidden_size
|
| 59 |
+
|
| 60 |
+
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
|
| 61 |
+
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
|
| 62 |
+
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
| 63 |
+
self.image_size = image_size
|
| 64 |
+
self.patch_size = patch_size
|
| 65 |
+
self.num_patches = num_patches
|
| 66 |
+
|
| 67 |
+
self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size, padding=2)
|
| 68 |
+
|
| 69 |
+
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
| 70 |
+
height, width = pixel_values.shape[-2:]
|
| 71 |
+
if height != self.image_size[0] or width != self.image_size[1]:
|
| 72 |
+
raise ValueError(
|
| 73 |
+
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
|
| 74 |
+
)
|
| 75 |
+
embeddings = self.projection(pixel_values)
|
| 76 |
+
|
| 77 |
+
embeddings = embeddings.flatten(2).transpose(1, 2)
|
| 78 |
+
return embeddings
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class VitPoseBackboneEmbeddings(nn.Module):
|
| 82 |
+
"""
|
| 83 |
+
Construct the position and patch embeddings.
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
def __init__(self, config: VitPoseBackboneConfig) -> None:
|
| 87 |
+
super().__init__()
|
| 88 |
+
|
| 89 |
+
self.patch_embeddings = VitPoseBackbonePatchEmbeddings(config)
|
| 90 |
+
num_patches = self.patch_embeddings.num_patches
|
| 91 |
+
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
|
| 92 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 93 |
+
|
| 94 |
+
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
| 95 |
+
embeddings = self.patch_embeddings(pixel_values)
|
| 96 |
+
|
| 97 |
+
# add positional encoding to each token
|
| 98 |
+
embeddings = embeddings + self.position_embeddings[:, 1:] + self.position_embeddings[:, :1]
|
| 99 |
+
|
| 100 |
+
embeddings = self.dropout(embeddings)
|
| 101 |
+
|
| 102 |
+
return embeddings
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
# Copied from transformers.models.vit.modeling_vit.eager_attention_forward
|
| 106 |
+
def eager_attention_forward(
|
| 107 |
+
module: nn.Module,
|
| 108 |
+
query: torch.Tensor,
|
| 109 |
+
key: torch.Tensor,
|
| 110 |
+
value: torch.Tensor,
|
| 111 |
+
attention_mask: Optional[torch.Tensor],
|
| 112 |
+
scaling: float,
|
| 113 |
+
dropout: float = 0.0,
|
| 114 |
+
**kwargs,
|
| 115 |
+
):
|
| 116 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
| 117 |
+
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
|
| 118 |
+
|
| 119 |
+
# Normalize the attention scores to probabilities.
|
| 120 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
| 121 |
+
|
| 122 |
+
# This is actually dropping out entire tokens to attend to, which might
|
| 123 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
| 124 |
+
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
| 125 |
+
|
| 126 |
+
# Mask heads if we want to
|
| 127 |
+
if attention_mask is not None:
|
| 128 |
+
attn_weights = attn_weights * attention_mask
|
| 129 |
+
|
| 130 |
+
attn_output = torch.matmul(attn_weights, value)
|
| 131 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 132 |
+
|
| 133 |
+
return attn_output, attn_weights
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->VitPoseBackbone
|
| 137 |
+
class VitPoseBackboneSelfAttention(nn.Module):
|
| 138 |
+
def __init__(self, config: VitPoseBackboneConfig) -> None:
|
| 139 |
+
super().__init__()
|
| 140 |
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
| 141 |
+
raise ValueError(
|
| 142 |
+
f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
|
| 143 |
+
f"heads {config.num_attention_heads}."
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
self.config = config
|
| 147 |
+
self.num_attention_heads = config.num_attention_heads
|
| 148 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
| 149 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
| 150 |
+
self.dropout_prob = config.attention_probs_dropout_prob
|
| 151 |
+
self.scaling = self.attention_head_size**-0.5
|
| 152 |
+
self.is_causal = False
|
| 153 |
+
|
| 154 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
| 155 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
| 156 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
| 157 |
+
|
| 158 |
+
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
| 159 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
| 160 |
+
x = x.view(new_x_shape)
|
| 161 |
+
return x.permute(0, 2, 1, 3)
|
| 162 |
+
|
| 163 |
+
def forward(
|
| 164 |
+
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
|
| 165 |
+
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
| 166 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
| 167 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
| 168 |
+
query_layer = self.transpose_for_scores(self.query(hidden_states))
|
| 169 |
+
|
| 170 |
+
attention_interface: Callable = eager_attention_forward
|
| 171 |
+
if self.config._attn_implementation != "eager":
|
| 172 |
+
if self.config._attn_implementation == "sdpa" and output_attentions:
|
| 173 |
+
logger.warning_once(
|
| 174 |
+
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
| 175 |
+
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
| 176 |
+
)
|
| 177 |
+
else:
|
| 178 |
+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
| 179 |
+
|
| 180 |
+
context_layer, attention_probs = attention_interface(
|
| 181 |
+
self,
|
| 182 |
+
query_layer,
|
| 183 |
+
key_layer,
|
| 184 |
+
value_layer,
|
| 185 |
+
head_mask,
|
| 186 |
+
is_causal=self.is_causal,
|
| 187 |
+
scaling=self.scaling,
|
| 188 |
+
dropout=0.0 if not self.training else self.dropout_prob,
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
| 192 |
+
context_layer = context_layer.reshape(new_context_layer_shape)
|
| 193 |
+
|
| 194 |
+
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
| 195 |
+
|
| 196 |
+
return outputs
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->VitPoseBackbone
|
| 200 |
+
class VitPoseBackboneSelfOutput(nn.Module):
|
| 201 |
+
"""
|
| 202 |
+
The residual connection is defined in VitPoseBackboneLayer instead of here (as is the case with other models), due to the
|
| 203 |
+
layernorm applied before each block.
|
| 204 |
+
"""
|
| 205 |
+
|
| 206 |
+
def __init__(self, config: VitPoseBackboneConfig) -> None:
|
| 207 |
+
super().__init__()
|
| 208 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 209 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 210 |
+
|
| 211 |
+
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
| 212 |
+
hidden_states = self.dense(hidden_states)
|
| 213 |
+
hidden_states = self.dropout(hidden_states)
|
| 214 |
+
|
| 215 |
+
return hidden_states
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
# Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->VitPoseBackbone
|
| 219 |
+
class VitPoseBackboneAttention(nn.Module):
|
| 220 |
+
def __init__(self, config: VitPoseBackboneConfig) -> None:
|
| 221 |
+
super().__init__()
|
| 222 |
+
self.attention = VitPoseBackboneSelfAttention(config)
|
| 223 |
+
self.output = VitPoseBackboneSelfOutput(config)
|
| 224 |
+
self.pruned_heads = set()
|
| 225 |
+
|
| 226 |
+
def prune_heads(self, heads: Set[int]) -> None:
|
| 227 |
+
if len(heads) == 0:
|
| 228 |
+
return
|
| 229 |
+
heads, index = find_pruneable_heads_and_indices(
|
| 230 |
+
heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
# Prune linear layers
|
| 234 |
+
self.attention.query = prune_linear_layer(self.attention.query, index)
|
| 235 |
+
self.attention.key = prune_linear_layer(self.attention.key, index)
|
| 236 |
+
self.attention.value = prune_linear_layer(self.attention.value, index)
|
| 237 |
+
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
| 238 |
+
|
| 239 |
+
# Update hyper params and store pruned heads
|
| 240 |
+
self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
|
| 241 |
+
self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
|
| 242 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
| 243 |
+
|
| 244 |
+
def forward(
|
| 245 |
+
self,
|
| 246 |
+
hidden_states: torch.Tensor,
|
| 247 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 248 |
+
output_attentions: bool = False,
|
| 249 |
+
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
| 250 |
+
self_outputs = self.attention(hidden_states, head_mask, output_attentions)
|
| 251 |
+
|
| 252 |
+
attention_output = self.output(self_outputs[0], hidden_states)
|
| 253 |
+
|
| 254 |
+
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
| 255 |
+
return outputs
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
class VitPoseBackboneMoeMLP(nn.Module):
|
| 259 |
+
def __init__(self, config: VitPoseBackboneConfig):
|
| 260 |
+
super().__init__()
|
| 261 |
+
|
| 262 |
+
in_features = out_features = config.hidden_size
|
| 263 |
+
hidden_features = int(config.hidden_size * config.mlp_ratio)
|
| 264 |
+
|
| 265 |
+
num_experts = config.num_experts
|
| 266 |
+
part_features = config.part_features
|
| 267 |
+
|
| 268 |
+
self.part_features = part_features
|
| 269 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 270 |
+
self.act = ACT2FN[config.hidden_act]
|
| 271 |
+
self.fc2 = nn.Linear(hidden_features, out_features - part_features)
|
| 272 |
+
self.drop = nn.Dropout(config.hidden_dropout_prob)
|
| 273 |
+
|
| 274 |
+
self.num_experts = num_experts
|
| 275 |
+
experts = [nn.Linear(hidden_features, part_features) for _ in range(num_experts)]
|
| 276 |
+
self.experts = nn.ModuleList(experts)
|
| 277 |
+
|
| 278 |
+
def forward(self, hidden_state: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
|
| 279 |
+
expert_hidden_state = torch.zeros_like(hidden_state[:, :, -self.part_features :])
|
| 280 |
+
|
| 281 |
+
hidden_state = self.fc1(hidden_state)
|
| 282 |
+
hidden_state = self.act(hidden_state)
|
| 283 |
+
shared_hidden_state = self.fc2(hidden_state)
|
| 284 |
+
indices = indices.view(-1, 1, 1)
|
| 285 |
+
|
| 286 |
+
# to support ddp training
|
| 287 |
+
for i in range(self.num_experts):
|
| 288 |
+
selected_index = indices == i
|
| 289 |
+
current_hidden_state = self.experts[i](hidden_state) * selected_index
|
| 290 |
+
expert_hidden_state = expert_hidden_state + current_hidden_state
|
| 291 |
+
|
| 292 |
+
hidden_state = torch.cat([shared_hidden_state, expert_hidden_state], dim=-1)
|
| 293 |
+
|
| 294 |
+
return hidden_state
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
class VitPoseBackboneMLP(nn.Module):
|
| 298 |
+
def __init__(self, config: VitPoseBackboneConfig) -> None:
|
| 299 |
+
super().__init__()
|
| 300 |
+
in_features = out_features = config.hidden_size
|
| 301 |
+
hidden_features = int(config.hidden_size * config.mlp_ratio)
|
| 302 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=True)
|
| 303 |
+
self.activation = ACT2FN[config.hidden_act]
|
| 304 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=True)
|
| 305 |
+
|
| 306 |
+
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
| 307 |
+
hidden_state = self.fc1(hidden_state)
|
| 308 |
+
hidden_state = self.activation(hidden_state)
|
| 309 |
+
hidden_state = self.fc2(hidden_state)
|
| 310 |
+
return hidden_state
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
class VitPoseBackboneLayer(nn.Module):
|
| 314 |
+
def __init__(self, config: VitPoseBackboneConfig) -> None:
|
| 315 |
+
super().__init__()
|
| 316 |
+
self.num_experts = config.num_experts
|
| 317 |
+
self.attention = VitPoseBackboneAttention(config)
|
| 318 |
+
self.mlp = VitPoseBackboneMLP(config) if self.num_experts == 1 else VitPoseBackboneMoeMLP(config)
|
| 319 |
+
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 320 |
+
self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 321 |
+
|
| 322 |
+
def forward(
|
| 323 |
+
self,
|
| 324 |
+
hidden_states: torch.Tensor,
|
| 325 |
+
dataset_index: Optional[torch.Tensor] = None,
|
| 326 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 327 |
+
output_attentions: bool = False,
|
| 328 |
+
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
| 329 |
+
# Validate dataset_index when using multiple experts
|
| 330 |
+
if self.num_experts > 1 and dataset_index is None:
|
| 331 |
+
raise ValueError(
|
| 332 |
+
"dataset_index must be provided when using multiple experts "
|
| 333 |
+
f"(num_experts={self.num_experts}). Please provide dataset_index "
|
| 334 |
+
"to the forward pass."
|
| 335 |
+
)
|
| 336 |
+
self_attention_outputs = self.attention(
|
| 337 |
+
self.layernorm_before(hidden_states), # in VitPoseBackbone, layernorm is applied before self-attention
|
| 338 |
+
head_mask,
|
| 339 |
+
output_attentions=output_attentions,
|
| 340 |
+
)
|
| 341 |
+
attention_output = self_attention_outputs[0]
|
| 342 |
+
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
| 343 |
+
|
| 344 |
+
# first residual connection
|
| 345 |
+
hidden_states = attention_output + hidden_states
|
| 346 |
+
|
| 347 |
+
layer_output = self.layernorm_after(hidden_states)
|
| 348 |
+
if self.num_experts == 1:
|
| 349 |
+
layer_output = self.mlp(layer_output)
|
| 350 |
+
else:
|
| 351 |
+
layer_output = self.mlp(layer_output, indices=dataset_index)
|
| 352 |
+
|
| 353 |
+
# second residual connection
|
| 354 |
+
layer_output = layer_output + hidden_states
|
| 355 |
+
|
| 356 |
+
outputs = (layer_output,) + outputs
|
| 357 |
+
|
| 358 |
+
return outputs
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
# Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->VitPoseBackbone
|
| 362 |
+
class VitPoseBackboneEncoder(nn.Module):
|
| 363 |
+
def __init__(self, config: VitPoseBackboneConfig) -> None:
|
| 364 |
+
super().__init__()
|
| 365 |
+
self.config = config
|
| 366 |
+
self.layer = nn.ModuleList([VitPoseBackboneLayer(config) for _ in range(config.num_hidden_layers)])
|
| 367 |
+
self.gradient_checkpointing = False
|
| 368 |
+
|
| 369 |
+
# Ignore copy
|
| 370 |
+
def forward(
|
| 371 |
+
self,
|
| 372 |
+
hidden_states: torch.Tensor,
|
| 373 |
+
dataset_index: Optional[torch.Tensor] = None,
|
| 374 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 375 |
+
output_attentions: bool = False,
|
| 376 |
+
output_hidden_states: bool = False,
|
| 377 |
+
return_dict: bool = True,
|
| 378 |
+
) -> Union[tuple, BaseModelOutput]:
|
| 379 |
+
all_hidden_states = () if output_hidden_states else None
|
| 380 |
+
all_self_attentions = () if output_attentions else None
|
| 381 |
+
|
| 382 |
+
for i, layer_module in enumerate(self.layer):
|
| 383 |
+
if output_hidden_states:
|
| 384 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 385 |
+
|
| 386 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
| 387 |
+
|
| 388 |
+
if self.gradient_checkpointing and self.training:
|
| 389 |
+
layer_outputs = self._gradient_checkpointing_func(
|
| 390 |
+
layer_module.__call__,
|
| 391 |
+
hidden_states,
|
| 392 |
+
dataset_index,
|
| 393 |
+
layer_head_mask,
|
| 394 |
+
output_attentions,
|
| 395 |
+
)
|
| 396 |
+
else:
|
| 397 |
+
layer_outputs = layer_module(hidden_states, dataset_index, layer_head_mask, output_attentions)
|
| 398 |
+
|
| 399 |
+
hidden_states = layer_outputs[0]
|
| 400 |
+
|
| 401 |
+
if output_attentions:
|
| 402 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
| 403 |
+
|
| 404 |
+
if output_hidden_states:
|
| 405 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 406 |
+
|
| 407 |
+
if not return_dict:
|
| 408 |
+
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
|
| 409 |
+
return BaseModelOutput(
|
| 410 |
+
last_hidden_state=hidden_states,
|
| 411 |
+
hidden_states=all_hidden_states,
|
| 412 |
+
attentions=all_self_attentions,
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
class VitPoseBackbonePreTrainedModel(PreTrainedModel):
|
| 417 |
+
"""
|
| 418 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 419 |
+
models.
|
| 420 |
+
"""
|
| 421 |
+
|
| 422 |
+
config_class = VitPoseBackboneConfig
|
| 423 |
+
base_model_prefix = "vit"
|
| 424 |
+
main_input_name = "pixel_values"
|
| 425 |
+
supports_gradient_checkpointing = True
|
| 426 |
+
_no_split_modules = ["VitPoseBackboneEmbeddings", "VitPoseBackboneLayer"]
|
| 427 |
+
_supports_sdpa = True
|
| 428 |
+
_supports_flash_attn_2 = True
|
| 429 |
+
|
| 430 |
+
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm, VitPoseBackboneEmbeddings]) -> None:
|
| 431 |
+
"""Initialize the weights"""
|
| 432 |
+
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
| 433 |
+
# Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
|
| 434 |
+
# `trunc_normal_cpu` not implemented in `half` issues
|
| 435 |
+
module.weight.data = nn.init.trunc_normal_(
|
| 436 |
+
module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
|
| 437 |
+
).to(module.weight.dtype)
|
| 438 |
+
if module.bias is not None:
|
| 439 |
+
module.bias.data.zero_()
|
| 440 |
+
elif isinstance(module, nn.LayerNorm):
|
| 441 |
+
module.bias.data.zero_()
|
| 442 |
+
module.weight.data.fill_(1.0)
|
| 443 |
+
elif isinstance(module, VitPoseBackboneEmbeddings):
|
| 444 |
+
module.position_embeddings.data = nn.init.trunc_normal_(
|
| 445 |
+
module.position_embeddings.data.to(torch.float32),
|
| 446 |
+
mean=0.0,
|
| 447 |
+
std=self.config.initializer_range,
|
| 448 |
+
).to(module.position_embeddings.dtype)
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
VITPOSE_BACKBONE_START_DOCSTRING = r"""
|
| 452 |
+
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
|
| 453 |
+
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
| 454 |
+
behavior.
|
| 455 |
+
|
| 456 |
+
Parameters:
|
| 457 |
+
config ([`VitPoseBackboneConfig`]): Model configuration class with all the parameters of the model.
|
| 458 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 459 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 460 |
+
"""
|
| 461 |
+
|
| 462 |
+
VITPOSE_BACKBONE_INPUTS_DOCSTRING = r"""
|
| 463 |
+
Args:
|
| 464 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
| 465 |
+
Pixel values.
|
| 466 |
+
|
| 467 |
+
dataset_index (`torch.Tensor` of shape `(batch_size,)`):
|
| 468 |
+
Index to use in the Mixture-of-Experts (MoE) blocks of the backbone.
|
| 469 |
+
|
| 470 |
+
This corresponds to the dataset index used during training, e.g. index 0 refers to COCO.
|
| 471 |
+
|
| 472 |
+
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
| 473 |
+
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
| 474 |
+
|
| 475 |
+
- 1 indicates the head is **not masked**,
|
| 476 |
+
- 0 indicates the head is **masked**.
|
| 477 |
+
|
| 478 |
+
output_attentions (`bool`, *optional*):
|
| 479 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 480 |
+
tensors for more detail.
|
| 481 |
+
output_hidden_states (`bool`, *optional*):
|
| 482 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 483 |
+
more detail.
|
| 484 |
+
return_dict (`bool`, *optional*):
|
| 485 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 486 |
+
"""
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
@add_start_docstrings(
|
| 490 |
+
"The VitPose backbone useful for downstream tasks.",
|
| 491 |
+
VITPOSE_BACKBONE_START_DOCSTRING,
|
| 492 |
+
)
|
| 493 |
+
class VitPoseBackbone(VitPoseBackbonePreTrainedModel, BackboneMixin):
|
| 494 |
+
def __init__(self, config: VitPoseBackboneConfig):
|
| 495 |
+
super().__init__(config)
|
| 496 |
+
super()._init_backbone(config)
|
| 497 |
+
|
| 498 |
+
self.num_features = [config.hidden_size for _ in range(config.num_hidden_layers + 1)]
|
| 499 |
+
self.embeddings = VitPoseBackboneEmbeddings(config)
|
| 500 |
+
self.encoder = VitPoseBackboneEncoder(config)
|
| 501 |
+
|
| 502 |
+
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 503 |
+
|
| 504 |
+
# Initialize weights and apply final processing
|
| 505 |
+
self.post_init()
|
| 506 |
+
|
| 507 |
+
@add_start_docstrings_to_model_forward(VITPOSE_BACKBONE_INPUTS_DOCSTRING)
|
| 508 |
+
@replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
|
| 509 |
+
def forward(
|
| 510 |
+
self,
|
| 511 |
+
pixel_values: torch.Tensor,
|
| 512 |
+
dataset_index: Optional[torch.Tensor] = None,
|
| 513 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 514 |
+
output_attentions: Optional[bool] = None,
|
| 515 |
+
output_hidden_states: Optional[bool] = None,
|
| 516 |
+
return_dict: Optional[bool] = None,
|
| 517 |
+
):
|
| 518 |
+
"""
|
| 519 |
+
Returns:
|
| 520 |
+
|
| 521 |
+
Examples:
|
| 522 |
+
|
| 523 |
+
```python
|
| 524 |
+
>>> from transformers import VitPoseBackboneConfig, VitPoseBackbone
|
| 525 |
+
>>> import torch
|
| 526 |
+
|
| 527 |
+
>>> config = VitPoseBackboneConfig(out_indices=[-1])
|
| 528 |
+
>>> model = VitPoseBackbone(config)
|
| 529 |
+
|
| 530 |
+
>>> pixel_values = torch.randn(1, 3, 256, 192)
|
| 531 |
+
>>> dataset_index = torch.tensor([1])
|
| 532 |
+
>>> outputs = model(pixel_values, dataset_index)
|
| 533 |
+
```"""
|
| 534 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 535 |
+
output_hidden_states = (
|
| 536 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 537 |
+
)
|
| 538 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 539 |
+
|
| 540 |
+
# Prepare head mask if needed
|
| 541 |
+
# 1.0 in head_mask indicate we keep the head
|
| 542 |
+
# attention_probs has shape bsz x n_heads x N x N
|
| 543 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
| 544 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
| 545 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
| 546 |
+
|
| 547 |
+
embedding_output = self.embeddings(pixel_values)
|
| 548 |
+
|
| 549 |
+
outputs = self.encoder(
|
| 550 |
+
embedding_output,
|
| 551 |
+
dataset_index=dataset_index,
|
| 552 |
+
head_mask=head_mask,
|
| 553 |
+
output_attentions=output_attentions,
|
| 554 |
+
output_hidden_states=True,
|
| 555 |
+
return_dict=return_dict,
|
| 556 |
+
)
|
| 557 |
+
hidden_states = outputs.hidden_states if return_dict else outputs[1]
|
| 558 |
+
|
| 559 |
+
feature_maps = ()
|
| 560 |
+
for stage, hidden_state in zip(self.stage_names, hidden_states):
|
| 561 |
+
if stage in self.out_features:
|
| 562 |
+
hidden_state = self.layernorm(hidden_state)
|
| 563 |
+
feature_maps += (hidden_state,)
|
| 564 |
+
|
| 565 |
+
if not return_dict:
|
| 566 |
+
if output_hidden_states:
|
| 567 |
+
output = (feature_maps,) + outputs[1:]
|
| 568 |
+
else:
|
| 569 |
+
output = (feature_maps,) + outputs[2:]
|
| 570 |
+
return output
|
| 571 |
+
|
| 572 |
+
return BackboneOutput(
|
| 573 |
+
feature_maps=feature_maps,
|
| 574 |
+
hidden_states=outputs.hidden_states if output_hidden_states else None,
|
| 575 |
+
attentions=outputs.attentions,
|
| 576 |
+
)
|
| 577 |
+
|
| 578 |
+
|
| 579 |
+
__all__ = ["VitPoseBackbonePreTrainedModel", "VitPoseBackbone"]
|
docs/transformers/build/lib/transformers/models/vits/__init__.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_vits import *
|
| 22 |
+
from .modeling_vits import *
|
| 23 |
+
from .tokenization_vits import *
|
| 24 |
+
else:
|
| 25 |
+
import sys
|
| 26 |
+
|
| 27 |
+
_file = globals()["__file__"]
|
| 28 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
docs/transformers/build/lib/transformers/models/vits/configuration_vits.py
ADDED
|
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The Kakao Enterprise Authors and the HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""VITS model configuration"""
|
| 16 |
+
|
| 17 |
+
from ...configuration_utils import PretrainedConfig
|
| 18 |
+
from ...utils import logging
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
logger = logging.get_logger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class VitsConfig(PretrainedConfig):
|
| 25 |
+
r"""
|
| 26 |
+
This is the configuration class to store the configuration of a [`VitsModel`]. It is used to instantiate a VITS
|
| 27 |
+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
| 28 |
+
defaults will yield a similar configuration to that of the VITS
|
| 29 |
+
[facebook/mms-tts-eng](https://huggingface.co/facebook/mms-tts-eng) architecture.
|
| 30 |
+
|
| 31 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 32 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
vocab_size (`int`, *optional*, defaults to 38):
|
| 36 |
+
Vocabulary size of the VITS model. Defines the number of different tokens that can be represented by the
|
| 37 |
+
`inputs_ids` passed to the forward method of [`VitsModel`].
|
| 38 |
+
hidden_size (`int`, *optional*, defaults to 192):
|
| 39 |
+
Dimensionality of the text encoder layers.
|
| 40 |
+
num_hidden_layers (`int`, *optional*, defaults to 6):
|
| 41 |
+
Number of hidden layers in the Transformer encoder.
|
| 42 |
+
num_attention_heads (`int`, *optional*, defaults to 2):
|
| 43 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 44 |
+
window_size (`int`, *optional*, defaults to 4):
|
| 45 |
+
Window size for the relative positional embeddings in the attention layers of the Transformer encoder.
|
| 46 |
+
use_bias (`bool`, *optional*, defaults to `True`):
|
| 47 |
+
Whether to use bias in the key, query, value projection layers in the Transformer encoder.
|
| 48 |
+
ffn_dim (`int`, *optional*, defaults to 768):
|
| 49 |
+
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
| 50 |
+
layerdrop (`float`, *optional*, defaults to 0.1):
|
| 51 |
+
The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
|
| 52 |
+
for more details.
|
| 53 |
+
ffn_kernel_size (`int`, *optional*, defaults to 3):
|
| 54 |
+
Kernel size of the 1D convolution layers used by the feed-forward network in the Transformer encoder.
|
| 55 |
+
flow_size (`int`, *optional*, defaults to 192):
|
| 56 |
+
Dimensionality of the flow layers.
|
| 57 |
+
spectrogram_bins (`int`, *optional*, defaults to 513):
|
| 58 |
+
Number of frequency bins in the target spectrogram.
|
| 59 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"relu"`):
|
| 60 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 61 |
+
`"relu"`, `"selu"` and `"gelu_new"` are supported.
|
| 62 |
+
hidden_dropout (`float`, *optional*, defaults to 0.1):
|
| 63 |
+
The dropout probability for all fully connected layers in the embeddings and encoder.
|
| 64 |
+
attention_dropout (`float`, *optional*, defaults to 0.1):
|
| 65 |
+
The dropout ratio for the attention probabilities.
|
| 66 |
+
activation_dropout (`float`, *optional*, defaults to 0.1):
|
| 67 |
+
The dropout ratio for activations inside the fully connected layer.
|
| 68 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 69 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 70 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-05):
|
| 71 |
+
The epsilon used by the layer normalization layers.
|
| 72 |
+
use_stochastic_duration_prediction (`bool`, *optional*, defaults to `True`):
|
| 73 |
+
Whether to use the stochastic duration prediction module or the regular duration predictor.
|
| 74 |
+
num_speakers (`int`, *optional*, defaults to 1):
|
| 75 |
+
Number of speakers if this is a multi-speaker model.
|
| 76 |
+
speaker_embedding_size (`int`, *optional*, defaults to 0):
|
| 77 |
+
Number of channels used by the speaker embeddings. Is zero for single-speaker models.
|
| 78 |
+
upsample_initial_channel (`int`, *optional*, defaults to 512):
|
| 79 |
+
The number of input channels into the HiFi-GAN upsampling network.
|
| 80 |
+
upsample_rates (`Tuple[int]` or `List[int]`, *optional*, defaults to `[8, 8, 2, 2]`):
|
| 81 |
+
A tuple of integers defining the stride of each 1D convolutional layer in the HiFi-GAN upsampling network.
|
| 82 |
+
The length of `upsample_rates` defines the number of convolutional layers and has to match the length of
|
| 83 |
+
`upsample_kernel_sizes`.
|
| 84 |
+
upsample_kernel_sizes (`Tuple[int]` or `List[int]`, *optional*, defaults to `[16, 16, 4, 4]`):
|
| 85 |
+
A tuple of integers defining the kernel size of each 1D convolutional layer in the HiFi-GAN upsampling
|
| 86 |
+
network. The length of `upsample_kernel_sizes` defines the number of convolutional layers and has to match
|
| 87 |
+
the length of `upsample_rates`.
|
| 88 |
+
resblock_kernel_sizes (`Tuple[int]` or `List[int]`, *optional*, defaults to `[3, 7, 11]`):
|
| 89 |
+
A tuple of integers defining the kernel sizes of the 1D convolutional layers in the HiFi-GAN
|
| 90 |
+
multi-receptive field fusion (MRF) module.
|
| 91 |
+
resblock_dilation_sizes (`Tuple[Tuple[int]]` or `List[List[int]]`, *optional*, defaults to `[[1, 3, 5], [1, 3, 5], [1, 3, 5]]`):
|
| 92 |
+
A nested tuple of integers defining the dilation rates of the dilated 1D convolutional layers in the
|
| 93 |
+
HiFi-GAN multi-receptive field fusion (MRF) module.
|
| 94 |
+
leaky_relu_slope (`float`, *optional*, defaults to 0.1):
|
| 95 |
+
The angle of the negative slope used by the leaky ReLU activation.
|
| 96 |
+
depth_separable_channels (`int`, *optional*, defaults to 2):
|
| 97 |
+
Number of channels to use in each depth-separable block.
|
| 98 |
+
depth_separable_num_layers (`int`, *optional*, defaults to 3):
|
| 99 |
+
Number of convolutional layers to use in each depth-separable block.
|
| 100 |
+
duration_predictor_flow_bins (`int`, *optional*, defaults to 10):
|
| 101 |
+
Number of channels to map using the unonstrained rational spline in the duration predictor model.
|
| 102 |
+
duration_predictor_tail_bound (`float`, *optional*, defaults to 5.0):
|
| 103 |
+
Value of the tail bin boundary when computing the unconstrained rational spline in the duration predictor
|
| 104 |
+
model.
|
| 105 |
+
duration_predictor_kernel_size (`int`, *optional*, defaults to 3):
|
| 106 |
+
Kernel size of the 1D convolution layers used in the duration predictor model.
|
| 107 |
+
duration_predictor_dropout (`float`, *optional*, defaults to 0.5):
|
| 108 |
+
The dropout ratio for the duration predictor model.
|
| 109 |
+
duration_predictor_num_flows (`int`, *optional*, defaults to 4):
|
| 110 |
+
Number of flow stages used by the duration predictor model.
|
| 111 |
+
duration_predictor_filter_channels (`int`, *optional*, defaults to 256):
|
| 112 |
+
Number of channels for the convolution layers used in the duration predictor model.
|
| 113 |
+
prior_encoder_num_flows (`int`, *optional*, defaults to 4):
|
| 114 |
+
Number of flow stages used by the prior encoder flow model.
|
| 115 |
+
prior_encoder_num_wavenet_layers (`int`, *optional*, defaults to 4):
|
| 116 |
+
Number of WaveNet layers used by the prior encoder flow model.
|
| 117 |
+
posterior_encoder_num_wavenet_layers (`int`, *optional*, defaults to 16):
|
| 118 |
+
Number of WaveNet layers used by the posterior encoder model.
|
| 119 |
+
wavenet_kernel_size (`int`, *optional*, defaults to 5):
|
| 120 |
+
Kernel size of the 1D convolution layers used in the WaveNet model.
|
| 121 |
+
wavenet_dilation_rate (`int`, *optional*, defaults to 1):
|
| 122 |
+
Dilation rates of the dilated 1D convolutional layers used in the WaveNet model.
|
| 123 |
+
wavenet_dropout (`float`, *optional*, defaults to 0.0):
|
| 124 |
+
The dropout ratio for the WaveNet layers.
|
| 125 |
+
speaking_rate (`float`, *optional*, defaults to 1.0):
|
| 126 |
+
Speaking rate. Larger values give faster synthesised speech.
|
| 127 |
+
noise_scale (`float`, *optional*, defaults to 0.667):
|
| 128 |
+
How random the speech prediction is. Larger values create more variation in the predicted speech.
|
| 129 |
+
noise_scale_duration (`float`, *optional*, defaults to 0.8):
|
| 130 |
+
How random the duration prediction is. Larger values create more variation in the predicted durations.
|
| 131 |
+
sampling_rate (`int`, *optional*, defaults to 16000):
|
| 132 |
+
The sampling rate at which the output audio waveform is digitalized expressed in hertz (Hz).
|
| 133 |
+
|
| 134 |
+
Example:
|
| 135 |
+
|
| 136 |
+
```python
|
| 137 |
+
>>> from transformers import VitsModel, VitsConfig
|
| 138 |
+
|
| 139 |
+
>>> # Initializing a "facebook/mms-tts-eng" style configuration
|
| 140 |
+
>>> configuration = VitsConfig()
|
| 141 |
+
|
| 142 |
+
>>> # Initializing a model (with random weights) from the "facebook/mms-tts-eng" style configuration
|
| 143 |
+
>>> model = VitsModel(configuration)
|
| 144 |
+
|
| 145 |
+
>>> # Accessing the model configuration
|
| 146 |
+
>>> configuration = model.config
|
| 147 |
+
```"""
|
| 148 |
+
|
| 149 |
+
model_type = "vits"
|
| 150 |
+
|
| 151 |
+
def __init__(
|
| 152 |
+
self,
|
| 153 |
+
vocab_size=38,
|
| 154 |
+
hidden_size=192,
|
| 155 |
+
num_hidden_layers=6,
|
| 156 |
+
num_attention_heads=2,
|
| 157 |
+
window_size=4,
|
| 158 |
+
use_bias=True,
|
| 159 |
+
ffn_dim=768,
|
| 160 |
+
layerdrop=0.1,
|
| 161 |
+
ffn_kernel_size=3,
|
| 162 |
+
flow_size=192,
|
| 163 |
+
spectrogram_bins=513,
|
| 164 |
+
hidden_act="relu",
|
| 165 |
+
hidden_dropout=0.1,
|
| 166 |
+
attention_dropout=0.1,
|
| 167 |
+
activation_dropout=0.1,
|
| 168 |
+
initializer_range=0.02,
|
| 169 |
+
layer_norm_eps=1e-5,
|
| 170 |
+
use_stochastic_duration_prediction=True,
|
| 171 |
+
num_speakers=1,
|
| 172 |
+
speaker_embedding_size=0,
|
| 173 |
+
upsample_initial_channel=512,
|
| 174 |
+
upsample_rates=[8, 8, 2, 2],
|
| 175 |
+
upsample_kernel_sizes=[16, 16, 4, 4],
|
| 176 |
+
resblock_kernel_sizes=[3, 7, 11],
|
| 177 |
+
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
| 178 |
+
leaky_relu_slope=0.1,
|
| 179 |
+
depth_separable_channels=2,
|
| 180 |
+
depth_separable_num_layers=3,
|
| 181 |
+
duration_predictor_flow_bins=10,
|
| 182 |
+
duration_predictor_tail_bound=5.0,
|
| 183 |
+
duration_predictor_kernel_size=3,
|
| 184 |
+
duration_predictor_dropout=0.5,
|
| 185 |
+
duration_predictor_num_flows=4,
|
| 186 |
+
duration_predictor_filter_channels=256,
|
| 187 |
+
prior_encoder_num_flows=4,
|
| 188 |
+
prior_encoder_num_wavenet_layers=4,
|
| 189 |
+
posterior_encoder_num_wavenet_layers=16,
|
| 190 |
+
wavenet_kernel_size=5,
|
| 191 |
+
wavenet_dilation_rate=1,
|
| 192 |
+
wavenet_dropout=0.0,
|
| 193 |
+
speaking_rate=1.0,
|
| 194 |
+
noise_scale=0.667,
|
| 195 |
+
noise_scale_duration=0.8,
|
| 196 |
+
sampling_rate=16_000,
|
| 197 |
+
**kwargs,
|
| 198 |
+
):
|
| 199 |
+
self.vocab_size = vocab_size
|
| 200 |
+
self.hidden_size = hidden_size
|
| 201 |
+
self.num_hidden_layers = num_hidden_layers
|
| 202 |
+
self.num_attention_heads = num_attention_heads
|
| 203 |
+
self.window_size = window_size
|
| 204 |
+
self.use_bias = use_bias
|
| 205 |
+
self.ffn_dim = ffn_dim
|
| 206 |
+
self.layerdrop = layerdrop
|
| 207 |
+
self.ffn_kernel_size = ffn_kernel_size
|
| 208 |
+
self.flow_size = flow_size
|
| 209 |
+
self.spectrogram_bins = spectrogram_bins
|
| 210 |
+
self.hidden_act = hidden_act
|
| 211 |
+
self.hidden_dropout = hidden_dropout
|
| 212 |
+
self.attention_dropout = attention_dropout
|
| 213 |
+
self.activation_dropout = activation_dropout
|
| 214 |
+
self.initializer_range = initializer_range
|
| 215 |
+
self.layer_norm_eps = layer_norm_eps
|
| 216 |
+
self.use_stochastic_duration_prediction = use_stochastic_duration_prediction
|
| 217 |
+
self.num_speakers = num_speakers
|
| 218 |
+
self.speaker_embedding_size = speaker_embedding_size
|
| 219 |
+
self.upsample_initial_channel = upsample_initial_channel
|
| 220 |
+
self.upsample_rates = upsample_rates
|
| 221 |
+
self.upsample_kernel_sizes = upsample_kernel_sizes
|
| 222 |
+
self.resblock_kernel_sizes = resblock_kernel_sizes
|
| 223 |
+
self.resblock_dilation_sizes = resblock_dilation_sizes
|
| 224 |
+
self.leaky_relu_slope = leaky_relu_slope
|
| 225 |
+
self.depth_separable_channels = depth_separable_channels
|
| 226 |
+
self.depth_separable_num_layers = depth_separable_num_layers
|
| 227 |
+
self.duration_predictor_flow_bins = duration_predictor_flow_bins
|
| 228 |
+
self.duration_predictor_tail_bound = duration_predictor_tail_bound
|
| 229 |
+
self.duration_predictor_kernel_size = duration_predictor_kernel_size
|
| 230 |
+
self.duration_predictor_dropout = duration_predictor_dropout
|
| 231 |
+
self.duration_predictor_num_flows = duration_predictor_num_flows
|
| 232 |
+
self.duration_predictor_filter_channels = duration_predictor_filter_channels
|
| 233 |
+
self.prior_encoder_num_flows = prior_encoder_num_flows
|
| 234 |
+
self.prior_encoder_num_wavenet_layers = prior_encoder_num_wavenet_layers
|
| 235 |
+
self.posterior_encoder_num_wavenet_layers = posterior_encoder_num_wavenet_layers
|
| 236 |
+
self.wavenet_kernel_size = wavenet_kernel_size
|
| 237 |
+
self.wavenet_dilation_rate = wavenet_dilation_rate
|
| 238 |
+
self.wavenet_dropout = wavenet_dropout
|
| 239 |
+
self.speaking_rate = speaking_rate
|
| 240 |
+
self.noise_scale = noise_scale
|
| 241 |
+
self.noise_scale_duration = noise_scale_duration
|
| 242 |
+
self.sampling_rate = sampling_rate
|
| 243 |
+
|
| 244 |
+
if len(upsample_kernel_sizes) != len(upsample_rates):
|
| 245 |
+
raise ValueError(
|
| 246 |
+
f"The length of `upsample_kernel_sizes` ({len(upsample_kernel_sizes)}) must match the length of "
|
| 247 |
+
f"`upsample_rates` ({len(upsample_rates)})"
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
super().__init__(**kwargs)
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
__all__ = ["VitsConfig"]
|
docs/transformers/build/lib/transformers/models/vits/convert_original_checkpoint.py
ADDED
|
@@ -0,0 +1,390 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Convert VITS checkpoint."""
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import json
|
| 19 |
+
import tempfile
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
from huggingface_hub import hf_hub_download
|
| 23 |
+
|
| 24 |
+
from transformers import VitsConfig, VitsModel, VitsTokenizer, logging
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
logging.set_verbosity_info()
|
| 28 |
+
logger = logging.get_logger("transformers.models.vits")
|
| 29 |
+
|
| 30 |
+
MAPPING_TEXT_ENCODER = {
|
| 31 |
+
"enc_p.emb": "text_encoder.embed_tokens",
|
| 32 |
+
"enc_p.encoder.attn_layers.*.conv_k": "text_encoder.encoder.layers.*.attention.k_proj",
|
| 33 |
+
"enc_p.encoder.attn_layers.*.conv_v": "text_encoder.encoder.layers.*.attention.v_proj",
|
| 34 |
+
"enc_p.encoder.attn_layers.*.conv_q": "text_encoder.encoder.layers.*.attention.q_proj",
|
| 35 |
+
"enc_p.encoder.attn_layers.*.conv_o": "text_encoder.encoder.layers.*.attention.out_proj",
|
| 36 |
+
"enc_p.encoder.attn_layers.*.emb_rel_k": "text_encoder.encoder.layers.*.attention.emb_rel_k",
|
| 37 |
+
"enc_p.encoder.attn_layers.*.emb_rel_v": "text_encoder.encoder.layers.*.attention.emb_rel_v",
|
| 38 |
+
"enc_p.encoder.norm_layers_1.*.gamma": "text_encoder.encoder.layers.*.layer_norm.weight",
|
| 39 |
+
"enc_p.encoder.norm_layers_1.*.beta": "text_encoder.encoder.layers.*.layer_norm.bias",
|
| 40 |
+
"enc_p.encoder.ffn_layers.*.conv_1": "text_encoder.encoder.layers.*.feed_forward.conv_1",
|
| 41 |
+
"enc_p.encoder.ffn_layers.*.conv_2": "text_encoder.encoder.layers.*.feed_forward.conv_2",
|
| 42 |
+
"enc_p.encoder.norm_layers_2.*.gamma": "text_encoder.encoder.layers.*.final_layer_norm.weight",
|
| 43 |
+
"enc_p.encoder.norm_layers_2.*.beta": "text_encoder.encoder.layers.*.final_layer_norm.bias",
|
| 44 |
+
"enc_p.proj": "text_encoder.project",
|
| 45 |
+
}
|
| 46 |
+
MAPPING_STOCHASTIC_DURATION_PREDICTOR = {
|
| 47 |
+
"dp.pre": "duration_predictor.conv_pre",
|
| 48 |
+
"dp.proj": "duration_predictor.conv_proj",
|
| 49 |
+
"dp.convs.convs_sep.*": "duration_predictor.conv_dds.convs_dilated.*",
|
| 50 |
+
"dp.convs.convs_1x1.*": "duration_predictor.conv_dds.convs_pointwise.*",
|
| 51 |
+
"dp.convs.norms_1.*.gamma": "duration_predictor.conv_dds.norms_1.*.weight",
|
| 52 |
+
"dp.convs.norms_1.*.beta": "duration_predictor.conv_dds.norms_1.*.bias",
|
| 53 |
+
"dp.convs.norms_2.*.gamma": "duration_predictor.conv_dds.norms_2.*.weight",
|
| 54 |
+
"dp.convs.norms_2.*.beta": "duration_predictor.conv_dds.norms_2.*.bias",
|
| 55 |
+
"dp.flows.0.logs": "duration_predictor.flows.0.log_scale",
|
| 56 |
+
"dp.flows.0.m": "duration_predictor.flows.0.translate",
|
| 57 |
+
"dp.flows.*.pre": "duration_predictor.flows.*.conv_pre",
|
| 58 |
+
"dp.flows.*.proj": "duration_predictor.flows.*.conv_proj",
|
| 59 |
+
"dp.flows.*.convs.convs_1x1.0": "duration_predictor.flows.*.conv_dds.convs_pointwise.0",
|
| 60 |
+
"dp.flows.*.convs.convs_1x1.1": "duration_predictor.flows.*.conv_dds.convs_pointwise.1",
|
| 61 |
+
"dp.flows.*.convs.convs_1x1.2": "duration_predictor.flows.*.conv_dds.convs_pointwise.2",
|
| 62 |
+
"dp.flows.*.convs.convs_sep.0": "duration_predictor.flows.*.conv_dds.convs_dilated.0",
|
| 63 |
+
"dp.flows.*.convs.convs_sep.1": "duration_predictor.flows.*.conv_dds.convs_dilated.1",
|
| 64 |
+
"dp.flows.*.convs.convs_sep.2": "duration_predictor.flows.*.conv_dds.convs_dilated.2",
|
| 65 |
+
"dp.flows.*.convs.norms_1.0.gamma": "duration_predictor.flows.*.conv_dds.norms_1.0.weight",
|
| 66 |
+
"dp.flows.*.convs.norms_1.0.beta": "duration_predictor.flows.*.conv_dds.norms_1.0.bias",
|
| 67 |
+
"dp.flows.*.convs.norms_1.1.gamma": "duration_predictor.flows.*.conv_dds.norms_1.1.weight",
|
| 68 |
+
"dp.flows.*.convs.norms_1.1.beta": "duration_predictor.flows.*.conv_dds.norms_1.1.bias",
|
| 69 |
+
"dp.flows.*.convs.norms_1.2.gamma": "duration_predictor.flows.*.conv_dds.norms_1.2.weight",
|
| 70 |
+
"dp.flows.*.convs.norms_1.2.beta": "duration_predictor.flows.*.conv_dds.norms_1.2.bias",
|
| 71 |
+
"dp.flows.*.convs.norms_2.0.gamma": "duration_predictor.flows.*.conv_dds.norms_2.0.weight",
|
| 72 |
+
"dp.flows.*.convs.norms_2.0.beta": "duration_predictor.flows.*.conv_dds.norms_2.0.bias",
|
| 73 |
+
"dp.flows.*.convs.norms_2.1.gamma": "duration_predictor.flows.*.conv_dds.norms_2.1.weight",
|
| 74 |
+
"dp.flows.*.convs.norms_2.1.beta": "duration_predictor.flows.*.conv_dds.norms_2.1.bias",
|
| 75 |
+
"dp.flows.*.convs.norms_2.2.gamma": "duration_predictor.flows.*.conv_dds.norms_2.2.weight",
|
| 76 |
+
"dp.flows.*.convs.norms_2.2.beta": "duration_predictor.flows.*.conv_dds.norms_2.2.bias",
|
| 77 |
+
"dp.post_pre": "duration_predictor.post_conv_pre",
|
| 78 |
+
"dp.post_proj": "duration_predictor.post_conv_proj",
|
| 79 |
+
"dp.post_convs.convs_sep.*": "duration_predictor.post_conv_dds.convs_dilated.*",
|
| 80 |
+
"dp.post_convs.convs_1x1.*": "duration_predictor.post_conv_dds.convs_pointwise.*",
|
| 81 |
+
"dp.post_convs.norms_1.*.gamma": "duration_predictor.post_conv_dds.norms_1.*.weight",
|
| 82 |
+
"dp.post_convs.norms_1.*.beta": "duration_predictor.post_conv_dds.norms_1.*.bias",
|
| 83 |
+
"dp.post_convs.norms_2.*.gamma": "duration_predictor.post_conv_dds.norms_2.*.weight",
|
| 84 |
+
"dp.post_convs.norms_2.*.beta": "duration_predictor.post_conv_dds.norms_2.*.bias",
|
| 85 |
+
"dp.post_flows.0.logs": "duration_predictor.post_flows.0.log_scale",
|
| 86 |
+
"dp.post_flows.0.m": "duration_predictor.post_flows.0.translate",
|
| 87 |
+
"dp.post_flows.*.pre": "duration_predictor.post_flows.*.conv_pre",
|
| 88 |
+
"dp.post_flows.*.proj": "duration_predictor.post_flows.*.conv_proj",
|
| 89 |
+
"dp.post_flows.*.convs.convs_1x1.0": "duration_predictor.post_flows.*.conv_dds.convs_pointwise.0",
|
| 90 |
+
"dp.post_flows.*.convs.convs_1x1.1": "duration_predictor.post_flows.*.conv_dds.convs_pointwise.1",
|
| 91 |
+
"dp.post_flows.*.convs.convs_1x1.2": "duration_predictor.post_flows.*.conv_dds.convs_pointwise.2",
|
| 92 |
+
"dp.post_flows.*.convs.convs_sep.0": "duration_predictor.post_flows.*.conv_dds.convs_dilated.0",
|
| 93 |
+
"dp.post_flows.*.convs.convs_sep.1": "duration_predictor.post_flows.*.conv_dds.convs_dilated.1",
|
| 94 |
+
"dp.post_flows.*.convs.convs_sep.2": "duration_predictor.post_flows.*.conv_dds.convs_dilated.2",
|
| 95 |
+
"dp.post_flows.*.convs.norms_1.0.gamma": "duration_predictor.post_flows.*.conv_dds.norms_1.0.weight",
|
| 96 |
+
"dp.post_flows.*.convs.norms_1.0.beta": "duration_predictor.post_flows.*.conv_dds.norms_1.0.bias",
|
| 97 |
+
"dp.post_flows.*.convs.norms_1.1.gamma": "duration_predictor.post_flows.*.conv_dds.norms_1.1.weight",
|
| 98 |
+
"dp.post_flows.*.convs.norms_1.1.beta": "duration_predictor.post_flows.*.conv_dds.norms_1.1.bias",
|
| 99 |
+
"dp.post_flows.*.convs.norms_1.2.gamma": "duration_predictor.post_flows.*.conv_dds.norms_1.2.weight",
|
| 100 |
+
"dp.post_flows.*.convs.norms_1.2.beta": "duration_predictor.post_flows.*.conv_dds.norms_1.2.bias",
|
| 101 |
+
"dp.post_flows.*.convs.norms_2.0.gamma": "duration_predictor.post_flows.*.conv_dds.norms_2.0.weight",
|
| 102 |
+
"dp.post_flows.*.convs.norms_2.0.beta": "duration_predictor.post_flows.*.conv_dds.norms_2.0.bias",
|
| 103 |
+
"dp.post_flows.*.convs.norms_2.1.gamma": "duration_predictor.post_flows.*.conv_dds.norms_2.1.weight",
|
| 104 |
+
"dp.post_flows.*.convs.norms_2.1.beta": "duration_predictor.post_flows.*.conv_dds.norms_2.1.bias",
|
| 105 |
+
"dp.post_flows.*.convs.norms_2.2.gamma": "duration_predictor.post_flows.*.conv_dds.norms_2.2.weight",
|
| 106 |
+
"dp.post_flows.*.convs.norms_2.2.beta": "duration_predictor.post_flows.*.conv_dds.norms_2.2.bias",
|
| 107 |
+
"dp.cond": "duration_predictor.cond", # num_speakers > 1
|
| 108 |
+
}
|
| 109 |
+
MAPPING_FLOW = {
|
| 110 |
+
"flow.flows.*.pre": "flow.flows.*.conv_pre",
|
| 111 |
+
"flow.flows.*.enc.in_layers.0": "flow.flows.*.wavenet.in_layers.0",
|
| 112 |
+
"flow.flows.*.enc.in_layers.1": "flow.flows.*.wavenet.in_layers.1",
|
| 113 |
+
"flow.flows.*.enc.in_layers.2": "flow.flows.*.wavenet.in_layers.2",
|
| 114 |
+
"flow.flows.*.enc.in_layers.3": "flow.flows.*.wavenet.in_layers.3",
|
| 115 |
+
"flow.flows.*.enc.res_skip_layers.0": "flow.flows.*.wavenet.res_skip_layers.0",
|
| 116 |
+
"flow.flows.*.enc.res_skip_layers.1": "flow.flows.*.wavenet.res_skip_layers.1",
|
| 117 |
+
"flow.flows.*.enc.res_skip_layers.2": "flow.flows.*.wavenet.res_skip_layers.2",
|
| 118 |
+
"flow.flows.*.enc.res_skip_layers.3": "flow.flows.*.wavenet.res_skip_layers.3",
|
| 119 |
+
"flow.flows.*.enc.cond_layer": "flow.flows.*.wavenet.cond_layer", # num_speakers > 1
|
| 120 |
+
"flow.flows.*.post": "flow.flows.*.conv_post",
|
| 121 |
+
}
|
| 122 |
+
MAPPING_GENERATOR = {
|
| 123 |
+
"dec.conv_pre": "decoder.conv_pre",
|
| 124 |
+
"dec.ups.0": "decoder.upsampler.0",
|
| 125 |
+
"dec.ups.1": "decoder.upsampler.1",
|
| 126 |
+
"dec.ups.2": "decoder.upsampler.2",
|
| 127 |
+
"dec.ups.3": "decoder.upsampler.3",
|
| 128 |
+
"dec.resblocks.*.convs1.0": "decoder.resblocks.*.convs1.0",
|
| 129 |
+
"dec.resblocks.*.convs1.1": "decoder.resblocks.*.convs1.1",
|
| 130 |
+
"dec.resblocks.*.convs1.2": "decoder.resblocks.*.convs1.2",
|
| 131 |
+
"dec.resblocks.*.convs2.0": "decoder.resblocks.*.convs2.0",
|
| 132 |
+
"dec.resblocks.*.convs2.1": "decoder.resblocks.*.convs2.1",
|
| 133 |
+
"dec.resblocks.*.convs2.2": "decoder.resblocks.*.convs2.2",
|
| 134 |
+
"dec.conv_post": "decoder.conv_post",
|
| 135 |
+
"dec.cond": "decoder.cond", # num_speakers > 1
|
| 136 |
+
}
|
| 137 |
+
MAPPING_POSTERIOR_ENCODER = {
|
| 138 |
+
"enc_q.pre": "posterior_encoder.conv_pre",
|
| 139 |
+
"enc_q.enc.in_layers.*": "posterior_encoder.wavenet.in_layers.*",
|
| 140 |
+
"enc_q.enc.res_skip_layers.*": "posterior_encoder.wavenet.res_skip_layers.*",
|
| 141 |
+
"enc_q.enc.cond_layer": "posterior_encoder.wavenet.cond_layer", # num_speakers > 1
|
| 142 |
+
"enc_q.proj": "posterior_encoder.conv_proj",
|
| 143 |
+
}
|
| 144 |
+
MAPPING = {
|
| 145 |
+
**MAPPING_TEXT_ENCODER,
|
| 146 |
+
**MAPPING_STOCHASTIC_DURATION_PREDICTOR,
|
| 147 |
+
**MAPPING_FLOW,
|
| 148 |
+
**MAPPING_GENERATOR,
|
| 149 |
+
**MAPPING_POSTERIOR_ENCODER,
|
| 150 |
+
"emb_g": "embed_speaker", # num_speakers > 1
|
| 151 |
+
}
|
| 152 |
+
TOP_LEVEL_KEYS = []
|
| 153 |
+
IGNORE_KEYS = []
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def set_recursively(hf_pointer, key, value, full_name, weight_type):
|
| 157 |
+
for attribute in key.split("."):
|
| 158 |
+
hf_pointer = getattr(hf_pointer, attribute)
|
| 159 |
+
|
| 160 |
+
if weight_type is not None:
|
| 161 |
+
hf_shape = getattr(hf_pointer, weight_type).shape
|
| 162 |
+
else:
|
| 163 |
+
hf_shape = hf_pointer.shape
|
| 164 |
+
|
| 165 |
+
# strip off the kernel dimension at the end (original weights are Conv1d)
|
| 166 |
+
if key.endswith(".k_proj") or key.endswith(".v_proj") or key.endswith(".q_proj") or key.endswith(".out_proj"):
|
| 167 |
+
value = value.squeeze(-1)
|
| 168 |
+
|
| 169 |
+
if hf_shape != value.shape:
|
| 170 |
+
raise ValueError(
|
| 171 |
+
f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be"
|
| 172 |
+
f" {value.shape} for {full_name}"
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
if weight_type == "weight":
|
| 176 |
+
hf_pointer.weight.data = value
|
| 177 |
+
elif weight_type == "weight_g":
|
| 178 |
+
hf_pointer.weight_g.data = value
|
| 179 |
+
elif weight_type == "weight_v":
|
| 180 |
+
hf_pointer.weight_v.data = value
|
| 181 |
+
elif weight_type == "bias":
|
| 182 |
+
hf_pointer.bias.data = value
|
| 183 |
+
elif weight_type == "running_mean":
|
| 184 |
+
hf_pointer.running_mean.data = value
|
| 185 |
+
elif weight_type == "running_var":
|
| 186 |
+
hf_pointer.running_var.data = value
|
| 187 |
+
elif weight_type == "num_batches_tracked":
|
| 188 |
+
hf_pointer.num_batches_tracked.data = value
|
| 189 |
+
else:
|
| 190 |
+
hf_pointer.data = value
|
| 191 |
+
|
| 192 |
+
logger.info(f"{key + ('.' + weight_type if weight_type is not None else '')} was initialized from {full_name}.")
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def should_ignore(name, ignore_keys):
|
| 196 |
+
for key in ignore_keys:
|
| 197 |
+
if key.endswith(".*"):
|
| 198 |
+
if name.startswith(key[:-1]):
|
| 199 |
+
return True
|
| 200 |
+
elif ".*." in key:
|
| 201 |
+
prefix, suffix = key.split(".*.")
|
| 202 |
+
if prefix in name and suffix in name:
|
| 203 |
+
return True
|
| 204 |
+
elif key in name:
|
| 205 |
+
return True
|
| 206 |
+
return False
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def recursively_load_weights(fairseq_dict, hf_model):
|
| 210 |
+
unused_weights = []
|
| 211 |
+
|
| 212 |
+
for name, value in fairseq_dict.items():
|
| 213 |
+
if should_ignore(name, IGNORE_KEYS):
|
| 214 |
+
logger.info(f"{name} was ignored")
|
| 215 |
+
continue
|
| 216 |
+
|
| 217 |
+
is_used = False
|
| 218 |
+
for key, mapped_key in MAPPING.items():
|
| 219 |
+
if key.endswith(".*"):
|
| 220 |
+
key = key[:-1]
|
| 221 |
+
elif "*" in key:
|
| 222 |
+
prefix, suffix = key.split(".*.")
|
| 223 |
+
if prefix in name and suffix in name:
|
| 224 |
+
key = suffix
|
| 225 |
+
|
| 226 |
+
if key in name:
|
| 227 |
+
is_used = True
|
| 228 |
+
if mapped_key.endswith(".*"):
|
| 229 |
+
layer_index = name.split(key)[-1].split(".")[0]
|
| 230 |
+
mapped_key = mapped_key.replace("*", layer_index)
|
| 231 |
+
elif "*" in mapped_key:
|
| 232 |
+
layer_index = name.split(key)[0].split(".")[-2]
|
| 233 |
+
|
| 234 |
+
# remap the layer index since we removed the Flip layers
|
| 235 |
+
if "flow.flows" in mapped_key:
|
| 236 |
+
layer_index = str(int(layer_index) // 2)
|
| 237 |
+
if "duration_predictor.flows" in mapped_key or "duration_predictor.post_flows" in mapped_key:
|
| 238 |
+
layer_index = str(int(layer_index) // 2 + 1)
|
| 239 |
+
|
| 240 |
+
mapped_key = mapped_key.replace("*", layer_index)
|
| 241 |
+
if "weight_g" in name:
|
| 242 |
+
weight_type = "weight_g"
|
| 243 |
+
elif "weight_v" in name:
|
| 244 |
+
weight_type = "weight_v"
|
| 245 |
+
elif "bias" in name:
|
| 246 |
+
weight_type = "bias"
|
| 247 |
+
elif "weight" in name:
|
| 248 |
+
weight_type = "weight"
|
| 249 |
+
elif "running_mean" in name:
|
| 250 |
+
weight_type = "running_mean"
|
| 251 |
+
elif "running_var" in name:
|
| 252 |
+
weight_type = "running_var"
|
| 253 |
+
elif "num_batches_tracked" in name:
|
| 254 |
+
weight_type = "num_batches_tracked"
|
| 255 |
+
else:
|
| 256 |
+
weight_type = None
|
| 257 |
+
set_recursively(hf_model, mapped_key, value, name, weight_type)
|
| 258 |
+
continue
|
| 259 |
+
if not is_used:
|
| 260 |
+
unused_weights.append(name)
|
| 261 |
+
|
| 262 |
+
logger.warning(f"Unused weights: {unused_weights}")
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
@torch.no_grad()
|
| 266 |
+
def convert_checkpoint(
|
| 267 |
+
pytorch_dump_folder_path,
|
| 268 |
+
checkpoint_path=None,
|
| 269 |
+
config_path=None,
|
| 270 |
+
vocab_path=None,
|
| 271 |
+
language=None,
|
| 272 |
+
num_speakers=None,
|
| 273 |
+
sampling_rate=None,
|
| 274 |
+
repo_id=None,
|
| 275 |
+
):
|
| 276 |
+
"""
|
| 277 |
+
Copy/paste/tweak model's weights to transformers design.
|
| 278 |
+
"""
|
| 279 |
+
if config_path is not None:
|
| 280 |
+
config = VitsConfig.from_pretrained(config_path)
|
| 281 |
+
else:
|
| 282 |
+
config = VitsConfig()
|
| 283 |
+
|
| 284 |
+
if num_speakers:
|
| 285 |
+
config.num_speakers = num_speakers
|
| 286 |
+
config.speaker_embedding_size = 256
|
| 287 |
+
|
| 288 |
+
if sampling_rate:
|
| 289 |
+
config.sampling_rate = sampling_rate
|
| 290 |
+
|
| 291 |
+
if checkpoint_path is None:
|
| 292 |
+
logger.info(f"***Converting model: facebook/mms-tts {language}***")
|
| 293 |
+
|
| 294 |
+
vocab_path = hf_hub_download(
|
| 295 |
+
repo_id="facebook/mms-tts",
|
| 296 |
+
filename="vocab.txt",
|
| 297 |
+
subfolder=f"models/{language}",
|
| 298 |
+
)
|
| 299 |
+
config_file = hf_hub_download(
|
| 300 |
+
repo_id="facebook/mms-tts",
|
| 301 |
+
filename="config.json",
|
| 302 |
+
subfolder=f"models/{language}",
|
| 303 |
+
)
|
| 304 |
+
checkpoint_path = hf_hub_download(
|
| 305 |
+
repo_id="facebook/mms-tts",
|
| 306 |
+
filename="G_100000.pth",
|
| 307 |
+
subfolder=f"models/{language}",
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
with open(config_file, "r") as f:
|
| 311 |
+
data = f.read()
|
| 312 |
+
hps = json.loads(data)
|
| 313 |
+
|
| 314 |
+
is_uroman = hps["data"]["training_files"].split(".")[-1] == "uroman"
|
| 315 |
+
if is_uroman:
|
| 316 |
+
logger.warning("For this checkpoint, you should use `uroman` to convert input text before tokenizing it!")
|
| 317 |
+
else:
|
| 318 |
+
logger.info(f"***Converting model: {checkpoint_path}***")
|
| 319 |
+
is_uroman = False
|
| 320 |
+
|
| 321 |
+
# original VITS checkpoint
|
| 322 |
+
if vocab_path is None:
|
| 323 |
+
_pad = "_"
|
| 324 |
+
_punctuation = ';:,.!?¡¿—…"«»“” '
|
| 325 |
+
_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
| 326 |
+
_letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
|
| 327 |
+
symbols = _pad + _punctuation + _letters + _letters_ipa
|
| 328 |
+
symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
| 329 |
+
phonemize = True
|
| 330 |
+
else:
|
| 331 |
+
# Save vocab as temporary json file
|
| 332 |
+
symbols = [line.replace("\n", "") for line in open(vocab_path, encoding="utf-8").readlines()]
|
| 333 |
+
symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
| 334 |
+
# MMS-TTS does not use a <pad> token, so we set to the token used to space characters
|
| 335 |
+
_pad = symbols[0]
|
| 336 |
+
phonemize = False
|
| 337 |
+
|
| 338 |
+
with tempfile.NamedTemporaryFile() as tf:
|
| 339 |
+
with open(tf.name, "w", encoding="utf-8") as f:
|
| 340 |
+
f.write(json.dumps(symbol_to_id, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
|
| 341 |
+
|
| 342 |
+
tokenizer = VitsTokenizer(tf.name, language=language, phonemize=phonemize, is_uroman=is_uroman, pad_token=_pad)
|
| 343 |
+
|
| 344 |
+
config.vocab_size = len(symbols)
|
| 345 |
+
model = VitsModel(config)
|
| 346 |
+
|
| 347 |
+
model.decoder.apply_weight_norm()
|
| 348 |
+
|
| 349 |
+
orig_checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu"), weights_only=True)
|
| 350 |
+
recursively_load_weights(orig_checkpoint["model"], model)
|
| 351 |
+
|
| 352 |
+
model.decoder.remove_weight_norm()
|
| 353 |
+
|
| 354 |
+
model.save_pretrained(pytorch_dump_folder_path)
|
| 355 |
+
tokenizer.save_pretrained(pytorch_dump_folder_path)
|
| 356 |
+
|
| 357 |
+
if repo_id:
|
| 358 |
+
print("Pushing to the hub...")
|
| 359 |
+
tokenizer.push_to_hub(repo_id)
|
| 360 |
+
model.push_to_hub(repo_id)
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
if __name__ == "__main__":
|
| 364 |
+
parser = argparse.ArgumentParser()
|
| 365 |
+
parser.add_argument("--checkpoint_path", default=None, type=str, help="Local path to original checkpoint")
|
| 366 |
+
parser.add_argument("--vocab_path", default=None, type=str, help="Path to vocab.txt")
|
| 367 |
+
parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
|
| 368 |
+
parser.add_argument("--language", default=None, type=str, help="Tokenizer language (three-letter code)")
|
| 369 |
+
parser.add_argument("--num_speakers", default=None, type=int, help="Number of speakers")
|
| 370 |
+
parser.add_argument(
|
| 371 |
+
"--sampling_rate", default=None, type=int, help="Sampling rate on which the model was trained."
|
| 372 |
+
)
|
| 373 |
+
parser.add_argument(
|
| 374 |
+
"--pytorch_dump_folder_path", required=True, default=None, type=str, help="Path to the output PyTorch model."
|
| 375 |
+
)
|
| 376 |
+
parser.add_argument(
|
| 377 |
+
"--push_to_hub", default=None, type=str, help="Where to upload the converted model on the 🤗 hub."
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
args = parser.parse_args()
|
| 381 |
+
convert_checkpoint(
|
| 382 |
+
args.pytorch_dump_folder_path,
|
| 383 |
+
args.checkpoint_path,
|
| 384 |
+
args.config_path,
|
| 385 |
+
args.vocab_path,
|
| 386 |
+
args.language,
|
| 387 |
+
args.num_speakers,
|
| 388 |
+
args.sampling_rate,
|
| 389 |
+
args.push_to_hub,
|
| 390 |
+
)
|
docs/transformers/build/lib/transformers/models/vits/modeling_vits.py
ADDED
|
@@ -0,0 +1,1493 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The Kakao Enterprise Authors and the HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""PyTorch VITS model."""
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
from typing import Any, Optional, Tuple, Union
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
import torch
|
| 23 |
+
import torch.utils.checkpoint
|
| 24 |
+
from torch import nn
|
| 25 |
+
|
| 26 |
+
from ...activations import ACT2FN
|
| 27 |
+
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
|
| 28 |
+
from ...integrations.fsdp import is_fsdp_managed_module
|
| 29 |
+
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
|
| 30 |
+
from ...modeling_outputs import (
|
| 31 |
+
BaseModelOutput,
|
| 32 |
+
ModelOutput,
|
| 33 |
+
)
|
| 34 |
+
from ...modeling_utils import PreTrainedModel
|
| 35 |
+
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
| 36 |
+
from .configuration_vits import VitsConfig
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
logger = logging.get_logger(__name__)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# General docstring
|
| 43 |
+
_CONFIG_FOR_DOC = "VitsConfig"
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
@dataclass
|
| 47 |
+
class VitsModelOutput(ModelOutput):
|
| 48 |
+
"""
|
| 49 |
+
Describes the outputs for the VITS model, with potential hidden states and attentions.
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
waveform (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
|
| 53 |
+
The final audio waveform predicted by the model.
|
| 54 |
+
sequence_lengths (`torch.FloatTensor` of shape `(batch_size,)`):
|
| 55 |
+
The length in samples of each element in the `waveform` batch.
|
| 56 |
+
spectrogram (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_bins)`):
|
| 57 |
+
The log-mel spectrogram predicted at the output of the flow model. This spectrogram is passed to the Hi-Fi
|
| 58 |
+
GAN decoder model to obtain the final audio waveform.
|
| 59 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 60 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
| 61 |
+
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
| 62 |
+
|
| 63 |
+
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
| 64 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 65 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 66 |
+
sequence_length)`.
|
| 67 |
+
|
| 68 |
+
Attention weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 69 |
+
heads.
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
waveform: Optional[torch.FloatTensor] = None
|
| 73 |
+
sequence_lengths: Optional[torch.FloatTensor] = None
|
| 74 |
+
spectrogram: Optional[Tuple[torch.FloatTensor]] = None
|
| 75 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 76 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
@dataclass
|
| 80 |
+
class VitsTextEncoderOutput(ModelOutput):
|
| 81 |
+
"""
|
| 82 |
+
Describes the outputs for the VITS text encoder model, with potential hidden states and attentions.
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
| 86 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
| 87 |
+
prior_means (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
| 88 |
+
The predicted mean values of the prior distribution for the latent text variables.
|
| 89 |
+
prior_log_variances (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
| 90 |
+
The predicted log-variance values of the prior distribution for the latent text variables.
|
| 91 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 92 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
| 93 |
+
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
| 94 |
+
|
| 95 |
+
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
| 96 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 97 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 98 |
+
sequence_length)`.
|
| 99 |
+
|
| 100 |
+
Attention weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 101 |
+
heads.
|
| 102 |
+
"""
|
| 103 |
+
|
| 104 |
+
last_hidden_state: Optional[torch.FloatTensor] = None
|
| 105 |
+
prior_means: Optional[torch.FloatTensor] = None
|
| 106 |
+
prior_log_variances: Optional[torch.FloatTensor] = None
|
| 107 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 108 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
@torch.jit.script
|
| 112 |
+
def fused_add_tanh_sigmoid_multiply(input_a, input_b, num_channels):
|
| 113 |
+
in_act = input_a + input_b
|
| 114 |
+
t_act = torch.tanh(in_act[:, :num_channels, :])
|
| 115 |
+
s_act = torch.sigmoid(in_act[:, num_channels:, :])
|
| 116 |
+
acts = t_act * s_act
|
| 117 |
+
return acts
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def _unconstrained_rational_quadratic_spline(
|
| 121 |
+
inputs,
|
| 122 |
+
unnormalized_widths,
|
| 123 |
+
unnormalized_heights,
|
| 124 |
+
unnormalized_derivatives,
|
| 125 |
+
reverse=False,
|
| 126 |
+
tail_bound=5.0,
|
| 127 |
+
min_bin_width=1e-3,
|
| 128 |
+
min_bin_height=1e-3,
|
| 129 |
+
min_derivative=1e-3,
|
| 130 |
+
):
|
| 131 |
+
"""
|
| 132 |
+
This transformation represents a monotonically increasing piecewise rational quadratic function. Outside of the
|
| 133 |
+
`tail_bound`, the transform behaves as an identity function.
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
inputs (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`:
|
| 137 |
+
Second half of the hidden-states input to the Vits convolutional flow module.
|
| 138 |
+
unnormalized_widths (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`):
|
| 139 |
+
First `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection
|
| 140 |
+
layer in the convolutional flow module
|
| 141 |
+
unnormalized_heights (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`):
|
| 142 |
+
Second `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection
|
| 143 |
+
layer in the convolutional flow module
|
| 144 |
+
unnormalized_derivatives (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`):
|
| 145 |
+
Third `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection
|
| 146 |
+
layer in the convolutional flow module
|
| 147 |
+
reverse (`bool`, *optional*, defaults to `False`):
|
| 148 |
+
Whether the model is being run in reverse mode.
|
| 149 |
+
tail_bound (`float`, *optional* defaults to 5):
|
| 150 |
+
Upper and lower limit bound for the rational quadratic function. Outside of this `tail_bound`, the
|
| 151 |
+
transform behaves as an identity function.
|
| 152 |
+
min_bin_width (`float`, *optional*, defaults to 1e-3):
|
| 153 |
+
Minimum bin value across the width dimension for the piecewise rational quadratic function.
|
| 154 |
+
min_bin_height (`float`, *optional*, defaults to 1e-3):
|
| 155 |
+
Minimum bin value across the height dimension for the piecewise rational quadratic function.
|
| 156 |
+
min_derivative (`float`, *optional*, defaults to 1e-3):
|
| 157 |
+
Minimum bin value across the derivatives for the piecewise rational quadratic function.
|
| 158 |
+
Returns:
|
| 159 |
+
outputs (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`:
|
| 160 |
+
Hidden-states as transformed by the piecewise rational quadratic function with the `tail_bound` limits
|
| 161 |
+
applied.
|
| 162 |
+
log_abs_det (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`:
|
| 163 |
+
Logarithm of the absolute value of the determinants corresponding to the `outputs` with the `tail_bound`
|
| 164 |
+
limits applied.
|
| 165 |
+
"""
|
| 166 |
+
inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
|
| 167 |
+
outside_interval_mask = ~inside_interval_mask
|
| 168 |
+
|
| 169 |
+
outputs = torch.zeros_like(inputs)
|
| 170 |
+
log_abs_det = torch.zeros_like(inputs)
|
| 171 |
+
constant = np.log(np.exp(1 - min_derivative) - 1)
|
| 172 |
+
|
| 173 |
+
unnormalized_derivatives = nn.functional.pad(unnormalized_derivatives, pad=(1, 1))
|
| 174 |
+
unnormalized_derivatives[..., 0] = constant
|
| 175 |
+
unnormalized_derivatives[..., -1] = constant
|
| 176 |
+
|
| 177 |
+
outputs[outside_interval_mask] = inputs[outside_interval_mask]
|
| 178 |
+
log_abs_det[outside_interval_mask] = 0.0
|
| 179 |
+
|
| 180 |
+
outputs[inside_interval_mask], log_abs_det[inside_interval_mask] = _rational_quadratic_spline(
|
| 181 |
+
inputs=inputs[inside_interval_mask],
|
| 182 |
+
unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
|
| 183 |
+
unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
|
| 184 |
+
unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
|
| 185 |
+
reverse=reverse,
|
| 186 |
+
tail_bound=tail_bound,
|
| 187 |
+
min_bin_width=min_bin_width,
|
| 188 |
+
min_bin_height=min_bin_height,
|
| 189 |
+
min_derivative=min_derivative,
|
| 190 |
+
)
|
| 191 |
+
return outputs, log_abs_det
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def _rational_quadratic_spline(
|
| 195 |
+
inputs,
|
| 196 |
+
unnormalized_widths,
|
| 197 |
+
unnormalized_heights,
|
| 198 |
+
unnormalized_derivatives,
|
| 199 |
+
reverse,
|
| 200 |
+
tail_bound,
|
| 201 |
+
min_bin_width,
|
| 202 |
+
min_bin_height,
|
| 203 |
+
min_derivative,
|
| 204 |
+
):
|
| 205 |
+
"""
|
| 206 |
+
This transformation represents a monotonically increasing piecewise rational quadratic function. Unlike the
|
| 207 |
+
function `_unconstrained_rational_quadratic_spline`, the function behaves the same across the `tail_bound`.
|
| 208 |
+
|
| 209 |
+
Args:
|
| 210 |
+
inputs (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`:
|
| 211 |
+
Second half of the hidden-states input to the Vits convolutional flow module.
|
| 212 |
+
unnormalized_widths (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`):
|
| 213 |
+
First `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection
|
| 214 |
+
layer in the convolutional flow module
|
| 215 |
+
unnormalized_heights (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`):
|
| 216 |
+
Second `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection
|
| 217 |
+
layer in the convolutional flow module
|
| 218 |
+
unnormalized_derivatives (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`):
|
| 219 |
+
Third `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection
|
| 220 |
+
layer in the convolutional flow module
|
| 221 |
+
reverse (`bool`):
|
| 222 |
+
Whether the model is being run in reverse mode.
|
| 223 |
+
tail_bound (`float`):
|
| 224 |
+
Upper and lower limit bound for the rational quadratic function. Outside of this `tail_bound`, the
|
| 225 |
+
transform behaves as an identity function.
|
| 226 |
+
min_bin_width (`float`):
|
| 227 |
+
Minimum bin value across the width dimension for the piecewise rational quadratic function.
|
| 228 |
+
min_bin_height (`float`):
|
| 229 |
+
Minimum bin value across the height dimension for the piecewise rational quadratic function.
|
| 230 |
+
min_derivative (`float`):
|
| 231 |
+
Minimum bin value across the derivatives for the piecewise rational quadratic function.
|
| 232 |
+
Returns:
|
| 233 |
+
outputs (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`:
|
| 234 |
+
Hidden-states as transformed by the piecewise rational quadratic function.
|
| 235 |
+
log_abs_det (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`:
|
| 236 |
+
Logarithm of the absolute value of the determinants corresponding to the `outputs`.
|
| 237 |
+
"""
|
| 238 |
+
upper_bound = tail_bound
|
| 239 |
+
lower_bound = -tail_bound
|
| 240 |
+
|
| 241 |
+
if torch.min(inputs) < lower_bound or torch.max(inputs) > upper_bound:
|
| 242 |
+
raise ValueError("Input to a transform is not within its domain")
|
| 243 |
+
|
| 244 |
+
num_bins = unnormalized_widths.shape[-1]
|
| 245 |
+
|
| 246 |
+
if min_bin_width * num_bins > 1.0:
|
| 247 |
+
raise ValueError(f"Minimal bin width {min_bin_width} too large for the number of bins {num_bins}")
|
| 248 |
+
if min_bin_height * num_bins > 1.0:
|
| 249 |
+
raise ValueError(f"Minimal bin height {min_bin_height} too large for the number of bins {num_bins}")
|
| 250 |
+
|
| 251 |
+
widths = nn.functional.softmax(unnormalized_widths, dim=-1)
|
| 252 |
+
widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
|
| 253 |
+
cumwidths = torch.cumsum(widths, dim=-1)
|
| 254 |
+
cumwidths = nn.functional.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
|
| 255 |
+
cumwidths = (upper_bound - lower_bound) * cumwidths + lower_bound
|
| 256 |
+
cumwidths[..., 0] = lower_bound
|
| 257 |
+
cumwidths[..., -1] = upper_bound
|
| 258 |
+
widths = cumwidths[..., 1:] - cumwidths[..., :-1]
|
| 259 |
+
|
| 260 |
+
derivatives = min_derivative + nn.functional.softplus(unnormalized_derivatives)
|
| 261 |
+
|
| 262 |
+
heights = nn.functional.softmax(unnormalized_heights, dim=-1)
|
| 263 |
+
heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
|
| 264 |
+
cumheights = torch.cumsum(heights, dim=-1)
|
| 265 |
+
cumheights = nn.functional.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
|
| 266 |
+
cumheights = (upper_bound - lower_bound) * cumheights + lower_bound
|
| 267 |
+
cumheights[..., 0] = lower_bound
|
| 268 |
+
cumheights[..., -1] = upper_bound
|
| 269 |
+
heights = cumheights[..., 1:] - cumheights[..., :-1]
|
| 270 |
+
|
| 271 |
+
bin_locations = cumheights if reverse else cumwidths
|
| 272 |
+
bin_locations[..., -1] += 1e-6
|
| 273 |
+
bin_idx = torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
|
| 274 |
+
bin_idx = bin_idx[..., None]
|
| 275 |
+
|
| 276 |
+
input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
|
| 277 |
+
input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
|
| 278 |
+
|
| 279 |
+
input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
|
| 280 |
+
delta = heights / widths
|
| 281 |
+
input_delta = delta.gather(-1, bin_idx)[..., 0]
|
| 282 |
+
|
| 283 |
+
input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
|
| 284 |
+
input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
|
| 285 |
+
|
| 286 |
+
input_heights = heights.gather(-1, bin_idx)[..., 0]
|
| 287 |
+
|
| 288 |
+
intermediate1 = input_derivatives + input_derivatives_plus_one - 2 * input_delta
|
| 289 |
+
if not reverse:
|
| 290 |
+
theta = (inputs - input_cumwidths) / input_bin_widths
|
| 291 |
+
theta_one_minus_theta = theta * (1 - theta)
|
| 292 |
+
|
| 293 |
+
numerator = input_heights * (input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta)
|
| 294 |
+
denominator = input_delta + intermediate1 * theta_one_minus_theta
|
| 295 |
+
outputs = input_cumheights + numerator / denominator
|
| 296 |
+
|
| 297 |
+
derivative_numerator = input_delta.pow(2) * (
|
| 298 |
+
input_derivatives_plus_one * theta.pow(2)
|
| 299 |
+
+ 2 * input_delta * theta_one_minus_theta
|
| 300 |
+
+ input_derivatives * (1 - theta).pow(2)
|
| 301 |
+
)
|
| 302 |
+
log_abs_det = torch.log(derivative_numerator) - 2 * torch.log(denominator)
|
| 303 |
+
return outputs, log_abs_det
|
| 304 |
+
else:
|
| 305 |
+
# find the roots of a quadratic equation
|
| 306 |
+
intermediate2 = inputs - input_cumheights
|
| 307 |
+
intermediate3 = intermediate2 * intermediate1
|
| 308 |
+
a = input_heights * (input_delta - input_derivatives) + intermediate3
|
| 309 |
+
b = input_heights * input_derivatives - intermediate3
|
| 310 |
+
c = -input_delta * intermediate2
|
| 311 |
+
|
| 312 |
+
discriminant = b.pow(2) - 4 * a * c
|
| 313 |
+
if not (discriminant >= 0).all():
|
| 314 |
+
raise RuntimeError(f"invalid discriminant {discriminant}")
|
| 315 |
+
|
| 316 |
+
root = (2 * c) / (-b - torch.sqrt(discriminant))
|
| 317 |
+
outputs = root * input_bin_widths + input_cumwidths
|
| 318 |
+
|
| 319 |
+
theta_one_minus_theta = root * (1 - root)
|
| 320 |
+
denominator = input_delta + intermediate1 * theta_one_minus_theta
|
| 321 |
+
derivative_numerator = input_delta.pow(2) * (
|
| 322 |
+
input_derivatives_plus_one * root.pow(2)
|
| 323 |
+
+ 2 * input_delta * theta_one_minus_theta
|
| 324 |
+
+ input_derivatives * (1 - root).pow(2)
|
| 325 |
+
)
|
| 326 |
+
log_abs_det = torch.log(derivative_numerator) - 2 * torch.log(denominator)
|
| 327 |
+
return outputs, -log_abs_det
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
class VitsWaveNet(torch.nn.Module):
|
| 331 |
+
def __init__(self, config: VitsConfig, num_layers: int):
|
| 332 |
+
super().__init__()
|
| 333 |
+
self.hidden_size = config.hidden_size
|
| 334 |
+
self.num_layers = num_layers
|
| 335 |
+
|
| 336 |
+
self.in_layers = torch.nn.ModuleList()
|
| 337 |
+
self.res_skip_layers = torch.nn.ModuleList()
|
| 338 |
+
self.dropout = nn.Dropout(config.wavenet_dropout)
|
| 339 |
+
|
| 340 |
+
if hasattr(nn.utils.parametrizations, "weight_norm"):
|
| 341 |
+
weight_norm = nn.utils.parametrizations.weight_norm
|
| 342 |
+
else:
|
| 343 |
+
weight_norm = nn.utils.weight_norm
|
| 344 |
+
|
| 345 |
+
if config.speaker_embedding_size != 0:
|
| 346 |
+
cond_layer = torch.nn.Conv1d(config.speaker_embedding_size, 2 * config.hidden_size * num_layers, 1)
|
| 347 |
+
self.cond_layer = weight_norm(cond_layer, name="weight")
|
| 348 |
+
|
| 349 |
+
for i in range(num_layers):
|
| 350 |
+
dilation = config.wavenet_dilation_rate**i
|
| 351 |
+
padding = (config.wavenet_kernel_size * dilation - dilation) // 2
|
| 352 |
+
in_layer = torch.nn.Conv1d(
|
| 353 |
+
in_channels=config.hidden_size,
|
| 354 |
+
out_channels=2 * config.hidden_size,
|
| 355 |
+
kernel_size=config.wavenet_kernel_size,
|
| 356 |
+
dilation=dilation,
|
| 357 |
+
padding=padding,
|
| 358 |
+
)
|
| 359 |
+
in_layer = weight_norm(in_layer, name="weight")
|
| 360 |
+
self.in_layers.append(in_layer)
|
| 361 |
+
|
| 362 |
+
# last one is not necessary
|
| 363 |
+
if i < num_layers - 1:
|
| 364 |
+
res_skip_channels = 2 * config.hidden_size
|
| 365 |
+
else:
|
| 366 |
+
res_skip_channels = config.hidden_size
|
| 367 |
+
|
| 368 |
+
res_skip_layer = torch.nn.Conv1d(config.hidden_size, res_skip_channels, 1)
|
| 369 |
+
res_skip_layer = weight_norm(res_skip_layer, name="weight")
|
| 370 |
+
self.res_skip_layers.append(res_skip_layer)
|
| 371 |
+
|
| 372 |
+
def forward(self, inputs, padding_mask, global_conditioning=None):
|
| 373 |
+
outputs = torch.zeros_like(inputs)
|
| 374 |
+
num_channels_tensor = torch.IntTensor([self.hidden_size])
|
| 375 |
+
|
| 376 |
+
if global_conditioning is not None:
|
| 377 |
+
global_conditioning = self.cond_layer(global_conditioning)
|
| 378 |
+
|
| 379 |
+
for i in range(self.num_layers):
|
| 380 |
+
hidden_states = self.in_layers[i](inputs)
|
| 381 |
+
|
| 382 |
+
if global_conditioning is not None:
|
| 383 |
+
cond_offset = i * 2 * self.hidden_size
|
| 384 |
+
global_states = global_conditioning[:, cond_offset : cond_offset + 2 * self.hidden_size, :]
|
| 385 |
+
else:
|
| 386 |
+
global_states = torch.zeros_like(hidden_states)
|
| 387 |
+
|
| 388 |
+
acts = fused_add_tanh_sigmoid_multiply(hidden_states, global_states, num_channels_tensor[0])
|
| 389 |
+
acts = self.dropout(acts)
|
| 390 |
+
|
| 391 |
+
res_skip_acts = self.res_skip_layers[i](acts)
|
| 392 |
+
if i < self.num_layers - 1:
|
| 393 |
+
res_acts = res_skip_acts[:, : self.hidden_size, :]
|
| 394 |
+
inputs = (inputs + res_acts) * padding_mask
|
| 395 |
+
outputs = outputs + res_skip_acts[:, self.hidden_size :, :]
|
| 396 |
+
else:
|
| 397 |
+
outputs = outputs + res_skip_acts
|
| 398 |
+
|
| 399 |
+
return outputs * padding_mask
|
| 400 |
+
|
| 401 |
+
def remove_weight_norm(self):
|
| 402 |
+
if self.speaker_embedding_size != 0:
|
| 403 |
+
torch.nn.utils.remove_weight_norm(self.cond_layer)
|
| 404 |
+
for layer in self.in_layers:
|
| 405 |
+
torch.nn.utils.remove_weight_norm(layer)
|
| 406 |
+
for layer in self.res_skip_layers:
|
| 407 |
+
torch.nn.utils.remove_weight_norm(layer)
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
class VitsPosteriorEncoder(nn.Module):
|
| 411 |
+
def __init__(self, config: VitsConfig):
|
| 412 |
+
super().__init__()
|
| 413 |
+
self.out_channels = config.flow_size
|
| 414 |
+
|
| 415 |
+
self.conv_pre = nn.Conv1d(config.spectrogram_bins, config.hidden_size, 1)
|
| 416 |
+
self.wavenet = VitsWaveNet(config, num_layers=config.posterior_encoder_num_wavenet_layers)
|
| 417 |
+
self.conv_proj = nn.Conv1d(config.hidden_size, self.out_channels * 2, 1)
|
| 418 |
+
|
| 419 |
+
def forward(self, inputs, padding_mask, global_conditioning=None):
|
| 420 |
+
inputs = self.conv_pre(inputs) * padding_mask
|
| 421 |
+
inputs = self.wavenet(inputs, padding_mask, global_conditioning)
|
| 422 |
+
stats = self.conv_proj(inputs) * padding_mask
|
| 423 |
+
mean, log_stddev = torch.split(stats, self.out_channels, dim=1)
|
| 424 |
+
sampled = (mean + torch.randn_like(mean) * torch.exp(log_stddev)) * padding_mask
|
| 425 |
+
return sampled, mean, log_stddev
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
# Copied from transformers.models.speecht5.modeling_speecht5.HifiGanResidualBlock
|
| 429 |
+
class HifiGanResidualBlock(nn.Module):
|
| 430 |
+
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), leaky_relu_slope=0.1):
|
| 431 |
+
super().__init__()
|
| 432 |
+
self.leaky_relu_slope = leaky_relu_slope
|
| 433 |
+
|
| 434 |
+
self.convs1 = nn.ModuleList(
|
| 435 |
+
[
|
| 436 |
+
nn.Conv1d(
|
| 437 |
+
channels,
|
| 438 |
+
channels,
|
| 439 |
+
kernel_size,
|
| 440 |
+
stride=1,
|
| 441 |
+
dilation=dilation[i],
|
| 442 |
+
padding=self.get_padding(kernel_size, dilation[i]),
|
| 443 |
+
)
|
| 444 |
+
for i in range(len(dilation))
|
| 445 |
+
]
|
| 446 |
+
)
|
| 447 |
+
self.convs2 = nn.ModuleList(
|
| 448 |
+
[
|
| 449 |
+
nn.Conv1d(
|
| 450 |
+
channels,
|
| 451 |
+
channels,
|
| 452 |
+
kernel_size,
|
| 453 |
+
stride=1,
|
| 454 |
+
dilation=1,
|
| 455 |
+
padding=self.get_padding(kernel_size, 1),
|
| 456 |
+
)
|
| 457 |
+
for _ in range(len(dilation))
|
| 458 |
+
]
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
def get_padding(self, kernel_size, dilation=1):
|
| 462 |
+
return (kernel_size * dilation - dilation) // 2
|
| 463 |
+
|
| 464 |
+
def apply_weight_norm(self):
|
| 465 |
+
weight_norm = nn.utils.weight_norm
|
| 466 |
+
if hasattr(nn.utils.parametrizations, "weight_norm"):
|
| 467 |
+
weight_norm = nn.utils.parametrizations.weight_norm
|
| 468 |
+
|
| 469 |
+
for layer in self.convs1:
|
| 470 |
+
weight_norm(layer)
|
| 471 |
+
for layer in self.convs2:
|
| 472 |
+
weight_norm(layer)
|
| 473 |
+
|
| 474 |
+
def remove_weight_norm(self):
|
| 475 |
+
for layer in self.convs1:
|
| 476 |
+
nn.utils.remove_weight_norm(layer)
|
| 477 |
+
for layer in self.convs2:
|
| 478 |
+
nn.utils.remove_weight_norm(layer)
|
| 479 |
+
|
| 480 |
+
def forward(self, hidden_states):
|
| 481 |
+
for conv1, conv2 in zip(self.convs1, self.convs2):
|
| 482 |
+
residual = hidden_states
|
| 483 |
+
hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope)
|
| 484 |
+
hidden_states = conv1(hidden_states)
|
| 485 |
+
hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope)
|
| 486 |
+
hidden_states = conv2(hidden_states)
|
| 487 |
+
hidden_states = hidden_states + residual
|
| 488 |
+
return hidden_states
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
class VitsHifiGan(nn.Module):
|
| 492 |
+
def __init__(self, config: VitsConfig):
|
| 493 |
+
super().__init__()
|
| 494 |
+
self.config = config
|
| 495 |
+
self.num_kernels = len(config.resblock_kernel_sizes)
|
| 496 |
+
self.num_upsamples = len(config.upsample_rates)
|
| 497 |
+
self.conv_pre = nn.Conv1d(
|
| 498 |
+
config.flow_size,
|
| 499 |
+
config.upsample_initial_channel,
|
| 500 |
+
kernel_size=7,
|
| 501 |
+
stride=1,
|
| 502 |
+
padding=3,
|
| 503 |
+
)
|
| 504 |
+
|
| 505 |
+
self.upsampler = nn.ModuleList()
|
| 506 |
+
for i, (upsample_rate, kernel_size) in enumerate(zip(config.upsample_rates, config.upsample_kernel_sizes)):
|
| 507 |
+
self.upsampler.append(
|
| 508 |
+
nn.ConvTranspose1d(
|
| 509 |
+
config.upsample_initial_channel // (2**i),
|
| 510 |
+
config.upsample_initial_channel // (2 ** (i + 1)),
|
| 511 |
+
kernel_size=kernel_size,
|
| 512 |
+
stride=upsample_rate,
|
| 513 |
+
padding=(kernel_size - upsample_rate) // 2,
|
| 514 |
+
)
|
| 515 |
+
)
|
| 516 |
+
|
| 517 |
+
self.resblocks = nn.ModuleList()
|
| 518 |
+
for i in range(len(self.upsampler)):
|
| 519 |
+
channels = config.upsample_initial_channel // (2 ** (i + 1))
|
| 520 |
+
for kernel_size, dilation in zip(config.resblock_kernel_sizes, config.resblock_dilation_sizes):
|
| 521 |
+
self.resblocks.append(HifiGanResidualBlock(channels, kernel_size, dilation, config.leaky_relu_slope))
|
| 522 |
+
|
| 523 |
+
self.conv_post = nn.Conv1d(channels, 1, kernel_size=7, stride=1, padding=3, bias=False)
|
| 524 |
+
|
| 525 |
+
if config.speaker_embedding_size != 0:
|
| 526 |
+
self.cond = nn.Conv1d(config.speaker_embedding_size, config.upsample_initial_channel, 1)
|
| 527 |
+
|
| 528 |
+
def apply_weight_norm(self):
|
| 529 |
+
weight_norm = nn.utils.weight_norm
|
| 530 |
+
if hasattr(nn.utils.parametrizations, "weight_norm"):
|
| 531 |
+
weight_norm = nn.utils.parametrizations.weight_norm
|
| 532 |
+
|
| 533 |
+
for layer in self.upsampler:
|
| 534 |
+
weight_norm(layer)
|
| 535 |
+
for layer in self.resblocks:
|
| 536 |
+
layer.apply_weight_norm()
|
| 537 |
+
|
| 538 |
+
def remove_weight_norm(self):
|
| 539 |
+
for layer in self.upsampler:
|
| 540 |
+
nn.utils.remove_weight_norm(layer)
|
| 541 |
+
for layer in self.resblocks:
|
| 542 |
+
layer.remove_weight_norm()
|
| 543 |
+
|
| 544 |
+
def forward(
|
| 545 |
+
self, spectrogram: torch.FloatTensor, global_conditioning: Optional[torch.FloatTensor] = None
|
| 546 |
+
) -> torch.FloatTensor:
|
| 547 |
+
r"""
|
| 548 |
+
Converts a spectrogram into a speech waveform.
|
| 549 |
+
|
| 550 |
+
Args:
|
| 551 |
+
spectrogram (`torch.FloatTensor` of shape `(batch_size, config.spectrogram_bins, sequence_length)`):
|
| 552 |
+
Tensor containing the spectrograms.
|
| 553 |
+
global_conditioning (`torch.FloatTensor` of shape `(batch_size, config.speaker_embedding_size, 1)`, *optional*):
|
| 554 |
+
Tensor containing speaker embeddings, for multispeaker models.
|
| 555 |
+
|
| 556 |
+
Returns:
|
| 557 |
+
`torch.FloatTensor`: Tensor of shape shape `(batch_size, 1, num_frames)` containing the speech waveform.
|
| 558 |
+
"""
|
| 559 |
+
hidden_states = self.conv_pre(spectrogram)
|
| 560 |
+
|
| 561 |
+
if global_conditioning is not None:
|
| 562 |
+
hidden_states = hidden_states + self.cond(global_conditioning)
|
| 563 |
+
|
| 564 |
+
for i in range(self.num_upsamples):
|
| 565 |
+
hidden_states = nn.functional.leaky_relu(hidden_states, self.config.leaky_relu_slope)
|
| 566 |
+
hidden_states = self.upsampler[i](hidden_states)
|
| 567 |
+
|
| 568 |
+
res_state = self.resblocks[i * self.num_kernels](hidden_states)
|
| 569 |
+
for j in range(1, self.num_kernels):
|
| 570 |
+
res_state += self.resblocks[i * self.num_kernels + j](hidden_states)
|
| 571 |
+
hidden_states = res_state / self.num_kernels
|
| 572 |
+
|
| 573 |
+
hidden_states = nn.functional.leaky_relu(hidden_states)
|
| 574 |
+
hidden_states = self.conv_post(hidden_states)
|
| 575 |
+
waveform = torch.tanh(hidden_states)
|
| 576 |
+
return waveform
|
| 577 |
+
|
| 578 |
+
|
| 579 |
+
class VitsResidualCouplingLayer(nn.Module):
|
| 580 |
+
def __init__(self, config: VitsConfig):
|
| 581 |
+
super().__init__()
|
| 582 |
+
self.half_channels = config.flow_size // 2
|
| 583 |
+
|
| 584 |
+
self.conv_pre = nn.Conv1d(self.half_channels, config.hidden_size, 1)
|
| 585 |
+
self.wavenet = VitsWaveNet(config, num_layers=config.prior_encoder_num_wavenet_layers)
|
| 586 |
+
self.conv_post = nn.Conv1d(config.hidden_size, self.half_channels, 1)
|
| 587 |
+
|
| 588 |
+
def forward(self, inputs, padding_mask, global_conditioning=None, reverse=False):
|
| 589 |
+
first_half, second_half = torch.split(inputs, [self.half_channels] * 2, dim=1)
|
| 590 |
+
hidden_states = self.conv_pre(first_half) * padding_mask
|
| 591 |
+
hidden_states = self.wavenet(hidden_states, padding_mask, global_conditioning)
|
| 592 |
+
mean = self.conv_post(hidden_states) * padding_mask
|
| 593 |
+
log_stddev = torch.zeros_like(mean)
|
| 594 |
+
|
| 595 |
+
if not reverse:
|
| 596 |
+
second_half = mean + second_half * torch.exp(log_stddev) * padding_mask
|
| 597 |
+
outputs = torch.cat([first_half, second_half], dim=1)
|
| 598 |
+
log_determinant = torch.sum(log_stddev, [1, 2])
|
| 599 |
+
return outputs, log_determinant
|
| 600 |
+
else:
|
| 601 |
+
second_half = (second_half - mean) * torch.exp(-log_stddev) * padding_mask
|
| 602 |
+
outputs = torch.cat([first_half, second_half], dim=1)
|
| 603 |
+
return outputs, None
|
| 604 |
+
|
| 605 |
+
|
| 606 |
+
class VitsResidualCouplingBlock(nn.Module):
|
| 607 |
+
def __init__(self, config: VitsConfig):
|
| 608 |
+
super().__init__()
|
| 609 |
+
self.flows = nn.ModuleList()
|
| 610 |
+
for _ in range(config.prior_encoder_num_flows):
|
| 611 |
+
self.flows.append(VitsResidualCouplingLayer(config))
|
| 612 |
+
|
| 613 |
+
def forward(self, inputs, padding_mask, global_conditioning=None, reverse=False):
|
| 614 |
+
if not reverse:
|
| 615 |
+
for flow in self.flows:
|
| 616 |
+
inputs, _ = flow(inputs, padding_mask, global_conditioning)
|
| 617 |
+
inputs = torch.flip(inputs, [1])
|
| 618 |
+
else:
|
| 619 |
+
for flow in reversed(self.flows):
|
| 620 |
+
inputs = torch.flip(inputs, [1])
|
| 621 |
+
inputs, _ = flow(inputs, padding_mask, global_conditioning, reverse=True)
|
| 622 |
+
return inputs
|
| 623 |
+
|
| 624 |
+
|
| 625 |
+
class VitsDilatedDepthSeparableConv(nn.Module):
|
| 626 |
+
def __init__(self, config: VitsConfig, dropout_rate=0.0):
|
| 627 |
+
super().__init__()
|
| 628 |
+
kernel_size = config.duration_predictor_kernel_size
|
| 629 |
+
channels = config.hidden_size
|
| 630 |
+
self.num_layers = config.depth_separable_num_layers
|
| 631 |
+
|
| 632 |
+
self.dropout = nn.Dropout(dropout_rate)
|
| 633 |
+
self.convs_dilated = nn.ModuleList()
|
| 634 |
+
self.convs_pointwise = nn.ModuleList()
|
| 635 |
+
self.norms_1 = nn.ModuleList()
|
| 636 |
+
self.norms_2 = nn.ModuleList()
|
| 637 |
+
for i in range(self.num_layers):
|
| 638 |
+
dilation = kernel_size**i
|
| 639 |
+
padding = (kernel_size * dilation - dilation) // 2
|
| 640 |
+
self.convs_dilated.append(
|
| 641 |
+
nn.Conv1d(
|
| 642 |
+
in_channels=channels,
|
| 643 |
+
out_channels=channels,
|
| 644 |
+
kernel_size=kernel_size,
|
| 645 |
+
groups=channels,
|
| 646 |
+
dilation=dilation,
|
| 647 |
+
padding=padding,
|
| 648 |
+
)
|
| 649 |
+
)
|
| 650 |
+
self.convs_pointwise.append(nn.Conv1d(channels, channels, 1))
|
| 651 |
+
self.norms_1.append(nn.LayerNorm(channels))
|
| 652 |
+
self.norms_2.append(nn.LayerNorm(channels))
|
| 653 |
+
|
| 654 |
+
def forward(self, inputs, padding_mask, global_conditioning=None):
|
| 655 |
+
if global_conditioning is not None:
|
| 656 |
+
inputs = inputs + global_conditioning
|
| 657 |
+
|
| 658 |
+
for i in range(self.num_layers):
|
| 659 |
+
hidden_states = self.convs_dilated[i](inputs * padding_mask)
|
| 660 |
+
hidden_states = self.norms_1[i](hidden_states.transpose(1, -1)).transpose(1, -1)
|
| 661 |
+
hidden_states = nn.functional.gelu(hidden_states)
|
| 662 |
+
hidden_states = self.convs_pointwise[i](hidden_states)
|
| 663 |
+
hidden_states = self.norms_2[i](hidden_states.transpose(1, -1)).transpose(1, -1)
|
| 664 |
+
hidden_states = nn.functional.gelu(hidden_states)
|
| 665 |
+
hidden_states = self.dropout(hidden_states)
|
| 666 |
+
inputs = inputs + hidden_states
|
| 667 |
+
|
| 668 |
+
return inputs * padding_mask
|
| 669 |
+
|
| 670 |
+
|
| 671 |
+
class VitsConvFlow(nn.Module):
|
| 672 |
+
def __init__(self, config: VitsConfig):
|
| 673 |
+
super().__init__()
|
| 674 |
+
self.filter_channels = config.hidden_size
|
| 675 |
+
self.half_channels = config.depth_separable_channels // 2
|
| 676 |
+
self.num_bins = config.duration_predictor_flow_bins
|
| 677 |
+
self.tail_bound = config.duration_predictor_tail_bound
|
| 678 |
+
|
| 679 |
+
self.conv_pre = nn.Conv1d(self.half_channels, self.filter_channels, 1)
|
| 680 |
+
self.conv_dds = VitsDilatedDepthSeparableConv(config)
|
| 681 |
+
self.conv_proj = nn.Conv1d(self.filter_channels, self.half_channels * (self.num_bins * 3 - 1), 1)
|
| 682 |
+
|
| 683 |
+
def forward(self, inputs, padding_mask, global_conditioning=None, reverse=False):
|
| 684 |
+
first_half, second_half = torch.split(inputs, [self.half_channels] * 2, dim=1)
|
| 685 |
+
|
| 686 |
+
hidden_states = self.conv_pre(first_half)
|
| 687 |
+
hidden_states = self.conv_dds(hidden_states, padding_mask, global_conditioning)
|
| 688 |
+
hidden_states = self.conv_proj(hidden_states) * padding_mask
|
| 689 |
+
|
| 690 |
+
batch_size, channels, length = first_half.shape
|
| 691 |
+
hidden_states = hidden_states.reshape(batch_size, channels, -1, length).permute(0, 1, 3, 2)
|
| 692 |
+
|
| 693 |
+
unnormalized_widths = hidden_states[..., : self.num_bins] / math.sqrt(self.filter_channels)
|
| 694 |
+
unnormalized_heights = hidden_states[..., self.num_bins : 2 * self.num_bins] / math.sqrt(self.filter_channels)
|
| 695 |
+
unnormalized_derivatives = hidden_states[..., 2 * self.num_bins :]
|
| 696 |
+
|
| 697 |
+
second_half, log_abs_det = _unconstrained_rational_quadratic_spline(
|
| 698 |
+
second_half,
|
| 699 |
+
unnormalized_widths,
|
| 700 |
+
unnormalized_heights,
|
| 701 |
+
unnormalized_derivatives,
|
| 702 |
+
reverse=reverse,
|
| 703 |
+
tail_bound=self.tail_bound,
|
| 704 |
+
)
|
| 705 |
+
|
| 706 |
+
outputs = torch.cat([first_half, second_half], dim=1) * padding_mask
|
| 707 |
+
if not reverse:
|
| 708 |
+
log_determinant = torch.sum(log_abs_det * padding_mask, [1, 2])
|
| 709 |
+
return outputs, log_determinant
|
| 710 |
+
else:
|
| 711 |
+
return outputs, None
|
| 712 |
+
|
| 713 |
+
|
| 714 |
+
class VitsElementwiseAffine(nn.Module):
|
| 715 |
+
def __init__(self, config: VitsConfig):
|
| 716 |
+
super().__init__()
|
| 717 |
+
self.channels = config.depth_separable_channels
|
| 718 |
+
self.translate = nn.Parameter(torch.zeros(self.channels, 1))
|
| 719 |
+
self.log_scale = nn.Parameter(torch.zeros(self.channels, 1))
|
| 720 |
+
|
| 721 |
+
def forward(self, inputs, padding_mask, global_conditioning=None, reverse=False):
|
| 722 |
+
if not reverse:
|
| 723 |
+
outputs = self.translate + torch.exp(self.log_scale) * inputs
|
| 724 |
+
outputs = outputs * padding_mask
|
| 725 |
+
log_determinant = torch.sum(self.log_scale * padding_mask, [1, 2])
|
| 726 |
+
return outputs, log_determinant
|
| 727 |
+
else:
|
| 728 |
+
outputs = (inputs - self.translate) * torch.exp(-self.log_scale) * padding_mask
|
| 729 |
+
return outputs, None
|
| 730 |
+
|
| 731 |
+
|
| 732 |
+
class VitsStochasticDurationPredictor(nn.Module):
|
| 733 |
+
def __init__(self, config):
|
| 734 |
+
super().__init__()
|
| 735 |
+
embed_dim = config.speaker_embedding_size
|
| 736 |
+
filter_channels = config.hidden_size
|
| 737 |
+
|
| 738 |
+
self.conv_pre = nn.Conv1d(filter_channels, filter_channels, 1)
|
| 739 |
+
self.conv_proj = nn.Conv1d(filter_channels, filter_channels, 1)
|
| 740 |
+
self.conv_dds = VitsDilatedDepthSeparableConv(
|
| 741 |
+
config,
|
| 742 |
+
dropout_rate=config.duration_predictor_dropout,
|
| 743 |
+
)
|
| 744 |
+
|
| 745 |
+
if embed_dim != 0:
|
| 746 |
+
self.cond = nn.Conv1d(embed_dim, filter_channels, 1)
|
| 747 |
+
|
| 748 |
+
self.flows = nn.ModuleList()
|
| 749 |
+
self.flows.append(VitsElementwiseAffine(config))
|
| 750 |
+
for _ in range(config.duration_predictor_num_flows):
|
| 751 |
+
self.flows.append(VitsConvFlow(config))
|
| 752 |
+
|
| 753 |
+
self.post_conv_pre = nn.Conv1d(1, filter_channels, 1)
|
| 754 |
+
self.post_conv_proj = nn.Conv1d(filter_channels, filter_channels, 1)
|
| 755 |
+
self.post_conv_dds = VitsDilatedDepthSeparableConv(
|
| 756 |
+
config,
|
| 757 |
+
dropout_rate=config.duration_predictor_dropout,
|
| 758 |
+
)
|
| 759 |
+
|
| 760 |
+
self.post_flows = nn.ModuleList()
|
| 761 |
+
self.post_flows.append(VitsElementwiseAffine(config))
|
| 762 |
+
for _ in range(config.duration_predictor_num_flows):
|
| 763 |
+
self.post_flows.append(VitsConvFlow(config))
|
| 764 |
+
|
| 765 |
+
def forward(self, inputs, padding_mask, global_conditioning=None, durations=None, reverse=False, noise_scale=1.0):
|
| 766 |
+
inputs = torch.detach(inputs)
|
| 767 |
+
inputs = self.conv_pre(inputs)
|
| 768 |
+
|
| 769 |
+
if global_conditioning is not None:
|
| 770 |
+
global_conditioning = torch.detach(global_conditioning)
|
| 771 |
+
inputs = inputs + self.cond(global_conditioning)
|
| 772 |
+
|
| 773 |
+
inputs = self.conv_dds(inputs, padding_mask)
|
| 774 |
+
inputs = self.conv_proj(inputs) * padding_mask
|
| 775 |
+
|
| 776 |
+
if not reverse:
|
| 777 |
+
hidden_states = self.post_conv_pre(durations)
|
| 778 |
+
hidden_states = self.post_conv_dds(hidden_states, padding_mask)
|
| 779 |
+
hidden_states = self.post_conv_proj(hidden_states) * padding_mask
|
| 780 |
+
|
| 781 |
+
random_posterior = (
|
| 782 |
+
torch.randn(durations.size(0), 2, durations.size(2)).to(device=inputs.device, dtype=inputs.dtype)
|
| 783 |
+
* padding_mask
|
| 784 |
+
)
|
| 785 |
+
log_determinant_posterior_sum = 0
|
| 786 |
+
latents_posterior = random_posterior
|
| 787 |
+
for flow in self.post_flows:
|
| 788 |
+
latents_posterior, log_determinant = flow(
|
| 789 |
+
latents_posterior, padding_mask, global_conditioning=inputs + hidden_states
|
| 790 |
+
)
|
| 791 |
+
latents_posterior = torch.flip(latents_posterior, [1])
|
| 792 |
+
log_determinant_posterior_sum += log_determinant
|
| 793 |
+
|
| 794 |
+
first_half, second_half = torch.split(latents_posterior, [1, 1], dim=1)
|
| 795 |
+
|
| 796 |
+
log_determinant_posterior_sum += torch.sum(
|
| 797 |
+
(nn.functional.logsigmoid(first_half) + nn.functional.logsigmoid(-first_half)) * padding_mask, [1, 2]
|
| 798 |
+
)
|
| 799 |
+
logq = (
|
| 800 |
+
torch.sum(-0.5 * (math.log(2 * math.pi) + (random_posterior**2)) * padding_mask, [1, 2])
|
| 801 |
+
- log_determinant_posterior_sum
|
| 802 |
+
)
|
| 803 |
+
|
| 804 |
+
first_half = (durations - torch.sigmoid(first_half)) * padding_mask
|
| 805 |
+
first_half = torch.log(torch.clamp_min(first_half, 1e-5)) * padding_mask
|
| 806 |
+
log_determinant_sum = torch.sum(-first_half, [1, 2])
|
| 807 |
+
|
| 808 |
+
latents = torch.cat([first_half, second_half], dim=1)
|
| 809 |
+
for flow in self.flows:
|
| 810 |
+
latents, log_determinant = flow(latents, padding_mask, global_conditioning=inputs)
|
| 811 |
+
latents = torch.flip(latents, [1])
|
| 812 |
+
log_determinant_sum += log_determinant
|
| 813 |
+
|
| 814 |
+
nll = torch.sum(0.5 * (math.log(2 * math.pi) + (latents**2)) * padding_mask, [1, 2]) - log_determinant_sum
|
| 815 |
+
return nll + logq
|
| 816 |
+
else:
|
| 817 |
+
flows = list(reversed(self.flows))
|
| 818 |
+
flows = flows[:-2] + [flows[-1]] # remove a useless vflow
|
| 819 |
+
|
| 820 |
+
latents = (
|
| 821 |
+
torch.randn(inputs.size(0), 2, inputs.size(2)).to(device=inputs.device, dtype=inputs.dtype)
|
| 822 |
+
* noise_scale
|
| 823 |
+
)
|
| 824 |
+
for flow in flows:
|
| 825 |
+
latents = torch.flip(latents, [1])
|
| 826 |
+
latents, _ = flow(latents, padding_mask, global_conditioning=inputs, reverse=True)
|
| 827 |
+
|
| 828 |
+
log_duration, _ = torch.split(latents, [1, 1], dim=1)
|
| 829 |
+
return log_duration
|
| 830 |
+
|
| 831 |
+
|
| 832 |
+
class VitsDurationPredictor(nn.Module):
|
| 833 |
+
def __init__(self, config):
|
| 834 |
+
super().__init__()
|
| 835 |
+
kernel_size = config.duration_predictor_kernel_size
|
| 836 |
+
filter_channels = config.duration_predictor_filter_channels
|
| 837 |
+
|
| 838 |
+
self.dropout = nn.Dropout(config.duration_predictor_dropout)
|
| 839 |
+
self.conv_1 = nn.Conv1d(config.hidden_size, filter_channels, kernel_size, padding=kernel_size // 2)
|
| 840 |
+
self.norm_1 = nn.LayerNorm(filter_channels, eps=config.layer_norm_eps)
|
| 841 |
+
self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
|
| 842 |
+
self.norm_2 = nn.LayerNorm(filter_channels, eps=config.layer_norm_eps)
|
| 843 |
+
self.proj = nn.Conv1d(filter_channels, 1, 1)
|
| 844 |
+
|
| 845 |
+
if config.speaker_embedding_size != 0:
|
| 846 |
+
self.cond = nn.Conv1d(config.speaker_embedding_size, config.hidden_size, 1)
|
| 847 |
+
|
| 848 |
+
def forward(self, inputs, padding_mask, global_conditioning=None):
|
| 849 |
+
inputs = torch.detach(inputs)
|
| 850 |
+
|
| 851 |
+
if global_conditioning is not None:
|
| 852 |
+
global_conditioning = torch.detach(global_conditioning)
|
| 853 |
+
inputs = inputs + self.cond(global_conditioning)
|
| 854 |
+
|
| 855 |
+
inputs = self.conv_1(inputs * padding_mask)
|
| 856 |
+
inputs = torch.relu(inputs)
|
| 857 |
+
inputs = self.norm_1(inputs.transpose(1, -1)).transpose(1, -1)
|
| 858 |
+
inputs = self.dropout(inputs)
|
| 859 |
+
|
| 860 |
+
inputs = self.conv_2(inputs * padding_mask)
|
| 861 |
+
inputs = torch.relu(inputs)
|
| 862 |
+
inputs = self.norm_2(inputs.transpose(1, -1)).transpose(1, -1)
|
| 863 |
+
inputs = self.dropout(inputs)
|
| 864 |
+
|
| 865 |
+
inputs = self.proj(inputs * padding_mask)
|
| 866 |
+
return inputs * padding_mask
|
| 867 |
+
|
| 868 |
+
|
| 869 |
+
class VitsAttention(nn.Module):
|
| 870 |
+
"""Multi-headed attention with relative positional representation."""
|
| 871 |
+
|
| 872 |
+
def __init__(self, config: VitsConfig):
|
| 873 |
+
super().__init__()
|
| 874 |
+
self.embed_dim = config.hidden_size
|
| 875 |
+
self.num_heads = config.num_attention_heads
|
| 876 |
+
self.dropout = config.attention_dropout
|
| 877 |
+
self.window_size = config.window_size
|
| 878 |
+
|
| 879 |
+
self.head_dim = self.embed_dim // self.num_heads
|
| 880 |
+
self.scaling = self.head_dim**-0.5
|
| 881 |
+
|
| 882 |
+
if (self.head_dim * self.num_heads) != self.embed_dim:
|
| 883 |
+
raise ValueError(
|
| 884 |
+
f"hidden_size must be divisible by num_attention_heads (got `hidden_size`: {self.embed_dim}"
|
| 885 |
+
f" and `num_attention_heads`: {self.num_heads})."
|
| 886 |
+
)
|
| 887 |
+
|
| 888 |
+
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias)
|
| 889 |
+
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias)
|
| 890 |
+
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias)
|
| 891 |
+
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias)
|
| 892 |
+
|
| 893 |
+
if self.window_size:
|
| 894 |
+
self.emb_rel_k = nn.Parameter(torch.randn(1, self.window_size * 2 + 1, self.head_dim) * self.scaling)
|
| 895 |
+
self.emb_rel_v = nn.Parameter(torch.randn(1, self.window_size * 2 + 1, self.head_dim) * self.scaling)
|
| 896 |
+
|
| 897 |
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
| 898 |
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
| 899 |
+
|
| 900 |
+
def forward(
|
| 901 |
+
self,
|
| 902 |
+
hidden_states: torch.Tensor,
|
| 903 |
+
key_value_states: Optional[torch.Tensor] = None,
|
| 904 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 905 |
+
layer_head_mask: Optional[torch.Tensor] = None,
|
| 906 |
+
output_attentions: bool = False,
|
| 907 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 908 |
+
"""Input shape: Batch x Time x Channel"""
|
| 909 |
+
|
| 910 |
+
# if key_value_states are provided this layer is used as a cross-attention layer
|
| 911 |
+
# for the decoder
|
| 912 |
+
|
| 913 |
+
bsz, tgt_len, _ = hidden_states.size()
|
| 914 |
+
|
| 915 |
+
# get query proj
|
| 916 |
+
query_states = self.q_proj(hidden_states) * self.scaling
|
| 917 |
+
|
| 918 |
+
# self_attention
|
| 919 |
+
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
| 920 |
+
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
| 921 |
+
|
| 922 |
+
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
| 923 |
+
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
| 924 |
+
key_states = key_states.view(*proj_shape)
|
| 925 |
+
value_states = value_states.view(*proj_shape)
|
| 926 |
+
|
| 927 |
+
src_len = key_states.size(1)
|
| 928 |
+
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
| 929 |
+
|
| 930 |
+
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
| 931 |
+
raise ValueError(
|
| 932 |
+
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
| 933 |
+
f" {attn_weights.size()}"
|
| 934 |
+
)
|
| 935 |
+
|
| 936 |
+
if self.window_size is not None:
|
| 937 |
+
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, src_len)
|
| 938 |
+
relative_logits = torch.matmul(query_states, key_relative_embeddings.transpose(-2, -1))
|
| 939 |
+
rel_pos_bias = self._relative_position_to_absolute_position(relative_logits)
|
| 940 |
+
attn_weights += rel_pos_bias
|
| 941 |
+
|
| 942 |
+
if attention_mask is not None:
|
| 943 |
+
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
| 944 |
+
raise ValueError(
|
| 945 |
+
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
| 946 |
+
)
|
| 947 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
| 948 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
| 949 |
+
|
| 950 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
| 951 |
+
|
| 952 |
+
if layer_head_mask is not None:
|
| 953 |
+
if layer_head_mask.size() != (self.num_heads,):
|
| 954 |
+
raise ValueError(
|
| 955 |
+
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
|
| 956 |
+
f" {layer_head_mask.size()}"
|
| 957 |
+
)
|
| 958 |
+
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
| 959 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
| 960 |
+
|
| 961 |
+
if output_attentions:
|
| 962 |
+
# this operation is a bit awkward, but it's required to
|
| 963 |
+
# make sure that attn_weights keeps its gradient.
|
| 964 |
+
# In order to do so, attn_weights have to be reshaped
|
| 965 |
+
# twice and have to be reused in the following
|
| 966 |
+
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
| 967 |
+
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
|
| 968 |
+
else:
|
| 969 |
+
attn_weights_reshaped = None
|
| 970 |
+
|
| 971 |
+
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
| 972 |
+
|
| 973 |
+
attn_output = torch.bmm(attn_probs, value_states)
|
| 974 |
+
|
| 975 |
+
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
| 976 |
+
raise ValueError(
|
| 977 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
| 978 |
+
f" {attn_output.size()}"
|
| 979 |
+
)
|
| 980 |
+
|
| 981 |
+
if self.window_size is not None:
|
| 982 |
+
value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, src_len)
|
| 983 |
+
relative_weights = self._absolute_position_to_relative_position(attn_probs)
|
| 984 |
+
rel_pos_bias = torch.matmul(relative_weights, value_relative_embeddings)
|
| 985 |
+
attn_output += rel_pos_bias
|
| 986 |
+
|
| 987 |
+
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
| 988 |
+
attn_output = attn_output.transpose(1, 2)
|
| 989 |
+
|
| 990 |
+
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
|
| 991 |
+
# partitioned aross GPUs when using tensor-parallelism.
|
| 992 |
+
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
| 993 |
+
|
| 994 |
+
attn_output = self.out_proj(attn_output)
|
| 995 |
+
|
| 996 |
+
return attn_output, attn_weights_reshaped
|
| 997 |
+
|
| 998 |
+
def _get_relative_embeddings(self, relative_embeddings, length):
|
| 999 |
+
pad_length = max(length - (self.window_size + 1), 0)
|
| 1000 |
+
if pad_length > 0:
|
| 1001 |
+
relative_embeddings = nn.functional.pad(relative_embeddings, [0, 0, pad_length, pad_length, 0, 0])
|
| 1002 |
+
|
| 1003 |
+
slice_start_position = max((self.window_size + 1) - length, 0)
|
| 1004 |
+
slice_end_position = slice_start_position + 2 * length - 1
|
| 1005 |
+
return relative_embeddings[:, slice_start_position:slice_end_position]
|
| 1006 |
+
|
| 1007 |
+
def _relative_position_to_absolute_position(self, x):
|
| 1008 |
+
batch_heads, length, _ = x.size()
|
| 1009 |
+
|
| 1010 |
+
# Concat columns of pad to shift from relative to absolute indexing.
|
| 1011 |
+
x = nn.functional.pad(x, [0, 1, 0, 0, 0, 0])
|
| 1012 |
+
|
| 1013 |
+
# Concat extra elements so to add up to shape (len+1, 2*len-1).
|
| 1014 |
+
x_flat = x.view([batch_heads, length * 2 * length])
|
| 1015 |
+
x_flat = nn.functional.pad(x_flat, [0, length - 1, 0, 0])
|
| 1016 |
+
|
| 1017 |
+
# Reshape and slice out the padded elements.
|
| 1018 |
+
x_final = x_flat.view([batch_heads, length + 1, 2 * length - 1])
|
| 1019 |
+
x_final = x_final[:, :length, length - 1 :]
|
| 1020 |
+
return x_final
|
| 1021 |
+
|
| 1022 |
+
def _absolute_position_to_relative_position(self, x):
|
| 1023 |
+
batch_heads, length, _ = x.size()
|
| 1024 |
+
|
| 1025 |
+
# Pad along column
|
| 1026 |
+
x = nn.functional.pad(x, [0, length - 1, 0, 0, 0, 0])
|
| 1027 |
+
x_flat = x.view([batch_heads, length * (2 * length - 1)])
|
| 1028 |
+
|
| 1029 |
+
# Add 0's in the beginning that will skew the elements after reshape
|
| 1030 |
+
x_flat = nn.functional.pad(x_flat, [length, 0, 0, 0])
|
| 1031 |
+
x_final = x_flat.view([batch_heads, length, 2 * length])[:, :, 1:]
|
| 1032 |
+
return x_final
|
| 1033 |
+
|
| 1034 |
+
|
| 1035 |
+
class VitsFeedForward(nn.Module):
|
| 1036 |
+
def __init__(self, config):
|
| 1037 |
+
super().__init__()
|
| 1038 |
+
self.conv_1 = nn.Conv1d(config.hidden_size, config.ffn_dim, config.ffn_kernel_size)
|
| 1039 |
+
self.conv_2 = nn.Conv1d(config.ffn_dim, config.hidden_size, config.ffn_kernel_size)
|
| 1040 |
+
self.dropout = nn.Dropout(config.activation_dropout)
|
| 1041 |
+
|
| 1042 |
+
if isinstance(config.hidden_act, str):
|
| 1043 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
| 1044 |
+
else:
|
| 1045 |
+
self.act_fn = config.hidden_act
|
| 1046 |
+
|
| 1047 |
+
if config.ffn_kernel_size > 1:
|
| 1048 |
+
pad_left = (config.ffn_kernel_size - 1) // 2
|
| 1049 |
+
pad_right = config.ffn_kernel_size // 2
|
| 1050 |
+
self.padding = [pad_left, pad_right, 0, 0, 0, 0]
|
| 1051 |
+
else:
|
| 1052 |
+
self.padding = None
|
| 1053 |
+
|
| 1054 |
+
def forward(self, hidden_states, padding_mask):
|
| 1055 |
+
hidden_states = hidden_states.permute(0, 2, 1)
|
| 1056 |
+
padding_mask = padding_mask.permute(0, 2, 1)
|
| 1057 |
+
|
| 1058 |
+
hidden_states = hidden_states * padding_mask
|
| 1059 |
+
if self.padding is not None:
|
| 1060 |
+
hidden_states = nn.functional.pad(hidden_states, self.padding)
|
| 1061 |
+
|
| 1062 |
+
hidden_states = self.conv_1(hidden_states)
|
| 1063 |
+
hidden_states = self.act_fn(hidden_states)
|
| 1064 |
+
hidden_states = self.dropout(hidden_states)
|
| 1065 |
+
|
| 1066 |
+
hidden_states = hidden_states * padding_mask
|
| 1067 |
+
if self.padding is not None:
|
| 1068 |
+
hidden_states = nn.functional.pad(hidden_states, self.padding)
|
| 1069 |
+
|
| 1070 |
+
hidden_states = self.conv_2(hidden_states)
|
| 1071 |
+
hidden_states = hidden_states * padding_mask
|
| 1072 |
+
|
| 1073 |
+
hidden_states = hidden_states.permute(0, 2, 1)
|
| 1074 |
+
return hidden_states
|
| 1075 |
+
|
| 1076 |
+
|
| 1077 |
+
class VitsEncoderLayer(nn.Module):
|
| 1078 |
+
def __init__(self, config: VitsConfig):
|
| 1079 |
+
super().__init__()
|
| 1080 |
+
self.attention = VitsAttention(config)
|
| 1081 |
+
self.dropout = nn.Dropout(config.hidden_dropout)
|
| 1082 |
+
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 1083 |
+
self.feed_forward = VitsFeedForward(config)
|
| 1084 |
+
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 1085 |
+
|
| 1086 |
+
def forward(
|
| 1087 |
+
self,
|
| 1088 |
+
hidden_states: torch.Tensor,
|
| 1089 |
+
padding_mask: torch.FloatTensor,
|
| 1090 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1091 |
+
output_attentions: bool = False,
|
| 1092 |
+
):
|
| 1093 |
+
residual = hidden_states
|
| 1094 |
+
hidden_states, attn_weights = self.attention(
|
| 1095 |
+
hidden_states=hidden_states,
|
| 1096 |
+
attention_mask=attention_mask,
|
| 1097 |
+
output_attentions=output_attentions,
|
| 1098 |
+
)
|
| 1099 |
+
|
| 1100 |
+
hidden_states = self.dropout(hidden_states)
|
| 1101 |
+
hidden_states = self.layer_norm(residual + hidden_states)
|
| 1102 |
+
|
| 1103 |
+
residual = hidden_states
|
| 1104 |
+
hidden_states = self.feed_forward(hidden_states, padding_mask)
|
| 1105 |
+
hidden_states = self.dropout(hidden_states)
|
| 1106 |
+
hidden_states = self.final_layer_norm(residual + hidden_states)
|
| 1107 |
+
|
| 1108 |
+
outputs = (hidden_states,)
|
| 1109 |
+
|
| 1110 |
+
if output_attentions:
|
| 1111 |
+
outputs += (attn_weights,)
|
| 1112 |
+
|
| 1113 |
+
return outputs
|
| 1114 |
+
|
| 1115 |
+
|
| 1116 |
+
class VitsEncoder(nn.Module):
|
| 1117 |
+
def __init__(self, config: VitsConfig):
|
| 1118 |
+
super().__init__()
|
| 1119 |
+
self.config = config
|
| 1120 |
+
self.layers = nn.ModuleList([VitsEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
| 1121 |
+
self.gradient_checkpointing = False
|
| 1122 |
+
self.layerdrop = config.layerdrop
|
| 1123 |
+
|
| 1124 |
+
def forward(
|
| 1125 |
+
self,
|
| 1126 |
+
hidden_states: torch.FloatTensor,
|
| 1127 |
+
padding_mask: torch.FloatTensor,
|
| 1128 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1129 |
+
output_attentions: Optional[bool] = None,
|
| 1130 |
+
output_hidden_states: Optional[bool] = None,
|
| 1131 |
+
return_dict: Optional[bool] = None,
|
| 1132 |
+
) -> Union[Tuple, BaseModelOutput]:
|
| 1133 |
+
all_hidden_states = () if output_hidden_states else None
|
| 1134 |
+
all_self_attentions = () if output_attentions else None
|
| 1135 |
+
|
| 1136 |
+
# expand attention_mask
|
| 1137 |
+
if attention_mask is not None:
|
| 1138 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
| 1139 |
+
attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
|
| 1140 |
+
|
| 1141 |
+
hidden_states = hidden_states * padding_mask
|
| 1142 |
+
|
| 1143 |
+
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
|
| 1144 |
+
|
| 1145 |
+
for encoder_layer in self.layers:
|
| 1146 |
+
if output_hidden_states:
|
| 1147 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 1148 |
+
|
| 1149 |
+
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
| 1150 |
+
dropout_probability = np.random.uniform(0, 1)
|
| 1151 |
+
|
| 1152 |
+
skip_the_layer = self.training and (dropout_probability < self.layerdrop)
|
| 1153 |
+
if not skip_the_layer or synced_gpus:
|
| 1154 |
+
# under fsdp or deepspeed zero3 all gpus must run in sync
|
| 1155 |
+
if self.gradient_checkpointing and self.training:
|
| 1156 |
+
layer_outputs = self._gradient_checkpointing_func(
|
| 1157 |
+
encoder_layer.__call__,
|
| 1158 |
+
hidden_states,
|
| 1159 |
+
padding_mask,
|
| 1160 |
+
attention_mask,
|
| 1161 |
+
output_attentions,
|
| 1162 |
+
)
|
| 1163 |
+
else:
|
| 1164 |
+
layer_outputs = encoder_layer(
|
| 1165 |
+
hidden_states,
|
| 1166 |
+
attention_mask=attention_mask,
|
| 1167 |
+
padding_mask=padding_mask,
|
| 1168 |
+
output_attentions=output_attentions,
|
| 1169 |
+
)
|
| 1170 |
+
hidden_states = layer_outputs[0]
|
| 1171 |
+
|
| 1172 |
+
if skip_the_layer:
|
| 1173 |
+
layer_outputs = (None, None)
|
| 1174 |
+
|
| 1175 |
+
if output_attentions:
|
| 1176 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
| 1177 |
+
|
| 1178 |
+
hidden_states = hidden_states * padding_mask
|
| 1179 |
+
|
| 1180 |
+
if output_hidden_states:
|
| 1181 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 1182 |
+
|
| 1183 |
+
if not return_dict:
|
| 1184 |
+
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
|
| 1185 |
+
|
| 1186 |
+
return BaseModelOutput(
|
| 1187 |
+
last_hidden_state=hidden_states,
|
| 1188 |
+
hidden_states=all_hidden_states,
|
| 1189 |
+
attentions=all_self_attentions,
|
| 1190 |
+
)
|
| 1191 |
+
|
| 1192 |
+
|
| 1193 |
+
class VitsTextEncoder(nn.Module):
|
| 1194 |
+
"""
|
| 1195 |
+
Transformer encoder that uses relative positional representation instead of absolute positional encoding.
|
| 1196 |
+
"""
|
| 1197 |
+
|
| 1198 |
+
def __init__(self, config: VitsConfig):
|
| 1199 |
+
super().__init__()
|
| 1200 |
+
self.config = config
|
| 1201 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
|
| 1202 |
+
self.encoder = VitsEncoder(config)
|
| 1203 |
+
self.project = nn.Conv1d(config.hidden_size, config.flow_size * 2, kernel_size=1)
|
| 1204 |
+
|
| 1205 |
+
def get_input_embeddings(self):
|
| 1206 |
+
return self.embed_tokens
|
| 1207 |
+
|
| 1208 |
+
def set_input_embeddings(self, value):
|
| 1209 |
+
self.embed_tokens = value
|
| 1210 |
+
|
| 1211 |
+
def forward(
|
| 1212 |
+
self,
|
| 1213 |
+
input_ids: torch.Tensor,
|
| 1214 |
+
padding_mask: torch.FloatTensor,
|
| 1215 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1216 |
+
output_attentions: Optional[bool] = None,
|
| 1217 |
+
output_hidden_states: Optional[bool] = None,
|
| 1218 |
+
return_dict: Optional[bool] = True,
|
| 1219 |
+
) -> Union[Tuple[torch.Tensor], VitsTextEncoderOutput]:
|
| 1220 |
+
hidden_states = self.embed_tokens(input_ids) * math.sqrt(self.config.hidden_size)
|
| 1221 |
+
|
| 1222 |
+
encoder_outputs = self.encoder(
|
| 1223 |
+
hidden_states=hidden_states,
|
| 1224 |
+
padding_mask=padding_mask,
|
| 1225 |
+
attention_mask=attention_mask,
|
| 1226 |
+
output_attentions=output_attentions,
|
| 1227 |
+
output_hidden_states=output_hidden_states,
|
| 1228 |
+
return_dict=return_dict,
|
| 1229 |
+
)
|
| 1230 |
+
|
| 1231 |
+
last_hidden_state = encoder_outputs[0] if not return_dict else encoder_outputs.last_hidden_state
|
| 1232 |
+
|
| 1233 |
+
stats = self.project(last_hidden_state.transpose(1, 2)).transpose(1, 2) * padding_mask
|
| 1234 |
+
prior_means, prior_log_variances = torch.split(stats, self.config.flow_size, dim=2)
|
| 1235 |
+
|
| 1236 |
+
if not return_dict:
|
| 1237 |
+
outputs = (last_hidden_state, prior_means, prior_log_variances) + encoder_outputs[1:]
|
| 1238 |
+
return outputs
|
| 1239 |
+
|
| 1240 |
+
return VitsTextEncoderOutput(
|
| 1241 |
+
last_hidden_state=last_hidden_state,
|
| 1242 |
+
prior_means=prior_means,
|
| 1243 |
+
prior_log_variances=prior_log_variances,
|
| 1244 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 1245 |
+
attentions=encoder_outputs.attentions,
|
| 1246 |
+
)
|
| 1247 |
+
|
| 1248 |
+
|
| 1249 |
+
class VitsPreTrainedModel(PreTrainedModel):
|
| 1250 |
+
"""
|
| 1251 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 1252 |
+
models.
|
| 1253 |
+
"""
|
| 1254 |
+
|
| 1255 |
+
config_class = VitsConfig
|
| 1256 |
+
base_model_prefix = "vits"
|
| 1257 |
+
main_input_name = "input_ids"
|
| 1258 |
+
supports_gradient_checkpointing = True
|
| 1259 |
+
|
| 1260 |
+
def _init_weights(self, module):
|
| 1261 |
+
"""Initialize the weights"""
|
| 1262 |
+
if isinstance(module, nn.Linear):
|
| 1263 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 1264 |
+
if module.bias is not None:
|
| 1265 |
+
module.bias.data.zero_()
|
| 1266 |
+
elif isinstance(module, nn.LayerNorm):
|
| 1267 |
+
module.bias.data.zero_()
|
| 1268 |
+
module.weight.data.fill_(1.0)
|
| 1269 |
+
elif isinstance(module, nn.Conv1d):
|
| 1270 |
+
nn.init.kaiming_normal_(module.weight)
|
| 1271 |
+
if module.bias is not None:
|
| 1272 |
+
k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
|
| 1273 |
+
nn.init.uniform_(module.bias, a=-k, b=k)
|
| 1274 |
+
elif isinstance(module, nn.Embedding):
|
| 1275 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 1276 |
+
if module.padding_idx is not None:
|
| 1277 |
+
module.weight.data[module.padding_idx].zero_()
|
| 1278 |
+
|
| 1279 |
+
|
| 1280 |
+
VITS_START_DOCSTRING = r"""
|
| 1281 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 1282 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| 1283 |
+
etc.)
|
| 1284 |
+
|
| 1285 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
| 1286 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
| 1287 |
+
and behavior.
|
| 1288 |
+
|
| 1289 |
+
Parameters:
|
| 1290 |
+
config ([`VitsConfig`]):
|
| 1291 |
+
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
| 1292 |
+
load the weights associated with the model, only the configuration. Check out the
|
| 1293 |
+
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 1294 |
+
"""
|
| 1295 |
+
|
| 1296 |
+
|
| 1297 |
+
VITS_INPUTS_DOCSTRING = r"""
|
| 1298 |
+
Args:
|
| 1299 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 1300 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
| 1301 |
+
it.
|
| 1302 |
+
|
| 1303 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 1304 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 1305 |
+
|
| 1306 |
+
[What are input IDs?](../glossary#input-ids)
|
| 1307 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1308 |
+
Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0,
|
| 1309 |
+
1]`:
|
| 1310 |
+
|
| 1311 |
+
- 1 for tokens that are **not masked**,
|
| 1312 |
+
- 0 for tokens that are **masked**.
|
| 1313 |
+
|
| 1314 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 1315 |
+
speaker_id (`int`, *optional*):
|
| 1316 |
+
Which speaker embedding to use. Only used for multispeaker models.
|
| 1317 |
+
output_attentions (`bool`, *optional*):
|
| 1318 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 1319 |
+
tensors for more detail.
|
| 1320 |
+
output_hidden_states (`bool`, *optional*):
|
| 1321 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 1322 |
+
more detail.
|
| 1323 |
+
return_dict (`bool`, *optional*):
|
| 1324 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 1325 |
+
"""
|
| 1326 |
+
|
| 1327 |
+
|
| 1328 |
+
@add_start_docstrings(
|
| 1329 |
+
"The complete VITS model, for text-to-speech synthesis.",
|
| 1330 |
+
VITS_START_DOCSTRING,
|
| 1331 |
+
)
|
| 1332 |
+
class VitsModel(VitsPreTrainedModel):
|
| 1333 |
+
def __init__(self, config: VitsConfig):
|
| 1334 |
+
super().__init__(config)
|
| 1335 |
+
self.config = config
|
| 1336 |
+
self.text_encoder = VitsTextEncoder(config)
|
| 1337 |
+
self.flow = VitsResidualCouplingBlock(config)
|
| 1338 |
+
self.decoder = VitsHifiGan(config)
|
| 1339 |
+
|
| 1340 |
+
if config.use_stochastic_duration_prediction:
|
| 1341 |
+
self.duration_predictor = VitsStochasticDurationPredictor(config)
|
| 1342 |
+
else:
|
| 1343 |
+
self.duration_predictor = VitsDurationPredictor(config)
|
| 1344 |
+
|
| 1345 |
+
if config.num_speakers > 1:
|
| 1346 |
+
self.embed_speaker = nn.Embedding(config.num_speakers, config.speaker_embedding_size)
|
| 1347 |
+
|
| 1348 |
+
# This is used only for training.
|
| 1349 |
+
self.posterior_encoder = VitsPosteriorEncoder(config)
|
| 1350 |
+
|
| 1351 |
+
# These parameters control the synthesised speech properties
|
| 1352 |
+
self.speaking_rate = config.speaking_rate
|
| 1353 |
+
self.noise_scale = config.noise_scale
|
| 1354 |
+
self.noise_scale_duration = config.noise_scale_duration
|
| 1355 |
+
|
| 1356 |
+
# Initialize weights and apply final processing
|
| 1357 |
+
self.post_init()
|
| 1358 |
+
|
| 1359 |
+
def get_encoder(self):
|
| 1360 |
+
return self.text_encoder
|
| 1361 |
+
|
| 1362 |
+
@add_start_docstrings_to_model_forward(VITS_INPUTS_DOCSTRING)
|
| 1363 |
+
@replace_return_docstrings(output_type=VitsModelOutput, config_class=_CONFIG_FOR_DOC)
|
| 1364 |
+
def forward(
|
| 1365 |
+
self,
|
| 1366 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 1367 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1368 |
+
speaker_id: Optional[int] = None,
|
| 1369 |
+
output_attentions: Optional[bool] = None,
|
| 1370 |
+
output_hidden_states: Optional[bool] = None,
|
| 1371 |
+
return_dict: Optional[bool] = None,
|
| 1372 |
+
labels: Optional[torch.FloatTensor] = None,
|
| 1373 |
+
) -> Union[Tuple[Any], VitsModelOutput]:
|
| 1374 |
+
r"""
|
| 1375 |
+
labels (`torch.FloatTensor` of shape `(batch_size, config.spectrogram_bins, sequence_length)`, *optional*):
|
| 1376 |
+
Float values of target spectrogram. Timesteps set to `-100.0` are ignored (masked) for the loss
|
| 1377 |
+
computation.
|
| 1378 |
+
|
| 1379 |
+
Returns:
|
| 1380 |
+
|
| 1381 |
+
Example:
|
| 1382 |
+
|
| 1383 |
+
```python
|
| 1384 |
+
>>> from transformers import VitsTokenizer, VitsModel, set_seed
|
| 1385 |
+
>>> import torch
|
| 1386 |
+
|
| 1387 |
+
>>> tokenizer = VitsTokenizer.from_pretrained("facebook/mms-tts-eng")
|
| 1388 |
+
>>> model = VitsModel.from_pretrained("facebook/mms-tts-eng")
|
| 1389 |
+
|
| 1390 |
+
>>> inputs = tokenizer(text="Hello - my dog is cute", return_tensors="pt")
|
| 1391 |
+
|
| 1392 |
+
>>> set_seed(555) # make deterministic
|
| 1393 |
+
|
| 1394 |
+
>>> with torch.no_grad():
|
| 1395 |
+
... outputs = model(inputs["input_ids"])
|
| 1396 |
+
>>> outputs.waveform.shape
|
| 1397 |
+
torch.Size([1, 45824])
|
| 1398 |
+
```
|
| 1399 |
+
"""
|
| 1400 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1401 |
+
output_hidden_states = (
|
| 1402 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1403 |
+
)
|
| 1404 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1405 |
+
|
| 1406 |
+
if labels is not None:
|
| 1407 |
+
raise NotImplementedError("Training of VITS is not supported yet.")
|
| 1408 |
+
|
| 1409 |
+
mask_dtype = self.text_encoder.embed_tokens.weight.dtype
|
| 1410 |
+
if attention_mask is not None:
|
| 1411 |
+
input_padding_mask = attention_mask.unsqueeze(-1).to(mask_dtype)
|
| 1412 |
+
else:
|
| 1413 |
+
input_padding_mask = torch.ones_like(input_ids).unsqueeze(-1).to(mask_dtype)
|
| 1414 |
+
|
| 1415 |
+
if self.config.num_speakers > 1 and speaker_id is not None:
|
| 1416 |
+
if not 0 <= speaker_id < self.config.num_speakers:
|
| 1417 |
+
raise ValueError(f"Set `speaker_id` in the range 0-{self.config.num_speakers - 1}.")
|
| 1418 |
+
if isinstance(speaker_id, int):
|
| 1419 |
+
speaker_id = torch.full(size=(1,), fill_value=speaker_id, device=self.device)
|
| 1420 |
+
speaker_embeddings = self.embed_speaker(speaker_id).unsqueeze(-1)
|
| 1421 |
+
else:
|
| 1422 |
+
speaker_embeddings = None
|
| 1423 |
+
|
| 1424 |
+
text_encoder_output = self.text_encoder(
|
| 1425 |
+
input_ids=input_ids,
|
| 1426 |
+
padding_mask=input_padding_mask,
|
| 1427 |
+
attention_mask=attention_mask,
|
| 1428 |
+
output_attentions=output_attentions,
|
| 1429 |
+
output_hidden_states=output_hidden_states,
|
| 1430 |
+
return_dict=return_dict,
|
| 1431 |
+
)
|
| 1432 |
+
hidden_states = text_encoder_output[0] if not return_dict else text_encoder_output.last_hidden_state
|
| 1433 |
+
hidden_states = hidden_states.transpose(1, 2)
|
| 1434 |
+
input_padding_mask = input_padding_mask.transpose(1, 2)
|
| 1435 |
+
prior_means = text_encoder_output[1] if not return_dict else text_encoder_output.prior_means
|
| 1436 |
+
prior_log_variances = text_encoder_output[2] if not return_dict else text_encoder_output.prior_log_variances
|
| 1437 |
+
|
| 1438 |
+
if self.config.use_stochastic_duration_prediction:
|
| 1439 |
+
log_duration = self.duration_predictor(
|
| 1440 |
+
hidden_states,
|
| 1441 |
+
input_padding_mask,
|
| 1442 |
+
speaker_embeddings,
|
| 1443 |
+
reverse=True,
|
| 1444 |
+
noise_scale=self.noise_scale_duration,
|
| 1445 |
+
)
|
| 1446 |
+
else:
|
| 1447 |
+
log_duration = self.duration_predictor(hidden_states, input_padding_mask, speaker_embeddings)
|
| 1448 |
+
|
| 1449 |
+
length_scale = 1.0 / self.speaking_rate
|
| 1450 |
+
duration = torch.ceil(torch.exp(log_duration) * input_padding_mask * length_scale)
|
| 1451 |
+
predicted_lengths = torch.clamp_min(torch.sum(duration, [1, 2]), 1).long()
|
| 1452 |
+
|
| 1453 |
+
# Create a padding mask for the output lengths of shape (batch, 1, max_output_length)
|
| 1454 |
+
indices = torch.arange(predicted_lengths.max(), dtype=predicted_lengths.dtype, device=predicted_lengths.device)
|
| 1455 |
+
output_padding_mask = indices.unsqueeze(0) < predicted_lengths.unsqueeze(1)
|
| 1456 |
+
output_padding_mask = output_padding_mask.unsqueeze(1).to(input_padding_mask.dtype)
|
| 1457 |
+
|
| 1458 |
+
# Reconstruct an attention tensor of shape (batch, 1, out_length, in_length)
|
| 1459 |
+
attn_mask = torch.unsqueeze(input_padding_mask, 2) * torch.unsqueeze(output_padding_mask, -1)
|
| 1460 |
+
batch_size, _, output_length, input_length = attn_mask.shape
|
| 1461 |
+
cum_duration = torch.cumsum(duration, -1).view(batch_size * input_length, 1)
|
| 1462 |
+
indices = torch.arange(output_length, dtype=duration.dtype, device=duration.device)
|
| 1463 |
+
valid_indices = indices.unsqueeze(0) < cum_duration
|
| 1464 |
+
valid_indices = valid_indices.to(attn_mask.dtype).view(batch_size, input_length, output_length)
|
| 1465 |
+
padded_indices = valid_indices - nn.functional.pad(valid_indices, [0, 0, 1, 0, 0, 0])[:, :-1]
|
| 1466 |
+
attn = padded_indices.unsqueeze(1).transpose(2, 3) * attn_mask
|
| 1467 |
+
|
| 1468 |
+
# Expand prior distribution
|
| 1469 |
+
prior_means = torch.matmul(attn.squeeze(1), prior_means).transpose(1, 2)
|
| 1470 |
+
prior_log_variances = torch.matmul(attn.squeeze(1), prior_log_variances).transpose(1, 2)
|
| 1471 |
+
|
| 1472 |
+
prior_latents = prior_means + torch.randn_like(prior_means) * torch.exp(prior_log_variances) * self.noise_scale
|
| 1473 |
+
latents = self.flow(prior_latents, output_padding_mask, speaker_embeddings, reverse=True)
|
| 1474 |
+
|
| 1475 |
+
spectrogram = latents * output_padding_mask
|
| 1476 |
+
waveform = self.decoder(spectrogram, speaker_embeddings)
|
| 1477 |
+
waveform = waveform.squeeze(1)
|
| 1478 |
+
sequence_lengths = predicted_lengths * np.prod(self.config.upsample_rates)
|
| 1479 |
+
|
| 1480 |
+
if not return_dict:
|
| 1481 |
+
outputs = (waveform, sequence_lengths, spectrogram) + text_encoder_output[3:]
|
| 1482 |
+
return outputs
|
| 1483 |
+
|
| 1484 |
+
return VitsModelOutput(
|
| 1485 |
+
waveform=waveform,
|
| 1486 |
+
sequence_lengths=sequence_lengths,
|
| 1487 |
+
spectrogram=spectrogram,
|
| 1488 |
+
hidden_states=text_encoder_output.hidden_states,
|
| 1489 |
+
attentions=text_encoder_output.attentions,
|
| 1490 |
+
)
|
| 1491 |
+
|
| 1492 |
+
|
| 1493 |
+
__all__ = ["VitsModel", "VitsPreTrainedModel"]
|
docs/transformers/build/lib/transformers/models/vits/tokenization_vits.py
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The Kakao Enterprise Authors, the MMS-TTS Authors and the HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Tokenization class for VITS."""
|
| 16 |
+
|
| 17 |
+
import json
|
| 18 |
+
import os
|
| 19 |
+
import re
|
| 20 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 21 |
+
|
| 22 |
+
from ...tokenization_utils import PreTrainedTokenizer
|
| 23 |
+
from ...utils import is_phonemizer_available, is_uroman_available, logging
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
if is_phonemizer_available():
|
| 27 |
+
import phonemizer
|
| 28 |
+
|
| 29 |
+
if is_uroman_available():
|
| 30 |
+
import uroman as ur
|
| 31 |
+
|
| 32 |
+
logger = logging.get_logger(__name__)
|
| 33 |
+
|
| 34 |
+
VOCAB_FILES_NAMES = {"vocab_file": "vocab.json"}
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def has_non_roman_characters(input_string):
|
| 38 |
+
# Find any character outside the ASCII range
|
| 39 |
+
non_roman_pattern = re.compile(r"[^\x00-\x7F]")
|
| 40 |
+
|
| 41 |
+
# Search the input string for non-Roman characters
|
| 42 |
+
match = non_roman_pattern.search(input_string)
|
| 43 |
+
has_non_roman = match is not None
|
| 44 |
+
return has_non_roman
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class VitsTokenizer(PreTrainedTokenizer):
|
| 48 |
+
"""
|
| 49 |
+
Construct a VITS tokenizer. Also supports MMS-TTS.
|
| 50 |
+
|
| 51 |
+
This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
|
| 52 |
+
this superclass for more information regarding those methods.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
vocab_file (`str`):
|
| 56 |
+
Path to the vocabulary file.
|
| 57 |
+
language (`str`, *optional*):
|
| 58 |
+
Language identifier.
|
| 59 |
+
add_blank (`bool`, *optional*, defaults to `True`):
|
| 60 |
+
Whether to insert token id 0 in between the other tokens.
|
| 61 |
+
normalize (`bool`, *optional*, defaults to `True`):
|
| 62 |
+
Whether to normalize the input text by removing all casing and punctuation.
|
| 63 |
+
phonemize (`bool`, *optional*, defaults to `True`):
|
| 64 |
+
Whether to convert the input text into phonemes.
|
| 65 |
+
is_uroman (`bool`, *optional*, defaults to `False`):
|
| 66 |
+
Whether the `uroman` Romanizer needs to be applied to the input text prior to tokenizing.
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
| 70 |
+
model_input_names = ["input_ids", "attention_mask"]
|
| 71 |
+
|
| 72 |
+
def __init__(
|
| 73 |
+
self,
|
| 74 |
+
vocab_file,
|
| 75 |
+
pad_token="<pad>",
|
| 76 |
+
unk_token="<unk>",
|
| 77 |
+
language=None,
|
| 78 |
+
add_blank=True,
|
| 79 |
+
normalize=True,
|
| 80 |
+
phonemize=True,
|
| 81 |
+
is_uroman=False,
|
| 82 |
+
**kwargs,
|
| 83 |
+
) -> None:
|
| 84 |
+
with open(vocab_file, encoding="utf-8") as vocab_handle:
|
| 85 |
+
self.encoder = json.load(vocab_handle)
|
| 86 |
+
|
| 87 |
+
self.decoder = {v: k for k, v in self.encoder.items()}
|
| 88 |
+
self.language = language
|
| 89 |
+
self.add_blank = add_blank
|
| 90 |
+
self.normalize = normalize
|
| 91 |
+
self.phonemize = phonemize
|
| 92 |
+
|
| 93 |
+
self.is_uroman = is_uroman
|
| 94 |
+
|
| 95 |
+
super().__init__(
|
| 96 |
+
pad_token=pad_token,
|
| 97 |
+
unk_token=unk_token,
|
| 98 |
+
language=language,
|
| 99 |
+
add_blank=add_blank,
|
| 100 |
+
normalize=normalize,
|
| 101 |
+
phonemize=phonemize,
|
| 102 |
+
is_uroman=is_uroman,
|
| 103 |
+
**kwargs,
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
@property
|
| 107 |
+
def vocab_size(self):
|
| 108 |
+
return len(self.encoder)
|
| 109 |
+
|
| 110 |
+
def get_vocab(self):
|
| 111 |
+
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
|
| 112 |
+
vocab.update(self.added_tokens_encoder)
|
| 113 |
+
return vocab
|
| 114 |
+
|
| 115 |
+
def normalize_text(self, input_string):
|
| 116 |
+
"""Lowercase the input string, respecting any special token ids that may be part or entirely upper-cased."""
|
| 117 |
+
all_vocabulary = list(self.encoder.keys()) + list(self.added_tokens_encoder.keys())
|
| 118 |
+
filtered_text = ""
|
| 119 |
+
|
| 120 |
+
i = 0
|
| 121 |
+
while i < len(input_string):
|
| 122 |
+
found_match = False
|
| 123 |
+
for word in all_vocabulary:
|
| 124 |
+
if input_string[i : i + len(word)] == word:
|
| 125 |
+
filtered_text += word
|
| 126 |
+
i += len(word)
|
| 127 |
+
found_match = True
|
| 128 |
+
break
|
| 129 |
+
|
| 130 |
+
if not found_match:
|
| 131 |
+
filtered_text += input_string[i].lower()
|
| 132 |
+
i += 1
|
| 133 |
+
|
| 134 |
+
return filtered_text
|
| 135 |
+
|
| 136 |
+
def _preprocess_char(self, text):
|
| 137 |
+
"""Special treatment of characters in certain languages"""
|
| 138 |
+
if self.language == "ron":
|
| 139 |
+
text = text.replace("ț", "ţ")
|
| 140 |
+
return text
|
| 141 |
+
|
| 142 |
+
def prepare_for_tokenization(
|
| 143 |
+
self, text: str, is_split_into_words: bool = False, normalize: Optional[bool] = None, **kwargs
|
| 144 |
+
) -> Tuple[str, Dict[str, Any]]:
|
| 145 |
+
"""
|
| 146 |
+
Performs any necessary transformations before tokenization.
|
| 147 |
+
|
| 148 |
+
This method should pop the arguments from kwargs and return the remaining `kwargs` as well. We test the
|
| 149 |
+
`kwargs` at the end of the encoding process to be sure all the arguments have been used.
|
| 150 |
+
|
| 151 |
+
Args:
|
| 152 |
+
text (`str`):
|
| 153 |
+
The text to prepare.
|
| 154 |
+
is_split_into_words (`bool`, *optional*, defaults to `False`):
|
| 155 |
+
Whether or not the input is already pre-tokenized (e.g., split into words). If set to `True`, the
|
| 156 |
+
tokenizer assumes the input is already split into words (for instance, by splitting it on whitespace)
|
| 157 |
+
which it will tokenize.
|
| 158 |
+
normalize (`bool`, *optional*, defaults to `None`):
|
| 159 |
+
Whether or not to apply punctuation and casing normalization to the text inputs. Typically, VITS is
|
| 160 |
+
trained on lower-cased and un-punctuated text. Hence, normalization is used to ensure that the input
|
| 161 |
+
text consists only of lower-case characters.
|
| 162 |
+
kwargs (`Dict[str, Any]`, *optional*):
|
| 163 |
+
Keyword arguments to use for the tokenization.
|
| 164 |
+
|
| 165 |
+
Returns:
|
| 166 |
+
`Tuple[str, Dict[str, Any]]`: The prepared text and the unused kwargs.
|
| 167 |
+
"""
|
| 168 |
+
normalize = normalize if normalize is not None else self.normalize
|
| 169 |
+
|
| 170 |
+
if normalize:
|
| 171 |
+
# normalise for casing
|
| 172 |
+
text = self.normalize_text(text)
|
| 173 |
+
|
| 174 |
+
filtered_text = self._preprocess_char(text)
|
| 175 |
+
|
| 176 |
+
if has_non_roman_characters(filtered_text) and self.is_uroman:
|
| 177 |
+
if not is_uroman_available():
|
| 178 |
+
logger.warning(
|
| 179 |
+
"Text to the tokenizer contains non-Roman characters. To apply the `uroman` pre-processing "
|
| 180 |
+
"step automatically, ensure the `uroman` Romanizer is installed with: `pip install uroman` "
|
| 181 |
+
"Note `uroman` requires python version >= 3.10"
|
| 182 |
+
"Otherwise, apply the Romanizer manually as per the instructions: https://github.com/isi-nlp/uroman"
|
| 183 |
+
)
|
| 184 |
+
else:
|
| 185 |
+
uroman = ur.Uroman()
|
| 186 |
+
filtered_text = uroman.romanize_string(filtered_text)
|
| 187 |
+
|
| 188 |
+
if self.phonemize:
|
| 189 |
+
if not is_phonemizer_available():
|
| 190 |
+
raise ImportError("Please install the `phonemizer` Python package to use this tokenizer.")
|
| 191 |
+
|
| 192 |
+
filtered_text = phonemizer.phonemize(
|
| 193 |
+
filtered_text,
|
| 194 |
+
language="en-us",
|
| 195 |
+
backend="espeak",
|
| 196 |
+
strip=True,
|
| 197 |
+
preserve_punctuation=True,
|
| 198 |
+
with_stress=True,
|
| 199 |
+
)
|
| 200 |
+
filtered_text = re.sub(r"\s+", " ", filtered_text)
|
| 201 |
+
elif normalize:
|
| 202 |
+
# strip any chars outside of the vocab (punctuation)
|
| 203 |
+
filtered_text = "".join(list(filter(lambda char: char in self.encoder, filtered_text))).strip()
|
| 204 |
+
|
| 205 |
+
return filtered_text, kwargs
|
| 206 |
+
|
| 207 |
+
def _tokenize(self, text: str) -> List[str]:
|
| 208 |
+
"""Tokenize a string by inserting the `<pad>` token at the boundary between adjacent characters."""
|
| 209 |
+
tokens = list(text)
|
| 210 |
+
|
| 211 |
+
if self.add_blank:
|
| 212 |
+
interspersed = [self._convert_id_to_token(0)] * (len(tokens) * 2 + 1)
|
| 213 |
+
interspersed[1::2] = tokens
|
| 214 |
+
tokens = interspersed
|
| 215 |
+
|
| 216 |
+
return tokens
|
| 217 |
+
|
| 218 |
+
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
| 219 |
+
if self.add_blank and len(tokens) > 1:
|
| 220 |
+
tokens = tokens[1::2]
|
| 221 |
+
return "".join(tokens)
|
| 222 |
+
|
| 223 |
+
def _convert_token_to_id(self, token):
|
| 224 |
+
"""Converts a token (str) in an id using the vocab."""
|
| 225 |
+
return self.encoder.get(token, self.encoder.get(self.unk_token))
|
| 226 |
+
|
| 227 |
+
def _convert_id_to_token(self, index):
|
| 228 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
| 229 |
+
return self.decoder.get(index)
|
| 230 |
+
|
| 231 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Union[Tuple[str], None]:
|
| 232 |
+
if not os.path.isdir(save_directory):
|
| 233 |
+
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
| 234 |
+
return
|
| 235 |
+
|
| 236 |
+
vocab_file = os.path.join(
|
| 237 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
with open(vocab_file, "w", encoding="utf-8") as f:
|
| 241 |
+
f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
|
| 242 |
+
|
| 243 |
+
return (vocab_file,)
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
__all__ = ["VitsTokenizer"]
|
docs/transformers/build/lib/transformers/models/vivit/__init__.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_vivit import *
|
| 22 |
+
from .image_processing_vivit import *
|
| 23 |
+
from .modeling_vivit import *
|
| 24 |
+
else:
|
| 25 |
+
import sys
|
| 26 |
+
|
| 27 |
+
_file = globals()["__file__"]
|
| 28 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
docs/transformers/build/lib/transformers/models/vivit/configuration_vivit.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""ViViT model configuration"""
|
| 16 |
+
|
| 17 |
+
from ...configuration_utils import PretrainedConfig
|
| 18 |
+
from ...utils import logging
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
logger = logging.get_logger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class VivitConfig(PretrainedConfig):
|
| 25 |
+
r"""
|
| 26 |
+
This is the configuration class to store the configuration of a [`VivitModel`]. It is used to instantiate a ViViT
|
| 27 |
+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
| 28 |
+
defaults will yield a similar configuration to that of the ViViT
|
| 29 |
+
[google/vivit-b-16x2-kinetics400](https://huggingface.co/google/vivit-b-16x2-kinetics400) architecture.
|
| 30 |
+
|
| 31 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 32 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
image_size (`int`, *optional*, defaults to 224):
|
| 36 |
+
The size (resolution) of each image.
|
| 37 |
+
num_frames (`int`, *optional*, defaults to 32):
|
| 38 |
+
The number of frames in each video.
|
| 39 |
+
tubelet_size (`List[int]`, *optional*, defaults to `[2, 16, 16]`):
|
| 40 |
+
The size (resolution) of each tubelet.
|
| 41 |
+
num_channels (`int`, *optional*, defaults to 3):
|
| 42 |
+
The number of input channels.
|
| 43 |
+
hidden_size (`int`, *optional*, defaults to 768):
|
| 44 |
+
Dimensionality of the encoder layers and the pooler layer.
|
| 45 |
+
num_hidden_layers (`int`, *optional*, defaults to 12):
|
| 46 |
+
Number of hidden layers in the Transformer encoder.
|
| 47 |
+
num_attention_heads (`int`, *optional*, defaults to 12):
|
| 48 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 49 |
+
intermediate_size (`int`, *optional*, defaults to 3072):
|
| 50 |
+
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
| 51 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"gelu_fast"`):
|
| 52 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 53 |
+
`"relu"`, `"selu"`, `"gelu_fast"` and `"gelu_new"` are supported.
|
| 54 |
+
hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
|
| 55 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
| 56 |
+
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
|
| 57 |
+
The dropout ratio for the attention probabilities.
|
| 58 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 59 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 60 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
|
| 61 |
+
The epsilon used by the layer normalization layers.
|
| 62 |
+
qkv_bias (`bool`, *optional*, defaults to `True`):
|
| 63 |
+
Whether to add a bias to the queries, keys and values.
|
| 64 |
+
|
| 65 |
+
Example:
|
| 66 |
+
|
| 67 |
+
```python
|
| 68 |
+
>>> from transformers import VivitConfig, VivitModel
|
| 69 |
+
|
| 70 |
+
>>> # Initializing a ViViT google/vivit-b-16x2-kinetics400 style configuration
|
| 71 |
+
>>> configuration = VivitConfig()
|
| 72 |
+
|
| 73 |
+
>>> # Initializing a model (with random weights) from the google/vivit-b-16x2-kinetics400 style configuration
|
| 74 |
+
>>> model = VivitModel(configuration)
|
| 75 |
+
|
| 76 |
+
>>> # Accessing the model configuration
|
| 77 |
+
>>> configuration = model.config
|
| 78 |
+
```"""
|
| 79 |
+
|
| 80 |
+
model_type = "vivit"
|
| 81 |
+
|
| 82 |
+
def __init__(
|
| 83 |
+
self,
|
| 84 |
+
image_size=224,
|
| 85 |
+
num_frames=32,
|
| 86 |
+
tubelet_size=[2, 16, 16],
|
| 87 |
+
num_channels=3,
|
| 88 |
+
hidden_size=768,
|
| 89 |
+
num_hidden_layers=12,
|
| 90 |
+
num_attention_heads=12,
|
| 91 |
+
intermediate_size=3072,
|
| 92 |
+
hidden_act="gelu_fast",
|
| 93 |
+
hidden_dropout_prob=0.0,
|
| 94 |
+
attention_probs_dropout_prob=0.0,
|
| 95 |
+
initializer_range=0.02,
|
| 96 |
+
layer_norm_eps=1e-06,
|
| 97 |
+
qkv_bias=True,
|
| 98 |
+
**kwargs,
|
| 99 |
+
):
|
| 100 |
+
self.hidden_size = hidden_size
|
| 101 |
+
self.num_hidden_layers = num_hidden_layers
|
| 102 |
+
self.num_attention_heads = num_attention_heads
|
| 103 |
+
self.intermediate_size = intermediate_size
|
| 104 |
+
self.hidden_act = hidden_act
|
| 105 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
| 106 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
| 107 |
+
self.initializer_range = initializer_range
|
| 108 |
+
self.layer_norm_eps = layer_norm_eps
|
| 109 |
+
|
| 110 |
+
self.image_size = image_size
|
| 111 |
+
self.num_frames = num_frames
|
| 112 |
+
self.tubelet_size = tubelet_size
|
| 113 |
+
self.num_channels = num_channels
|
| 114 |
+
self.qkv_bias = qkv_bias
|
| 115 |
+
|
| 116 |
+
super().__init__(**kwargs)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
__all__ = ["VivitConfig"]
|
docs/transformers/build/lib/transformers/models/vivit/convert_vivit_flax_to_pytorch.py
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Convert Flax ViViT checkpoints from the original repository to PyTorch. URL:
|
| 16 |
+
https://github.com/google-research/scenic/tree/main/scenic/projects/vivit
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import argparse
|
| 20 |
+
import json
|
| 21 |
+
import os.path
|
| 22 |
+
from collections import OrderedDict
|
| 23 |
+
|
| 24 |
+
import numpy as np
|
| 25 |
+
import requests
|
| 26 |
+
import torch
|
| 27 |
+
from flax.training.checkpoints import restore_checkpoint
|
| 28 |
+
from huggingface_hub import hf_hub_download
|
| 29 |
+
|
| 30 |
+
from transformers import VivitConfig, VivitForVideoClassification, VivitImageProcessor
|
| 31 |
+
from transformers.image_utils import PILImageResampling
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def download_checkpoint(path):
|
| 35 |
+
url = "https://storage.googleapis.com/scenic-bucket/vivit/kinetics_400/vivit_base_16x2_unfactorized/checkpoint"
|
| 36 |
+
|
| 37 |
+
with open(path, "wb") as f:
|
| 38 |
+
with requests.get(url, stream=True) as req:
|
| 39 |
+
for chunk in req.iter_content(chunk_size=2048):
|
| 40 |
+
f.write(chunk)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def get_vivit_config() -> VivitConfig:
|
| 44 |
+
config = VivitConfig()
|
| 45 |
+
|
| 46 |
+
config.num_labels = 400
|
| 47 |
+
repo_id = "huggingface/label-files"
|
| 48 |
+
filename = "kinetics400-id2label.json"
|
| 49 |
+
|
| 50 |
+
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
|
| 51 |
+
id2label = {int(k): v for k, v in id2label.items()}
|
| 52 |
+
config.id2label = id2label
|
| 53 |
+
config.label2id = {v: k for k, v in id2label.items()}
|
| 54 |
+
return config
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
# We will verify our results on a video of eating spaghetti
|
| 58 |
+
# Frame indices used: [ 47, 51, 55, 59, 63, 67, 71, 75, 80, 84, 88, 92, 96, 100, 104, 108, 113, 117,
|
| 59 |
+
# 121, 125, 129, 133, 137, 141, 146, 150, 154, 158, 162, 166, 170, 174]
|
| 60 |
+
def prepare_video():
|
| 61 |
+
file = hf_hub_download(
|
| 62 |
+
repo_id="hf-internal-testing/spaghetti-video", filename="eating_spaghetti_32_frames.npy", repo_type="dataset"
|
| 63 |
+
)
|
| 64 |
+
video = np.load(file)
|
| 65 |
+
return list(video)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def transform_attention(current: np.ndarray):
|
| 69 |
+
if np.ndim(current) == 2:
|
| 70 |
+
return transform_attention_bias(current)
|
| 71 |
+
|
| 72 |
+
elif np.ndim(current) == 3:
|
| 73 |
+
return transform_attention_kernel(current)
|
| 74 |
+
|
| 75 |
+
else:
|
| 76 |
+
raise Exception(f"Invalid number of dimensions: {np.ndim(current)}")
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def transform_attention_bias(current: np.ndarray):
|
| 80 |
+
return current.flatten()
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def transform_attention_kernel(current: np.ndarray):
|
| 84 |
+
return np.reshape(current, (current.shape[0], current.shape[1] * current.shape[2])).T
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def transform_attention_output_weight(current: np.ndarray):
|
| 88 |
+
return np.reshape(current, (current.shape[0] * current.shape[1], current.shape[2])).T
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def transform_state_encoder_block(state_dict, i):
|
| 92 |
+
state = state_dict["optimizer"]["target"]["Transformer"][f"encoderblock_{i}"]
|
| 93 |
+
|
| 94 |
+
prefix = f"encoder.layer.{i}."
|
| 95 |
+
new_state = {
|
| 96 |
+
prefix + "intermediate.dense.bias": state["MlpBlock_0"]["Dense_0"]["bias"],
|
| 97 |
+
prefix + "intermediate.dense.weight": np.transpose(state["MlpBlock_0"]["Dense_0"]["kernel"]),
|
| 98 |
+
prefix + "output.dense.bias": state["MlpBlock_0"]["Dense_1"]["bias"],
|
| 99 |
+
prefix + "output.dense.weight": np.transpose(state["MlpBlock_0"]["Dense_1"]["kernel"]),
|
| 100 |
+
prefix + "layernorm_before.bias": state["LayerNorm_0"]["bias"],
|
| 101 |
+
prefix + "layernorm_before.weight": state["LayerNorm_0"]["scale"],
|
| 102 |
+
prefix + "layernorm_after.bias": state["LayerNorm_1"]["bias"],
|
| 103 |
+
prefix + "layernorm_after.weight": state["LayerNorm_1"]["scale"],
|
| 104 |
+
prefix + "attention.attention.query.bias": transform_attention(
|
| 105 |
+
state["MultiHeadDotProductAttention_0"]["query"]["bias"]
|
| 106 |
+
),
|
| 107 |
+
prefix + "attention.attention.query.weight": transform_attention(
|
| 108 |
+
state["MultiHeadDotProductAttention_0"]["query"]["kernel"]
|
| 109 |
+
),
|
| 110 |
+
prefix + "attention.attention.key.bias": transform_attention(
|
| 111 |
+
state["MultiHeadDotProductAttention_0"]["key"]["bias"]
|
| 112 |
+
),
|
| 113 |
+
prefix + "attention.attention.key.weight": transform_attention(
|
| 114 |
+
state["MultiHeadDotProductAttention_0"]["key"]["kernel"]
|
| 115 |
+
),
|
| 116 |
+
prefix + "attention.attention.value.bias": transform_attention(
|
| 117 |
+
state["MultiHeadDotProductAttention_0"]["value"]["bias"]
|
| 118 |
+
),
|
| 119 |
+
prefix + "attention.attention.value.weight": transform_attention(
|
| 120 |
+
state["MultiHeadDotProductAttention_0"]["value"]["kernel"]
|
| 121 |
+
),
|
| 122 |
+
prefix + "attention.output.dense.bias": state["MultiHeadDotProductAttention_0"]["out"]["bias"],
|
| 123 |
+
prefix + "attention.output.dense.weight": transform_attention_output_weight(
|
| 124 |
+
state["MultiHeadDotProductAttention_0"]["out"]["kernel"]
|
| 125 |
+
),
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
return new_state
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def get_n_layers(state_dict):
|
| 132 |
+
return sum([1 if "encoderblock_" in k else 0 for k in state_dict["optimizer"]["target"]["Transformer"].keys()])
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def transform_state(state_dict, classification_head=False):
|
| 136 |
+
transformer_layers = get_n_layers(state_dict)
|
| 137 |
+
|
| 138 |
+
new_state = OrderedDict()
|
| 139 |
+
|
| 140 |
+
new_state["layernorm.bias"] = state_dict["optimizer"]["target"]["Transformer"]["encoder_norm"]["bias"]
|
| 141 |
+
new_state["layernorm.weight"] = state_dict["optimizer"]["target"]["Transformer"]["encoder_norm"]["scale"]
|
| 142 |
+
|
| 143 |
+
new_state["embeddings.patch_embeddings.projection.weight"] = np.transpose(
|
| 144 |
+
state_dict["optimizer"]["target"]["embedding"]["kernel"], (4, 3, 0, 1, 2)
|
| 145 |
+
)
|
| 146 |
+
new_state["embeddings.patch_embeddings.projection.bias"] = state_dict["optimizer"]["target"]["embedding"]["bias"]
|
| 147 |
+
|
| 148 |
+
new_state["embeddings.cls_token"] = state_dict["optimizer"]["target"]["cls"]
|
| 149 |
+
new_state["embeddings.position_embeddings"] = state_dict["optimizer"]["target"]["Transformer"]["posembed_input"][
|
| 150 |
+
"pos_embedding"
|
| 151 |
+
]
|
| 152 |
+
|
| 153 |
+
for i in range(transformer_layers):
|
| 154 |
+
new_state.update(transform_state_encoder_block(state_dict, i))
|
| 155 |
+
|
| 156 |
+
if classification_head:
|
| 157 |
+
new_state = {"vivit." + k: v for k, v in new_state.items()}
|
| 158 |
+
new_state["classifier.weight"] = np.transpose(state_dict["optimizer"]["target"]["output_projection"]["kernel"])
|
| 159 |
+
new_state["classifier.bias"] = np.transpose(state_dict["optimizer"]["target"]["output_projection"]["bias"])
|
| 160 |
+
|
| 161 |
+
return {k: torch.tensor(v) for k, v in new_state.items()}
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
# checks that image processor settings are the same as in the original implementation
|
| 165 |
+
# original: https://github.com/google-research/scenic/blob/main/scenic/projects/vivit/data/video_tfrecord_dataset.py
|
| 166 |
+
# dataset specific config:
|
| 167 |
+
# https://github.com/google-research/scenic/blob/main/scenic/projects/vivit/configs/kinetics400/vivit_base_k400.py
|
| 168 |
+
def get_processor() -> VivitImageProcessor:
|
| 169 |
+
extractor = VivitImageProcessor()
|
| 170 |
+
|
| 171 |
+
assert extractor.do_resize is True
|
| 172 |
+
assert extractor.size == {"shortest_edge": 256}
|
| 173 |
+
assert extractor.do_center_crop is True
|
| 174 |
+
assert extractor.crop_size == {"width": 224, "height": 224}
|
| 175 |
+
assert extractor.resample == PILImageResampling.BILINEAR
|
| 176 |
+
|
| 177 |
+
# here: https://github.com/deepmind/dmvr/blob/master/dmvr/modalities.py
|
| 178 |
+
# one can seen that add_image has default values for normalization_mean and normalization_std set to 0 and 1
|
| 179 |
+
# which effectively means no normalization (and ViViT does not overwrite those when calling this func)
|
| 180 |
+
assert extractor.do_normalize is False
|
| 181 |
+
assert extractor.do_rescale is True
|
| 182 |
+
assert extractor.rescale_factor == 1 / 255
|
| 183 |
+
|
| 184 |
+
# zero-centering = True in original implementation
|
| 185 |
+
assert extractor.do_zero_centering is True
|
| 186 |
+
|
| 187 |
+
return extractor
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def convert(output_path: str):
|
| 191 |
+
flax_model_path = "checkpoint"
|
| 192 |
+
|
| 193 |
+
if not os.path.exists(flax_model_path):
|
| 194 |
+
download_checkpoint(flax_model_path)
|
| 195 |
+
|
| 196 |
+
state_dict = restore_checkpoint(flax_model_path, None)
|
| 197 |
+
new_state = transform_state(state_dict, classification_head=True)
|
| 198 |
+
|
| 199 |
+
config = get_vivit_config()
|
| 200 |
+
|
| 201 |
+
assert config.image_size == 224
|
| 202 |
+
assert config.num_frames == 32
|
| 203 |
+
|
| 204 |
+
model = VivitForVideoClassification(config)
|
| 205 |
+
model.load_state_dict(new_state)
|
| 206 |
+
model.eval()
|
| 207 |
+
|
| 208 |
+
extractor = get_processor()
|
| 209 |
+
|
| 210 |
+
video = prepare_video()
|
| 211 |
+
inputs = extractor(video, return_tensors="pt")
|
| 212 |
+
|
| 213 |
+
outputs = model(**inputs)
|
| 214 |
+
|
| 215 |
+
expected_shape = torch.Size([1, 400])
|
| 216 |
+
expected_slice = torch.tensor([-1.0543, 2.0764, -0.2104, 0.4439, -0.9658])
|
| 217 |
+
|
| 218 |
+
assert outputs.logits.shape == expected_shape
|
| 219 |
+
assert torch.allclose(outputs.logits[0, :5], expected_slice, atol=1e-4), outputs.logits[0, :5]
|
| 220 |
+
|
| 221 |
+
model.save_pretrained(output_path)
|
| 222 |
+
extractor.save_pretrained(output_path)
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
if __name__ == "__main__":
|
| 226 |
+
parser = argparse.ArgumentParser()
|
| 227 |
+
|
| 228 |
+
parser.add_argument("--output_model_name", "-o", type=str, help="Output path for the converted HuggingFace model")
|
| 229 |
+
|
| 230 |
+
args = parser.parse_args()
|
| 231 |
+
convert(args.output_model_name)
|
docs/transformers/build/lib/transformers/models/vivit/image_processing_vivit.py
ADDED
|
@@ -0,0 +1,407 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Image processor class for Vivit."""
|
| 16 |
+
|
| 17 |
+
from typing import Dict, List, Optional, Union
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
|
| 21 |
+
from transformers.utils import is_vision_available
|
| 22 |
+
from transformers.utils.generic import TensorType
|
| 23 |
+
|
| 24 |
+
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
| 25 |
+
from ...image_transforms import (
|
| 26 |
+
get_resize_output_image_size,
|
| 27 |
+
rescale,
|
| 28 |
+
resize,
|
| 29 |
+
to_channel_dimension_format,
|
| 30 |
+
)
|
| 31 |
+
from ...image_utils import (
|
| 32 |
+
IMAGENET_STANDARD_MEAN,
|
| 33 |
+
IMAGENET_STANDARD_STD,
|
| 34 |
+
ChannelDimension,
|
| 35 |
+
ImageInput,
|
| 36 |
+
PILImageResampling,
|
| 37 |
+
infer_channel_dimension_format,
|
| 38 |
+
is_scaled_image,
|
| 39 |
+
is_valid_image,
|
| 40 |
+
to_numpy_array,
|
| 41 |
+
valid_images,
|
| 42 |
+
validate_preprocess_arguments,
|
| 43 |
+
)
|
| 44 |
+
from ...utils import filter_out_non_signature_kwargs, logging
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
if is_vision_available():
|
| 48 |
+
import PIL
|
| 49 |
+
|
| 50 |
+
logger = logging.get_logger(__name__)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def make_batched(videos) -> List[List[ImageInput]]:
|
| 54 |
+
if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]):
|
| 55 |
+
return videos
|
| 56 |
+
|
| 57 |
+
elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]):
|
| 58 |
+
return [videos]
|
| 59 |
+
|
| 60 |
+
elif is_valid_image(videos):
|
| 61 |
+
return [[videos]]
|
| 62 |
+
|
| 63 |
+
raise ValueError(f"Could not make batched video from {videos}")
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class VivitImageProcessor(BaseImageProcessor):
|
| 67 |
+
r"""
|
| 68 |
+
Constructs a Vivit image processor.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
do_resize (`bool`, *optional*, defaults to `True`):
|
| 72 |
+
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
|
| 73 |
+
`do_resize` parameter in the `preprocess` method.
|
| 74 |
+
size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 256}`):
|
| 75 |
+
Size of the output image after resizing. The shortest edge of the image will be resized to
|
| 76 |
+
`size["shortest_edge"]` while maintaining the aspect ratio of the original image. Can be overriden by
|
| 77 |
+
`size` in the `preprocess` method.
|
| 78 |
+
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
|
| 79 |
+
Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
|
| 80 |
+
`preprocess` method.
|
| 81 |
+
do_center_crop (`bool`, *optional*, defaults to `True`):
|
| 82 |
+
Whether to center crop the image to the specified `crop_size`. Can be overridden by the `do_center_crop`
|
| 83 |
+
parameter in the `preprocess` method.
|
| 84 |
+
crop_size (`Dict[str, int]`, *optional*, defaults to `{"height": 224, "width": 224}`):
|
| 85 |
+
Size of the image after applying the center crop. Can be overridden by the `crop_size` parameter in the
|
| 86 |
+
`preprocess` method.
|
| 87 |
+
do_rescale (`bool`, *optional*, defaults to `True`):
|
| 88 |
+
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
|
| 89 |
+
parameter in the `preprocess` method.
|
| 90 |
+
rescale_factor (`int` or `float`, *optional*, defaults to `1/127.5`):
|
| 91 |
+
Defines the scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter
|
| 92 |
+
in the `preprocess` method.
|
| 93 |
+
offset (`bool`, *optional*, defaults to `True`):
|
| 94 |
+
Whether to scale the image in both negative and positive directions. Can be overriden by the `offset` in
|
| 95 |
+
the `preprocess` method.
|
| 96 |
+
do_normalize (`bool`, *optional*, defaults to `True`):
|
| 97 |
+
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
|
| 98 |
+
method.
|
| 99 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
|
| 100 |
+
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
|
| 101 |
+
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
|
| 102 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
|
| 103 |
+
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
|
| 104 |
+
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
| 105 |
+
"""
|
| 106 |
+
|
| 107 |
+
model_input_names = ["pixel_values"]
|
| 108 |
+
|
| 109 |
+
def __init__(
|
| 110 |
+
self,
|
| 111 |
+
do_resize: bool = True,
|
| 112 |
+
size: Dict[str, int] = None,
|
| 113 |
+
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
| 114 |
+
do_center_crop: bool = True,
|
| 115 |
+
crop_size: Dict[str, int] = None,
|
| 116 |
+
do_rescale: bool = True,
|
| 117 |
+
rescale_factor: Union[int, float] = 1 / 127.5,
|
| 118 |
+
offset: bool = True,
|
| 119 |
+
do_normalize: bool = True,
|
| 120 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 121 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 122 |
+
**kwargs,
|
| 123 |
+
) -> None:
|
| 124 |
+
super().__init__(**kwargs)
|
| 125 |
+
size = size if size is not None else {"shortest_edge": 256}
|
| 126 |
+
size = get_size_dict(size, default_to_square=False)
|
| 127 |
+
crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
|
| 128 |
+
crop_size = get_size_dict(crop_size, param_name="crop_size")
|
| 129 |
+
|
| 130 |
+
self.do_resize = do_resize
|
| 131 |
+
self.size = size
|
| 132 |
+
self.do_center_crop = do_center_crop
|
| 133 |
+
self.crop_size = crop_size
|
| 134 |
+
self.resample = resample
|
| 135 |
+
self.do_rescale = do_rescale
|
| 136 |
+
self.rescale_factor = rescale_factor
|
| 137 |
+
self.offset = offset
|
| 138 |
+
self.do_normalize = do_normalize
|
| 139 |
+
self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
|
| 140 |
+
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
|
| 141 |
+
|
| 142 |
+
def resize(
|
| 143 |
+
self,
|
| 144 |
+
image: np.ndarray,
|
| 145 |
+
size: Dict[str, int],
|
| 146 |
+
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
| 147 |
+
data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 148 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 149 |
+
**kwargs,
|
| 150 |
+
) -> np.ndarray:
|
| 151 |
+
"""
|
| 152 |
+
Resize an image.
|
| 153 |
+
|
| 154 |
+
Args:
|
| 155 |
+
image (`np.ndarray`):
|
| 156 |
+
Image to resize.
|
| 157 |
+
size (`Dict[str, int]`):
|
| 158 |
+
Size of the output image. If `size` is of the form `{"height": h, "width": w}`, the output image will
|
| 159 |
+
have the size `(h, w)`. If `size` is of the form `{"shortest_edge": s}`, the output image will have its
|
| 160 |
+
shortest edge of length `s` while keeping the aspect ratio of the original image.
|
| 161 |
+
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
|
| 162 |
+
Resampling filter to use when resiizing the image.
|
| 163 |
+
data_format (`str` or `ChannelDimension`, *optional*):
|
| 164 |
+
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
| 165 |
+
input_data_format (`str` or `ChannelDimension`, *optional*):
|
| 166 |
+
The channel dimension format of the input image. If not provided, it will be inferred.
|
| 167 |
+
"""
|
| 168 |
+
size = get_size_dict(size, default_to_square=False)
|
| 169 |
+
if "shortest_edge" in size:
|
| 170 |
+
output_size = get_resize_output_image_size(
|
| 171 |
+
image, size["shortest_edge"], default_to_square=False, input_data_format=input_data_format
|
| 172 |
+
)
|
| 173 |
+
elif "height" in size and "width" in size:
|
| 174 |
+
output_size = (size["height"], size["width"])
|
| 175 |
+
else:
|
| 176 |
+
raise ValueError(f"Size must have 'height' and 'width' or 'shortest_edge' as keys. Got {size.keys()}")
|
| 177 |
+
return resize(
|
| 178 |
+
image,
|
| 179 |
+
size=output_size,
|
| 180 |
+
resample=resample,
|
| 181 |
+
data_format=data_format,
|
| 182 |
+
input_data_format=input_data_format,
|
| 183 |
+
**kwargs,
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
# Copied from transformers.models.efficientnet.image_processing_efficientnet.EfficientNetImageProcessor.rescale
|
| 187 |
+
def rescale(
|
| 188 |
+
self,
|
| 189 |
+
image: np.ndarray,
|
| 190 |
+
scale: Union[int, float],
|
| 191 |
+
offset: bool = True,
|
| 192 |
+
data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 193 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 194 |
+
**kwargs,
|
| 195 |
+
):
|
| 196 |
+
"""
|
| 197 |
+
Rescale an image by a scale factor.
|
| 198 |
+
|
| 199 |
+
If `offset` is `True`, the image has its values rescaled by `scale` and then offset by 1. If `scale` is
|
| 200 |
+
1/127.5, the image is rescaled between [-1, 1].
|
| 201 |
+
image = image * scale - 1
|
| 202 |
+
|
| 203 |
+
If `offset` is `False`, and `scale` is 1/255, the image is rescaled between [0, 1].
|
| 204 |
+
image = image * scale
|
| 205 |
+
|
| 206 |
+
Args:
|
| 207 |
+
image (`np.ndarray`):
|
| 208 |
+
Image to rescale.
|
| 209 |
+
scale (`int` or `float`):
|
| 210 |
+
Scale to apply to the image.
|
| 211 |
+
offset (`bool`, *optional*):
|
| 212 |
+
Whether to scale the image in both negative and positive directions.
|
| 213 |
+
data_format (`str` or `ChannelDimension`, *optional*):
|
| 214 |
+
The channel dimension format of the image. If not provided, it will be the same as the input image.
|
| 215 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 216 |
+
The channel dimension format of the input image. If not provided, it will be inferred.
|
| 217 |
+
"""
|
| 218 |
+
rescaled_image = rescale(
|
| 219 |
+
image, scale=scale, data_format=data_format, input_data_format=input_data_format, **kwargs
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
if offset:
|
| 223 |
+
rescaled_image = rescaled_image - 1
|
| 224 |
+
|
| 225 |
+
return rescaled_image
|
| 226 |
+
|
| 227 |
+
def _preprocess_image(
|
| 228 |
+
self,
|
| 229 |
+
image: ImageInput,
|
| 230 |
+
do_resize: Optional[bool] = None,
|
| 231 |
+
size: Dict[str, int] = None,
|
| 232 |
+
resample: PILImageResampling = None,
|
| 233 |
+
do_center_crop: Optional[bool] = None,
|
| 234 |
+
crop_size: Dict[str, int] = None,
|
| 235 |
+
do_rescale: Optional[bool] = None,
|
| 236 |
+
rescale_factor: Optional[float] = None,
|
| 237 |
+
offset: Optional[bool] = None,
|
| 238 |
+
do_normalize: Optional[bool] = None,
|
| 239 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 240 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 241 |
+
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
| 242 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 243 |
+
) -> np.ndarray:
|
| 244 |
+
"""Preprocesses a single image."""
|
| 245 |
+
|
| 246 |
+
validate_preprocess_arguments(
|
| 247 |
+
do_rescale=do_rescale,
|
| 248 |
+
rescale_factor=rescale_factor,
|
| 249 |
+
do_normalize=do_normalize,
|
| 250 |
+
image_mean=image_mean,
|
| 251 |
+
image_std=image_std,
|
| 252 |
+
do_center_crop=do_center_crop,
|
| 253 |
+
crop_size=crop_size,
|
| 254 |
+
do_resize=do_resize,
|
| 255 |
+
size=size,
|
| 256 |
+
resample=resample,
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
if offset and not do_rescale:
|
| 260 |
+
raise ValueError("For offset, do_rescale must also be set to True.")
|
| 261 |
+
|
| 262 |
+
# All transformations expect numpy arrays.
|
| 263 |
+
image = to_numpy_array(image)
|
| 264 |
+
|
| 265 |
+
if do_rescale and is_scaled_image(image):
|
| 266 |
+
logger.warning_once(
|
| 267 |
+
"It looks like you are trying to rescale already rescaled images. If the input"
|
| 268 |
+
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
if input_data_format is None:
|
| 272 |
+
input_data_format = infer_channel_dimension_format(image)
|
| 273 |
+
|
| 274 |
+
if do_resize:
|
| 275 |
+
image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
|
| 276 |
+
|
| 277 |
+
if do_center_crop:
|
| 278 |
+
image = self.center_crop(image, size=crop_size, input_data_format=input_data_format)
|
| 279 |
+
|
| 280 |
+
if do_rescale:
|
| 281 |
+
image = self.rescale(image=image, scale=rescale_factor, offset=offset, input_data_format=input_data_format)
|
| 282 |
+
|
| 283 |
+
if do_normalize:
|
| 284 |
+
image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
|
| 285 |
+
|
| 286 |
+
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
|
| 287 |
+
return image
|
| 288 |
+
|
| 289 |
+
@filter_out_non_signature_kwargs()
|
| 290 |
+
def preprocess(
|
| 291 |
+
self,
|
| 292 |
+
videos: ImageInput,
|
| 293 |
+
do_resize: Optional[bool] = None,
|
| 294 |
+
size: Dict[str, int] = None,
|
| 295 |
+
resample: PILImageResampling = None,
|
| 296 |
+
do_center_crop: Optional[bool] = None,
|
| 297 |
+
crop_size: Dict[str, int] = None,
|
| 298 |
+
do_rescale: Optional[bool] = None,
|
| 299 |
+
rescale_factor: Optional[float] = None,
|
| 300 |
+
offset: Optional[bool] = None,
|
| 301 |
+
do_normalize: Optional[bool] = None,
|
| 302 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
| 303 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
| 304 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 305 |
+
data_format: ChannelDimension = ChannelDimension.FIRST,
|
| 306 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
| 307 |
+
) -> PIL.Image.Image:
|
| 308 |
+
"""
|
| 309 |
+
Preprocess an image or batch of images.
|
| 310 |
+
|
| 311 |
+
Args:
|
| 312 |
+
videos (`ImageInput`):
|
| 313 |
+
Video frames to preprocess. Expects a single or batch of video frames with pixel values ranging from 0
|
| 314 |
+
to 255. If passing in frames with pixel values between 0 and 1, set `do_rescale=False`.
|
| 315 |
+
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
| 316 |
+
Whether to resize the image.
|
| 317 |
+
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
| 318 |
+
Size of the image after applying resize.
|
| 319 |
+
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
|
| 320 |
+
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only
|
| 321 |
+
has an effect if `do_resize` is set to `True`.
|
| 322 |
+
do_center_crop (`bool`, *optional*, defaults to `self.do_centre_crop`):
|
| 323 |
+
Whether to centre crop the image.
|
| 324 |
+
crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`):
|
| 325 |
+
Size of the image after applying the centre crop.
|
| 326 |
+
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
| 327 |
+
Whether to rescale the image values between `[-1 - 1]` if `offset` is `True`, `[0, 1]` otherwise.
|
| 328 |
+
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
| 329 |
+
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
| 330 |
+
offset (`bool`, *optional*, defaults to `self.offset`):
|
| 331 |
+
Whether to scale the image in both negative and positive directions.
|
| 332 |
+
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
| 333 |
+
Whether to normalize the image.
|
| 334 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
| 335 |
+
Image mean.
|
| 336 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
| 337 |
+
Image standard deviation.
|
| 338 |
+
return_tensors (`str` or `TensorType`, *optional*):
|
| 339 |
+
The type of tensors to return. Can be one of:
|
| 340 |
+
- Unset: Return a list of `np.ndarray`.
|
| 341 |
+
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
| 342 |
+
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
| 343 |
+
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
| 344 |
+
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
| 345 |
+
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
| 346 |
+
The channel dimension format for the output image. Can be one of:
|
| 347 |
+
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 348 |
+
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 349 |
+
- Unset: Use the inferred channel dimension format of the input image.
|
| 350 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
| 351 |
+
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
| 352 |
+
from the input image. Can be one of:
|
| 353 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
| 354 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
| 355 |
+
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
| 356 |
+
"""
|
| 357 |
+
do_resize = do_resize if do_resize is not None else self.do_resize
|
| 358 |
+
resample = resample if resample is not None else self.resample
|
| 359 |
+
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
|
| 360 |
+
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
| 361 |
+
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
| 362 |
+
offset = offset if offset is not None else self.offset
|
| 363 |
+
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
| 364 |
+
image_mean = image_mean if image_mean is not None else self.image_mean
|
| 365 |
+
image_std = image_std if image_std is not None else self.image_std
|
| 366 |
+
|
| 367 |
+
size = size if size is not None else self.size
|
| 368 |
+
size = get_size_dict(size, default_to_square=False)
|
| 369 |
+
crop_size = crop_size if crop_size is not None else self.crop_size
|
| 370 |
+
crop_size = get_size_dict(crop_size, param_name="crop_size")
|
| 371 |
+
|
| 372 |
+
if not valid_images(videos):
|
| 373 |
+
raise ValueError(
|
| 374 |
+
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
| 375 |
+
"torch.Tensor, tf.Tensor or jax.ndarray."
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
videos = make_batched(videos)
|
| 379 |
+
|
| 380 |
+
videos = [
|
| 381 |
+
[
|
| 382 |
+
self._preprocess_image(
|
| 383 |
+
image=img,
|
| 384 |
+
do_resize=do_resize,
|
| 385 |
+
size=size,
|
| 386 |
+
resample=resample,
|
| 387 |
+
do_center_crop=do_center_crop,
|
| 388 |
+
crop_size=crop_size,
|
| 389 |
+
do_rescale=do_rescale,
|
| 390 |
+
rescale_factor=rescale_factor,
|
| 391 |
+
offset=offset,
|
| 392 |
+
do_normalize=do_normalize,
|
| 393 |
+
image_mean=image_mean,
|
| 394 |
+
image_std=image_std,
|
| 395 |
+
data_format=data_format,
|
| 396 |
+
input_data_format=input_data_format,
|
| 397 |
+
)
|
| 398 |
+
for img in video
|
| 399 |
+
]
|
| 400 |
+
for video in videos
|
| 401 |
+
]
|
| 402 |
+
|
| 403 |
+
data = {"pixel_values": videos}
|
| 404 |
+
return BatchFeature(data=data, tensor_type=return_tensors)
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
__all__ = ["VivitImageProcessor"]
|
docs/transformers/build/lib/transformers/models/vivit/modeling_vivit.py
ADDED
|
@@ -0,0 +1,844 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 Google AI and The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""PyTorch ViViT model."""
|
| 16 |
+
|
| 17 |
+
from typing import Callable, Optional, Set, Tuple, Union
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import torch.utils.checkpoint
|
| 21 |
+
from torch import nn
|
| 22 |
+
from torch.nn import CrossEntropyLoss, MSELoss
|
| 23 |
+
|
| 24 |
+
from ...activations import ACT2FN
|
| 25 |
+
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
|
| 26 |
+
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 27 |
+
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
| 28 |
+
from ...utils import (
|
| 29 |
+
add_start_docstrings,
|
| 30 |
+
add_start_docstrings_to_model_forward,
|
| 31 |
+
logging,
|
| 32 |
+
replace_return_docstrings,
|
| 33 |
+
torch_int,
|
| 34 |
+
)
|
| 35 |
+
from .configuration_vivit import VivitConfig
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
logger = logging.get_logger(__name__)
|
| 39 |
+
|
| 40 |
+
_CHECKPOINT_FOR_DOC = "google/vivit-b-16x2-kinetics400"
|
| 41 |
+
_CONFIG_FOR_DOC = "VivitConfig"
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class VivitTubeletEmbeddings(nn.Module):
|
| 45 |
+
"""
|
| 46 |
+
Construct Vivit Tubelet embeddings.
|
| 47 |
+
|
| 48 |
+
This module turns a batch of videos of shape (batch_size, num_frames, num_channels, height, width) into a tensor of
|
| 49 |
+
shape (batch_size, seq_len, hidden_size) to be consumed by a Transformer encoder.
|
| 50 |
+
|
| 51 |
+
The seq_len (the number of patches) equals (number of frames // tubelet_size[0]) * (height // tubelet_size[1]) *
|
| 52 |
+
(width // tubelet_size[2]).
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
def __init__(self, config):
|
| 56 |
+
super().__init__()
|
| 57 |
+
self.num_frames = config.num_frames
|
| 58 |
+
self.image_size = config.image_size
|
| 59 |
+
self.patch_size = config.tubelet_size
|
| 60 |
+
self.num_patches = (
|
| 61 |
+
(self.image_size // self.patch_size[2])
|
| 62 |
+
* (self.image_size // self.patch_size[1])
|
| 63 |
+
* (self.num_frames // self.patch_size[0])
|
| 64 |
+
)
|
| 65 |
+
self.embed_dim = config.hidden_size
|
| 66 |
+
|
| 67 |
+
self.projection = nn.Conv3d(
|
| 68 |
+
config.num_channels, config.hidden_size, kernel_size=config.tubelet_size, stride=config.tubelet_size
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
def forward(self, pixel_values, interpolate_pos_encoding: bool = False):
|
| 72 |
+
batch_size, num_frames, num_channels, height, width = pixel_values.shape
|
| 73 |
+
if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size):
|
| 74 |
+
raise ValueError(
|
| 75 |
+
f"Image image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
# permute to (batch_size, num_channels, num_frames, height, width)
|
| 79 |
+
pixel_values = pixel_values.permute(0, 2, 1, 3, 4)
|
| 80 |
+
|
| 81 |
+
x = self.projection(pixel_values)
|
| 82 |
+
# out_batch_size, out_num_channels, out_num_frames, out_height, out_width = x.shape
|
| 83 |
+
# flattens time and space dimensions, transposes to (out_batch_size, flat_tokens, out_num_channels)
|
| 84 |
+
x = x.flatten(2).transpose(1, 2)
|
| 85 |
+
return x
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class VivitEmbeddings(nn.Module):
|
| 89 |
+
"""
|
| 90 |
+
Vivit Embeddings.
|
| 91 |
+
|
| 92 |
+
Creates embeddings from a video using VivitTubeletEmbeddings, adds CLS token and positional embeddings.
|
| 93 |
+
"""
|
| 94 |
+
|
| 95 |
+
def __init__(self, config):
|
| 96 |
+
super().__init__()
|
| 97 |
+
|
| 98 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
| 99 |
+
self.patch_embeddings = VivitTubeletEmbeddings(config)
|
| 100 |
+
|
| 101 |
+
self.position_embeddings = nn.Parameter(
|
| 102 |
+
torch.zeros(1, self.patch_embeddings.num_patches + 1, config.hidden_size)
|
| 103 |
+
)
|
| 104 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 105 |
+
self.patch_size = config.tubelet_size[1:]
|
| 106 |
+
self.config = config
|
| 107 |
+
|
| 108 |
+
# Adapted from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
|
| 109 |
+
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
| 110 |
+
"""
|
| 111 |
+
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
|
| 112 |
+
images. This method is also adapted to support torch.jit tracing.
|
| 113 |
+
|
| 114 |
+
Adapted from:
|
| 115 |
+
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
|
| 116 |
+
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
|
| 117 |
+
"""
|
| 118 |
+
|
| 119 |
+
num_patches = embeddings.shape[1] - 1
|
| 120 |
+
num_positions = self.position_embeddings.shape[1] - 1
|
| 121 |
+
|
| 122 |
+
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
|
| 123 |
+
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
|
| 124 |
+
return self.position_embeddings
|
| 125 |
+
|
| 126 |
+
class_pos_embed = self.position_embeddings[:, :1]
|
| 127 |
+
patch_pos_embed = self.position_embeddings[:, 1:]
|
| 128 |
+
|
| 129 |
+
dim = embeddings.shape[-1]
|
| 130 |
+
|
| 131 |
+
new_height = height // self.patch_size[0]
|
| 132 |
+
new_width = width // self.patch_size[1]
|
| 133 |
+
|
| 134 |
+
sqrt_num_positions = torch_int(num_positions**0.5)
|
| 135 |
+
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
|
| 136 |
+
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
| 137 |
+
|
| 138 |
+
patch_pos_embed = nn.functional.interpolate(
|
| 139 |
+
patch_pos_embed,
|
| 140 |
+
size=(new_height, new_width),
|
| 141 |
+
mode="bicubic",
|
| 142 |
+
align_corners=False,
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
| 146 |
+
|
| 147 |
+
return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
|
| 148 |
+
|
| 149 |
+
def forward(self, pixel_values, interpolate_pos_encoding: bool = False):
|
| 150 |
+
batch_size, num_frames, num_channels, height, width = pixel_values.shape
|
| 151 |
+
embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
|
| 152 |
+
|
| 153 |
+
cls_tokens = self.cls_token.tile([batch_size, 1, 1])
|
| 154 |
+
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
|
| 155 |
+
|
| 156 |
+
# add positional encoding to each token
|
| 157 |
+
if interpolate_pos_encoding:
|
| 158 |
+
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
|
| 159 |
+
else:
|
| 160 |
+
embeddings = embeddings + self.position_embeddings
|
| 161 |
+
|
| 162 |
+
embeddings = self.dropout(embeddings)
|
| 163 |
+
|
| 164 |
+
return embeddings
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
# Copied from transformers.models.vit.modeling_vit.eager_attention_forward
|
| 168 |
+
def eager_attention_forward(
|
| 169 |
+
module: nn.Module,
|
| 170 |
+
query: torch.Tensor,
|
| 171 |
+
key: torch.Tensor,
|
| 172 |
+
value: torch.Tensor,
|
| 173 |
+
attention_mask: Optional[torch.Tensor],
|
| 174 |
+
scaling: float,
|
| 175 |
+
dropout: float = 0.0,
|
| 176 |
+
**kwargs,
|
| 177 |
+
):
|
| 178 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
| 179 |
+
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
|
| 180 |
+
|
| 181 |
+
# Normalize the attention scores to probabilities.
|
| 182 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
| 183 |
+
|
| 184 |
+
# This is actually dropping out entire tokens to attend to, which might
|
| 185 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
| 186 |
+
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
| 187 |
+
|
| 188 |
+
# Mask heads if we want to
|
| 189 |
+
if attention_mask is not None:
|
| 190 |
+
attn_weights = attn_weights * attention_mask
|
| 191 |
+
|
| 192 |
+
attn_output = torch.matmul(attn_weights, value)
|
| 193 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 194 |
+
|
| 195 |
+
return attn_output, attn_weights
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->Vivit
|
| 199 |
+
class VivitSelfAttention(nn.Module):
|
| 200 |
+
def __init__(self, config: VivitConfig) -> None:
|
| 201 |
+
super().__init__()
|
| 202 |
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
| 203 |
+
raise ValueError(
|
| 204 |
+
f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
|
| 205 |
+
f"heads {config.num_attention_heads}."
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
self.config = config
|
| 209 |
+
self.num_attention_heads = config.num_attention_heads
|
| 210 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
| 211 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
| 212 |
+
self.dropout_prob = config.attention_probs_dropout_prob
|
| 213 |
+
self.scaling = self.attention_head_size**-0.5
|
| 214 |
+
self.is_causal = False
|
| 215 |
+
|
| 216 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
| 217 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
| 218 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
| 219 |
+
|
| 220 |
+
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
| 221 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
| 222 |
+
x = x.view(new_x_shape)
|
| 223 |
+
return x.permute(0, 2, 1, 3)
|
| 224 |
+
|
| 225 |
+
def forward(
|
| 226 |
+
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
|
| 227 |
+
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
| 228 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
| 229 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
| 230 |
+
query_layer = self.transpose_for_scores(self.query(hidden_states))
|
| 231 |
+
|
| 232 |
+
attention_interface: Callable = eager_attention_forward
|
| 233 |
+
if self.config._attn_implementation != "eager":
|
| 234 |
+
if self.config._attn_implementation == "sdpa" and output_attentions:
|
| 235 |
+
logger.warning_once(
|
| 236 |
+
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
| 237 |
+
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
| 238 |
+
)
|
| 239 |
+
else:
|
| 240 |
+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
| 241 |
+
|
| 242 |
+
context_layer, attention_probs = attention_interface(
|
| 243 |
+
self,
|
| 244 |
+
query_layer,
|
| 245 |
+
key_layer,
|
| 246 |
+
value_layer,
|
| 247 |
+
head_mask,
|
| 248 |
+
is_causal=self.is_causal,
|
| 249 |
+
scaling=self.scaling,
|
| 250 |
+
dropout=0.0 if not self.training else self.dropout_prob,
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
| 254 |
+
context_layer = context_layer.reshape(new_context_layer_shape)
|
| 255 |
+
|
| 256 |
+
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
| 257 |
+
|
| 258 |
+
return outputs
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Vivit
|
| 262 |
+
class VivitSelfOutput(nn.Module):
|
| 263 |
+
"""
|
| 264 |
+
The residual connection is defined in VivitLayer instead of here (as is the case with other models), due to the
|
| 265 |
+
layernorm applied before each block.
|
| 266 |
+
"""
|
| 267 |
+
|
| 268 |
+
def __init__(self, config: VivitConfig) -> None:
|
| 269 |
+
super().__init__()
|
| 270 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 271 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 272 |
+
|
| 273 |
+
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
| 274 |
+
hidden_states = self.dense(hidden_states)
|
| 275 |
+
hidden_states = self.dropout(hidden_states)
|
| 276 |
+
|
| 277 |
+
return hidden_states
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
# Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->Vivit
|
| 281 |
+
class VivitAttention(nn.Module):
|
| 282 |
+
def __init__(self, config: VivitConfig) -> None:
|
| 283 |
+
super().__init__()
|
| 284 |
+
self.attention = VivitSelfAttention(config)
|
| 285 |
+
self.output = VivitSelfOutput(config)
|
| 286 |
+
self.pruned_heads = set()
|
| 287 |
+
|
| 288 |
+
def prune_heads(self, heads: Set[int]) -> None:
|
| 289 |
+
if len(heads) == 0:
|
| 290 |
+
return
|
| 291 |
+
heads, index = find_pruneable_heads_and_indices(
|
| 292 |
+
heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
# Prune linear layers
|
| 296 |
+
self.attention.query = prune_linear_layer(self.attention.query, index)
|
| 297 |
+
self.attention.key = prune_linear_layer(self.attention.key, index)
|
| 298 |
+
self.attention.value = prune_linear_layer(self.attention.value, index)
|
| 299 |
+
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
| 300 |
+
|
| 301 |
+
# Update hyper params and store pruned heads
|
| 302 |
+
self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
|
| 303 |
+
self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
|
| 304 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
| 305 |
+
|
| 306 |
+
def forward(
|
| 307 |
+
self,
|
| 308 |
+
hidden_states: torch.Tensor,
|
| 309 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 310 |
+
output_attentions: bool = False,
|
| 311 |
+
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
| 312 |
+
self_outputs = self.attention(hidden_states, head_mask, output_attentions)
|
| 313 |
+
|
| 314 |
+
attention_output = self.output(self_outputs[0], hidden_states)
|
| 315 |
+
|
| 316 |
+
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
| 317 |
+
return outputs
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
class VivitIntermediate(nn.Module):
|
| 321 |
+
def __init__(self, config):
|
| 322 |
+
super().__init__()
|
| 323 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
| 324 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 325 |
+
if isinstance(config.hidden_act, str):
|
| 326 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
| 327 |
+
else:
|
| 328 |
+
self.intermediate_act_fn = config.hidden_act
|
| 329 |
+
|
| 330 |
+
def forward(self, hidden_states):
|
| 331 |
+
hidden_states = self.dense(hidden_states)
|
| 332 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
| 333 |
+
hidden_states = self.dropout(hidden_states)
|
| 334 |
+
|
| 335 |
+
return hidden_states
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
class VivitOutput(nn.Module):
|
| 339 |
+
def __init__(self, config):
|
| 340 |
+
super().__init__()
|
| 341 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
| 342 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 343 |
+
|
| 344 |
+
def forward(self, hidden_states, input_tensor):
|
| 345 |
+
hidden_states = self.dense(hidden_states)
|
| 346 |
+
|
| 347 |
+
hidden_states = self.dropout(hidden_states)
|
| 348 |
+
|
| 349 |
+
hidden_states = hidden_states + input_tensor
|
| 350 |
+
|
| 351 |
+
return hidden_states
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
class VivitLayer(nn.Module):
|
| 355 |
+
"""This corresponds to the EncoderBlock class in the scenic/vivit implementation."""
|
| 356 |
+
|
| 357 |
+
def __init__(self, config):
|
| 358 |
+
super().__init__()
|
| 359 |
+
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
| 360 |
+
self.seq_len_dim = 1
|
| 361 |
+
self.attention = VivitAttention(config)
|
| 362 |
+
self.intermediate = VivitIntermediate(config)
|
| 363 |
+
self.output = VivitOutput(config)
|
| 364 |
+
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 365 |
+
self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 366 |
+
|
| 367 |
+
def forward(self, hidden_states, head_mask=None, output_attentions=False):
|
| 368 |
+
self_attention_outputs = self.attention(
|
| 369 |
+
# in Vivit, layernorm is applied before self-attention
|
| 370 |
+
self.layernorm_before(hidden_states),
|
| 371 |
+
head_mask,
|
| 372 |
+
output_attentions=output_attentions,
|
| 373 |
+
)
|
| 374 |
+
attention_output = self_attention_outputs[0]
|
| 375 |
+
# add self attentions if we output attention weights
|
| 376 |
+
outputs = self_attention_outputs[1:]
|
| 377 |
+
|
| 378 |
+
# first residual connection
|
| 379 |
+
hidden_states = attention_output + hidden_states
|
| 380 |
+
|
| 381 |
+
# in Vivit, layernorm is also applied after self-attention
|
| 382 |
+
layer_output = self.layernorm_after(hidden_states)
|
| 383 |
+
layer_output = self.intermediate(layer_output)
|
| 384 |
+
|
| 385 |
+
# second residual connection is done here
|
| 386 |
+
layer_output = self.output(layer_output, hidden_states)
|
| 387 |
+
|
| 388 |
+
outputs = (layer_output,) + outputs
|
| 389 |
+
|
| 390 |
+
return outputs
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
class VivitEncoder(nn.Module):
|
| 394 |
+
def __init__(self, config):
|
| 395 |
+
super().__init__()
|
| 396 |
+
self.config = config
|
| 397 |
+
self.layer = nn.ModuleList([VivitLayer(config) for _ in range(config.num_hidden_layers)])
|
| 398 |
+
self.gradient_checkpointing = False
|
| 399 |
+
|
| 400 |
+
def forward(
|
| 401 |
+
self,
|
| 402 |
+
hidden_states,
|
| 403 |
+
head_mask=None,
|
| 404 |
+
output_attentions=False,
|
| 405 |
+
output_hidden_states=False,
|
| 406 |
+
return_dict=True,
|
| 407 |
+
):
|
| 408 |
+
all_hidden_states = () if output_hidden_states else None
|
| 409 |
+
all_self_attentions = () if output_attentions else None
|
| 410 |
+
|
| 411 |
+
for i, layer_module in enumerate(self.layer):
|
| 412 |
+
if output_hidden_states:
|
| 413 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 414 |
+
|
| 415 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
| 416 |
+
|
| 417 |
+
if self.gradient_checkpointing and self.training:
|
| 418 |
+
layer_outputs = self._gradient_checkpointing_func(
|
| 419 |
+
layer_module.__call__,
|
| 420 |
+
hidden_states,
|
| 421 |
+
layer_head_mask,
|
| 422 |
+
output_attentions,
|
| 423 |
+
)
|
| 424 |
+
else:
|
| 425 |
+
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
|
| 426 |
+
|
| 427 |
+
hidden_states = layer_outputs[0]
|
| 428 |
+
|
| 429 |
+
if output_attentions:
|
| 430 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
| 431 |
+
|
| 432 |
+
if output_hidden_states:
|
| 433 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 434 |
+
|
| 435 |
+
if not return_dict:
|
| 436 |
+
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
|
| 437 |
+
return BaseModelOutput(
|
| 438 |
+
last_hidden_state=hidden_states,
|
| 439 |
+
hidden_states=all_hidden_states,
|
| 440 |
+
attentions=all_self_attentions,
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
class VivitPooler(nn.Module):
|
| 445 |
+
def __init__(self, config):
|
| 446 |
+
super().__init__()
|
| 447 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 448 |
+
self.activation = nn.Tanh()
|
| 449 |
+
|
| 450 |
+
def forward(self, hidden_states):
|
| 451 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
| 452 |
+
# to the first token.
|
| 453 |
+
first_token_tensor = hidden_states[:, 0]
|
| 454 |
+
pooled_output = self.dense(first_token_tensor)
|
| 455 |
+
pooled_output = self.activation(pooled_output)
|
| 456 |
+
return pooled_output
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
class VivitPreTrainedModel(PreTrainedModel):
|
| 460 |
+
"""
|
| 461 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 462 |
+
models.
|
| 463 |
+
"""
|
| 464 |
+
|
| 465 |
+
config_class = VivitConfig
|
| 466 |
+
base_model_prefix = "vivit"
|
| 467 |
+
main_input_name = "pixel_values"
|
| 468 |
+
supports_gradient_checkpointing = True
|
| 469 |
+
_no_split_modules = []
|
| 470 |
+
_supports_sdpa = True
|
| 471 |
+
_supports_flash_attn_2 = True
|
| 472 |
+
|
| 473 |
+
def _init_weights(self, module):
|
| 474 |
+
"""Initialize the weights"""
|
| 475 |
+
if isinstance(module, (nn.Linear, nn.Conv3d)):
|
| 476 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
| 477 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
| 478 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 479 |
+
if module.bias is not None:
|
| 480 |
+
module.bias.data.zero_()
|
| 481 |
+
elif isinstance(module, nn.Embedding):
|
| 482 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 483 |
+
if module.padding_idx is not None:
|
| 484 |
+
module.weight.data[module.padding_idx].zero_()
|
| 485 |
+
elif isinstance(module, nn.LayerNorm):
|
| 486 |
+
module.bias.data.zero_()
|
| 487 |
+
module.weight.data.fill_(1.0)
|
| 488 |
+
elif isinstance(module, VivitEmbeddings):
|
| 489 |
+
module.cls_token.data.zero_()
|
| 490 |
+
module.position_embeddings.data.zero_()
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
VIVIT_START_DOCSTRING = r"""
|
| 494 |
+
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
|
| 495 |
+
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
| 496 |
+
behavior.
|
| 497 |
+
|
| 498 |
+
Parameters:
|
| 499 |
+
config ([`VivitConfig`]): Model configuration class with all the parameters of the model.
|
| 500 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 501 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 502 |
+
"""
|
| 503 |
+
|
| 504 |
+
VIVIT_INPUTS_DOCSTRING = r"""
|
| 505 |
+
Args:
|
| 506 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
|
| 507 |
+
Pixel values. Pixel values can be obtained using [`VivitImageProcessor`]. See
|
| 508 |
+
[`VivitImageProcessor.preprocess`] for details.
|
| 509 |
+
|
| 510 |
+
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
| 511 |
+
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
| 512 |
+
|
| 513 |
+
- 1 indicates the head is **not masked**,
|
| 514 |
+
- 0 indicates the head is **masked**.
|
| 515 |
+
|
| 516 |
+
output_attentions (`bool`, *optional*):
|
| 517 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 518 |
+
tensors for more detail.
|
| 519 |
+
output_hidden_states (`bool`, *optional*):
|
| 520 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 521 |
+
more detail.
|
| 522 |
+
interpolate_pos_encoding (`bool`, *optional*, `False`):
|
| 523 |
+
Whether to interpolate the pre-trained position encodings.
|
| 524 |
+
return_dict (`bool`, *optional*):
|
| 525 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 526 |
+
"""
|
| 527 |
+
|
| 528 |
+
|
| 529 |
+
@add_start_docstrings(
|
| 530 |
+
"The bare ViViT Transformer model outputting raw hidden-states without any specific head on top.",
|
| 531 |
+
VIVIT_START_DOCSTRING,
|
| 532 |
+
)
|
| 533 |
+
class VivitModel(VivitPreTrainedModel):
|
| 534 |
+
def __init__(self, config, add_pooling_layer=True):
|
| 535 |
+
super().__init__(config)
|
| 536 |
+
self.config = config
|
| 537 |
+
|
| 538 |
+
self.embeddings = VivitEmbeddings(config)
|
| 539 |
+
self.encoder = VivitEncoder(config)
|
| 540 |
+
|
| 541 |
+
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 542 |
+
self.pooler = VivitPooler(config) if add_pooling_layer else None
|
| 543 |
+
|
| 544 |
+
# Initialize weights and apply final processing
|
| 545 |
+
self.post_init()
|
| 546 |
+
|
| 547 |
+
def get_input_embeddings(self):
|
| 548 |
+
return self.embeddings.patch_embeddings
|
| 549 |
+
|
| 550 |
+
def _prune_heads(self, heads_to_prune):
|
| 551 |
+
"""
|
| 552 |
+
Prunes heads of the model.
|
| 553 |
+
|
| 554 |
+
Args:
|
| 555 |
+
heads_to_prune:
|
| 556 |
+
dict of {layer_num: list of heads to prune in this layer}
|
| 557 |
+
"""
|
| 558 |
+
for layer, heads in heads_to_prune.items():
|
| 559 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
| 560 |
+
|
| 561 |
+
@add_start_docstrings_to_model_forward(VIVIT_INPUTS_DOCSTRING)
|
| 562 |
+
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
|
| 563 |
+
def forward(
|
| 564 |
+
self,
|
| 565 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 566 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 567 |
+
output_attentions: Optional[bool] = None,
|
| 568 |
+
output_hidden_states: Optional[bool] = None,
|
| 569 |
+
interpolate_pos_encoding: bool = False,
|
| 570 |
+
return_dict: Optional[bool] = None,
|
| 571 |
+
) -> Union[Tuple[torch.FloatTensor], BaseModelOutputWithPooling]:
|
| 572 |
+
r"""
|
| 573 |
+
Returns:
|
| 574 |
+
|
| 575 |
+
Examples:
|
| 576 |
+
|
| 577 |
+
```python
|
| 578 |
+
>>> import av
|
| 579 |
+
>>> import numpy as np
|
| 580 |
+
|
| 581 |
+
>>> from transformers import VivitImageProcessor, VivitModel
|
| 582 |
+
>>> from huggingface_hub import hf_hub_download
|
| 583 |
+
|
| 584 |
+
>>> np.random.seed(0)
|
| 585 |
+
|
| 586 |
+
|
| 587 |
+
>>> def read_video_pyav(container, indices):
|
| 588 |
+
... '''
|
| 589 |
+
... Decode the video with PyAV decoder.
|
| 590 |
+
... Args:
|
| 591 |
+
... container (`av.container.input.InputContainer`): PyAV container.
|
| 592 |
+
... indices (`List[int]`): List of frame indices to decode.
|
| 593 |
+
... Returns:
|
| 594 |
+
... result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
|
| 595 |
+
... '''
|
| 596 |
+
... frames = []
|
| 597 |
+
... container.seek(0)
|
| 598 |
+
... start_index = indices[0]
|
| 599 |
+
... end_index = indices[-1]
|
| 600 |
+
... for i, frame in enumerate(container.decode(video=0)):
|
| 601 |
+
... if i > end_index:
|
| 602 |
+
... break
|
| 603 |
+
... if i >= start_index and i in indices:
|
| 604 |
+
... frames.append(frame)
|
| 605 |
+
... return np.stack([x.to_ndarray(format="rgb24") for x in frames])
|
| 606 |
+
|
| 607 |
+
|
| 608 |
+
>>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
|
| 609 |
+
... '''
|
| 610 |
+
... Sample a given number of frame indices from the video.
|
| 611 |
+
... Args:
|
| 612 |
+
... clip_len (`int`): Total number of frames to sample.
|
| 613 |
+
... frame_sample_rate (`int`): Sample every n-th frame.
|
| 614 |
+
... seg_len (`int`): Maximum allowed index of sample's last frame.
|
| 615 |
+
... Returns:
|
| 616 |
+
... indices (`List[int]`): List of sampled frame indices
|
| 617 |
+
... '''
|
| 618 |
+
... converted_len = int(clip_len * frame_sample_rate)
|
| 619 |
+
... end_idx = np.random.randint(converted_len, seg_len)
|
| 620 |
+
... start_idx = end_idx - converted_len
|
| 621 |
+
... indices = np.linspace(start_idx, end_idx, num=clip_len)
|
| 622 |
+
... indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
|
| 623 |
+
... return indices
|
| 624 |
+
|
| 625 |
+
|
| 626 |
+
>>> # video clip consists of 300 frames (10 seconds at 30 FPS)
|
| 627 |
+
>>> file_path = hf_hub_download(
|
| 628 |
+
... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset"
|
| 629 |
+
... )
|
| 630 |
+
>>> container = av.open(file_path)
|
| 631 |
+
|
| 632 |
+
>>> # sample 32 frames
|
| 633 |
+
>>> indices = sample_frame_indices(clip_len=32, frame_sample_rate=1, seg_len=container.streams.video[0].frames)
|
| 634 |
+
>>> video = read_video_pyav(container=container, indices=indices)
|
| 635 |
+
|
| 636 |
+
>>> image_processor = VivitImageProcessor.from_pretrained("google/vivit-b-16x2-kinetics400")
|
| 637 |
+
>>> model = VivitModel.from_pretrained("google/vivit-b-16x2-kinetics400")
|
| 638 |
+
|
| 639 |
+
>>> # prepare video for the model
|
| 640 |
+
>>> inputs = image_processor(list(video), return_tensors="pt")
|
| 641 |
+
|
| 642 |
+
>>> # forward pass
|
| 643 |
+
>>> outputs = model(**inputs)
|
| 644 |
+
>>> last_hidden_states = outputs.last_hidden_state
|
| 645 |
+
>>> list(last_hidden_states.shape)
|
| 646 |
+
[1, 3137, 768]
|
| 647 |
+
```"""
|
| 648 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 649 |
+
output_hidden_states = (
|
| 650 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 651 |
+
)
|
| 652 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 653 |
+
|
| 654 |
+
if pixel_values is None:
|
| 655 |
+
raise ValueError("You have to specify pixel_values")
|
| 656 |
+
|
| 657 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
| 658 |
+
|
| 659 |
+
embedding_output = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
|
| 660 |
+
|
| 661 |
+
encoder_outputs = self.encoder(
|
| 662 |
+
embedding_output,
|
| 663 |
+
head_mask=head_mask,
|
| 664 |
+
output_attentions=output_attentions,
|
| 665 |
+
output_hidden_states=output_hidden_states,
|
| 666 |
+
return_dict=return_dict,
|
| 667 |
+
)
|
| 668 |
+
sequence_output = encoder_outputs[0]
|
| 669 |
+
sequence_output = self.layernorm(sequence_output)
|
| 670 |
+
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
| 671 |
+
|
| 672 |
+
if not return_dict:
|
| 673 |
+
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
| 674 |
+
|
| 675 |
+
return BaseModelOutputWithPooling(
|
| 676 |
+
last_hidden_state=sequence_output,
|
| 677 |
+
pooler_output=pooled_output,
|
| 678 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 679 |
+
attentions=encoder_outputs.attentions,
|
| 680 |
+
)
|
| 681 |
+
|
| 682 |
+
|
| 683 |
+
@add_start_docstrings(
|
| 684 |
+
"""
|
| 685 |
+
ViViT Transformer model with a video classification head on top (a linear layer on top of the final hidden state of the
|
| 686 |
+
[CLS] token) e.g. for Kinetics-400.
|
| 687 |
+
|
| 688 |
+
<Tip>
|
| 689 |
+
|
| 690 |
+
Note that it's possible to fine-tune ViT on higher resolution images than the ones it has been trained on, by
|
| 691 |
+
setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
|
| 692 |
+
position embeddings to the higher resolution.
|
| 693 |
+
|
| 694 |
+
</Tip>
|
| 695 |
+
""",
|
| 696 |
+
VIVIT_START_DOCSTRING,
|
| 697 |
+
)
|
| 698 |
+
class VivitForVideoClassification(VivitPreTrainedModel):
|
| 699 |
+
def __init__(self, config):
|
| 700 |
+
super().__init__(config)
|
| 701 |
+
|
| 702 |
+
self.num_labels = config.num_labels
|
| 703 |
+
self.vivit = VivitModel(config, add_pooling_layer=False)
|
| 704 |
+
|
| 705 |
+
# Classifier head
|
| 706 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
|
| 707 |
+
|
| 708 |
+
# Initialize weights and apply final processing
|
| 709 |
+
self.post_init()
|
| 710 |
+
|
| 711 |
+
@add_start_docstrings_to_model_forward(VIVIT_INPUTS_DOCSTRING)
|
| 712 |
+
@replace_return_docstrings(output_type=ImageClassifierOutput, config_class=_CONFIG_FOR_DOC)
|
| 713 |
+
def forward(
|
| 714 |
+
self,
|
| 715 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 716 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 717 |
+
labels: Optional[torch.LongTensor] = None,
|
| 718 |
+
output_attentions: Optional[bool] = None,
|
| 719 |
+
output_hidden_states: Optional[bool] = None,
|
| 720 |
+
interpolate_pos_encoding: bool = False,
|
| 721 |
+
return_dict: Optional[bool] = None,
|
| 722 |
+
) -> Union[Tuple[torch.FloatTensor], ImageClassifierOutput]:
|
| 723 |
+
r"""
|
| 724 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 725 |
+
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
|
| 726 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 727 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 728 |
+
|
| 729 |
+
Returns:
|
| 730 |
+
|
| 731 |
+
Examples:
|
| 732 |
+
|
| 733 |
+
```python
|
| 734 |
+
>>> import av
|
| 735 |
+
>>> import numpy as np
|
| 736 |
+
>>> import torch
|
| 737 |
+
|
| 738 |
+
>>> from transformers import VivitImageProcessor, VivitForVideoClassification
|
| 739 |
+
>>> from huggingface_hub import hf_hub_download
|
| 740 |
+
|
| 741 |
+
>>> np.random.seed(0)
|
| 742 |
+
|
| 743 |
+
|
| 744 |
+
>>> def read_video_pyav(container, indices):
|
| 745 |
+
... '''
|
| 746 |
+
... Decode the video with PyAV decoder.
|
| 747 |
+
... Args:
|
| 748 |
+
... container (`av.container.input.InputContainer`): PyAV container.
|
| 749 |
+
... indices (`List[int]`): List of frame indices to decode.
|
| 750 |
+
... Returns:
|
| 751 |
+
... result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
|
| 752 |
+
... '''
|
| 753 |
+
... frames = []
|
| 754 |
+
... container.seek(0)
|
| 755 |
+
... start_index = indices[0]
|
| 756 |
+
... end_index = indices[-1]
|
| 757 |
+
... for i, frame in enumerate(container.decode(video=0)):
|
| 758 |
+
... if i > end_index:
|
| 759 |
+
... break
|
| 760 |
+
... if i >= start_index and i in indices:
|
| 761 |
+
... frames.append(frame)
|
| 762 |
+
... return np.stack([x.to_ndarray(format="rgb24") for x in frames])
|
| 763 |
+
|
| 764 |
+
|
| 765 |
+
>>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
|
| 766 |
+
... '''
|
| 767 |
+
... Sample a given number of frame indices from the video.
|
| 768 |
+
... Args:
|
| 769 |
+
... clip_len (`int`): Total number of frames to sample.
|
| 770 |
+
... frame_sample_rate (`int`): Sample every n-th frame.
|
| 771 |
+
... seg_len (`int`): Maximum allowed index of sample's last frame.
|
| 772 |
+
... Returns:
|
| 773 |
+
... indices (`List[int]`): List of sampled frame indices
|
| 774 |
+
... '''
|
| 775 |
+
... converted_len = int(clip_len * frame_sample_rate)
|
| 776 |
+
... end_idx = np.random.randint(converted_len, seg_len)
|
| 777 |
+
... start_idx = end_idx - converted_len
|
| 778 |
+
... indices = np.linspace(start_idx, end_idx, num=clip_len)
|
| 779 |
+
... indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
|
| 780 |
+
... return indices
|
| 781 |
+
|
| 782 |
+
|
| 783 |
+
>>> # video clip consists of 300 frames (10 seconds at 30 FPS)
|
| 784 |
+
>>> file_path = hf_hub_download(
|
| 785 |
+
... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset"
|
| 786 |
+
... )
|
| 787 |
+
>>> container = av.open(file_path)
|
| 788 |
+
|
| 789 |
+
>>> # sample 32 frames
|
| 790 |
+
>>> indices = sample_frame_indices(clip_len=32, frame_sample_rate=4, seg_len=container.streams.video[0].frames)
|
| 791 |
+
>>> video = read_video_pyav(container=container, indices=indices)
|
| 792 |
+
|
| 793 |
+
>>> image_processor = VivitImageProcessor.from_pretrained("google/vivit-b-16x2-kinetics400")
|
| 794 |
+
>>> model = VivitForVideoClassification.from_pretrained("google/vivit-b-16x2-kinetics400")
|
| 795 |
+
|
| 796 |
+
>>> inputs = image_processor(list(video), return_tensors="pt")
|
| 797 |
+
|
| 798 |
+
>>> with torch.no_grad():
|
| 799 |
+
... outputs = model(**inputs)
|
| 800 |
+
... logits = outputs.logits
|
| 801 |
+
|
| 802 |
+
>>> # model predicts one of the 400 Kinetics-400 classes
|
| 803 |
+
>>> predicted_label = logits.argmax(-1).item()
|
| 804 |
+
>>> print(model.config.id2label[predicted_label])
|
| 805 |
+
LABEL_116
|
| 806 |
+
```"""
|
| 807 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 808 |
+
|
| 809 |
+
outputs = self.vivit(
|
| 810 |
+
pixel_values,
|
| 811 |
+
head_mask=head_mask,
|
| 812 |
+
output_attentions=output_attentions,
|
| 813 |
+
output_hidden_states=output_hidden_states,
|
| 814 |
+
interpolate_pos_encoding=interpolate_pos_encoding,
|
| 815 |
+
return_dict=return_dict,
|
| 816 |
+
)
|
| 817 |
+
|
| 818 |
+
sequence_output = outputs[0]
|
| 819 |
+
|
| 820 |
+
logits = self.classifier(sequence_output[:, 0, :])
|
| 821 |
+
|
| 822 |
+
loss = None
|
| 823 |
+
if labels is not None:
|
| 824 |
+
if self.num_labels == 1:
|
| 825 |
+
# We are doing regression
|
| 826 |
+
loss_fct = MSELoss()
|
| 827 |
+
loss = loss_fct(logits.view(-1), labels.view(-1))
|
| 828 |
+
else:
|
| 829 |
+
loss_fct = CrossEntropyLoss()
|
| 830 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 831 |
+
|
| 832 |
+
if not return_dict:
|
| 833 |
+
output = (logits,) + outputs[2:]
|
| 834 |
+
return ((loss,) + output) if loss is not None else output
|
| 835 |
+
|
| 836 |
+
return ImageClassifierOutput(
|
| 837 |
+
loss=loss,
|
| 838 |
+
logits=logits,
|
| 839 |
+
hidden_states=outputs.hidden_states,
|
| 840 |
+
attentions=outputs.attentions,
|
| 841 |
+
)
|
| 842 |
+
|
| 843 |
+
|
| 844 |
+
__all__ = ["VivitModel", "VivitPreTrainedModel", "VivitForVideoClassification"]
|
docs/transformers/build/lib/transformers/models/wav2vec2/__init__.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from ...utils import _LazyModule
|
| 17 |
+
from ...utils.import_utils import define_import_structure
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from .configuration_wav2vec2 import *
|
| 22 |
+
from .feature_extraction_wav2vec2 import *
|
| 23 |
+
from .modeling_flax_wav2vec2 import *
|
| 24 |
+
from .modeling_tf_wav2vec2 import *
|
| 25 |
+
from .modeling_wav2vec2 import *
|
| 26 |
+
from .processing_wav2vec2 import *
|
| 27 |
+
from .tokenization_wav2vec2 import *
|
| 28 |
+
else:
|
| 29 |
+
import sys
|
| 30 |
+
|
| 31 |
+
_file = globals()["__file__"]
|
| 32 |
+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
docs/transformers/build/lib/transformers/models/wav2vec2/configuration_wav2vec2.py
ADDED
|
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Wav2Vec2 model configuration"""
|
| 16 |
+
|
| 17 |
+
import functools
|
| 18 |
+
import operator
|
| 19 |
+
|
| 20 |
+
from ...configuration_utils import PretrainedConfig
|
| 21 |
+
from ...utils import logging
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
logger = logging.get_logger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class Wav2Vec2Config(PretrainedConfig):
|
| 28 |
+
r"""
|
| 29 |
+
This is the configuration class to store the configuration of a [`Wav2Vec2Model`]. It is used to instantiate an
|
| 30 |
+
Wav2Vec2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
| 31 |
+
with the defaults will yield a similar configuration to that of the Wav2Vec2
|
| 32 |
+
[facebook/wav2vec2-base-960h](https://huggingface.co/facebook/wav2vec2-base-960h) architecture.
|
| 33 |
+
|
| 34 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 35 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
vocab_size (`int`, *optional*, defaults to 32):
|
| 40 |
+
Vocabulary size of the Wav2Vec2 model. Defines the number of different tokens that can be represented by
|
| 41 |
+
the `inputs_ids` passed when calling [`Wav2Vec2Model`] or [`TFWav2Vec2Model`]. Vocabulary size of the
|
| 42 |
+
model. Defines the different tokens that can be represented by the *inputs_ids* passed to the forward
|
| 43 |
+
method of [`Wav2Vec2Model`].
|
| 44 |
+
hidden_size (`int`, *optional*, defaults to 768):
|
| 45 |
+
Dimensionality of the encoder layers and the pooler layer.
|
| 46 |
+
num_hidden_layers (`int`, *optional*, defaults to 12):
|
| 47 |
+
Number of hidden layers in the Transformer encoder.
|
| 48 |
+
num_attention_heads (`int`, *optional*, defaults to 12):
|
| 49 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 50 |
+
intermediate_size (`int`, *optional*, defaults to 3072):
|
| 51 |
+
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
| 52 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
| 53 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 54 |
+
`"relu"`, `"selu"` and `"gelu_new"` are supported.
|
| 55 |
+
hidden_dropout (`float`, *optional*, defaults to 0.1):
|
| 56 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
| 57 |
+
activation_dropout (`float`, *optional*, defaults to 0.1):
|
| 58 |
+
The dropout ratio for activations inside the fully connected layer.
|
| 59 |
+
attention_dropout (`float`, *optional*, defaults to 0.1):
|
| 60 |
+
The dropout ratio for the attention probabilities.
|
| 61 |
+
final_dropout (`float`, *optional*, defaults to 0.1):
|
| 62 |
+
The dropout probability for the final projection layer of [`Wav2Vec2ForCTC`].
|
| 63 |
+
layerdrop (`float`, *optional*, defaults to 0.1):
|
| 64 |
+
The LayerDrop probability. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) for more
|
| 65 |
+
details.
|
| 66 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 67 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 68 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
| 69 |
+
The epsilon used by the layer normalization layers.
|
| 70 |
+
feat_extract_norm (`str`, *optional*, defaults to `"group"`):
|
| 71 |
+
The norm to be applied to 1D convolutional layers in feature encoder. One of `"group"` for group
|
| 72 |
+
normalization of only the first 1D convolutional layer or `"layer"` for layer normalization of all 1D
|
| 73 |
+
convolutional layers.
|
| 74 |
+
feat_proj_dropout (`float`, *optional*, defaults to 0.0):
|
| 75 |
+
The dropout probability for output of the feature encoder.
|
| 76 |
+
feat_extract_activation (`str, `optional`, defaults to `"gelu"`):
|
| 77 |
+
The non-linear activation function (function or string) in the 1D convolutional layers of the feature
|
| 78 |
+
extractor. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` are supported.
|
| 79 |
+
feat_quantizer_dropout (`float`, *optional*, defaults to 0.0):
|
| 80 |
+
The dropout probability for quantized feature encoder states.
|
| 81 |
+
conv_dim (`Tuple[int]` or `List[int]`, *optional*, defaults to `(512, 512, 512, 512, 512, 512, 512)`):
|
| 82 |
+
A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the
|
| 83 |
+
feature encoder. The length of *conv_dim* defines the number of 1D convolutional layers.
|
| 84 |
+
conv_stride (`Tuple[int]` or `List[int]`, *optional*, defaults to `(5, 2, 2, 2, 2, 2, 2)`):
|
| 85 |
+
A tuple of integers defining the stride of each 1D convolutional layer in the feature encoder. The length
|
| 86 |
+
of *conv_stride* defines the number of convolutional layers and has to match the length of *conv_dim*.
|
| 87 |
+
conv_kernel (`Tuple[int]` or `List[int]`, *optional*, defaults to `(10, 3, 3, 3, 3, 3, 3)`):
|
| 88 |
+
A tuple of integers defining the kernel size of each 1D convolutional layer in the feature encoder. The
|
| 89 |
+
length of *conv_kernel* defines the number of convolutional layers and has to match the length of
|
| 90 |
+
*conv_dim*.
|
| 91 |
+
conv_bias (`bool`, *optional*, defaults to `False`):
|
| 92 |
+
Whether the 1D convolutional layers have a bias.
|
| 93 |
+
num_conv_pos_embeddings (`int`, *optional*, defaults to 128):
|
| 94 |
+
Number of convolutional positional embeddings. Defines the kernel size of 1D convolutional positional
|
| 95 |
+
embeddings layer.
|
| 96 |
+
num_conv_pos_embedding_groups (`int`, *optional*, defaults to 16):
|
| 97 |
+
Number of groups of 1D convolutional positional embeddings layer.
|
| 98 |
+
do_stable_layer_norm (`bool`, *optional*, defaults to `False`):
|
| 99 |
+
Whether to apply *stable* layer norm architecture of the Transformer encoder. `do_stable_layer_norm is
|
| 100 |
+
True` corresponds to applying layer norm before the attention layer, whereas `do_stable_layer_norm is
|
| 101 |
+
False` corresponds to applying layer norm after the attention layer.
|
| 102 |
+
apply_spec_augment (`bool`, *optional*, defaults to `True`):
|
| 103 |
+
Whether to apply *SpecAugment* data augmentation to the outputs of the feature encoder. For reference see
|
| 104 |
+
[SpecAugment: A Simple Data Augmentation Method for Automatic Speech
|
| 105 |
+
Recognition](https://arxiv.org/abs/1904.08779).
|
| 106 |
+
mask_time_prob (`float`, *optional*, defaults to 0.05):
|
| 107 |
+
Percentage (between 0 and 1) of all feature vectors along the time axis which will be masked. The masking
|
| 108 |
+
procecure generates ''mask_time_prob*len(time_axis)/mask_time_length'' independent masks over the axis. If
|
| 109 |
+
reasoning from the propability of each feature vector to be chosen as the start of the vector span to be
|
| 110 |
+
masked, *mask_time_prob* should be `prob_vector_start*mask_time_length`. Note that overlap may decrease the
|
| 111 |
+
actual percentage of masked vectors. This is only relevant if `apply_spec_augment is True`.
|
| 112 |
+
mask_time_length (`int`, *optional*, defaults to 10):
|
| 113 |
+
Length of vector span along the time axis.
|
| 114 |
+
mask_time_min_masks (`int`, *optional*, defaults to 2),:
|
| 115 |
+
The minimum number of masks of length `mask_feature_length` generated along the time axis, each time step,
|
| 116 |
+
irrespectively of `mask_feature_prob`. Only relevant if ''mask_time_prob*len(time_axis)/mask_time_length <
|
| 117 |
+
mask_time_min_masks''
|
| 118 |
+
mask_feature_prob (`float`, *optional*, defaults to 0.0):
|
| 119 |
+
Percentage (between 0 and 1) of all feature vectors along the feature axis which will be masked. The
|
| 120 |
+
masking procecure generates ''mask_feature_prob*len(feature_axis)/mask_time_length'' independent masks over
|
| 121 |
+
the axis. If reasoning from the propability of each feature vector to be chosen as the start of the vector
|
| 122 |
+
span to be masked, *mask_feature_prob* should be `prob_vector_start*mask_feature_length`. Note that overlap
|
| 123 |
+
may decrease the actual percentage of masked vectors. This is only relevant if `apply_spec_augment is
|
| 124 |
+
True`.
|
| 125 |
+
mask_feature_length (`int`, *optional*, defaults to 10):
|
| 126 |
+
Length of vector span along the feature axis.
|
| 127 |
+
mask_feature_min_masks (`int`, *optional*, defaults to 0),:
|
| 128 |
+
The minimum number of masks of length `mask_feature_length` generated along the feature axis, each time
|
| 129 |
+
step, irrespectively of `mask_feature_prob`. Only relevant if
|
| 130 |
+
''mask_feature_prob*len(feature_axis)/mask_feature_length < mask_feature_min_masks''
|
| 131 |
+
num_codevectors_per_group (`int`, *optional*, defaults to 320):
|
| 132 |
+
Number of entries in each quantization codebook (group).
|
| 133 |
+
num_codevector_groups (`int`, *optional*, defaults to 2):
|
| 134 |
+
Number of codevector groups for product codevector quantization.
|
| 135 |
+
contrastive_logits_temperature (`float`, *optional*, defaults to 0.1):
|
| 136 |
+
The temperature *kappa* in the contrastive loss.
|
| 137 |
+
feat_quantizer_dropout (`float`, *optional*, defaults to 0.0):
|
| 138 |
+
The dropout probability for the output of the feature encoder that's used by the quantizer.
|
| 139 |
+
num_negatives (`int`, *optional*, defaults to 100):
|
| 140 |
+
Number of negative samples for the contrastive loss.
|
| 141 |
+
codevector_dim (`int`, *optional*, defaults to 256):
|
| 142 |
+
Dimensionality of the quantized feature vectors.
|
| 143 |
+
proj_codevector_dim (`int`, *optional*, defaults to 256):
|
| 144 |
+
Dimensionality of the final projection of both the quantized and the transformer features.
|
| 145 |
+
diversity_loss_weight (`int`, *optional*, defaults to 0.1):
|
| 146 |
+
The weight of the codebook diversity loss component.
|
| 147 |
+
ctc_loss_reduction (`str`, *optional*, defaults to `"sum"`):
|
| 148 |
+
Specifies the reduction to apply to the output of `torch.nn.CTCLoss`. Only relevant when training an
|
| 149 |
+
instance of [`Wav2Vec2ForCTC`].
|
| 150 |
+
ctc_zero_infinity (`bool`, *optional*, defaults to `False`):
|
| 151 |
+
Whether to zero infinite losses and the associated gradients of `torch.nn.CTCLoss`. Infinite losses mainly
|
| 152 |
+
occur when the inputs are too short to be aligned to the targets. Only relevant when training an instance
|
| 153 |
+
of [`Wav2Vec2ForCTC`].
|
| 154 |
+
use_weighted_layer_sum (`bool`, *optional*, defaults to `False`):
|
| 155 |
+
Whether to use a weighted average of layer outputs with learned weights. Only relevant when using an
|
| 156 |
+
instance of [`Wav2Vec2ForSequenceClassification`].
|
| 157 |
+
classifier_proj_size (`int`, *optional*, defaults to 256):
|
| 158 |
+
Dimensionality of the projection before token mean-pooling for classification.
|
| 159 |
+
tdnn_dim (`Tuple[int]` or `List[int]`, *optional*, defaults to `(512, 512, 512, 512, 1500)`):
|
| 160 |
+
A tuple of integers defining the number of output channels of each 1D convolutional layer in the *TDNN*
|
| 161 |
+
module of the *XVector* model. The length of *tdnn_dim* defines the number of *TDNN* layers.
|
| 162 |
+
tdnn_kernel (`Tuple[int]` or `List[int]`, *optional*, defaults to `(5, 3, 3, 1, 1)`):
|
| 163 |
+
A tuple of integers defining the kernel size of each 1D convolutional layer in the *TDNN* module of the
|
| 164 |
+
*XVector* model. The length of *tdnn_kernel* has to match the length of *tdnn_dim*.
|
| 165 |
+
tdnn_dilation (`Tuple[int]` or `List[int]`, *optional*, defaults to `(1, 2, 3, 1, 1)`):
|
| 166 |
+
A tuple of integers defining the dilation factor of each 1D convolutional layer in *TDNN* module of the
|
| 167 |
+
*XVector* model. The length of *tdnn_dilation* has to match the length of *tdnn_dim*.
|
| 168 |
+
xvector_output_dim (`int`, *optional*, defaults to 512):
|
| 169 |
+
Dimensionality of the *XVector* embedding vectors.
|
| 170 |
+
add_adapter (`bool`, *optional*, defaults to `False`):
|
| 171 |
+
Whether a convolutional network should be stacked on top of the Wav2Vec2 Encoder. Can be very useful for
|
| 172 |
+
warm-starting Wav2Vec2 for SpeechEncoderDecoder models.
|
| 173 |
+
adapter_kernel_size (`int`, *optional*, defaults to 3):
|
| 174 |
+
Kernel size of the convolutional layers in the adapter network. Only relevant if `add_adapter is True`.
|
| 175 |
+
adapter_stride (`int`, *optional*, defaults to 2):
|
| 176 |
+
Stride of the convolutional layers in the adapter network. Only relevant if `add_adapter is True`.
|
| 177 |
+
num_adapter_layers (`int`, *optional*, defaults to 3):
|
| 178 |
+
Number of convolutional layers that should be used in the adapter network. Only relevant if `add_adapter is
|
| 179 |
+
True`.
|
| 180 |
+
adapter_attn_dim (`int`, *optional*):
|
| 181 |
+
Dimension of the attention adapter weights to be used in each attention block. An example of a model using
|
| 182 |
+
attention adapters is [facebook/mms-1b-all](https://huggingface.co/facebook/mms-1b-all).
|
| 183 |
+
output_hidden_size (`int`, *optional*):
|
| 184 |
+
Dimensionality of the encoder output layer. If not defined, this defaults to *hidden-size*. Only relevant
|
| 185 |
+
if `add_adapter is True`.
|
| 186 |
+
|
| 187 |
+
Example:
|
| 188 |
+
|
| 189 |
+
```python
|
| 190 |
+
>>> from transformers import Wav2Vec2Config, Wav2Vec2Model
|
| 191 |
+
|
| 192 |
+
>>> # Initializing a Wav2Vec2 facebook/wav2vec2-base-960h style configuration
|
| 193 |
+
>>> configuration = Wav2Vec2Config()
|
| 194 |
+
|
| 195 |
+
>>> # Initializing a model (with random weights) from the facebook/wav2vec2-base-960h style configuration
|
| 196 |
+
>>> model = Wav2Vec2Model(configuration)
|
| 197 |
+
|
| 198 |
+
>>> # Accessing the model configuration
|
| 199 |
+
>>> configuration = model.config
|
| 200 |
+
```"""
|
| 201 |
+
|
| 202 |
+
model_type = "wav2vec2"
|
| 203 |
+
|
| 204 |
+
def __init__(
|
| 205 |
+
self,
|
| 206 |
+
vocab_size=32,
|
| 207 |
+
hidden_size=768,
|
| 208 |
+
num_hidden_layers=12,
|
| 209 |
+
num_attention_heads=12,
|
| 210 |
+
intermediate_size=3072,
|
| 211 |
+
hidden_act="gelu",
|
| 212 |
+
hidden_dropout=0.1,
|
| 213 |
+
activation_dropout=0.1,
|
| 214 |
+
attention_dropout=0.1,
|
| 215 |
+
feat_proj_dropout=0.0,
|
| 216 |
+
feat_quantizer_dropout=0.0,
|
| 217 |
+
final_dropout=0.1,
|
| 218 |
+
layerdrop=0.1,
|
| 219 |
+
initializer_range=0.02,
|
| 220 |
+
layer_norm_eps=1e-5,
|
| 221 |
+
feat_extract_norm="group",
|
| 222 |
+
feat_extract_activation="gelu",
|
| 223 |
+
conv_dim=(512, 512, 512, 512, 512, 512, 512),
|
| 224 |
+
conv_stride=(5, 2, 2, 2, 2, 2, 2),
|
| 225 |
+
conv_kernel=(10, 3, 3, 3, 3, 2, 2),
|
| 226 |
+
conv_bias=False,
|
| 227 |
+
num_conv_pos_embeddings=128,
|
| 228 |
+
num_conv_pos_embedding_groups=16,
|
| 229 |
+
do_stable_layer_norm=False,
|
| 230 |
+
apply_spec_augment=True,
|
| 231 |
+
mask_time_prob=0.05,
|
| 232 |
+
mask_time_length=10,
|
| 233 |
+
mask_time_min_masks=2,
|
| 234 |
+
mask_feature_prob=0.0,
|
| 235 |
+
mask_feature_length=10,
|
| 236 |
+
mask_feature_min_masks=0,
|
| 237 |
+
num_codevectors_per_group=320,
|
| 238 |
+
num_codevector_groups=2,
|
| 239 |
+
contrastive_logits_temperature=0.1,
|
| 240 |
+
num_negatives=100,
|
| 241 |
+
codevector_dim=256,
|
| 242 |
+
proj_codevector_dim=256,
|
| 243 |
+
diversity_loss_weight=0.1,
|
| 244 |
+
ctc_loss_reduction="sum",
|
| 245 |
+
ctc_zero_infinity=False,
|
| 246 |
+
use_weighted_layer_sum=False,
|
| 247 |
+
classifier_proj_size=256,
|
| 248 |
+
tdnn_dim=(512, 512, 512, 512, 1500),
|
| 249 |
+
tdnn_kernel=(5, 3, 3, 1, 1),
|
| 250 |
+
tdnn_dilation=(1, 2, 3, 1, 1),
|
| 251 |
+
xvector_output_dim=512,
|
| 252 |
+
pad_token_id=0,
|
| 253 |
+
bos_token_id=1,
|
| 254 |
+
eos_token_id=2,
|
| 255 |
+
add_adapter=False,
|
| 256 |
+
adapter_kernel_size=3,
|
| 257 |
+
adapter_stride=2,
|
| 258 |
+
num_adapter_layers=3,
|
| 259 |
+
output_hidden_size=None,
|
| 260 |
+
adapter_attn_dim=None,
|
| 261 |
+
**kwargs,
|
| 262 |
+
):
|
| 263 |
+
super().__init__(**kwargs, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id)
|
| 264 |
+
self.hidden_size = hidden_size
|
| 265 |
+
self.feat_extract_norm = feat_extract_norm
|
| 266 |
+
self.feat_extract_activation = feat_extract_activation
|
| 267 |
+
self.conv_dim = list(conv_dim)
|
| 268 |
+
self.conv_stride = list(conv_stride)
|
| 269 |
+
self.conv_kernel = list(conv_kernel)
|
| 270 |
+
self.conv_bias = conv_bias
|
| 271 |
+
self.num_conv_pos_embeddings = num_conv_pos_embeddings
|
| 272 |
+
self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups
|
| 273 |
+
self.num_feat_extract_layers = len(self.conv_dim)
|
| 274 |
+
self.num_hidden_layers = num_hidden_layers
|
| 275 |
+
self.intermediate_size = intermediate_size
|
| 276 |
+
self.hidden_act = hidden_act
|
| 277 |
+
self.num_attention_heads = num_attention_heads
|
| 278 |
+
self.hidden_dropout = hidden_dropout
|
| 279 |
+
self.attention_dropout = attention_dropout
|
| 280 |
+
self.activation_dropout = activation_dropout
|
| 281 |
+
self.feat_proj_dropout = feat_proj_dropout
|
| 282 |
+
self.final_dropout = final_dropout
|
| 283 |
+
self.layerdrop = layerdrop
|
| 284 |
+
self.layer_norm_eps = layer_norm_eps
|
| 285 |
+
self.initializer_range = initializer_range
|
| 286 |
+
self.vocab_size = vocab_size
|
| 287 |
+
self.do_stable_layer_norm = do_stable_layer_norm
|
| 288 |
+
self.use_weighted_layer_sum = use_weighted_layer_sum
|
| 289 |
+
|
| 290 |
+
if (
|
| 291 |
+
(len(self.conv_stride) != self.num_feat_extract_layers)
|
| 292 |
+
or (len(self.conv_kernel) != self.num_feat_extract_layers)
|
| 293 |
+
or (len(self.conv_dim) != self.num_feat_extract_layers)
|
| 294 |
+
):
|
| 295 |
+
raise ValueError(
|
| 296 |
+
"Configuration for convolutional layers is incorrect. It is required that `len(config.conv_dim)` =="
|
| 297 |
+
" `len(config.conv_stride)` == `len(config.conv_kernel)`, but is `len(config.conv_dim) ="
|
| 298 |
+
f" {len(self.conv_dim)}`, `len(config.conv_stride) = {len(self.conv_stride)}`,"
|
| 299 |
+
f" `len(config.conv_kernel) = {len(self.conv_kernel)}`."
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
# fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779
|
| 303 |
+
self.apply_spec_augment = apply_spec_augment
|
| 304 |
+
self.mask_time_prob = mask_time_prob
|
| 305 |
+
self.mask_time_length = mask_time_length
|
| 306 |
+
self.mask_time_min_masks = mask_time_min_masks
|
| 307 |
+
self.mask_feature_prob = mask_feature_prob
|
| 308 |
+
self.mask_feature_length = mask_feature_length
|
| 309 |
+
self.mask_feature_min_masks = mask_feature_min_masks
|
| 310 |
+
|
| 311 |
+
# parameters for pretraining with codevector quantized representations
|
| 312 |
+
self.num_codevectors_per_group = num_codevectors_per_group
|
| 313 |
+
self.num_codevector_groups = num_codevector_groups
|
| 314 |
+
self.contrastive_logits_temperature = contrastive_logits_temperature
|
| 315 |
+
self.feat_quantizer_dropout = feat_quantizer_dropout
|
| 316 |
+
self.num_negatives = num_negatives
|
| 317 |
+
self.codevector_dim = codevector_dim
|
| 318 |
+
self.proj_codevector_dim = proj_codevector_dim
|
| 319 |
+
self.diversity_loss_weight = diversity_loss_weight
|
| 320 |
+
|
| 321 |
+
# ctc loss
|
| 322 |
+
self.ctc_loss_reduction = ctc_loss_reduction
|
| 323 |
+
self.ctc_zero_infinity = ctc_zero_infinity
|
| 324 |
+
|
| 325 |
+
# adapter
|
| 326 |
+
self.add_adapter = add_adapter
|
| 327 |
+
self.adapter_kernel_size = adapter_kernel_size
|
| 328 |
+
self.adapter_stride = adapter_stride
|
| 329 |
+
self.num_adapter_layers = num_adapter_layers
|
| 330 |
+
self.output_hidden_size = output_hidden_size or hidden_size
|
| 331 |
+
self.adapter_attn_dim = adapter_attn_dim
|
| 332 |
+
|
| 333 |
+
# SequenceClassification-specific parameter. Feel free to ignore for other classes.
|
| 334 |
+
self.classifier_proj_size = classifier_proj_size
|
| 335 |
+
|
| 336 |
+
# XVector-specific parameters. Feel free to ignore for other classes.
|
| 337 |
+
self.tdnn_dim = list(tdnn_dim)
|
| 338 |
+
self.tdnn_kernel = list(tdnn_kernel)
|
| 339 |
+
self.tdnn_dilation = list(tdnn_dilation)
|
| 340 |
+
self.xvector_output_dim = xvector_output_dim
|
| 341 |
+
|
| 342 |
+
@property
|
| 343 |
+
def inputs_to_logits_ratio(self):
|
| 344 |
+
return functools.reduce(operator.mul, self.conv_stride, 1)
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
__all__ = ["Wav2Vec2Config"]
|
docs/transformers/build/lib/transformers/models/wav2vec2/convert_wav2vec2_original_pytorch_checkpoint_to_pytorch.py
ADDED
|
@@ -0,0 +1,385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2021 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Convert Wav2Vec2 checkpoint."""
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import json
|
| 19 |
+
import os
|
| 20 |
+
|
| 21 |
+
import fairseq
|
| 22 |
+
import torch
|
| 23 |
+
from fairseq.data import Dictionary
|
| 24 |
+
|
| 25 |
+
from transformers import (
|
| 26 |
+
Wav2Vec2Config,
|
| 27 |
+
Wav2Vec2CTCTokenizer,
|
| 28 |
+
Wav2Vec2FeatureExtractor,
|
| 29 |
+
Wav2Vec2ForCTC,
|
| 30 |
+
Wav2Vec2ForPreTraining,
|
| 31 |
+
Wav2Vec2Processor,
|
| 32 |
+
logging,
|
| 33 |
+
)
|
| 34 |
+
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2ForSequenceClassification
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
logging.set_verbosity_info()
|
| 38 |
+
logger = logging.get_logger(__name__)
|
| 39 |
+
|
| 40 |
+
MAPPING = {
|
| 41 |
+
"post_extract_proj": "feature_projection.projection",
|
| 42 |
+
"encoder.pos_conv.0": "encoder.pos_conv_embed.conv",
|
| 43 |
+
"self_attn.k_proj": "encoder.layers.*.attention.k_proj",
|
| 44 |
+
"self_attn.v_proj": "encoder.layers.*.attention.v_proj",
|
| 45 |
+
"self_attn.q_proj": "encoder.layers.*.attention.q_proj",
|
| 46 |
+
"self_attn.out_proj": "encoder.layers.*.attention.out_proj",
|
| 47 |
+
"self_attn_layer_norm": "encoder.layers.*.layer_norm",
|
| 48 |
+
"fc1": "encoder.layers.*.feed_forward.intermediate_dense",
|
| 49 |
+
"fc2": "encoder.layers.*.feed_forward.output_dense",
|
| 50 |
+
"final_layer_norm": "encoder.layers.*.final_layer_norm",
|
| 51 |
+
"encoder.layer_norm": "encoder.layer_norm",
|
| 52 |
+
"adapter_layer": "encoder.layers.*.adapter_layer",
|
| 53 |
+
"w2v_model.layer_norm": "feature_projection.layer_norm",
|
| 54 |
+
"quantizer.weight_proj": "quantizer.weight_proj",
|
| 55 |
+
"quantizer.vars": "quantizer.codevectors",
|
| 56 |
+
"project_q": "project_q",
|
| 57 |
+
"final_proj": "project_hid",
|
| 58 |
+
"w2v_encoder.proj": "lm_head",
|
| 59 |
+
"mask_emb": "masked_spec_embed",
|
| 60 |
+
"pooling_layer.linear": "projector",
|
| 61 |
+
"pooling_layer.projection": "classifier",
|
| 62 |
+
}
|
| 63 |
+
TOP_LEVEL_KEYS = [
|
| 64 |
+
"lm_head",
|
| 65 |
+
"quantizer.weight_proj",
|
| 66 |
+
"quantizer.codevectors",
|
| 67 |
+
"project_q",
|
| 68 |
+
"project_hid",
|
| 69 |
+
"projector",
|
| 70 |
+
"classifier",
|
| 71 |
+
]
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def read_txt_into_dict(filename):
|
| 75 |
+
result = {}
|
| 76 |
+
with open(filename, "r") as file:
|
| 77 |
+
for line_number, line in enumerate(file):
|
| 78 |
+
line = line.strip()
|
| 79 |
+
if line:
|
| 80 |
+
words = line.split()
|
| 81 |
+
key = line_number
|
| 82 |
+
value = words[0]
|
| 83 |
+
result[key] = value
|
| 84 |
+
return result
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def set_recursively(key, value, full_name, weight_type, hf_pointer):
|
| 88 |
+
for attribute in key.split("."):
|
| 89 |
+
hf_pointer = getattr(hf_pointer, attribute)
|
| 90 |
+
|
| 91 |
+
hf_param_name = None
|
| 92 |
+
for param_key in PARAM_MAPPING.keys():
|
| 93 |
+
if full_name.endswith(param_key):
|
| 94 |
+
hf_param_name = PARAM_MAPPING[full_name.split(".")[-1]]
|
| 95 |
+
weight_type = "param"
|
| 96 |
+
|
| 97 |
+
# fairseq uses nn.utils.weight_norm() while transformers switches to nn.utils.parametrizations.weight_norm()
|
| 98 |
+
# the mapping between two versions:
|
| 99 |
+
# https://github.com/pytorch/pytorch/blob/56935684c3dfad7841c83c719eeebecb560fe466/torch/nn/utils/parametrizations.py#L389-L395
|
| 100 |
+
|
| 101 |
+
if weight_type is not None and weight_type != "param":
|
| 102 |
+
if weight_type == "weight_g" and not hasattr(hf_pointer, "weight_g"):
|
| 103 |
+
hf_shape = hf_pointer.parametrizations.weight.original0.shape
|
| 104 |
+
elif weight_type == "weight_v" and not hasattr(hf_pointer, "weight_v"):
|
| 105 |
+
hf_shape = hf_pointer.parametrizations.weight.original1.shape
|
| 106 |
+
else:
|
| 107 |
+
hf_shape = getattr(hf_pointer, weight_type).shape
|
| 108 |
+
elif weight_type is not None and weight_type == "param":
|
| 109 |
+
shape_pointer = hf_pointer
|
| 110 |
+
for attribute in hf_param_name.split("."):
|
| 111 |
+
shape_pointer = getattr(shape_pointer, attribute)
|
| 112 |
+
hf_shape = shape_pointer.shape
|
| 113 |
+
|
| 114 |
+
# let's reduce dimension
|
| 115 |
+
value = value[0]
|
| 116 |
+
else:
|
| 117 |
+
hf_shape = hf_pointer.shape
|
| 118 |
+
|
| 119 |
+
if hf_shape != value.shape:
|
| 120 |
+
raise ValueError(
|
| 121 |
+
f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be"
|
| 122 |
+
f" {value.shape} for {full_name}"
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
if weight_type == "weight":
|
| 126 |
+
hf_pointer.weight.data = value
|
| 127 |
+
elif weight_type == "weight_g":
|
| 128 |
+
if hasattr(hf_pointer, "weight_g"):
|
| 129 |
+
hf_pointer.weight_g.data = value
|
| 130 |
+
else:
|
| 131 |
+
hf_pointer.parametrizations.weight.original0.data = value
|
| 132 |
+
elif weight_type == "weight_v":
|
| 133 |
+
if hasattr(hf_pointer, "weight_v"):
|
| 134 |
+
hf_pointer.weight_v.data = value
|
| 135 |
+
else:
|
| 136 |
+
hf_pointer.parametrizations.weight.original1.data = value
|
| 137 |
+
elif weight_type == "bias":
|
| 138 |
+
hf_pointer.bias.data = value
|
| 139 |
+
elif weight_type == "param":
|
| 140 |
+
for attribute in hf_param_name.split("."):
|
| 141 |
+
hf_pointer = getattr(hf_pointer, attribute)
|
| 142 |
+
hf_pointer.data = value
|
| 143 |
+
else:
|
| 144 |
+
hf_pointer.data = value
|
| 145 |
+
|
| 146 |
+
logger.info(f"{key + '.' + weight_type if weight_type is not None else ''} was initialized from {full_name}.")
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def rename_dict(key, value, full_name, weight_type, hf_dict):
|
| 150 |
+
hf_param_name = None
|
| 151 |
+
for param_key in PARAM_MAPPING.keys():
|
| 152 |
+
if full_name.endswith(param_key):
|
| 153 |
+
hf_param_name = PARAM_MAPPING[full_name.split(".")[-1]]
|
| 154 |
+
weight_type = "param"
|
| 155 |
+
|
| 156 |
+
if weight_type is not None and weight_type != "param":
|
| 157 |
+
full_key = ".".join([key, weight_type])
|
| 158 |
+
elif weight_type is not None and weight_type == "param":
|
| 159 |
+
full_key = ".".join([key, hf_param_name])
|
| 160 |
+
else:
|
| 161 |
+
full_key = key
|
| 162 |
+
|
| 163 |
+
hf_dict[full_key] = value if "lm_head" in full_key else value[0]
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
PARAM_MAPPING = {
|
| 167 |
+
"W_a": "linear_1.weight",
|
| 168 |
+
"W_b": "linear_2.weight",
|
| 169 |
+
"b_a": "linear_1.bias",
|
| 170 |
+
"b_b": "linear_2.bias",
|
| 171 |
+
"ln_W": "norm.weight",
|
| 172 |
+
"ln_b": "norm.bias",
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def load_wav2vec2_layer(name, value, hf_model=None, hf_dict=None):
|
| 177 |
+
is_used = False
|
| 178 |
+
for key, mapped_key in MAPPING.items():
|
| 179 |
+
mapped_key = "wav2vec2." + mapped_key if mapped_key not in TOP_LEVEL_KEYS else mapped_key
|
| 180 |
+
if key in name or key.split("w2v_model.")[-1] == name.split(".")[0]:
|
| 181 |
+
is_used = True
|
| 182 |
+
if "*" in mapped_key:
|
| 183 |
+
layer_index = name.split(key)[0].split(".")[-2]
|
| 184 |
+
mapped_key = mapped_key.replace("*", layer_index)
|
| 185 |
+
if "weight_g" in name:
|
| 186 |
+
weight_type = "weight_g"
|
| 187 |
+
elif "weight_v" in name:
|
| 188 |
+
weight_type = "weight_v"
|
| 189 |
+
elif "bias" in name:
|
| 190 |
+
weight_type = "bias"
|
| 191 |
+
elif "weight" in name:
|
| 192 |
+
# TODO: don't match quantizer.weight_proj
|
| 193 |
+
weight_type = "weight"
|
| 194 |
+
else:
|
| 195 |
+
weight_type = None
|
| 196 |
+
if hf_dict is not None:
|
| 197 |
+
rename_dict(mapped_key, value, name, weight_type, hf_dict)
|
| 198 |
+
else:
|
| 199 |
+
set_recursively(mapped_key, value, name, weight_type, hf_model)
|
| 200 |
+
return is_used
|
| 201 |
+
return is_used
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def recursively_load_weights(fairseq_model, hf_model, is_headless):
|
| 205 |
+
unused_weights = []
|
| 206 |
+
fairseq_dict = fairseq_model.state_dict()
|
| 207 |
+
|
| 208 |
+
feature_extractor = hf_model.wav2vec2.feature_extractor
|
| 209 |
+
|
| 210 |
+
for name, value in fairseq_dict.items():
|
| 211 |
+
is_used = False
|
| 212 |
+
if "conv_layers" in name:
|
| 213 |
+
load_conv_layer(
|
| 214 |
+
name,
|
| 215 |
+
value,
|
| 216 |
+
feature_extractor,
|
| 217 |
+
unused_weights,
|
| 218 |
+
hf_model.config.feat_extract_norm == "group",
|
| 219 |
+
)
|
| 220 |
+
is_used = True
|
| 221 |
+
else:
|
| 222 |
+
is_used = load_wav2vec2_layer(name, value, hf_model)
|
| 223 |
+
if not is_used:
|
| 224 |
+
unused_weights.append(name)
|
| 225 |
+
|
| 226 |
+
logger.warning(f"Unused weights: {unused_weights}")
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_group_norm):
|
| 230 |
+
name = full_name.split("conv_layers.")[-1]
|
| 231 |
+
items = name.split(".")
|
| 232 |
+
layer_id = int(items[0])
|
| 233 |
+
type_id = int(items[1])
|
| 234 |
+
|
| 235 |
+
if type_id == 0:
|
| 236 |
+
if "bias" in name:
|
| 237 |
+
if value.shape != feature_extractor.conv_layers[layer_id].conv.bias.data.shape:
|
| 238 |
+
raise ValueError(
|
| 239 |
+
f"{full_name} has size {value.shape}, but"
|
| 240 |
+
f" {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found."
|
| 241 |
+
)
|
| 242 |
+
feature_extractor.conv_layers[layer_id].conv.bias.data = value
|
| 243 |
+
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
|
| 244 |
+
elif "weight" in name:
|
| 245 |
+
if value.shape != feature_extractor.conv_layers[layer_id].conv.weight.data.shape:
|
| 246 |
+
raise ValueError(
|
| 247 |
+
f"{full_name} has size {value.shape}, but"
|
| 248 |
+
f" {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found."
|
| 249 |
+
)
|
| 250 |
+
feature_extractor.conv_layers[layer_id].conv.weight.data = value
|
| 251 |
+
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
|
| 252 |
+
elif (type_id == 2 and not use_group_norm) or (type_id == 2 and layer_id == 0 and use_group_norm):
|
| 253 |
+
if "bias" in name:
|
| 254 |
+
if value.shape != feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape:
|
| 255 |
+
raise ValueError(
|
| 256 |
+
f"{full_name} has size {value.shape}, but"
|
| 257 |
+
f" {feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape} was found."
|
| 258 |
+
)
|
| 259 |
+
feature_extractor.conv_layers[layer_id].layer_norm.bias.data = value
|
| 260 |
+
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
|
| 261 |
+
elif "weight" in name:
|
| 262 |
+
if value.shape != feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape:
|
| 263 |
+
raise ValueError(
|
| 264 |
+
f"{full_name} has size {value.shape}, but"
|
| 265 |
+
f" {feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape} was found."
|
| 266 |
+
)
|
| 267 |
+
feature_extractor.conv_layers[layer_id].layer_norm.weight.data = value
|
| 268 |
+
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
|
| 269 |
+
else:
|
| 270 |
+
unused_weights.append(full_name)
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
@torch.no_grad()
|
| 274 |
+
def convert_wav2vec2_checkpoint(
|
| 275 |
+
checkpoint_path, pytorch_dump_folder_path, config_path=None, dict_path=None, is_finetuned=True, is_seq_class=False
|
| 276 |
+
):
|
| 277 |
+
"""
|
| 278 |
+
Copy/paste/tweak model's weights to transformers design.
|
| 279 |
+
"""
|
| 280 |
+
if config_path is not None:
|
| 281 |
+
config = Wav2Vec2Config.from_pretrained(config_path)
|
| 282 |
+
else:
|
| 283 |
+
config = Wav2Vec2Config()
|
| 284 |
+
|
| 285 |
+
if is_seq_class:
|
| 286 |
+
id2label = read_txt_into_dict(dict_path)
|
| 287 |
+
config.id2label = id2label
|
| 288 |
+
hf_wav2vec = Wav2Vec2ForSequenceClassification(config)
|
| 289 |
+
feature_extractor = Wav2Vec2FeatureExtractor(
|
| 290 |
+
feature_size=1,
|
| 291 |
+
sampling_rate=16000,
|
| 292 |
+
padding_value=0,
|
| 293 |
+
do_normalize=True,
|
| 294 |
+
return_attention_mask=True,
|
| 295 |
+
)
|
| 296 |
+
feature_extractor.save_pretrained(pytorch_dump_folder_path)
|
| 297 |
+
|
| 298 |
+
elif is_finetuned:
|
| 299 |
+
if dict_path:
|
| 300 |
+
target_dict = Dictionary.load(dict_path)
|
| 301 |
+
|
| 302 |
+
# important change bos & pad token id since CTC symbol is <pad> and
|
| 303 |
+
# not <s> as in fairseq
|
| 304 |
+
config.bos_token_id = target_dict.pad_index
|
| 305 |
+
config.pad_token_id = target_dict.bos_index
|
| 306 |
+
config.eos_token_id = target_dict.eos_index
|
| 307 |
+
config.vocab_size = len(target_dict.symbols)
|
| 308 |
+
vocab_path = os.path.join(pytorch_dump_folder_path, "vocab.json")
|
| 309 |
+
if not os.path.isdir(pytorch_dump_folder_path):
|
| 310 |
+
logger.error("--pytorch_dump_folder_path ({}) should be a directory".format(pytorch_dump_folder_path))
|
| 311 |
+
return
|
| 312 |
+
os.makedirs(pytorch_dump_folder_path, exist_ok=True)
|
| 313 |
+
vocab_dict = target_dict.indices
|
| 314 |
+
|
| 315 |
+
# fairseq has the <pad> and <s> switched
|
| 316 |
+
vocab_dict["<pad>"] = 0
|
| 317 |
+
vocab_dict["<s>"] = 1
|
| 318 |
+
with open(vocab_path, "w", encoding="utf-8") as vocab_handle:
|
| 319 |
+
json.dump(vocab_dict, vocab_handle)
|
| 320 |
+
tokenizer = Wav2Vec2CTCTokenizer(
|
| 321 |
+
vocab_path,
|
| 322 |
+
unk_token=target_dict.unk_word,
|
| 323 |
+
pad_token=target_dict.pad_word,
|
| 324 |
+
bos_token=target_dict.bos_word,
|
| 325 |
+
eos_token=target_dict.eos_word,
|
| 326 |
+
word_delimiter_token="|",
|
| 327 |
+
do_lower_case=False,
|
| 328 |
+
)
|
| 329 |
+
return_attention_mask = True if config.feat_extract_norm == "layer" else False
|
| 330 |
+
feature_extractor = Wav2Vec2FeatureExtractor(
|
| 331 |
+
feature_size=1,
|
| 332 |
+
sampling_rate=16000,
|
| 333 |
+
padding_value=0,
|
| 334 |
+
do_normalize=True,
|
| 335 |
+
return_attention_mask=return_attention_mask,
|
| 336 |
+
)
|
| 337 |
+
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
|
| 338 |
+
processor.save_pretrained(pytorch_dump_folder_path)
|
| 339 |
+
|
| 340 |
+
hf_wav2vec = Wav2Vec2ForCTC(config)
|
| 341 |
+
else:
|
| 342 |
+
hf_wav2vec = Wav2Vec2ForPreTraining(config)
|
| 343 |
+
|
| 344 |
+
if is_finetuned or is_seq_class:
|
| 345 |
+
model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
|
| 346 |
+
[checkpoint_path], arg_overrides={"data": "/".join(dict_path.split("/")[:-1])}
|
| 347 |
+
)
|
| 348 |
+
else:
|
| 349 |
+
task_arg = argparse.Namespace(task="audio_pretraining")
|
| 350 |
+
task = fairseq.tasks.setup_task(task_arg)
|
| 351 |
+
|
| 352 |
+
model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([checkpoint_path], task=task)
|
| 353 |
+
|
| 354 |
+
model = model[0].eval()
|
| 355 |
+
|
| 356 |
+
recursively_load_weights(model, hf_wav2vec, not is_finetuned)
|
| 357 |
+
|
| 358 |
+
hf_wav2vec.save_pretrained(pytorch_dump_folder_path)
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
if __name__ == "__main__":
|
| 362 |
+
parser = argparse.ArgumentParser()
|
| 363 |
+
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
|
| 364 |
+
parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to fairseq checkpoint")
|
| 365 |
+
parser.add_argument("--dict_path", default=None, type=str, help="Path to dict of fine-tuned model")
|
| 366 |
+
parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
|
| 367 |
+
parser.add_argument(
|
| 368 |
+
"--not_finetuned", action="store_true", help="Whether the model to convert is a fine-tuned model or not"
|
| 369 |
+
)
|
| 370 |
+
parser.add_argument(
|
| 371 |
+
"--is_seq_class",
|
| 372 |
+
action="store_true",
|
| 373 |
+
help="Whether the model to convert is a fine-tuned sequence classification model or not",
|
| 374 |
+
)
|
| 375 |
+
args = parser.parse_args()
|
| 376 |
+
|
| 377 |
+
is_finetuned = not args.not_finetuned and not args.is_seq_class
|
| 378 |
+
convert_wav2vec2_checkpoint(
|
| 379 |
+
args.checkpoint_path,
|
| 380 |
+
args.pytorch_dump_folder_path,
|
| 381 |
+
args.config_path,
|
| 382 |
+
args.dict_path,
|
| 383 |
+
is_finetuned,
|
| 384 |
+
args.is_seq_class,
|
| 385 |
+
)
|