Student0809 commited on
Commit
1100969
·
verified ·
1 Parent(s): 463b4bf

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. docs/transformers/build/lib/transformers/models/visual_bert/configuration_visual_bert.py +135 -0
  2. docs/transformers/build/lib/transformers/models/visual_bert/modeling_visual_bert.py +1597 -0
  3. docs/transformers/build/lib/transformers/models/vit/__init__.py +32 -0
  4. docs/transformers/build/lib/transformers/models/vit/configuration_vit.py +151 -0
  5. docs/transformers/build/lib/transformers/models/vit/convert_dino_to_pytorch.py +218 -0
  6. docs/transformers/build/lib/transformers/models/vit/convert_vit_timm_to_pytorch.py +254 -0
  7. docs/transformers/build/lib/transformers/models/vit/feature_extraction_vit.py +38 -0
  8. docs/transformers/build/lib/transformers/models/vit/image_processing_vit.py +288 -0
  9. docs/transformers/build/lib/transformers/models/vit/image_processing_vit_fast.py +45 -0
  10. docs/transformers/build/lib/transformers/models/vit/modeling_flax_vit.py +677 -0
  11. docs/transformers/build/lib/transformers/models/vit/modeling_tf_vit.py +907 -0
  12. docs/transformers/build/lib/transformers/models/vit/modeling_vit.py +883 -0
  13. docs/transformers/build/lib/transformers/models/vit_mae/__init__.py +28 -0
  14. docs/transformers/build/lib/transformers/models/vit_mae/configuration_vit_mae.py +140 -0
  15. docs/transformers/build/lib/transformers/models/vit_mae/convert_vit_mae_to_pytorch.py +178 -0
  16. docs/transformers/build/lib/transformers/models/vit_mae/modeling_tf_vit_mae.py +1375 -0
  17. docs/transformers/build/lib/transformers/models/vit_mae/modeling_vit_mae.py +1163 -0
  18. docs/transformers/build/lib/transformers/models/vit_msn/__init__.py +27 -0
  19. docs/transformers/build/lib/transformers/models/vit_msn/configuration_vit_msn.py +115 -0
  20. docs/transformers/build/lib/transformers/models/vit_msn/convert_msn_to_pytorch.py +245 -0
  21. docs/transformers/build/lib/transformers/models/vit_msn/modeling_vit_msn.py +741 -0
  22. docs/transformers/build/lib/transformers/models/vitdet/__init__.py +27 -0
  23. docs/transformers/build/lib/transformers/models/vitdet/configuration_vitdet.py +156 -0
  24. docs/transformers/build/lib/transformers/models/vitdet/modeling_vitdet.py +883 -0
  25. docs/transformers/build/lib/transformers/models/vitmatte/__init__.py +28 -0
  26. docs/transformers/build/lib/transformers/models/vitmatte/configuration_vitmatte.py +136 -0
  27. docs/transformers/build/lib/transformers/models/vitmatte/convert_vitmatte_to_hf.py +170 -0
  28. docs/transformers/build/lib/transformers/models/vitmatte/image_processing_vitmatte.py +272 -0
  29. docs/transformers/build/lib/transformers/models/vitmatte/modeling_vitmatte.py +341 -0
  30. docs/transformers/build/lib/transformers/models/vitpose/__init__.py +28 -0
  31. docs/transformers/build/lib/transformers/models/vitpose/configuration_vitpose.py +126 -0
  32. docs/transformers/build/lib/transformers/models/vitpose/convert_vitpose_to_hf.py +428 -0
  33. docs/transformers/build/lib/transformers/models/vitpose/image_processing_vitpose.py +684 -0
  34. docs/transformers/build/lib/transformers/models/vitpose/modeling_vitpose.py +340 -0
  35. docs/transformers/build/lib/transformers/models/vitpose_backbone/__init__.py +17 -0
  36. docs/transformers/build/lib/transformers/models/vitpose_backbone/configuration_vitpose_backbone.py +139 -0
  37. docs/transformers/build/lib/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py +579 -0
  38. docs/transformers/build/lib/transformers/models/vits/__init__.py +28 -0
  39. docs/transformers/build/lib/transformers/models/vits/configuration_vits.py +253 -0
  40. docs/transformers/build/lib/transformers/models/vits/convert_original_checkpoint.py +390 -0
  41. docs/transformers/build/lib/transformers/models/vits/modeling_vits.py +1493 -0
  42. docs/transformers/build/lib/transformers/models/vits/tokenization_vits.py +246 -0
  43. docs/transformers/build/lib/transformers/models/vivit/__init__.py +28 -0
  44. docs/transformers/build/lib/transformers/models/vivit/configuration_vivit.py +119 -0
  45. docs/transformers/build/lib/transformers/models/vivit/convert_vivit_flax_to_pytorch.py +231 -0
  46. docs/transformers/build/lib/transformers/models/vivit/image_processing_vivit.py +407 -0
  47. docs/transformers/build/lib/transformers/models/vivit/modeling_vivit.py +844 -0
  48. docs/transformers/build/lib/transformers/models/wav2vec2/__init__.py +32 -0
  49. docs/transformers/build/lib/transformers/models/wav2vec2/configuration_wav2vec2.py +347 -0
  50. docs/transformers/build/lib/transformers/models/wav2vec2/convert_wav2vec2_original_pytorch_checkpoint_to_pytorch.py +385 -0
docs/transformers/build/lib/transformers/models/visual_bert/configuration_visual_bert.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """VisualBERT model configuration"""
16
+
17
+ from ...configuration_utils import PretrainedConfig
18
+ from ...utils import logging
19
+
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+
24
+ class VisualBertConfig(PretrainedConfig):
25
+ r"""
26
+ This is the configuration class to store the configuration of a [`VisualBertModel`]. It is used to instantiate an
27
+ VisualBERT model according to the specified arguments, defining the model architecture. Instantiating a
28
+ configuration with the defaults will yield a similar configuration to that of the VisualBERT
29
+ [uclanlp/visualbert-vqa-coco-pre](https://huggingface.co/uclanlp/visualbert-vqa-coco-pre) architecture.
30
+
31
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
32
+ documentation from [`PretrainedConfig`] for more information.
33
+
34
+
35
+ Args:
36
+ vocab_size (`int`, *optional*, defaults to 30522):
37
+ Vocabulary size of the VisualBERT model. Defines the number of different tokens that can be represented by
38
+ the `inputs_ids` passed when calling [`VisualBertModel`]. Vocabulary size of the model. Defines the
39
+ different tokens that can be represented by the `inputs_ids` passed to the forward method of
40
+ [`VisualBertModel`].
41
+ hidden_size (`int`, *optional*, defaults to 768):
42
+ Dimensionality of the encoder layers and the pooler layer.
43
+ visual_embedding_dim (`int`, *optional*, defaults to 512):
44
+ Dimensionality of the visual embeddings to be passed to the model.
45
+ num_hidden_layers (`int`, *optional*, defaults to 12):
46
+ Number of hidden layers in the Transformer encoder.
47
+ num_attention_heads (`int`, *optional*, defaults to 12):
48
+ Number of attention heads for each attention layer in the Transformer encoder.
49
+ intermediate_size (`int`, *optional*, defaults to 3072):
50
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
51
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
52
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
53
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
54
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
55
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
56
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
57
+ The dropout ratio for the attention probabilities.
58
+ max_position_embeddings (`int`, *optional*, defaults to 512):
59
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
60
+ just in case (e.g., 512 or 1024 or 2048).
61
+ type_vocab_size (`int`, *optional*, defaults to 2):
62
+ The vocabulary size of the `token_type_ids` passed when calling [`VisualBertModel`].
63
+ initializer_range (`float`, *optional*, defaults to 0.02):
64
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
65
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
66
+ The epsilon used by the layer normalization layers.
67
+ bypass_transformer (`bool`, *optional*, defaults to `False`):
68
+ Whether or not the model should bypass the transformer for the visual embeddings. If set to `True`, the
69
+ model directly concatenates the visual embeddings from [`VisualBertEmbeddings`] with text output from
70
+ transformers, and then pass it to a self-attention layer.
71
+ special_visual_initialize (`bool`, *optional*, defaults to `True`):
72
+ Whether or not the visual token type and position type embedding weights should be initialized the same as
73
+ the textual token type and positive type embeddings. When set to `True`, the weights of the textual token
74
+ type and position type embeddings are copied to the respective visual embedding layers.
75
+
76
+
77
+ Example:
78
+
79
+ ```python
80
+ >>> from transformers import VisualBertConfig, VisualBertModel
81
+
82
+ >>> # Initializing a VisualBERT visualbert-vqa-coco-pre style configuration
83
+ >>> configuration = VisualBertConfig.from_pretrained("uclanlp/visualbert-vqa-coco-pre")
84
+
85
+ >>> # Initializing a model (with random weights) from the visualbert-vqa-coco-pre style configuration
86
+ >>> model = VisualBertModel(configuration)
87
+
88
+ >>> # Accessing the model configuration
89
+ >>> configuration = model.config
90
+ ```"""
91
+
92
+ model_type = "visual_bert"
93
+
94
+ def __init__(
95
+ self,
96
+ vocab_size=30522,
97
+ hidden_size=768,
98
+ visual_embedding_dim=512,
99
+ num_hidden_layers=12,
100
+ num_attention_heads=12,
101
+ intermediate_size=3072,
102
+ hidden_act="gelu",
103
+ hidden_dropout_prob=0.1,
104
+ attention_probs_dropout_prob=0.1,
105
+ max_position_embeddings=512,
106
+ type_vocab_size=2,
107
+ initializer_range=0.02,
108
+ layer_norm_eps=1e-12,
109
+ bypass_transformer=False,
110
+ special_visual_initialize=True,
111
+ pad_token_id=1,
112
+ bos_token_id=0,
113
+ eos_token_id=2,
114
+ **kwargs,
115
+ ):
116
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
117
+
118
+ self.vocab_size = vocab_size
119
+ self.max_position_embeddings = max_position_embeddings
120
+ self.hidden_size = hidden_size
121
+ self.visual_embedding_dim = visual_embedding_dim
122
+ self.num_hidden_layers = num_hidden_layers
123
+ self.num_attention_heads = num_attention_heads
124
+ self.intermediate_size = intermediate_size
125
+ self.hidden_act = hidden_act
126
+ self.hidden_dropout_prob = hidden_dropout_prob
127
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
128
+ self.initializer_range = initializer_range
129
+ self.type_vocab_size = type_vocab_size
130
+ self.layer_norm_eps = layer_norm_eps
131
+ self.bypass_transformer = bypass_transformer
132
+ self.special_visual_initialize = special_visual_initialize
133
+
134
+
135
+ __all__ = ["VisualBertConfig"]
docs/transformers/build/lib/transformers/models/visual_bert/modeling_visual_bert.py ADDED
@@ -0,0 +1,1597 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The UCLA NLP Authors and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch VisualBERT model."""
16
+
17
+ import math
18
+ from dataclasses import dataclass
19
+ from typing import Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.utils.checkpoint
23
+ from torch import nn
24
+ from torch.nn import CrossEntropyLoss, KLDivLoss, LogSoftmax
25
+
26
+ from ...activations import ACT2FN
27
+ from ...modeling_outputs import (
28
+ BaseModelOutput,
29
+ BaseModelOutputWithPooling,
30
+ MultipleChoiceModelOutput,
31
+ SequenceClassifierOutput,
32
+ )
33
+ from ...modeling_utils import PreTrainedModel
34
+ from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
35
+ from ...utils import (
36
+ ModelOutput,
37
+ add_start_docstrings,
38
+ add_start_docstrings_to_model_forward,
39
+ logging,
40
+ replace_return_docstrings,
41
+ )
42
+ from .configuration_visual_bert import VisualBertConfig
43
+
44
+
45
+ logger = logging.get_logger(__name__)
46
+
47
+ _CONFIG_FOR_DOC = "VisualBertConfig"
48
+ _CHECKPOINT_FOR_DOC = "uclanlp/visualbert-vqa-coco-pre"
49
+
50
+
51
+ class VisualBertEmbeddings(nn.Module):
52
+ """Construct the embeddings from word, position and token_type embeddings and visual embeddings."""
53
+
54
+ def __init__(self, config):
55
+ super().__init__()
56
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
57
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
58
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
59
+
60
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
61
+ # any TensorFlow checkpoint file
62
+
63
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
64
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
65
+
66
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
67
+ self.register_buffer(
68
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
69
+ )
70
+
71
+ # For Visual Features
72
+ # Token type and position embedding for image features
73
+ self.visual_token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
74
+ self.visual_position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
75
+
76
+ if config.special_visual_initialize:
77
+ self.visual_token_type_embeddings.weight.data = nn.Parameter(
78
+ self.token_type_embeddings.weight.data.clone(), requires_grad=True
79
+ )
80
+ self.visual_position_embeddings.weight.data = nn.Parameter(
81
+ self.position_embeddings.weight.data.clone(), requires_grad=True
82
+ )
83
+
84
+ self.visual_projection = nn.Linear(config.visual_embedding_dim, config.hidden_size)
85
+
86
+ def forward(
87
+ self,
88
+ input_ids=None,
89
+ token_type_ids=None,
90
+ position_ids=None,
91
+ inputs_embeds=None,
92
+ visual_embeds=None,
93
+ visual_token_type_ids=None,
94
+ image_text_alignment=None,
95
+ ):
96
+ if input_ids is not None:
97
+ input_shape = input_ids.size()
98
+ else:
99
+ input_shape = inputs_embeds.size()[:-1]
100
+
101
+ seq_length = input_shape[1]
102
+
103
+ if position_ids is None:
104
+ position_ids = self.position_ids[:, :seq_length]
105
+
106
+ if inputs_embeds is None:
107
+ inputs_embeds = self.word_embeddings(input_ids)
108
+
109
+ if token_type_ids is None:
110
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
111
+
112
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
113
+
114
+ embeddings = inputs_embeds + token_type_embeddings
115
+
116
+ # Absolute Position Embeddings
117
+ position_embeddings = self.position_embeddings(position_ids)
118
+ embeddings += position_embeddings
119
+
120
+ if visual_embeds is not None:
121
+ if visual_token_type_ids is None:
122
+ visual_token_type_ids = torch.ones(
123
+ visual_embeds.size()[:-1], dtype=torch.long, device=self.position_ids.device
124
+ )
125
+
126
+ visual_embeds = self.visual_projection(visual_embeds)
127
+ visual_token_type_embeddings = self.visual_token_type_embeddings(visual_token_type_ids)
128
+
129
+ if image_text_alignment is not None:
130
+ # image_text_alignment = Batch x image_length x alignment_number.
131
+ # Each element denotes the position of the word corresponding to the image feature. -1 is the padding value.
132
+
133
+ dtype = token_type_embeddings.dtype
134
+ image_text_alignment_mask = (image_text_alignment != -1).long()
135
+ # Get rid of the -1.
136
+ image_text_alignment = image_text_alignment_mask * image_text_alignment
137
+
138
+ # Batch x image_length x alignment length x dim
139
+ visual_position_embeddings = self.position_embeddings(image_text_alignment)
140
+ visual_position_embeddings *= image_text_alignment_mask.to(dtype=dtype).unsqueeze(-1)
141
+ visual_position_embeddings = visual_position_embeddings.sum(2)
142
+
143
+ # We want to averge along the alignment_number dimension.
144
+ image_text_alignment_mask = image_text_alignment_mask.to(dtype=dtype).sum(2)
145
+
146
+ if (image_text_alignment_mask == 0).sum() != 0:
147
+ image_text_alignment_mask[image_text_alignment_mask == 0] = 1 # Avoid divide by zero error
148
+ logger.warning(
149
+ "Found 0 values in `image_text_alignment_mask`. Setting them to 1 to avoid divide-by-zero"
150
+ " error."
151
+ )
152
+ visual_position_embeddings = visual_position_embeddings / image_text_alignment_mask.unsqueeze(-1)
153
+
154
+ visual_position_ids = torch.zeros(
155
+ *visual_embeds.size()[:-1], dtype=torch.long, device=visual_embeds.device
156
+ )
157
+
158
+ # When fine-tuning the detector , the image_text_alignment is sometimes padded too long.
159
+ if visual_position_embeddings.size(1) != visual_embeds.size(1):
160
+ if visual_position_embeddings.size(1) < visual_embeds.size(1):
161
+ raise ValueError(
162
+ f"Visual position embeddings length: {visual_position_embeddings.size(1)} "
163
+ f"should be the same as `visual_embeds` length: {visual_embeds.size(1)}"
164
+ )
165
+ visual_position_embeddings = visual_position_embeddings[:, : visual_embeds.size(1), :]
166
+
167
+ visual_position_embeddings = visual_position_embeddings + self.visual_position_embeddings(
168
+ visual_position_ids
169
+ )
170
+ else:
171
+ visual_position_ids = torch.zeros(
172
+ *visual_embeds.size()[:-1], dtype=torch.long, device=visual_embeds.device
173
+ )
174
+ visual_position_embeddings = self.visual_position_embeddings(visual_position_ids)
175
+
176
+ visual_embeddings = visual_embeds + visual_position_embeddings + visual_token_type_embeddings
177
+
178
+ embeddings = torch.cat((embeddings, visual_embeddings), dim=1)
179
+
180
+ embeddings = self.LayerNorm(embeddings)
181
+ embeddings = self.dropout(embeddings)
182
+ return embeddings
183
+
184
+
185
+ class VisualBertSelfAttention(nn.Module):
186
+ def __init__(self, config):
187
+ super().__init__()
188
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
189
+ raise ValueError(
190
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
191
+ f"heads ({config.num_attention_heads})"
192
+ )
193
+
194
+ self.num_attention_heads = config.num_attention_heads
195
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
196
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
197
+
198
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
199
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
200
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
201
+
202
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
203
+
204
+ def transpose_for_scores(self, x):
205
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
206
+ x = x.view(*new_x_shape)
207
+ return x.permute(0, 2, 1, 3)
208
+
209
+ def forward(
210
+ self,
211
+ hidden_states,
212
+ attention_mask=None,
213
+ head_mask=None,
214
+ output_attentions=False,
215
+ ):
216
+ mixed_query_layer = self.query(hidden_states)
217
+
218
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
219
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
220
+
221
+ query_layer = self.transpose_for_scores(mixed_query_layer)
222
+
223
+ # Take the dot product between "query" and "key" to get the raw attention scores.
224
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
225
+
226
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
227
+ if attention_mask is not None:
228
+ # Apply the attention mask is (precomputed for all layers in VisualBertSelfAttentionModel forward() function)
229
+ attention_scores = attention_scores + attention_mask
230
+
231
+ # Normalize the attention scores to probabilities.
232
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
233
+
234
+ # This is actually dropping out entire tokens to attend to, which might
235
+ # seem a bit unusual, but is taken from the original Transformer paper.
236
+ attention_probs = self.dropout(attention_probs)
237
+
238
+ # Mask heads if we want to
239
+ if head_mask is not None:
240
+ attention_probs = attention_probs * head_mask
241
+
242
+ context_layer = torch.matmul(attention_probs, value_layer)
243
+
244
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
245
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
246
+ context_layer = context_layer.view(*new_context_layer_shape)
247
+
248
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
249
+
250
+ return outputs
251
+
252
+
253
+ # Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->VisualBert
254
+ class VisualBertSelfOutput(nn.Module):
255
+ def __init__(self, config):
256
+ super().__init__()
257
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
258
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
259
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
260
+
261
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
262
+ hidden_states = self.dense(hidden_states)
263
+ hidden_states = self.dropout(hidden_states)
264
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
265
+ return hidden_states
266
+
267
+
268
+ class VisualBertAttention(nn.Module):
269
+ def __init__(self, config):
270
+ super().__init__()
271
+ self.self = VisualBertSelfAttention(config)
272
+ self.output = VisualBertSelfOutput(config)
273
+ self.pruned_heads = set()
274
+
275
+ def prune_heads(self, heads):
276
+ if len(heads) == 0:
277
+ return
278
+ heads, index = find_pruneable_heads_and_indices(
279
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
280
+ )
281
+
282
+ # Prune linear layers
283
+ self.self.query = prune_linear_layer(self.self.query, index)
284
+ self.self.key = prune_linear_layer(self.self.key, index)
285
+ self.self.value = prune_linear_layer(self.self.value, index)
286
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
287
+
288
+ # Update hyper params and store pruned heads
289
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
290
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
291
+ self.pruned_heads = self.pruned_heads.union(heads)
292
+
293
+ def forward(
294
+ self,
295
+ hidden_states,
296
+ attention_mask=None,
297
+ head_mask=None,
298
+ output_attentions=False,
299
+ ):
300
+ self_outputs = self.self(
301
+ hidden_states,
302
+ attention_mask,
303
+ head_mask,
304
+ output_attentions,
305
+ )
306
+ attention_output = self.output(self_outputs[0], hidden_states)
307
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
308
+ return outputs
309
+
310
+
311
+ # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->VisualBert
312
+ class VisualBertIntermediate(nn.Module):
313
+ def __init__(self, config):
314
+ super().__init__()
315
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
316
+ if isinstance(config.hidden_act, str):
317
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
318
+ else:
319
+ self.intermediate_act_fn = config.hidden_act
320
+
321
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
322
+ hidden_states = self.dense(hidden_states)
323
+ hidden_states = self.intermediate_act_fn(hidden_states)
324
+ return hidden_states
325
+
326
+
327
+ # Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->VisualBert
328
+ class VisualBertOutput(nn.Module):
329
+ def __init__(self, config):
330
+ super().__init__()
331
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
332
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
333
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
334
+
335
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
336
+ hidden_states = self.dense(hidden_states)
337
+ hidden_states = self.dropout(hidden_states)
338
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
339
+ return hidden_states
340
+
341
+
342
+ class VisualBertLayer(nn.Module):
343
+ def __init__(self, config):
344
+ super().__init__()
345
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
346
+ self.seq_len_dim = 1
347
+ self.attention = VisualBertAttention(config)
348
+ self.intermediate = VisualBertIntermediate(config)
349
+ self.output = VisualBertOutput(config)
350
+
351
+ def forward(
352
+ self,
353
+ hidden_states,
354
+ attention_mask=None,
355
+ head_mask=None,
356
+ output_attentions=False,
357
+ ):
358
+ self_attention_outputs = self.attention(
359
+ hidden_states,
360
+ attention_mask,
361
+ head_mask,
362
+ output_attentions=output_attentions,
363
+ )
364
+ attention_output = self_attention_outputs[0]
365
+
366
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
367
+
368
+ layer_output = apply_chunking_to_forward(
369
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
370
+ )
371
+ outputs = (layer_output,) + outputs
372
+
373
+ return outputs
374
+
375
+ def feed_forward_chunk(self, attention_output):
376
+ intermediate_output = self.intermediate(attention_output)
377
+ layer_output = self.output(intermediate_output, attention_output)
378
+ return layer_output
379
+
380
+
381
+ class VisualBertEncoder(nn.Module):
382
+ def __init__(self, config):
383
+ super().__init__()
384
+ self.config = config
385
+ self.layer = nn.ModuleList([VisualBertLayer(config) for _ in range(config.num_hidden_layers)])
386
+ self.gradient_checkpointing = False
387
+
388
+ def forward(
389
+ self,
390
+ hidden_states,
391
+ attention_mask=None,
392
+ head_mask=None,
393
+ output_attentions=False,
394
+ output_hidden_states=False,
395
+ return_dict=True,
396
+ ):
397
+ all_hidden_states = () if output_hidden_states else None
398
+ all_self_attentions = () if output_attentions else None
399
+
400
+ for i, layer_module in enumerate(self.layer):
401
+ if output_hidden_states:
402
+ all_hidden_states = all_hidden_states + (hidden_states,)
403
+
404
+ layer_head_mask = head_mask[i] if head_mask is not None else None
405
+
406
+ if self.gradient_checkpointing and self.training:
407
+ layer_outputs = self._gradient_checkpointing_func(
408
+ layer_module.__call__,
409
+ hidden_states,
410
+ attention_mask,
411
+ layer_head_mask,
412
+ output_attentions,
413
+ )
414
+ else:
415
+ layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions)
416
+
417
+ hidden_states = layer_outputs[0]
418
+ if output_attentions:
419
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
420
+
421
+ if output_hidden_states:
422
+ all_hidden_states = all_hidden_states + (hidden_states,)
423
+
424
+ if not return_dict:
425
+ return tuple(
426
+ v
427
+ for v in [
428
+ hidden_states,
429
+ all_hidden_states,
430
+ all_self_attentions,
431
+ ]
432
+ if v is not None
433
+ )
434
+ return BaseModelOutput(
435
+ last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions
436
+ )
437
+
438
+
439
+ # Copied from transformers.models.bert.modeling_bert.BertPooler with Bert->VisualBert
440
+ class VisualBertPooler(nn.Module):
441
+ def __init__(self, config):
442
+ super().__init__()
443
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
444
+ self.activation = nn.Tanh()
445
+
446
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
447
+ # We "pool" the model by simply taking the hidden state corresponding
448
+ # to the first token.
449
+ first_token_tensor = hidden_states[:, 0]
450
+ pooled_output = self.dense(first_token_tensor)
451
+ pooled_output = self.activation(pooled_output)
452
+ return pooled_output
453
+
454
+
455
+ # Copied from transformers.models.bert.modeling_bert.BertPredictionHeadTransform with Bert->VisualBert
456
+ class VisualBertPredictionHeadTransform(nn.Module):
457
+ def __init__(self, config):
458
+ super().__init__()
459
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
460
+ if isinstance(config.hidden_act, str):
461
+ self.transform_act_fn = ACT2FN[config.hidden_act]
462
+ else:
463
+ self.transform_act_fn = config.hidden_act
464
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
465
+
466
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
467
+ hidden_states = self.dense(hidden_states)
468
+ hidden_states = self.transform_act_fn(hidden_states)
469
+ hidden_states = self.LayerNorm(hidden_states)
470
+ return hidden_states
471
+
472
+
473
+ # Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead with Bert->VisualBert
474
+ class VisualBertLMPredictionHead(nn.Module):
475
+ def __init__(self, config):
476
+ super().__init__()
477
+ self.transform = VisualBertPredictionHeadTransform(config)
478
+
479
+ # The output weights are the same as the input embeddings, but there is
480
+ # an output-only bias for each token.
481
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
482
+
483
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
484
+
485
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
486
+ self.decoder.bias = self.bias
487
+
488
+ def _tie_weights(self):
489
+ self.decoder.bias = self.bias
490
+
491
+ def forward(self, hidden_states):
492
+ hidden_states = self.transform(hidden_states)
493
+ hidden_states = self.decoder(hidden_states)
494
+ return hidden_states
495
+
496
+
497
+ # Copied from transformers.models.bert.modeling_bert.BertPreTrainingHeads with Bert->VisualBert
498
+ class VisualBertPreTrainingHeads(nn.Module):
499
+ def __init__(self, config):
500
+ super().__init__()
501
+ self.predictions = VisualBertLMPredictionHead(config)
502
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
503
+
504
+ def forward(self, sequence_output, pooled_output):
505
+ prediction_scores = self.predictions(sequence_output)
506
+ seq_relationship_score = self.seq_relationship(pooled_output)
507
+ return prediction_scores, seq_relationship_score
508
+
509
+
510
+ class VisualBertPreTrainedModel(PreTrainedModel):
511
+ """
512
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
513
+ models.
514
+ """
515
+
516
+ config_class = VisualBertConfig
517
+ base_model_prefix = "visual_bert"
518
+ supports_gradient_checkpointing = True
519
+
520
+ def _init_weights(self, module):
521
+ """Initialize the weights"""
522
+ if isinstance(module, (nn.Linear, nn.Embedding)):
523
+ # Slightly different from the TF version which uses truncated_normal for initialization
524
+ # cf https://github.com/pytorch/pytorch/pull/5617
525
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
526
+ if hasattr(module, "bias") and module.bias is not None:
527
+ module.bias.data.zero_()
528
+ elif isinstance(module, nn.LayerNorm):
529
+ module.bias.data.zero_()
530
+ module.weight.data.fill_(1.0)
531
+ elif isinstance(module, VisualBertLMPredictionHead):
532
+ module.bias.data.zero_()
533
+
534
+
535
+ @dataclass
536
+ class VisualBertForPreTrainingOutput(ModelOutput):
537
+ """
538
+ Output type of [`VisualBertForPreTraining`].
539
+
540
+ Args:
541
+ loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
542
+ Total loss as the sum of the masked language modeling loss and the sentence-image prediction
543
+ (classification) loss.
544
+ prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
545
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
546
+ seq_relationship_logits (`torch.FloatTensor` of shape `(batch_size, 2)`):
547
+ Prediction scores of the sentence-image prediction (classification) head (scores of True/False continuation
548
+ before SoftMax).
549
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
550
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
551
+ shape `(batch_size, sequence_length, hidden_size)`.
552
+
553
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
554
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
555
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
556
+ sequence_length)`.
557
+
558
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
559
+ heads.
560
+ """
561
+
562
+ loss: Optional[torch.FloatTensor] = None
563
+ prediction_logits: Optional[torch.FloatTensor] = None
564
+ seq_relationship_logits: Optional[torch.FloatTensor] = None
565
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
566
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
567
+
568
+
569
+ VISUAL_BERT_START_DOCSTRING = r"""
570
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
571
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
572
+ etc.)
573
+
574
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
575
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
576
+ and behavior.
577
+
578
+ Parameters:
579
+ config ([`VisualBertConfig`]): Model configuration class with all the parameters of the model.
580
+ Initializing with a config file does not load the weights associated with the model, only the
581
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
582
+ """
583
+
584
+ VISUAL_BERT_INPUTS_DOCSTRING = r"""
585
+ Args:
586
+ input_ids (`torch.LongTensor` of shape `({0})`):
587
+ Indices of input sequence tokens in the vocabulary.
588
+
589
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
590
+ [`PreTrainedTokenizer.__call__`] for details.
591
+
592
+ [What are input IDs?](../glossary#input-ids)
593
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
594
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
595
+
596
+ - 1 for tokens that are **not masked**,
597
+ - 0 for tokens that are **masked**.
598
+
599
+ [What are attention masks?](../glossary#attention-mask)
600
+ token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
601
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
602
+ 1]`:
603
+
604
+ - 0 corresponds to a *sentence A* token,
605
+ - 1 corresponds to a *sentence B* token.
606
+
607
+ [What are token type IDs?](../glossary#token-type-ids)
608
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
609
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
610
+ config.max_position_embeddings - 1]`.
611
+
612
+ [What are position IDs?](../glossary#position-ids)
613
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
614
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
615
+
616
+ - 1 indicates the head is **not masked**,
617
+ - 0 indicates the head is **masked**.
618
+
619
+ inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
620
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
621
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
622
+ model's internal embedding lookup matrix.
623
+
624
+ visual_embeds (`torch.FloatTensor` of shape `(batch_size, visual_seq_length, visual_embedding_dim)`, *optional*):
625
+ The embedded representation of the visual inputs, generally derived using using an object detector.
626
+
627
+ visual_attention_mask (`torch.FloatTensor` of shape `(batch_size, visual_seq_length)`, *optional*):
628
+ Mask to avoid performing attention on visual embeddings. Mask values selected in `[0, 1]`:
629
+
630
+ - 1 for tokens that are **not masked**,
631
+ - 0 for tokens that are **masked**.
632
+
633
+ [What are attention masks?](../glossary#attention-mask)
634
+ visual_token_type_ids (`torch.LongTensor` of shape `(batch_size, visual_seq_length)`, *optional*):
635
+ Segment token indices to indicate different portions of the visual embeds.
636
+
637
+ [What are token type IDs?](../glossary#token-type-ids) The authors of VisualBERT set the
638
+ *visual_token_type_ids* to *1* for all tokens.
639
+
640
+ image_text_alignment (`torch.LongTensor` of shape `(batch_size, visual_seq_length, alignment_number)`, *optional*):
641
+ Image-Text alignment uses to decide the position IDs of the visual embeddings.
642
+
643
+ output_attentions (`bool`, *optional*):
644
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
645
+ tensors for more detail.
646
+ output_hidden_states (`bool`, *optional*):
647
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
648
+ more detail.
649
+ return_dict (`bool`, *optional*):
650
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
651
+ """
652
+
653
+
654
+ @add_start_docstrings(
655
+ "The bare VisualBert Model transformer outputting raw hidden-states without any specific head on top.",
656
+ VISUAL_BERT_START_DOCSTRING,
657
+ )
658
+ class VisualBertModel(VisualBertPreTrainedModel):
659
+ """
660
+
661
+ The model can behave as an encoder (with only self-attention) following the architecture described in [Attention is
662
+ all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
663
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
664
+ """
665
+
666
+ def __init__(self, config, add_pooling_layer=True):
667
+ super().__init__(config)
668
+ self.config = config
669
+
670
+ self.embeddings = VisualBertEmbeddings(config)
671
+ self.encoder = VisualBertEncoder(config)
672
+
673
+ self.pooler = VisualBertPooler(config) if add_pooling_layer else None
674
+
675
+ self.bypass_transformer = config.bypass_transformer
676
+
677
+ if self.bypass_transformer:
678
+ self.additional_layer = VisualBertLayer(config)
679
+
680
+ # Initialize weights and apply final processing
681
+ self.post_init()
682
+
683
+ def get_input_embeddings(self):
684
+ return self.embeddings.word_embeddings
685
+
686
+ def set_input_embeddings(self, value):
687
+ self.embeddings.word_embeddings = value
688
+
689
+ def _prune_heads(self, heads_to_prune):
690
+ """
691
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
692
+ class PreTrainedModel
693
+ """
694
+ for layer, heads in heads_to_prune.items():
695
+ self.encoder.layer[layer].attention.prune_heads(heads)
696
+
697
+ @add_start_docstrings_to_model_forward(VISUAL_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
698
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
699
+ def forward(
700
+ self,
701
+ input_ids: Optional[torch.LongTensor] = None,
702
+ attention_mask: Optional[torch.LongTensor] = None,
703
+ token_type_ids: Optional[torch.LongTensor] = None,
704
+ position_ids: Optional[torch.LongTensor] = None,
705
+ head_mask: Optional[torch.LongTensor] = None,
706
+ inputs_embeds: Optional[torch.FloatTensor] = None,
707
+ visual_embeds: Optional[torch.FloatTensor] = None,
708
+ visual_attention_mask: Optional[torch.LongTensor] = None,
709
+ visual_token_type_ids: Optional[torch.LongTensor] = None,
710
+ image_text_alignment: Optional[torch.LongTensor] = None,
711
+ output_attentions: Optional[bool] = None,
712
+ output_hidden_states: Optional[bool] = None,
713
+ return_dict: Optional[bool] = None,
714
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPooling]:
715
+ r"""
716
+
717
+ Returns:
718
+
719
+ Example:
720
+
721
+ ```python
722
+ # Assumption: *get_visual_embeddings(image)* gets the visual embeddings of the image.
723
+ from transformers import AutoTokenizer, VisualBertModel
724
+ import torch
725
+
726
+ tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
727
+ model = VisualBertModel.from_pretrained("uclanlp/visualbert-vqa-coco-pre")
728
+
729
+ inputs = tokenizer("The capital of France is Paris.", return_tensors="pt")
730
+ visual_embeds = get_visual_embeddings(image).unsqueeze(0)
731
+ visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long)
732
+ visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)
733
+
734
+ inputs.update(
735
+ {
736
+ "visual_embeds": visual_embeds,
737
+ "visual_token_type_ids": visual_token_type_ids,
738
+ "visual_attention_mask": visual_attention_mask,
739
+ }
740
+ )
741
+
742
+ outputs = model(**inputs)
743
+
744
+ last_hidden_states = outputs.last_hidden_state
745
+ ```"""
746
+
747
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
748
+ output_hidden_states = (
749
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
750
+ )
751
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
752
+
753
+ if input_ids is not None and inputs_embeds is not None:
754
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
755
+ elif input_ids is not None:
756
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
757
+ input_shape = input_ids.size()
758
+ elif inputs_embeds is not None:
759
+ input_shape = inputs_embeds.size()[:-1]
760
+ else:
761
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
762
+
763
+ batch_size, seq_length = input_shape
764
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
765
+
766
+ if visual_embeds is not None:
767
+ visual_input_shape = visual_embeds.size()[:-1]
768
+
769
+ if attention_mask is None:
770
+ attention_mask = torch.ones(input_shape, device=device)
771
+
772
+ if visual_embeds is not None and visual_attention_mask is None:
773
+ visual_attention_mask = torch.ones(visual_input_shape, device=device)
774
+
775
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
776
+ # ourselves in which case we just need to make it broadcastable to all heads.
777
+ if visual_embeds is not None:
778
+ combined_attention_mask = torch.cat((attention_mask, visual_attention_mask), dim=-1)
779
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
780
+ combined_attention_mask, (batch_size, input_shape + visual_input_shape)
781
+ )
782
+
783
+ else:
784
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
785
+ attention_mask, (batch_size, input_shape)
786
+ )
787
+
788
+ # Prepare head mask if needed
789
+ # 1.0 in head_mask indicate we keep the head
790
+ # attention_probs has shape bsz x n_heads x N x N
791
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
792
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
793
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
794
+
795
+ embedding_output = self.embeddings(
796
+ input_ids=input_ids,
797
+ position_ids=position_ids,
798
+ token_type_ids=token_type_ids,
799
+ inputs_embeds=inputs_embeds,
800
+ visual_embeds=visual_embeds,
801
+ visual_token_type_ids=visual_token_type_ids,
802
+ image_text_alignment=image_text_alignment,
803
+ )
804
+
805
+ if self.bypass_transformer and visual_embeds is not None:
806
+ text_length = input_ids.size(1)
807
+ text_embedding_output = embedding_output[:, :text_length, :]
808
+ visual_embedding_output = embedding_output[:, text_length:, :]
809
+
810
+ text_extended_attention_mask = extended_attention_mask[:, :, text_length, :text_length]
811
+
812
+ encoded_outputs = self.encoder(
813
+ text_embedding_output,
814
+ attention_mask=text_extended_attention_mask,
815
+ output_attentions=output_attentions,
816
+ output_hidden_states=output_hidden_states,
817
+ return_dict=return_dict,
818
+ )
819
+ sequence_output = encoded_outputs[0]
820
+ concatenated_input = torch.cat((sequence_output, visual_embedding_output), dim=1)
821
+ sequence_output = self.additional_layer(concatenated_input, extended_attention_mask)
822
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
823
+
824
+ else:
825
+ encoder_outputs = self.encoder(
826
+ embedding_output,
827
+ attention_mask=extended_attention_mask,
828
+ head_mask=head_mask,
829
+ output_attentions=output_attentions,
830
+ output_hidden_states=output_hidden_states,
831
+ return_dict=return_dict,
832
+ )
833
+ sequence_output = encoder_outputs[0]
834
+
835
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
836
+
837
+ if not return_dict:
838
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
839
+
840
+ return BaseModelOutputWithPooling(
841
+ last_hidden_state=sequence_output,
842
+ pooler_output=pooled_output,
843
+ hidden_states=encoder_outputs.hidden_states,
844
+ attentions=encoder_outputs.attentions,
845
+ )
846
+
847
+
848
+ @add_start_docstrings(
849
+ """
850
+ VisualBert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a
851
+ `sentence-image prediction (classification)` head.
852
+ """,
853
+ VISUAL_BERT_START_DOCSTRING,
854
+ )
855
+ class VisualBertForPreTraining(VisualBertPreTrainedModel):
856
+ _tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
857
+
858
+ def __init__(self, config):
859
+ super().__init__(config)
860
+
861
+ self.visual_bert = VisualBertModel(config)
862
+ self.cls = VisualBertPreTrainingHeads(config)
863
+
864
+ # Initialize weights and apply final processing
865
+ self.post_init()
866
+
867
+ def get_output_embeddings(self):
868
+ return self.cls.predictions.decoder
869
+
870
+ def set_output_embeddings(self, new_embeddings):
871
+ self.cls.predictions.decoder = new_embeddings
872
+ self.cls.predictions.bias = new_embeddings.bias
873
+
874
+ @add_start_docstrings_to_model_forward(VISUAL_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
875
+ @replace_return_docstrings(output_type=VisualBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
876
+ def forward(
877
+ self,
878
+ input_ids: Optional[torch.LongTensor] = None,
879
+ attention_mask: Optional[torch.LongTensor] = None,
880
+ token_type_ids: Optional[torch.LongTensor] = None,
881
+ position_ids: Optional[torch.LongTensor] = None,
882
+ head_mask: Optional[torch.LongTensor] = None,
883
+ inputs_embeds: Optional[torch.FloatTensor] = None,
884
+ visual_embeds: Optional[torch.FloatTensor] = None,
885
+ visual_attention_mask: Optional[torch.LongTensor] = None,
886
+ visual_token_type_ids: Optional[torch.LongTensor] = None,
887
+ image_text_alignment: Optional[torch.LongTensor] = None,
888
+ output_attentions: Optional[bool] = None,
889
+ output_hidden_states: Optional[bool] = None,
890
+ return_dict: Optional[bool] = None,
891
+ labels: Optional[torch.LongTensor] = None,
892
+ sentence_image_labels: Optional[torch.LongTensor] = None,
893
+ ) -> Union[Tuple[torch.Tensor], VisualBertForPreTrainingOutput]:
894
+ r"""
895
+ labels (`torch.LongTensor` of shape `(batch_size, total_sequence_length)`, *optional*):
896
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
897
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
898
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
899
+ sentence_image_labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
900
+ Labels for computing the sentence-image prediction (classification) loss. Input should be a sequence pair
901
+ (see `input_ids` docstring) Indices should be in `[0, 1]`:
902
+
903
+ - 0 indicates sequence B is a matching pair of sequence A for the given image,
904
+ - 1 indicates sequence B is a random sequence w.r.t A for the given image.
905
+
906
+ Returns:
907
+
908
+ Example:
909
+
910
+ ```python
911
+ # Assumption: *get_visual_embeddings(image)* gets the visual embeddings of the image in the batch.
912
+ from transformers import AutoTokenizer, VisualBertForPreTraining
913
+
914
+ tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
915
+ model = VisualBertForPreTraining.from_pretrained("uclanlp/visualbert-vqa-coco-pre")
916
+
917
+ inputs = tokenizer("The capital of France is [MASK].", return_tensors="pt")
918
+ visual_embeds = get_visual_embeddings(image).unsqueeze(0)
919
+ visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long)
920
+ visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)
921
+
922
+ inputs.update(
923
+ {
924
+ "visual_embeds": visual_embeds,
925
+ "visual_token_type_ids": visual_token_type_ids,
926
+ "visual_attention_mask": visual_attention_mask,
927
+ }
928
+ )
929
+ max_length = inputs["input_ids"].shape[-1] + visual_embeds.shape[-2]
930
+ labels = tokenizer(
931
+ "The capital of France is Paris.", return_tensors="pt", padding="max_length", max_length=max_length
932
+ )["input_ids"]
933
+ sentence_image_labels = torch.tensor(1).unsqueeze(0) # Batch_size
934
+
935
+
936
+ outputs = model(**inputs, labels=labels, sentence_image_labels=sentence_image_labels)
937
+ loss = outputs.loss
938
+ prediction_logits = outputs.prediction_logits
939
+ seq_relationship_logits = outputs.seq_relationship_logits
940
+ ```"""
941
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
942
+
943
+ if labels is not None:
944
+ total_size = attention_mask.size(-1) + visual_attention_mask.size(-1)
945
+ if labels.size(-1) != total_size:
946
+ raise ValueError(
947
+ "The labels provided should have same sequence length as total attention mask. "
948
+ f"Found labels with sequence length {labels.size(-1)}, expected {total_size}."
949
+ )
950
+
951
+ outputs = self.visual_bert(
952
+ input_ids,
953
+ attention_mask=attention_mask,
954
+ token_type_ids=token_type_ids,
955
+ position_ids=position_ids,
956
+ head_mask=head_mask,
957
+ inputs_embeds=inputs_embeds,
958
+ visual_embeds=visual_embeds,
959
+ visual_attention_mask=visual_attention_mask,
960
+ visual_token_type_ids=visual_token_type_ids,
961
+ image_text_alignment=image_text_alignment,
962
+ output_attentions=output_attentions,
963
+ output_hidden_states=output_hidden_states,
964
+ return_dict=return_dict,
965
+ )
966
+
967
+ sequence_output, pooled_output = outputs[:2]
968
+ prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
969
+
970
+ total_loss = None
971
+ if labels is not None and sentence_image_labels is not None:
972
+ loss_fct = CrossEntropyLoss()
973
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
974
+ sentence_image_loss = loss_fct(seq_relationship_score.view(-1, 2), sentence_image_labels.view(-1))
975
+ total_loss = masked_lm_loss + sentence_image_loss
976
+
977
+ elif labels is not None:
978
+ loss_fct = CrossEntropyLoss()
979
+ total_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
980
+
981
+ if not return_dict:
982
+ output = (prediction_scores, seq_relationship_score) + outputs[2:]
983
+ return ((total_loss,) + output) if total_loss is not None else output
984
+
985
+ return VisualBertForPreTrainingOutput(
986
+ loss=total_loss,
987
+ prediction_logits=prediction_scores,
988
+ seq_relationship_logits=seq_relationship_score,
989
+ hidden_states=outputs.hidden_states,
990
+ attentions=outputs.attentions,
991
+ )
992
+
993
+
994
+ @add_start_docstrings(
995
+ """
996
+ VisualBert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and
997
+ a softmax) e.g. for VCR tasks.
998
+ """,
999
+ VISUAL_BERT_START_DOCSTRING,
1000
+ )
1001
+ class VisualBertForMultipleChoice(VisualBertPreTrainedModel):
1002
+ def __init__(self, config):
1003
+ super().__init__(config)
1004
+
1005
+ self.visual_bert = VisualBertModel(config)
1006
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1007
+ self.cls = nn.Linear(config.hidden_size, 1)
1008
+
1009
+ # Initialize weights and apply final processing
1010
+ self.post_init()
1011
+
1012
+ @add_start_docstrings_to_model_forward(
1013
+ VISUAL_BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
1014
+ )
1015
+ @replace_return_docstrings(output_type=MultipleChoiceModelOutput, config_class=_CONFIG_FOR_DOC)
1016
+ def forward(
1017
+ self,
1018
+ input_ids: Optional[torch.LongTensor] = None,
1019
+ attention_mask: Optional[torch.LongTensor] = None,
1020
+ token_type_ids: Optional[torch.LongTensor] = None,
1021
+ position_ids: Optional[torch.LongTensor] = None,
1022
+ head_mask: Optional[torch.LongTensor] = None,
1023
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1024
+ visual_embeds: Optional[torch.FloatTensor] = None,
1025
+ visual_attention_mask: Optional[torch.LongTensor] = None,
1026
+ visual_token_type_ids: Optional[torch.LongTensor] = None,
1027
+ image_text_alignment: Optional[torch.LongTensor] = None,
1028
+ output_attentions: Optional[bool] = None,
1029
+ output_hidden_states: Optional[bool] = None,
1030
+ return_dict: Optional[bool] = None,
1031
+ labels: Optional[torch.LongTensor] = None,
1032
+ ) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]:
1033
+ r"""
1034
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1035
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
1036
+ num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
1037
+ `input_ids` above)
1038
+
1039
+ Returns:
1040
+
1041
+ Example:
1042
+
1043
+ ```python
1044
+ # Assumption: *get_visual_embeddings(image)* gets the visual embeddings of the image in the batch.
1045
+ from transformers import AutoTokenizer, VisualBertForMultipleChoice
1046
+ import torch
1047
+
1048
+ tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
1049
+ model = VisualBertForMultipleChoice.from_pretrained("uclanlp/visualbert-vcr")
1050
+
1051
+ prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
1052
+ choice0 = "It is eaten with a fork and a knife."
1053
+ choice1 = "It is eaten while held in the hand."
1054
+
1055
+ visual_embeds = get_visual_embeddings(image)
1056
+ # (batch_size, num_choices, visual_seq_length, visual_embedding_dim)
1057
+ visual_embeds = visual_embeds.expand(1, 2, *visual_embeds.shape)
1058
+ visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long)
1059
+ visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)
1060
+
1061
+ labels = torch.tensor(0).unsqueeze(0) # choice0 is correct (according to Wikipedia ;)), batch size 1
1062
+
1063
+ encoding = tokenizer([[prompt, prompt], [choice0, choice1]], return_tensors="pt", padding=True)
1064
+ # batch size is 1
1065
+ inputs_dict = {k: v.unsqueeze(0) for k, v in encoding.items()}
1066
+ inputs_dict.update(
1067
+ {
1068
+ "visual_embeds": visual_embeds,
1069
+ "visual_attention_mask": visual_attention_mask,
1070
+ "visual_token_type_ids": visual_token_type_ids,
1071
+ "labels": labels,
1072
+ }
1073
+ )
1074
+ outputs = model(**inputs_dict)
1075
+
1076
+ loss = outputs.loss
1077
+ logits = outputs.logits
1078
+ ```"""
1079
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1080
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
1081
+
1082
+ input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
1083
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
1084
+ token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
1085
+ position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
1086
+ inputs_embeds = (
1087
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
1088
+ if inputs_embeds is not None
1089
+ else None
1090
+ )
1091
+
1092
+ visual_embeds = (
1093
+ visual_embeds.view(-1, visual_embeds.size(-2), visual_embeds.size(-1))
1094
+ if visual_embeds is not None
1095
+ else None
1096
+ )
1097
+ visual_attention_mask = (
1098
+ visual_attention_mask.view(-1, visual_attention_mask.size(-1))
1099
+ if visual_attention_mask is not None
1100
+ else None
1101
+ )
1102
+ visual_token_type_ids = (
1103
+ visual_token_type_ids.view(-1, visual_token_type_ids.size(-1))
1104
+ if visual_token_type_ids is not None
1105
+ else None
1106
+ )
1107
+
1108
+ outputs = self.visual_bert(
1109
+ input_ids,
1110
+ attention_mask=attention_mask,
1111
+ token_type_ids=token_type_ids,
1112
+ position_ids=position_ids,
1113
+ head_mask=head_mask,
1114
+ inputs_embeds=inputs_embeds,
1115
+ visual_embeds=visual_embeds,
1116
+ visual_attention_mask=visual_attention_mask,
1117
+ visual_token_type_ids=visual_token_type_ids,
1118
+ image_text_alignment=image_text_alignment,
1119
+ output_attentions=output_attentions,
1120
+ output_hidden_states=output_hidden_states,
1121
+ return_dict=return_dict,
1122
+ )
1123
+
1124
+ _, pooled_output = outputs[0], outputs[1]
1125
+
1126
+ pooled_output = self.dropout(pooled_output)
1127
+ logits = self.cls(pooled_output)
1128
+ reshaped_logits = logits.view(-1, num_choices)
1129
+
1130
+ loss = None
1131
+ if labels is not None:
1132
+ loss_fct = CrossEntropyLoss()
1133
+ loss = loss_fct(reshaped_logits, labels)
1134
+
1135
+ if not return_dict:
1136
+ output = (reshaped_logits,) + outputs[2:]
1137
+ return ((loss,) + output) if loss is not None else output
1138
+
1139
+ return MultipleChoiceModelOutput(
1140
+ loss=loss,
1141
+ logits=reshaped_logits,
1142
+ hidden_states=outputs.hidden_states,
1143
+ attentions=outputs.attentions,
1144
+ )
1145
+
1146
+
1147
+ @add_start_docstrings(
1148
+ """
1149
+ VisualBert Model with a classification/regression head on top (a dropout and a linear layer on top of the pooled
1150
+ output) for VQA.
1151
+ """,
1152
+ VISUAL_BERT_START_DOCSTRING,
1153
+ )
1154
+ class VisualBertForQuestionAnswering(VisualBertPreTrainedModel):
1155
+ def __init__(self, config):
1156
+ super().__init__(config)
1157
+ self.num_labels = config.num_labels
1158
+
1159
+ self.visual_bert = VisualBertModel(config)
1160
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1161
+ self.cls = nn.Linear(config.hidden_size, config.num_labels)
1162
+
1163
+ # Initialize weights and apply final processing
1164
+ self.post_init()
1165
+
1166
+ @add_start_docstrings_to_model_forward(VISUAL_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1167
+ @replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
1168
+ def forward(
1169
+ self,
1170
+ input_ids: Optional[torch.LongTensor] = None,
1171
+ attention_mask: Optional[torch.LongTensor] = None,
1172
+ token_type_ids: Optional[torch.LongTensor] = None,
1173
+ position_ids: Optional[torch.LongTensor] = None,
1174
+ head_mask: Optional[torch.LongTensor] = None,
1175
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1176
+ visual_embeds: Optional[torch.FloatTensor] = None,
1177
+ visual_attention_mask: Optional[torch.LongTensor] = None,
1178
+ visual_token_type_ids: Optional[torch.LongTensor] = None,
1179
+ image_text_alignment: Optional[torch.LongTensor] = None,
1180
+ output_attentions: Optional[bool] = None,
1181
+ output_hidden_states: Optional[bool] = None,
1182
+ return_dict: Optional[bool] = None,
1183
+ labels: Optional[torch.LongTensor] = None,
1184
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
1185
+ r"""
1186
+ labels (`torch.LongTensor` of shape `(batch_size, total_sequence_length)`, *optional*):
1187
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1188
+ config.num_labels - 1]`. A KLDivLoss is computed between the labels and the returned logits.
1189
+
1190
+ Returns:
1191
+
1192
+ Example:
1193
+
1194
+ ```python
1195
+ # Assumption: *get_visual_embeddings(image)* gets the visual embeddings of the image in the batch.
1196
+ from transformers import AutoTokenizer, VisualBertForQuestionAnswering
1197
+ import torch
1198
+
1199
+ tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
1200
+ model = VisualBertForQuestionAnswering.from_pretrained("uclanlp/visualbert-vqa")
1201
+
1202
+ text = "Who is eating the apple?"
1203
+ inputs = tokenizer(text, return_tensors="pt")
1204
+ visual_embeds = get_visual_embeddings(image).unsqueeze(0)
1205
+ visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long)
1206
+ visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)
1207
+
1208
+ inputs.update(
1209
+ {
1210
+ "visual_embeds": visual_embeds,
1211
+ "visual_token_type_ids": visual_token_type_ids,
1212
+ "visual_attention_mask": visual_attention_mask,
1213
+ }
1214
+ )
1215
+
1216
+ labels = torch.tensor([[0.0, 1.0]]).unsqueeze(0) # Batch size 1, Num labels 2
1217
+
1218
+ outputs = model(**inputs, labels=labels)
1219
+ loss = outputs.loss
1220
+ scores = outputs.logits
1221
+ ```"""
1222
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1223
+
1224
+ # Get the index of the last text token
1225
+ index_to_gather = attention_mask.sum(1) - 2 # as in original code
1226
+
1227
+ outputs = self.visual_bert(
1228
+ input_ids,
1229
+ attention_mask=attention_mask,
1230
+ token_type_ids=token_type_ids,
1231
+ position_ids=position_ids,
1232
+ head_mask=head_mask,
1233
+ inputs_embeds=inputs_embeds,
1234
+ visual_embeds=visual_embeds,
1235
+ visual_attention_mask=visual_attention_mask,
1236
+ visual_token_type_ids=visual_token_type_ids,
1237
+ image_text_alignment=image_text_alignment,
1238
+ output_attentions=output_attentions,
1239
+ output_hidden_states=output_hidden_states,
1240
+ return_dict=return_dict,
1241
+ )
1242
+
1243
+ sequence_output = outputs[0]
1244
+
1245
+ # TO-CHECK: From the original code
1246
+ index_to_gather = (
1247
+ index_to_gather.unsqueeze(-1).unsqueeze(-1).expand(index_to_gather.size(0), 1, sequence_output.size(-1))
1248
+ )
1249
+ pooled_output = torch.gather(sequence_output, 1, index_to_gather)
1250
+
1251
+ pooled_output = self.dropout(pooled_output)
1252
+ logits = self.cls(pooled_output)
1253
+ reshaped_logits = logits.view(-1, self.num_labels)
1254
+
1255
+ loss = None
1256
+ if labels is not None:
1257
+ loss_fct = nn.KLDivLoss(reduction="batchmean")
1258
+ log_softmax = nn.LogSoftmax(dim=-1)
1259
+ reshaped_logits = log_softmax(reshaped_logits)
1260
+ loss = loss_fct(reshaped_logits, labels.contiguous())
1261
+ if not return_dict:
1262
+ output = (reshaped_logits,) + outputs[2:]
1263
+ return ((loss,) + output) if loss is not None else output
1264
+
1265
+ return SequenceClassifierOutput(
1266
+ loss=loss,
1267
+ logits=reshaped_logits,
1268
+ hidden_states=outputs.hidden_states,
1269
+ attentions=outputs.attentions,
1270
+ )
1271
+
1272
+
1273
+ @add_start_docstrings(
1274
+ """
1275
+ VisualBert Model with a sequence classification head on top (a dropout and a linear layer on top of the pooled
1276
+ output) for Visual Reasoning e.g. for NLVR task.
1277
+ """,
1278
+ VISUAL_BERT_START_DOCSTRING,
1279
+ )
1280
+ class VisualBertForVisualReasoning(VisualBertPreTrainedModel):
1281
+ def __init__(self, config):
1282
+ super().__init__(config)
1283
+ self.num_labels = config.num_labels
1284
+
1285
+ self.visual_bert = VisualBertModel(config)
1286
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1287
+ self.cls = nn.Linear(config.hidden_size, config.num_labels) # 2
1288
+
1289
+ # Initialize weights and apply final processing
1290
+ self.post_init()
1291
+
1292
+ @add_start_docstrings_to_model_forward(VISUAL_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1293
+ @replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
1294
+ def forward(
1295
+ self,
1296
+ input_ids: Optional[torch.LongTensor] = None,
1297
+ attention_mask: Optional[torch.LongTensor] = None,
1298
+ token_type_ids: Optional[torch.LongTensor] = None,
1299
+ position_ids: Optional[torch.LongTensor] = None,
1300
+ head_mask: Optional[torch.LongTensor] = None,
1301
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1302
+ visual_embeds: Optional[torch.FloatTensor] = None,
1303
+ visual_attention_mask: Optional[torch.LongTensor] = None,
1304
+ visual_token_type_ids: Optional[torch.LongTensor] = None,
1305
+ image_text_alignment: Optional[torch.LongTensor] = None,
1306
+ output_attentions: Optional[bool] = None,
1307
+ output_hidden_states: Optional[bool] = None,
1308
+ return_dict: Optional[bool] = None,
1309
+ labels: Optional[torch.LongTensor] = None,
1310
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
1311
+ r"""
1312
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1313
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1314
+ config.num_labels - 1]`. A classification loss is computed (Cross-Entropy) against these labels.
1315
+
1316
+ Returns:
1317
+
1318
+ Example:
1319
+
1320
+ ```python
1321
+ # Assumption: *get_visual_embeddings(image)* gets the visual embeddings of the image in the batch.
1322
+ from transformers import AutoTokenizer, VisualBertForVisualReasoning
1323
+ import torch
1324
+
1325
+ tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
1326
+ model = VisualBertForVisualReasoning.from_pretrained("uclanlp/visualbert-nlvr2")
1327
+
1328
+ text = "Who is eating the apple?"
1329
+ inputs = tokenizer(text, return_tensors="pt")
1330
+ visual_embeds = get_visual_embeddings(image).unsqueeze(0)
1331
+ visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long)
1332
+ visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)
1333
+
1334
+ inputs.update(
1335
+ {
1336
+ "visual_embeds": visual_embeds,
1337
+ "visual_token_type_ids": visual_token_type_ids,
1338
+ "visual_attention_mask": visual_attention_mask,
1339
+ }
1340
+ )
1341
+
1342
+ labels = torch.tensor(1).unsqueeze(0) # Batch size 1, Num choices 2
1343
+
1344
+ outputs = model(**inputs, labels=labels)
1345
+ loss = outputs.loss
1346
+ scores = outputs.logits
1347
+ ```"""
1348
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1349
+
1350
+ outputs = self.visual_bert(
1351
+ input_ids,
1352
+ attention_mask=attention_mask,
1353
+ token_type_ids=token_type_ids,
1354
+ position_ids=position_ids,
1355
+ head_mask=head_mask,
1356
+ inputs_embeds=inputs_embeds,
1357
+ visual_embeds=visual_embeds,
1358
+ visual_attention_mask=visual_attention_mask,
1359
+ visual_token_type_ids=visual_token_type_ids,
1360
+ image_text_alignment=image_text_alignment,
1361
+ output_attentions=output_attentions,
1362
+ output_hidden_states=output_hidden_states,
1363
+ return_dict=return_dict,
1364
+ )
1365
+
1366
+ # sequence_output = outputs[0]
1367
+ pooled_output = outputs[1]
1368
+ pooled_output = self.dropout(pooled_output)
1369
+ logits = self.cls(pooled_output)
1370
+ reshaped_logits = logits.contiguous()
1371
+
1372
+ loss = None
1373
+ if labels is not None:
1374
+ loss_fct = CrossEntropyLoss()
1375
+ loss = loss_fct(reshaped_logits, labels.view(-1))
1376
+
1377
+ if not return_dict:
1378
+ output = (logits,) + outputs[2:]
1379
+ return ((loss,) + output) if loss is not None else output
1380
+
1381
+ return SequenceClassifierOutput(
1382
+ loss=loss,
1383
+ logits=reshaped_logits,
1384
+ hidden_states=outputs.hidden_states,
1385
+ attentions=outputs.attentions,
1386
+ )
1387
+
1388
+
1389
+ class VisualBertRegionToPhraseAttention(nn.Module):
1390
+ def __init__(self, config):
1391
+ super().__init__()
1392
+ if config.hidden_size % config.num_attention_heads != 0:
1393
+ raise ValueError(
1394
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
1395
+ f"heads ({config.num_attention_heads})"
1396
+ )
1397
+ self.num_attention_heads = 1 # config.num_attention_heads
1398
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
1399
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
1400
+
1401
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
1402
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
1403
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
1404
+
1405
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
1406
+
1407
+ def transpose_for_scores(self, x):
1408
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
1409
+ x = x.view(*new_x_shape)
1410
+ return x.permute(0, 2, 1, 3)
1411
+
1412
+ def forward(self, query, key, attention_mask):
1413
+ attention_mask = attention_mask.to(query.dtype)
1414
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
1415
+ attention_mask = (1.0 - attention_mask) * torch.finfo(query.dtype).min
1416
+
1417
+ mixed_query_layer = self.query(query)
1418
+ mixed_key_layer = self.key(key)
1419
+
1420
+ query_layer = self.transpose_for_scores(mixed_query_layer)
1421
+ key_layer = self.transpose_for_scores(mixed_key_layer)
1422
+
1423
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
1424
+
1425
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
1426
+
1427
+ attention_scores = attention_scores + attention_mask
1428
+
1429
+ attention_scores = attention_scores.squeeze(1)
1430
+ return attention_scores
1431
+
1432
+
1433
+ @add_start_docstrings(
1434
+ """
1435
+ VisualBert Model with a Masked Language Modeling head and an attention layer on top for Region-to-Phrase Alignment
1436
+ e.g. for Flickr30 Entities task.
1437
+ """,
1438
+ VISUAL_BERT_START_DOCSTRING,
1439
+ )
1440
+ class VisualBertForRegionToPhraseAlignment(VisualBertPreTrainedModel):
1441
+ _tied_weights_keys = ["cls.predictions.decoder.bias"]
1442
+
1443
+ def __init__(self, config):
1444
+ super().__init__(config)
1445
+
1446
+ self.visual_bert = VisualBertModel(config)
1447
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1448
+ self.cls = VisualBertPreTrainingHeads(config)
1449
+ self.attention = VisualBertRegionToPhraseAttention(config)
1450
+
1451
+ # Initialize weights and apply final processing
1452
+ self.post_init()
1453
+
1454
+ @add_start_docstrings_to_model_forward(VISUAL_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1455
+ @replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
1456
+ def forward(
1457
+ self,
1458
+ input_ids: Optional[torch.LongTensor] = None,
1459
+ attention_mask: Optional[torch.LongTensor] = None,
1460
+ token_type_ids: Optional[torch.LongTensor] = None,
1461
+ position_ids: Optional[torch.LongTensor] = None,
1462
+ head_mask: Optional[torch.LongTensor] = None,
1463
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1464
+ visual_embeds: Optional[torch.FloatTensor] = None,
1465
+ visual_attention_mask: Optional[torch.LongTensor] = None,
1466
+ visual_token_type_ids: Optional[torch.LongTensor] = None,
1467
+ image_text_alignment: Optional[torch.LongTensor] = None,
1468
+ output_attentions: Optional[bool] = None,
1469
+ output_hidden_states: Optional[bool] = None,
1470
+ return_dict: Optional[bool] = None,
1471
+ region_to_phrase_position: Optional[torch.LongTensor] = None,
1472
+ labels: Optional[torch.LongTensor] = None,
1473
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
1474
+ r"""
1475
+ region_to_phrase_position (`torch.LongTensor` of shape `(batch_size, total_sequence_length)`, *optional*):
1476
+ The positions depicting the position of the image embedding corresponding to the textual tokens.
1477
+
1478
+ labels (`torch.LongTensor` of shape `(batch_size, total_sequence_length, visual_sequence_length)`, *optional*):
1479
+ Labels for computing the masked language modeling loss. KLDivLoss is computed against these labels and the
1480
+ outputs from the attention layer.
1481
+
1482
+ Returns:
1483
+
1484
+ Example:
1485
+
1486
+ ```python
1487
+ # Assumption: *get_visual_embeddings(image)* gets the visual embeddings of the image in the batch.
1488
+ from transformers import AutoTokenizer, VisualBertForRegionToPhraseAlignment
1489
+ import torch
1490
+
1491
+ tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
1492
+ model = VisualBertForRegionToPhraseAlignment.from_pretrained("uclanlp/visualbert-vqa-coco-pre")
1493
+
1494
+ text = "Who is eating the apple?"
1495
+ inputs = tokenizer(text, return_tensors="pt")
1496
+ visual_embeds = get_visual_embeddings(image).unsqueeze(0)
1497
+ visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long)
1498
+ visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)
1499
+ region_to_phrase_position = torch.ones((1, inputs["input_ids"].shape[-1] + visual_embeds.shape[-2]))
1500
+
1501
+ inputs.update(
1502
+ {
1503
+ "region_to_phrase_position": region_to_phrase_position,
1504
+ "visual_embeds": visual_embeds,
1505
+ "visual_token_type_ids": visual_token_type_ids,
1506
+ "visual_attention_mask": visual_attention_mask,
1507
+ }
1508
+ )
1509
+
1510
+ labels = torch.ones(
1511
+ (1, inputs["input_ids"].shape[-1] + visual_embeds.shape[-2], visual_embeds.shape[-2])
1512
+ ) # Batch size 1
1513
+
1514
+ outputs = model(**inputs, labels=labels)
1515
+ loss = outputs.loss
1516
+ scores = outputs.logits
1517
+ ```"""
1518
+ if region_to_phrase_position is None:
1519
+ raise ValueError("`region_to_phrase_position` should not be None when using Flickr Model.")
1520
+
1521
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1522
+
1523
+ outputs = self.visual_bert(
1524
+ input_ids,
1525
+ attention_mask=attention_mask,
1526
+ token_type_ids=token_type_ids,
1527
+ position_ids=position_ids,
1528
+ head_mask=head_mask,
1529
+ inputs_embeds=inputs_embeds,
1530
+ visual_embeds=visual_embeds,
1531
+ visual_attention_mask=visual_attention_mask,
1532
+ visual_token_type_ids=visual_token_type_ids,
1533
+ image_text_alignment=image_text_alignment,
1534
+ output_attentions=output_attentions,
1535
+ output_hidden_states=output_hidden_states,
1536
+ return_dict=return_dict,
1537
+ )
1538
+
1539
+ sequence_output = outputs[0]
1540
+
1541
+ region_to_phrase_position_mask = (region_to_phrase_position != -1).long()
1542
+
1543
+ # Make the -1 become 0
1544
+ region_to_phrase_position = region_to_phrase_position * region_to_phrase_position_mask
1545
+
1546
+ # Selected_positions = batch x selected position x dim
1547
+ expanded_region_to_phrase_positions = region_to_phrase_position.unsqueeze(2).expand(
1548
+ region_to_phrase_position.size(0), region_to_phrase_position.size(1), sequence_output.size(2)
1549
+ )
1550
+ selected_positions = sequence_output.gather(1, expanded_region_to_phrase_positions)
1551
+
1552
+ # Visual Features = batch x visual_feature_length x dim
1553
+ # This will need separate image and visual masks.
1554
+ visual_features = sequence_output[:, attention_mask.size(1) :]
1555
+
1556
+ if visual_features.size(1) != visual_attention_mask.size(1):
1557
+ raise ValueError(
1558
+ f"Visual features length :{visual_features.size(1)} should be the same"
1559
+ f" as visual attention mask length: {visual_attention_mask.size(1)}."
1560
+ )
1561
+
1562
+ logits = self.attention(selected_positions, visual_features, visual_attention_mask)
1563
+
1564
+ loss = None
1565
+
1566
+ if labels is not None:
1567
+ # scores = batch x selected position x visual_feature
1568
+ # scores = selected_positions.bmm(visual_features.transpose(1,2))
1569
+ # label = batch x selected_postion x needed position
1570
+ loss_fct = KLDivLoss(reduction="batchmean")
1571
+ log_softmax = LogSoftmax(dim=-1)
1572
+ scores = log_softmax(logits)
1573
+ labels = labels.contiguous()
1574
+ loss = loss_fct(scores, labels)
1575
+
1576
+ if not return_dict:
1577
+ output = (logits,) + outputs[2:]
1578
+ return ((loss,) + output) if loss is not None else output
1579
+
1580
+ return SequenceClassifierOutput(
1581
+ loss=loss,
1582
+ logits=logits,
1583
+ hidden_states=outputs.hidden_states,
1584
+ attentions=outputs.attentions,
1585
+ )
1586
+
1587
+
1588
+ __all__ = [
1589
+ "VisualBertForMultipleChoice",
1590
+ "VisualBertForPreTraining",
1591
+ "VisualBertForQuestionAnswering",
1592
+ "VisualBertForRegionToPhraseAlignment",
1593
+ "VisualBertForVisualReasoning",
1594
+ "VisualBertLayer",
1595
+ "VisualBertModel",
1596
+ "VisualBertPreTrainedModel",
1597
+ ]
docs/transformers/build/lib/transformers/models/vit/__init__.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_vit import *
22
+ from .feature_extraction_vit import *
23
+ from .image_processing_vit import *
24
+ from .image_processing_vit_fast import *
25
+ from .modeling_flax_vit import *
26
+ from .modeling_tf_vit import *
27
+ from .modeling_vit import *
28
+ else:
29
+ import sys
30
+
31
+ _file = globals()["__file__"]
32
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/build/lib/transformers/models/vit/configuration_vit.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 Google AI and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ViT model configuration"""
16
+
17
+ from collections import OrderedDict
18
+ from typing import Mapping
19
+
20
+ from packaging import version
21
+
22
+ from ...configuration_utils import PretrainedConfig
23
+ from ...onnx import OnnxConfig
24
+ from ...utils import logging
25
+
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+
30
+ class ViTConfig(PretrainedConfig):
31
+ r"""
32
+ This is the configuration class to store the configuration of a [`ViTModel`]. It is used to instantiate an ViT
33
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
34
+ defaults will yield a similar configuration to that of the ViT
35
+ [google/vit-base-patch16-224](https://huggingface.co/google/vit-base-patch16-224) architecture.
36
+
37
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
38
+ documentation from [`PretrainedConfig`] for more information.
39
+
40
+
41
+ Args:
42
+ hidden_size (`int`, *optional*, defaults to 768):
43
+ Dimensionality of the encoder layers and the pooler layer.
44
+ num_hidden_layers (`int`, *optional*, defaults to 12):
45
+ Number of hidden layers in the Transformer encoder.
46
+ num_attention_heads (`int`, *optional*, defaults to 12):
47
+ Number of attention heads for each attention layer in the Transformer encoder.
48
+ intermediate_size (`int`, *optional*, defaults to 3072):
49
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
50
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
51
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
52
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
53
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
54
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
55
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
56
+ The dropout ratio for the attention probabilities.
57
+ initializer_range (`float`, *optional*, defaults to 0.02):
58
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
59
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
60
+ The epsilon used by the layer normalization layers.
61
+ image_size (`int`, *optional*, defaults to 224):
62
+ The size (resolution) of each image.
63
+ patch_size (`int`, *optional*, defaults to 16):
64
+ The size (resolution) of each patch.
65
+ num_channels (`int`, *optional*, defaults to 3):
66
+ The number of input channels.
67
+ qkv_bias (`bool`, *optional*, defaults to `True`):
68
+ Whether to add a bias to the queries, keys and values.
69
+ encoder_stride (`int`, *optional*, defaults to 16):
70
+ Factor to increase the spatial resolution by in the decoder head for masked image modeling.
71
+ pooler_output_size (`int`, *optional*):
72
+ Dimensionality of the pooler layer. If None, defaults to `hidden_size`.
73
+ pooler_act (`str`, *optional*, defaults to `"tanh"`):
74
+ The activation function to be used by the pooler. Keys of ACT2FN are supported for Flax and
75
+ Pytorch, and elements of https://www.tensorflow.org/api_docs/python/tf/keras/activations are
76
+ supported for Tensorflow.
77
+
78
+ Example:
79
+
80
+ ```python
81
+ >>> from transformers import ViTConfig, ViTModel
82
+
83
+ >>> # Initializing a ViT vit-base-patch16-224 style configuration
84
+ >>> configuration = ViTConfig()
85
+
86
+ >>> # Initializing a model (with random weights) from the vit-base-patch16-224 style configuration
87
+ >>> model = ViTModel(configuration)
88
+
89
+ >>> # Accessing the model configuration
90
+ >>> configuration = model.config
91
+ ```"""
92
+
93
+ model_type = "vit"
94
+
95
+ def __init__(
96
+ self,
97
+ hidden_size=768,
98
+ num_hidden_layers=12,
99
+ num_attention_heads=12,
100
+ intermediate_size=3072,
101
+ hidden_act="gelu",
102
+ hidden_dropout_prob=0.0,
103
+ attention_probs_dropout_prob=0.0,
104
+ initializer_range=0.02,
105
+ layer_norm_eps=1e-12,
106
+ image_size=224,
107
+ patch_size=16,
108
+ num_channels=3,
109
+ qkv_bias=True,
110
+ encoder_stride=16,
111
+ pooler_output_size=None,
112
+ pooler_act="tanh",
113
+ **kwargs,
114
+ ):
115
+ super().__init__(**kwargs)
116
+
117
+ self.hidden_size = hidden_size
118
+ self.num_hidden_layers = num_hidden_layers
119
+ self.num_attention_heads = num_attention_heads
120
+ self.intermediate_size = intermediate_size
121
+ self.hidden_act = hidden_act
122
+ self.hidden_dropout_prob = hidden_dropout_prob
123
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
124
+ self.initializer_range = initializer_range
125
+ self.layer_norm_eps = layer_norm_eps
126
+ self.image_size = image_size
127
+ self.patch_size = patch_size
128
+ self.num_channels = num_channels
129
+ self.qkv_bias = qkv_bias
130
+ self.encoder_stride = encoder_stride
131
+ self.pooler_output_size = pooler_output_size if pooler_output_size else hidden_size
132
+ self.pooler_act = pooler_act
133
+
134
+
135
+ class ViTOnnxConfig(OnnxConfig):
136
+ torch_onnx_minimum_version = version.parse("1.11")
137
+
138
+ @property
139
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
140
+ return OrderedDict(
141
+ [
142
+ ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
143
+ ]
144
+ )
145
+
146
+ @property
147
+ def atol_for_validation(self) -> float:
148
+ return 1e-4
149
+
150
+
151
+ __all__ = ["ViTConfig", "ViTOnnxConfig"]
docs/transformers/build/lib/transformers/models/vit/convert_dino_to_pytorch.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Convert ViT checkpoints trained with the DINO method."""
16
+
17
+ import argparse
18
+ import json
19
+ from pathlib import Path
20
+
21
+ import requests
22
+ import torch
23
+ from huggingface_hub import hf_hub_download
24
+ from PIL import Image
25
+
26
+ from transformers import ViTConfig, ViTForImageClassification, ViTImageProcessor, ViTModel
27
+ from transformers.utils import logging
28
+
29
+
30
+ logging.set_verbosity_info()
31
+ logger = logging.get_logger(__name__)
32
+
33
+
34
+ # here we list all keys to be renamed (original name on the left, our name on the right)
35
+ def create_rename_keys(config, base_model=False):
36
+ rename_keys = []
37
+ for i in range(config.num_hidden_layers):
38
+ # encoder layers: output projection, 2 feedforward neural networks and 2 layernorms
39
+ rename_keys.append((f"blocks.{i}.norm1.weight", f"vit.encoder.layer.{i}.layernorm_before.weight"))
40
+ rename_keys.append((f"blocks.{i}.norm1.bias", f"vit.encoder.layer.{i}.layernorm_before.bias"))
41
+ rename_keys.append((f"blocks.{i}.attn.proj.weight", f"vit.encoder.layer.{i}.attention.output.dense.weight"))
42
+ rename_keys.append((f"blocks.{i}.attn.proj.bias", f"vit.encoder.layer.{i}.attention.output.dense.bias"))
43
+ rename_keys.append((f"blocks.{i}.norm2.weight", f"vit.encoder.layer.{i}.layernorm_after.weight"))
44
+ rename_keys.append((f"blocks.{i}.norm2.bias", f"vit.encoder.layer.{i}.layernorm_after.bias"))
45
+ rename_keys.append((f"blocks.{i}.mlp.fc1.weight", f"vit.encoder.layer.{i}.intermediate.dense.weight"))
46
+ rename_keys.append((f"blocks.{i}.mlp.fc1.bias", f"vit.encoder.layer.{i}.intermediate.dense.bias"))
47
+ rename_keys.append((f"blocks.{i}.mlp.fc2.weight", f"vit.encoder.layer.{i}.output.dense.weight"))
48
+ rename_keys.append((f"blocks.{i}.mlp.fc2.bias", f"vit.encoder.layer.{i}.output.dense.bias"))
49
+
50
+ # projection layer + position embeddings
51
+ rename_keys.extend(
52
+ [
53
+ ("cls_token", "vit.embeddings.cls_token"),
54
+ ("patch_embed.proj.weight", "vit.embeddings.patch_embeddings.projection.weight"),
55
+ ("patch_embed.proj.bias", "vit.embeddings.patch_embeddings.projection.bias"),
56
+ ("pos_embed", "vit.embeddings.position_embeddings"),
57
+ ]
58
+ )
59
+
60
+ if base_model:
61
+ # layernorm + pooler
62
+ rename_keys.extend(
63
+ [
64
+ ("norm.weight", "layernorm.weight"),
65
+ ("norm.bias", "layernorm.bias"),
66
+ ]
67
+ )
68
+
69
+ # if just the base model, we should remove "vit" from all keys that start with "vit"
70
+ rename_keys = [(pair[0], pair[1][4:]) if pair[1].startswith("vit") else pair for pair in rename_keys]
71
+ else:
72
+ # layernorm + classification head
73
+ rename_keys.extend(
74
+ [
75
+ ("norm.weight", "vit.layernorm.weight"),
76
+ ("norm.bias", "vit.layernorm.bias"),
77
+ ("head.weight", "classifier.weight"),
78
+ ("head.bias", "classifier.bias"),
79
+ ]
80
+ )
81
+
82
+ return rename_keys
83
+
84
+
85
+ # we split up the matrix of each encoder layer into queries, keys and values
86
+ def read_in_q_k_v(state_dict, config, base_model=False):
87
+ for i in range(config.num_hidden_layers):
88
+ if base_model:
89
+ prefix = ""
90
+ else:
91
+ prefix = "vit."
92
+ # read in weights + bias of input projection layer (in timm, this is a single matrix + bias)
93
+ in_proj_weight = state_dict.pop(f"blocks.{i}.attn.qkv.weight")
94
+ in_proj_bias = state_dict.pop(f"blocks.{i}.attn.qkv.bias")
95
+ # next, add query, keys and values (in that order) to the state dict
96
+ state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[
97
+ : config.hidden_size, :
98
+ ]
99
+ state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.bias"] = in_proj_bias[: config.hidden_size]
100
+ state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[
101
+ config.hidden_size : config.hidden_size * 2, :
102
+ ]
103
+ state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.bias"] = in_proj_bias[
104
+ config.hidden_size : config.hidden_size * 2
105
+ ]
106
+ state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[
107
+ -config.hidden_size :, :
108
+ ]
109
+ state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.bias"] = in_proj_bias[-config.hidden_size :]
110
+
111
+
112
+ def remove_classification_head_(state_dict):
113
+ ignore_keys = ["head.weight", "head.bias"]
114
+ for k in ignore_keys:
115
+ state_dict.pop(k, None)
116
+
117
+
118
+ def rename_key(dct, old, new):
119
+ val = dct.pop(old)
120
+ dct[new] = val
121
+
122
+
123
+ # We will verify our results on an image of cute cats
124
+ def prepare_img():
125
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
126
+ im = Image.open(requests.get(url, stream=True).raw)
127
+ return im
128
+
129
+
130
+ @torch.no_grad()
131
+ def convert_vit_checkpoint(model_name, pytorch_dump_folder_path, base_model=True):
132
+ """
133
+ Copy/paste/tweak model's weights to our ViT structure.
134
+ """
135
+
136
+ # define default ViT configuration
137
+ config = ViTConfig()
138
+ # patch_size
139
+ if model_name[-1] == "8":
140
+ config.patch_size = 8
141
+ # set labels if required
142
+ if not base_model:
143
+ config.num_labels = 1000
144
+ repo_id = "huggingface/label-files"
145
+ filename = "imagenet-1k-id2label.json"
146
+ id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
147
+ id2label = {int(k): v for k, v in id2label.items()}
148
+ config.id2label = id2label
149
+ config.label2id = {v: k for k, v in id2label.items()}
150
+ # size of the architecture
151
+ if model_name in ["dino_vits8", "dino_vits16"]:
152
+ config.hidden_size = 384
153
+ config.intermediate_size = 1536
154
+ config.num_hidden_layers = 12
155
+ config.num_attention_heads = 6
156
+
157
+ # load original model from torch hub
158
+ original_model = torch.hub.load("facebookresearch/dino:main", model_name)
159
+ original_model.eval()
160
+
161
+ # load state_dict of original model, remove and rename some keys
162
+ state_dict = original_model.state_dict()
163
+ if base_model:
164
+ remove_classification_head_(state_dict)
165
+ rename_keys = create_rename_keys(config, base_model=base_model)
166
+ for src, dest in rename_keys:
167
+ rename_key(state_dict, src, dest)
168
+ read_in_q_k_v(state_dict, config, base_model)
169
+
170
+ # load HuggingFace model
171
+ if base_model:
172
+ model = ViTModel(config, add_pooling_layer=False).eval()
173
+ else:
174
+ model = ViTForImageClassification(config).eval()
175
+ model.load_state_dict(state_dict)
176
+
177
+ # Check outputs on an image, prepared by ViTImageProcessor
178
+ image_processor = ViTImageProcessor()
179
+ encoding = image_processor(images=prepare_img(), return_tensors="pt")
180
+ pixel_values = encoding["pixel_values"]
181
+ outputs = model(pixel_values)
182
+
183
+ if base_model:
184
+ final_hidden_state_cls_token = original_model(pixel_values)
185
+ assert torch.allclose(final_hidden_state_cls_token, outputs.last_hidden_state[:, 0, :], atol=1e-1)
186
+ else:
187
+ logits = original_model(pixel_values)
188
+ assert logits.shape == outputs.logits.shape
189
+ assert torch.allclose(logits, outputs.logits, atol=1e-3)
190
+
191
+ Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
192
+ print(f"Saving model {model_name} to {pytorch_dump_folder_path}")
193
+ model.save_pretrained(pytorch_dump_folder_path)
194
+ print(f"Saving image processor to {pytorch_dump_folder_path}")
195
+ image_processor.save_pretrained(pytorch_dump_folder_path)
196
+
197
+
198
+ if __name__ == "__main__":
199
+ parser = argparse.ArgumentParser()
200
+ # Required parameters
201
+ parser.add_argument(
202
+ "--model_name",
203
+ default="dino_vitb16",
204
+ type=str,
205
+ help="Name of the model trained with DINO you'd like to convert.",
206
+ )
207
+ parser.add_argument(
208
+ "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
209
+ )
210
+ parser.add_argument(
211
+ "--base_model",
212
+ action="store_true",
213
+ help="Whether to only convert the base model (no projection head weights).",
214
+ )
215
+
216
+ parser.set_defaults(base_model=True)
217
+ args = parser.parse_args()
218
+ convert_vit_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.base_model)
docs/transformers/build/lib/transformers/models/vit/convert_vit_timm_to_pytorch.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Convert ViT and non-distilled DeiT checkpoints from the timm library."""
16
+
17
+ import argparse
18
+ from pathlib import Path
19
+
20
+ import requests
21
+ import timm
22
+ import torch
23
+ from PIL import Image
24
+ from timm.data import ImageNetInfo, infer_imagenet_subset
25
+
26
+ from transformers import DeiTImageProcessor, ViTConfig, ViTForImageClassification, ViTImageProcessor, ViTModel
27
+ from transformers.utils import logging
28
+
29
+
30
+ logging.set_verbosity_info()
31
+ logger = logging.get_logger(__name__)
32
+
33
+
34
+ # here we list all keys to be renamed (original name on the left, our name on the right)
35
+ def create_rename_keys(config, base_model=False):
36
+ rename_keys = []
37
+ for i in range(config.num_hidden_layers):
38
+ # encoder layers: output projection, 2 feedforward neural networks and 2 layernorms
39
+ rename_keys.append((f"blocks.{i}.norm1.weight", f"vit.encoder.layer.{i}.layernorm_before.weight"))
40
+ rename_keys.append((f"blocks.{i}.norm1.bias", f"vit.encoder.layer.{i}.layernorm_before.bias"))
41
+ rename_keys.append((f"blocks.{i}.attn.proj.weight", f"vit.encoder.layer.{i}.attention.output.dense.weight"))
42
+ rename_keys.append((f"blocks.{i}.attn.proj.bias", f"vit.encoder.layer.{i}.attention.output.dense.bias"))
43
+ rename_keys.append((f"blocks.{i}.norm2.weight", f"vit.encoder.layer.{i}.layernorm_after.weight"))
44
+ rename_keys.append((f"blocks.{i}.norm2.bias", f"vit.encoder.layer.{i}.layernorm_after.bias"))
45
+ rename_keys.append((f"blocks.{i}.mlp.fc1.weight", f"vit.encoder.layer.{i}.intermediate.dense.weight"))
46
+ rename_keys.append((f"blocks.{i}.mlp.fc1.bias", f"vit.encoder.layer.{i}.intermediate.dense.bias"))
47
+ rename_keys.append((f"blocks.{i}.mlp.fc2.weight", f"vit.encoder.layer.{i}.output.dense.weight"))
48
+ rename_keys.append((f"blocks.{i}.mlp.fc2.bias", f"vit.encoder.layer.{i}.output.dense.bias"))
49
+
50
+ # projection layer + position embeddings
51
+ rename_keys.extend(
52
+ [
53
+ ("cls_token", "vit.embeddings.cls_token"),
54
+ ("patch_embed.proj.weight", "vit.embeddings.patch_embeddings.projection.weight"),
55
+ ("patch_embed.proj.bias", "vit.embeddings.patch_embeddings.projection.bias"),
56
+ ("pos_embed", "vit.embeddings.position_embeddings"),
57
+ ]
58
+ )
59
+
60
+ if base_model:
61
+ # layernorm
62
+ rename_keys.extend(
63
+ [
64
+ ("norm.weight", "layernorm.weight"),
65
+ ("norm.bias", "layernorm.bias"),
66
+ ]
67
+ )
68
+
69
+ # if just the base model, we should remove "vit" from all keys that start with "vit"
70
+ rename_keys = [(pair[0], pair[1][4:]) if pair[1].startswith("vit") else pair for pair in rename_keys]
71
+ else:
72
+ # layernorm + classification head
73
+ rename_keys.extend(
74
+ [
75
+ ("norm.weight", "vit.layernorm.weight"),
76
+ ("norm.bias", "vit.layernorm.bias"),
77
+ ("head.weight", "classifier.weight"),
78
+ ("head.bias", "classifier.bias"),
79
+ ]
80
+ )
81
+
82
+ return rename_keys
83
+
84
+
85
+ # we split up the matrix of each encoder layer into queries, keys and values
86
+ def read_in_q_k_v(state_dict, config, base_model=False):
87
+ for i in range(config.num_hidden_layers):
88
+ if base_model:
89
+ prefix = ""
90
+ else:
91
+ prefix = "vit."
92
+ # read in weights + bias of input projection layer (in timm, this is a single matrix + bias)
93
+ in_proj_weight = state_dict.pop(f"blocks.{i}.attn.qkv.weight")
94
+ in_proj_bias = state_dict.pop(f"blocks.{i}.attn.qkv.bias")
95
+ # next, add query, keys and values (in that order) to the state dict
96
+ state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[
97
+ : config.hidden_size, :
98
+ ]
99
+ state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.bias"] = in_proj_bias[: config.hidden_size]
100
+ state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[
101
+ config.hidden_size : config.hidden_size * 2, :
102
+ ]
103
+ state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.bias"] = in_proj_bias[
104
+ config.hidden_size : config.hidden_size * 2
105
+ ]
106
+ state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[
107
+ -config.hidden_size :, :
108
+ ]
109
+ state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.bias"] = in_proj_bias[-config.hidden_size :]
110
+
111
+
112
+ def remove_classification_head_(state_dict):
113
+ ignore_keys = ["head.weight", "head.bias"]
114
+ for k in ignore_keys:
115
+ state_dict.pop(k, None)
116
+
117
+
118
+ def rename_key(dct, old, new):
119
+ val = dct.pop(old)
120
+ dct[new] = val
121
+
122
+
123
+ # We will verify our results on an image of cute cats
124
+ def prepare_img():
125
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
126
+ im = Image.open(requests.get(url, stream=True).raw)
127
+ return im
128
+
129
+
130
+ @torch.no_grad()
131
+ def convert_vit_checkpoint(vit_name, pytorch_dump_folder_path):
132
+ """
133
+ Copy/paste/tweak model's weights to our ViT structure.
134
+ """
135
+
136
+ # define default ViT configuration
137
+ config = ViTConfig()
138
+ base_model = False
139
+
140
+ # load original model from timm
141
+ timm_model = timm.create_model(vit_name, pretrained=True)
142
+ timm_model.eval()
143
+
144
+ # detect unsupported ViT models in transformers
145
+ # fc_norm is present
146
+ if not isinstance(getattr(timm_model, "fc_norm", None), torch.nn.Identity):
147
+ raise ValueError(f"{vit_name} is not supported in transformers because of the presence of fc_norm.")
148
+
149
+ # use of global average pooling in combination (or without) class token
150
+ if getattr(timm_model, "global_pool", None) == "avg":
151
+ raise ValueError(f"{vit_name} is not supported in transformers because of use of global average pooling.")
152
+
153
+ # CLIP style vit with norm_pre layer present
154
+ if "clip" in vit_name and not isinstance(getattr(timm_model, "norm_pre", None), torch.nn.Identity):
155
+ raise ValueError(
156
+ f"{vit_name} is not supported in transformers because it's a CLIP style ViT with norm_pre layer."
157
+ )
158
+
159
+ # SigLIP style vit with attn_pool layer present
160
+ if "siglip" in vit_name and getattr(timm_model, "global_pool", None) == "map":
161
+ raise ValueError(
162
+ f"{vit_name} is not supported in transformers because it's a SigLIP style ViT with attn_pool."
163
+ )
164
+
165
+ # use of layer scale in ViT model blocks
166
+ if not isinstance(getattr(timm_model.blocks[0], "ls1", None), torch.nn.Identity) or not isinstance(
167
+ getattr(timm_model.blocks[0], "ls2", None), torch.nn.Identity
168
+ ):
169
+ raise ValueError(f"{vit_name} is not supported in transformers because it uses a layer scale in its blocks.")
170
+
171
+ # Hybrid ResNet-ViTs
172
+ if not isinstance(timm_model.patch_embed, timm.layers.PatchEmbed):
173
+ raise ValueError(f"{vit_name} is not supported in transformers because it is a hybrid ResNet-ViT.")
174
+
175
+ # get patch size and image size from the patch embedding submodule
176
+ config.patch_size = timm_model.patch_embed.patch_size[0]
177
+ config.image_size = timm_model.patch_embed.img_size[0]
178
+
179
+ # retrieve architecture-specific parameters from the timm model
180
+ config.hidden_size = timm_model.embed_dim
181
+ config.intermediate_size = timm_model.blocks[0].mlp.fc1.out_features
182
+ config.num_hidden_layers = len(timm_model.blocks)
183
+ config.num_attention_heads = timm_model.blocks[0].attn.num_heads
184
+
185
+ # check whether the model has a classification head or not
186
+ if timm_model.num_classes != 0:
187
+ config.num_labels = timm_model.num_classes
188
+ # infer ImageNet subset from timm model
189
+ imagenet_subset = infer_imagenet_subset(timm_model)
190
+ dataset_info = ImageNetInfo(imagenet_subset)
191
+ config.id2label = {i: dataset_info.index_to_label_name(i) for i in range(dataset_info.num_classes())}
192
+ config.label2id = {v: k for k, v in config.id2label.items()}
193
+ else:
194
+ print(f"{vit_name} is going to be converted as a feature extractor only.")
195
+ base_model = True
196
+
197
+ # load state_dict of original model
198
+ state_dict = timm_model.state_dict()
199
+
200
+ # remove and rename some keys in the state dict
201
+ if base_model:
202
+ remove_classification_head_(state_dict)
203
+ rename_keys = create_rename_keys(config, base_model)
204
+ for src, dest in rename_keys:
205
+ rename_key(state_dict, src, dest)
206
+ read_in_q_k_v(state_dict, config, base_model)
207
+
208
+ # load HuggingFace model
209
+ if base_model:
210
+ model = ViTModel(config, add_pooling_layer=False).eval()
211
+ else:
212
+ model = ViTForImageClassification(config).eval()
213
+ model.load_state_dict(state_dict)
214
+
215
+ # Check outputs on an image, prepared by ViTImageProcessor/DeiTImageProcessor
216
+ if "deit" in vit_name:
217
+ image_processor = DeiTImageProcessor(size=config.image_size)
218
+ else:
219
+ image_processor = ViTImageProcessor(size=config.image_size)
220
+ encoding = image_processor(images=prepare_img(), return_tensors="pt")
221
+ pixel_values = encoding["pixel_values"]
222
+ outputs = model(pixel_values)
223
+
224
+ if base_model:
225
+ timm_pooled_output = timm_model.forward_features(pixel_values)
226
+ assert timm_pooled_output.shape == outputs.last_hidden_state.shape
227
+ assert torch.allclose(timm_pooled_output, outputs.last_hidden_state, atol=1e-1)
228
+ else:
229
+ timm_logits = timm_model(pixel_values)
230
+ assert timm_logits.shape == outputs.logits.shape
231
+ assert torch.allclose(timm_logits, outputs.logits, atol=1e-3)
232
+
233
+ Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
234
+ print(f"Saving model {vit_name} to {pytorch_dump_folder_path}")
235
+ model.save_pretrained(pytorch_dump_folder_path)
236
+ print(f"Saving image processor to {pytorch_dump_folder_path}")
237
+ image_processor.save_pretrained(pytorch_dump_folder_path)
238
+
239
+
240
+ if __name__ == "__main__":
241
+ parser = argparse.ArgumentParser()
242
+ # Required parameters
243
+ parser.add_argument(
244
+ "--vit_name",
245
+ default="vit_base_patch16_224",
246
+ type=str,
247
+ help="Name of the ViT timm model you'd like to convert.",
248
+ )
249
+ parser.add_argument(
250
+ "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
251
+ )
252
+
253
+ args = parser.parse_args()
254
+ convert_vit_checkpoint(args.vit_name, args.pytorch_dump_folder_path)
docs/transformers/build/lib/transformers/models/vit/feature_extraction_vit.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Feature extractor class for ViT."""
16
+
17
+ import warnings
18
+
19
+ from ...utils import logging
20
+ from ...utils.import_utils import requires
21
+ from .image_processing_vit import ViTImageProcessor
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+
27
+ @requires(backends=("vision",))
28
+ class ViTFeatureExtractor(ViTImageProcessor):
29
+ def __init__(self, *args, **kwargs) -> None:
30
+ warnings.warn(
31
+ "The class ViTFeatureExtractor is deprecated and will be removed in version 5 of Transformers. Please"
32
+ " use ViTImageProcessor instead.",
33
+ FutureWarning,
34
+ )
35
+ super().__init__(*args, **kwargs)
36
+
37
+
38
+ __all__ = ["ViTFeatureExtractor"]
docs/transformers/build/lib/transformers/models/vit/image_processing_vit.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Image processor class for ViT."""
16
+
17
+ from typing import Dict, List, Optional, Union
18
+
19
+ import numpy as np
20
+
21
+ from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
22
+ from ...image_transforms import convert_to_rgb, resize, to_channel_dimension_format
23
+ from ...image_utils import (
24
+ IMAGENET_STANDARD_MEAN,
25
+ IMAGENET_STANDARD_STD,
26
+ ChannelDimension,
27
+ ImageInput,
28
+ PILImageResampling,
29
+ infer_channel_dimension_format,
30
+ is_scaled_image,
31
+ make_list_of_images,
32
+ to_numpy_array,
33
+ valid_images,
34
+ validate_preprocess_arguments,
35
+ )
36
+ from ...utils import TensorType, filter_out_non_signature_kwargs, logging
37
+ from ...utils.import_utils import requires
38
+
39
+
40
+ logger = logging.get_logger(__name__)
41
+
42
+
43
+ @requires(backends=("vision",))
44
+ class ViTImageProcessor(BaseImageProcessor):
45
+ r"""
46
+ Constructs a ViT image processor.
47
+
48
+ Args:
49
+ do_resize (`bool`, *optional*, defaults to `True`):
50
+ Whether to resize the image's (height, width) dimensions to the specified `(size["height"],
51
+ size["width"])`. Can be overridden by the `do_resize` parameter in the `preprocess` method.
52
+ size (`dict`, *optional*, defaults to `{"height": 224, "width": 224}`):
53
+ Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
54
+ method.
55
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
56
+ Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
57
+ `preprocess` method.
58
+ do_rescale (`bool`, *optional*, defaults to `True`):
59
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
60
+ parameter in the `preprocess` method.
61
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
62
+ Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
63
+ `preprocess` method.
64
+ do_normalize (`bool`, *optional*, defaults to `True`):
65
+ Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
66
+ method.
67
+ image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
68
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
69
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
70
+ image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
71
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
72
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
73
+ do_convert_rgb (`bool`, *optional*):
74
+ Whether to convert the image to RGB.
75
+ """
76
+
77
+ model_input_names = ["pixel_values"]
78
+
79
+ def __init__(
80
+ self,
81
+ do_resize: bool = True,
82
+ size: Optional[Dict[str, int]] = None,
83
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
84
+ do_rescale: bool = True,
85
+ rescale_factor: Union[int, float] = 1 / 255,
86
+ do_normalize: bool = True,
87
+ image_mean: Optional[Union[float, List[float]]] = None,
88
+ image_std: Optional[Union[float, List[float]]] = None,
89
+ do_convert_rgb: Optional[bool] = None,
90
+ **kwargs,
91
+ ) -> None:
92
+ super().__init__(**kwargs)
93
+ size = size if size is not None else {"height": 224, "width": 224}
94
+ size = get_size_dict(size)
95
+ self.do_resize = do_resize
96
+ self.do_rescale = do_rescale
97
+ self.do_normalize = do_normalize
98
+ self.size = size
99
+ self.resample = resample
100
+ self.rescale_factor = rescale_factor
101
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
102
+ self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
103
+ self.do_convert_rgb = do_convert_rgb
104
+
105
+ def resize(
106
+ self,
107
+ image: np.ndarray,
108
+ size: Dict[str, int],
109
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
110
+ data_format: Optional[Union[str, ChannelDimension]] = None,
111
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
112
+ **kwargs,
113
+ ) -> np.ndarray:
114
+ """
115
+ Resize an image to `(size["height"], size["width"])`.
116
+
117
+ Args:
118
+ image (`np.ndarray`):
119
+ Image to resize.
120
+ size (`Dict[str, int]`):
121
+ Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
122
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
123
+ `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`.
124
+ data_format (`ChannelDimension` or `str`, *optional*):
125
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
126
+ image is used. Can be one of:
127
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
128
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
129
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
130
+ input_data_format (`ChannelDimension` or `str`, *optional*):
131
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
132
+ from the input image. Can be one of:
133
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
134
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
135
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
136
+
137
+ Returns:
138
+ `np.ndarray`: The resized image.
139
+ """
140
+ size = get_size_dict(size)
141
+ if "height" not in size or "width" not in size:
142
+ raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
143
+ output_size = (size["height"], size["width"])
144
+ return resize(
145
+ image,
146
+ size=output_size,
147
+ resample=resample,
148
+ data_format=data_format,
149
+ input_data_format=input_data_format,
150
+ **kwargs,
151
+ )
152
+
153
+ @filter_out_non_signature_kwargs()
154
+ def preprocess(
155
+ self,
156
+ images: ImageInput,
157
+ do_resize: Optional[bool] = None,
158
+ size: Dict[str, int] = None,
159
+ resample: PILImageResampling = None,
160
+ do_rescale: Optional[bool] = None,
161
+ rescale_factor: Optional[float] = None,
162
+ do_normalize: Optional[bool] = None,
163
+ image_mean: Optional[Union[float, List[float]]] = None,
164
+ image_std: Optional[Union[float, List[float]]] = None,
165
+ return_tensors: Optional[Union[str, TensorType]] = None,
166
+ data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
167
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
168
+ do_convert_rgb: Optional[bool] = None,
169
+ ):
170
+ """
171
+ Preprocess an image or batch of images.
172
+
173
+ Args:
174
+ images (`ImageInput`):
175
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
176
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
177
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
178
+ Whether to resize the image.
179
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
180
+ Dictionary in the format `{"height": h, "width": w}` specifying the size of the output image after
181
+ resizing.
182
+ resample (`PILImageResampling` filter, *optional*, defaults to `self.resample`):
183
+ `PILImageResampling` filter to use if resizing the image e.g. `PILImageResampling.BILINEAR`. Only has
184
+ an effect if `do_resize` is set to `True`.
185
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
186
+ Whether to rescale the image values between [0 - 1].
187
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
188
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
189
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
190
+ Whether to normalize the image.
191
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
192
+ Image mean to use if `do_normalize` is set to `True`.
193
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
194
+ Image standard deviation to use if `do_normalize` is set to `True`.
195
+ return_tensors (`str` or `TensorType`, *optional*):
196
+ The type of tensors to return. Can be one of:
197
+ - Unset: Return a list of `np.ndarray`.
198
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
199
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
200
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
201
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
202
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
203
+ The channel dimension format for the output image. Can be one of:
204
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
205
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
206
+ - Unset: Use the channel dimension format of the input image.
207
+ input_data_format (`ChannelDimension` or `str`, *optional*):
208
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
209
+ from the input image. Can be one of:
210
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
211
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
212
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
213
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
214
+ Whether to convert the image to RGB.
215
+ """
216
+ do_resize = do_resize if do_resize is not None else self.do_resize
217
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
218
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
219
+ resample = resample if resample is not None else self.resample
220
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
221
+ image_mean = image_mean if image_mean is not None else self.image_mean
222
+ image_std = image_std if image_std is not None else self.image_std
223
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
224
+
225
+ size = size if size is not None else self.size
226
+ size_dict = get_size_dict(size)
227
+
228
+ images = make_list_of_images(images)
229
+
230
+ if not valid_images(images):
231
+ raise ValueError(
232
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
233
+ "torch.Tensor, tf.Tensor or jax.ndarray."
234
+ )
235
+ validate_preprocess_arguments(
236
+ do_rescale=do_rescale,
237
+ rescale_factor=rescale_factor,
238
+ do_normalize=do_normalize,
239
+ image_mean=image_mean,
240
+ image_std=image_std,
241
+ do_resize=do_resize,
242
+ size=size,
243
+ resample=resample,
244
+ )
245
+
246
+ if do_convert_rgb:
247
+ images = [convert_to_rgb(image) for image in images]
248
+
249
+ # All transformations expect numpy arrays.
250
+ images = [to_numpy_array(image) for image in images]
251
+
252
+ if do_rescale and is_scaled_image(images[0]):
253
+ logger.warning_once(
254
+ "It looks like you are trying to rescale already rescaled images. If the input"
255
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
256
+ )
257
+
258
+ if input_data_format is None:
259
+ # We assume that all images have the same channel dimension format.
260
+ input_data_format = infer_channel_dimension_format(images[0])
261
+
262
+ if do_resize:
263
+ images = [
264
+ self.resize(image=image, size=size_dict, resample=resample, input_data_format=input_data_format)
265
+ for image in images
266
+ ]
267
+
268
+ if do_rescale:
269
+ images = [
270
+ self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
271
+ for image in images
272
+ ]
273
+
274
+ if do_normalize:
275
+ images = [
276
+ self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
277
+ for image in images
278
+ ]
279
+
280
+ images = [
281
+ to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
282
+ ]
283
+
284
+ data = {"pixel_values": images}
285
+ return BatchFeature(data=data, tensor_type=return_tensors)
286
+
287
+
288
+ __all__ = ["ViTImageProcessor"]
docs/transformers/build/lib/transformers/models/vit/image_processing_vit_fast.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Fast Image processor class for ViT."""
16
+
17
+ from ...image_processing_utils_fast import (
18
+ BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
19
+ BaseImageProcessorFast,
20
+ )
21
+ from ...image_utils import (
22
+ IMAGENET_STANDARD_MEAN,
23
+ IMAGENET_STANDARD_STD,
24
+ PILImageResampling,
25
+ )
26
+ from ...utils import (
27
+ add_start_docstrings,
28
+ )
29
+
30
+
31
+ @add_start_docstrings(
32
+ "Constructs a fast ViT image processor.",
33
+ BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
34
+ )
35
+ class ViTImageProcessorFast(BaseImageProcessorFast):
36
+ resample = PILImageResampling.BILINEAR
37
+ image_mean = IMAGENET_STANDARD_MEAN
38
+ image_std = IMAGENET_STANDARD_STD
39
+ size = {"height": 224, "width": 224}
40
+ do_resize = True
41
+ do_rescale = True
42
+ do_normalize = True
43
+
44
+
45
+ __all__ = ["ViTImageProcessorFast"]
docs/transformers/build/lib/transformers/models/vit/modeling_flax_vit.py ADDED
@@ -0,0 +1,677 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The Google Flax Team Authors and The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Optional, Tuple
17
+
18
+ import flax.linen as nn
19
+ import jax
20
+ import jax.numpy as jnp
21
+ from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
22
+ from flax.linen.attention import dot_product_attention_weights
23
+ from flax.traverse_util import flatten_dict, unflatten_dict
24
+
25
+ from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxBaseModelOutputWithPooling, FlaxSequenceClassifierOutput
26
+ from ...modeling_flax_utils import (
27
+ ACT2FN,
28
+ FlaxPreTrainedModel,
29
+ append_replace_return_docstrings,
30
+ overwrite_call_docstring,
31
+ )
32
+ from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward
33
+ from .configuration_vit import ViTConfig
34
+
35
+
36
+ VIT_START_DOCSTRING = r"""
37
+
38
+ This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
39
+ library implements for all its model (such as downloading, saving and converting weights from PyTorch models)
40
+
41
+ This model is also a
42
+ [flax.linen.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) subclass. Use it as
43
+ a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and
44
+ behavior.
45
+
46
+ Finally, this model supports inherent JAX features such as:
47
+
48
+ - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
49
+ - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
50
+ - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
51
+ - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
52
+
53
+ Parameters:
54
+ config ([`ViTConfig`]): Model configuration class with all the parameters of the model.
55
+ Initializing with a config file does not load the weights associated with the model, only the
56
+ configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
57
+ dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
58
+ The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
59
+ `jax.numpy.bfloat16` (on TPUs).
60
+
61
+ This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
62
+ specified all the computation will be performed with the given `dtype`.
63
+
64
+ **Note that this only specifies the dtype of the computation and does not influence the dtype of model
65
+ parameters.**
66
+
67
+ If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
68
+ [`~FlaxPreTrainedModel.to_bf16`].
69
+ """
70
+
71
+ VIT_INPUTS_DOCSTRING = r"""
72
+ Args:
73
+ pixel_values (`numpy.ndarray` of shape `(batch_size, num_channels, height, width)`):
74
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`]
75
+ for details.
76
+
77
+ output_attentions (`bool`, *optional*):
78
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
79
+ tensors for more detail.
80
+ output_hidden_states (`bool`, *optional*):
81
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
82
+ more detail.
83
+ return_dict (`bool`, *optional*):
84
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
85
+ """
86
+
87
+
88
+ class FlaxViTPatchEmbeddings(nn.Module):
89
+ config: ViTConfig
90
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
91
+
92
+ def setup(self):
93
+ image_size = self.config.image_size
94
+ patch_size = self.config.patch_size
95
+ num_patches = (image_size // patch_size) * (image_size // patch_size)
96
+ self.num_patches = num_patches
97
+ self.num_channels = self.config.num_channels
98
+ self.projection = nn.Conv(
99
+ self.config.hidden_size,
100
+ kernel_size=(patch_size, patch_size),
101
+ strides=(patch_size, patch_size),
102
+ padding="VALID",
103
+ dtype=self.dtype,
104
+ kernel_init=jax.nn.initializers.variance_scaling(
105
+ self.config.initializer_range**2, "fan_in", "truncated_normal"
106
+ ),
107
+ )
108
+
109
+ def __call__(self, pixel_values):
110
+ num_channels = pixel_values.shape[-1]
111
+ if num_channels != self.num_channels:
112
+ raise ValueError(
113
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
114
+ )
115
+ embeddings = self.projection(pixel_values)
116
+ batch_size, _, _, channels = embeddings.shape
117
+ return jnp.reshape(embeddings, (batch_size, -1, channels))
118
+
119
+
120
+ class FlaxViTEmbeddings(nn.Module):
121
+ """Construct the CLS token, position and patch embeddings."""
122
+
123
+ config: ViTConfig
124
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
125
+
126
+ def setup(self):
127
+ self.cls_token = self.param(
128
+ "cls_token",
129
+ jax.nn.initializers.variance_scaling(self.config.initializer_range**2, "fan_in", "truncated_normal"),
130
+ (1, 1, self.config.hidden_size),
131
+ )
132
+ self.patch_embeddings = FlaxViTPatchEmbeddings(self.config, dtype=self.dtype)
133
+ num_patches = self.patch_embeddings.num_patches
134
+ self.position_embeddings = self.param(
135
+ "position_embeddings",
136
+ jax.nn.initializers.variance_scaling(self.config.initializer_range**2, "fan_in", "truncated_normal"),
137
+ (1, num_patches + 1, self.config.hidden_size),
138
+ )
139
+ self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
140
+
141
+ def __call__(self, pixel_values, deterministic=True):
142
+ batch_size = pixel_values.shape[0]
143
+
144
+ embeddings = self.patch_embeddings(pixel_values)
145
+
146
+ cls_tokens = jnp.broadcast_to(self.cls_token, (batch_size, 1, self.config.hidden_size))
147
+ embeddings = jnp.concatenate((cls_tokens, embeddings), axis=1)
148
+ embeddings = embeddings + self.position_embeddings
149
+ embeddings = self.dropout(embeddings, deterministic=deterministic)
150
+ return embeddings
151
+
152
+
153
+ class FlaxViTSelfAttention(nn.Module):
154
+ config: ViTConfig
155
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
156
+
157
+ def setup(self):
158
+ if self.config.hidden_size % self.config.num_attention_heads != 0:
159
+ raise ValueError(
160
+ "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`:"
161
+ " {self.config.num_attention_heads}"
162
+ )
163
+
164
+ self.query = nn.Dense(
165
+ self.config.hidden_size,
166
+ dtype=self.dtype,
167
+ kernel_init=jax.nn.initializers.variance_scaling(
168
+ self.config.initializer_range**2, mode="fan_in", distribution="truncated_normal"
169
+ ),
170
+ use_bias=self.config.qkv_bias,
171
+ )
172
+ self.key = nn.Dense(
173
+ self.config.hidden_size,
174
+ dtype=self.dtype,
175
+ kernel_init=jax.nn.initializers.variance_scaling(
176
+ self.config.initializer_range**2, mode="fan_in", distribution="truncated_normal"
177
+ ),
178
+ use_bias=self.config.qkv_bias,
179
+ )
180
+ self.value = nn.Dense(
181
+ self.config.hidden_size,
182
+ dtype=self.dtype,
183
+ kernel_init=jax.nn.initializers.variance_scaling(
184
+ self.config.initializer_range**2, mode="fan_in", distribution="truncated_normal"
185
+ ),
186
+ use_bias=self.config.qkv_bias,
187
+ )
188
+
189
+ def __call__(self, hidden_states, deterministic: bool = True, output_attentions: bool = False):
190
+ head_dim = self.config.hidden_size // self.config.num_attention_heads
191
+
192
+ query_states = self.query(hidden_states).reshape(
193
+ hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
194
+ )
195
+ value_states = self.value(hidden_states).reshape(
196
+ hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
197
+ )
198
+ key_states = self.key(hidden_states).reshape(
199
+ hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
200
+ )
201
+
202
+ dropout_rng = None
203
+ if not deterministic and self.config.attention_probs_dropout_prob > 0.0:
204
+ dropout_rng = self.make_rng("dropout")
205
+
206
+ attn_weights = dot_product_attention_weights(
207
+ query_states,
208
+ key_states,
209
+ dropout_rng=dropout_rng,
210
+ dropout_rate=self.config.attention_probs_dropout_prob,
211
+ broadcast_dropout=True,
212
+ deterministic=deterministic,
213
+ dtype=self.dtype,
214
+ precision=None,
215
+ )
216
+
217
+ attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
218
+ attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,))
219
+
220
+ outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
221
+ return outputs
222
+
223
+
224
+ class FlaxViTSelfOutput(nn.Module):
225
+ config: ViTConfig
226
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
227
+
228
+ def setup(self):
229
+ self.dense = nn.Dense(
230
+ self.config.hidden_size,
231
+ kernel_init=jax.nn.initializers.variance_scaling(
232
+ self.config.initializer_range**2, "fan_in", "truncated_normal"
233
+ ),
234
+ dtype=self.dtype,
235
+ )
236
+ self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
237
+
238
+ def __call__(self, hidden_states, input_tensor, deterministic: bool = True):
239
+ hidden_states = self.dense(hidden_states)
240
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
241
+ return hidden_states
242
+
243
+
244
+ class FlaxViTAttention(nn.Module):
245
+ config: ViTConfig
246
+ dtype: jnp.dtype = jnp.float32
247
+
248
+ def setup(self):
249
+ self.attention = FlaxViTSelfAttention(self.config, dtype=self.dtype)
250
+ self.output = FlaxViTSelfOutput(self.config, dtype=self.dtype)
251
+
252
+ def __call__(self, hidden_states, deterministic=True, output_attentions: bool = False):
253
+ attn_outputs = self.attention(hidden_states, deterministic=deterministic, output_attentions=output_attentions)
254
+ attn_output = attn_outputs[0]
255
+ hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic)
256
+
257
+ outputs = (hidden_states,)
258
+
259
+ if output_attentions:
260
+ outputs += (attn_outputs[1],)
261
+
262
+ return outputs
263
+
264
+
265
+ class FlaxViTIntermediate(nn.Module):
266
+ config: ViTConfig
267
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
268
+
269
+ def setup(self):
270
+ self.dense = nn.Dense(
271
+ self.config.intermediate_size,
272
+ kernel_init=jax.nn.initializers.variance_scaling(
273
+ self.config.initializer_range**2, "fan_in", "truncated_normal"
274
+ ),
275
+ dtype=self.dtype,
276
+ )
277
+ self.activation = ACT2FN[self.config.hidden_act]
278
+
279
+ def __call__(self, hidden_states):
280
+ hidden_states = self.dense(hidden_states)
281
+ hidden_states = self.activation(hidden_states)
282
+ return hidden_states
283
+
284
+
285
+ class FlaxViTOutput(nn.Module):
286
+ config: ViTConfig
287
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
288
+
289
+ def setup(self):
290
+ self.dense = nn.Dense(
291
+ self.config.hidden_size,
292
+ kernel_init=jax.nn.initializers.variance_scaling(
293
+ self.config.initializer_range**2, "fan_in", "truncated_normal"
294
+ ),
295
+ dtype=self.dtype,
296
+ )
297
+ self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
298
+
299
+ def __call__(self, hidden_states, attention_output, deterministic: bool = True):
300
+ hidden_states = self.dense(hidden_states)
301
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
302
+ hidden_states = hidden_states + attention_output
303
+ return hidden_states
304
+
305
+
306
+ class FlaxViTLayer(nn.Module):
307
+ config: ViTConfig
308
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
309
+
310
+ def setup(self):
311
+ self.attention = FlaxViTAttention(self.config, dtype=self.dtype)
312
+ self.intermediate = FlaxViTIntermediate(self.config, dtype=self.dtype)
313
+ self.output = FlaxViTOutput(self.config, dtype=self.dtype)
314
+ self.layernorm_before = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
315
+ self.layernorm_after = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
316
+
317
+ def __call__(self, hidden_states, deterministic: bool = True, output_attentions: bool = False):
318
+ attention_outputs = self.attention(
319
+ self.layernorm_before(hidden_states), # in ViT, layernorm is applied before self-attention
320
+ deterministic=deterministic,
321
+ output_attentions=output_attentions,
322
+ )
323
+
324
+ attention_output = attention_outputs[0]
325
+
326
+ # first residual connection
327
+ attention_output = attention_output + hidden_states
328
+
329
+ # in ViT, layernorm is also applied after self-attention
330
+ layer_output = self.layernorm_after(attention_output)
331
+
332
+ hidden_states = self.intermediate(layer_output)
333
+ hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic)
334
+
335
+ outputs = (hidden_states,)
336
+
337
+ if output_attentions:
338
+ outputs += (attention_outputs[1],)
339
+ return outputs
340
+
341
+
342
+ class FlaxViTLayerCollection(nn.Module):
343
+ config: ViTConfig
344
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
345
+
346
+ def setup(self):
347
+ self.layers = [
348
+ FlaxViTLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers)
349
+ ]
350
+
351
+ def __call__(
352
+ self,
353
+ hidden_states,
354
+ deterministic: bool = True,
355
+ output_attentions: bool = False,
356
+ output_hidden_states: bool = False,
357
+ return_dict: bool = True,
358
+ ):
359
+ all_attentions = () if output_attentions else None
360
+ all_hidden_states = () if output_hidden_states else None
361
+
362
+ for i, layer in enumerate(self.layers):
363
+ if output_hidden_states:
364
+ all_hidden_states += (hidden_states,)
365
+
366
+ layer_outputs = layer(hidden_states, deterministic=deterministic, output_attentions=output_attentions)
367
+
368
+ hidden_states = layer_outputs[0]
369
+
370
+ if output_attentions:
371
+ all_attentions += (layer_outputs[1],)
372
+
373
+ if output_hidden_states:
374
+ all_hidden_states += (hidden_states,)
375
+
376
+ outputs = (hidden_states,)
377
+ if not return_dict:
378
+ return tuple(v for v in outputs if v is not None)
379
+
380
+ return FlaxBaseModelOutput(
381
+ last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
382
+ )
383
+
384
+
385
+ class FlaxViTEncoder(nn.Module):
386
+ config: ViTConfig
387
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
388
+
389
+ def setup(self):
390
+ self.layer = FlaxViTLayerCollection(self.config, dtype=self.dtype)
391
+
392
+ def __call__(
393
+ self,
394
+ hidden_states,
395
+ deterministic: bool = True,
396
+ output_attentions: bool = False,
397
+ output_hidden_states: bool = False,
398
+ return_dict: bool = True,
399
+ ):
400
+ return self.layer(
401
+ hidden_states,
402
+ deterministic=deterministic,
403
+ output_attentions=output_attentions,
404
+ output_hidden_states=output_hidden_states,
405
+ return_dict=return_dict,
406
+ )
407
+
408
+
409
+ class FlaxViTPooler(nn.Module):
410
+ config: ViTConfig
411
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
412
+
413
+ def setup(self):
414
+ self.dense = nn.Dense(
415
+ self.config.pooler_output_size,
416
+ kernel_init=jax.nn.initializers.variance_scaling(
417
+ self.config.initializer_range**2, "fan_in", "truncated_normal"
418
+ ),
419
+ dtype=self.dtype,
420
+ )
421
+ self.activation = ACT2FN[self.config.pooler_act]
422
+
423
+ def __call__(self, hidden_states):
424
+ cls_hidden_state = hidden_states[:, 0]
425
+ cls_hidden_state = self.dense(cls_hidden_state)
426
+ return self.activation(cls_hidden_state)
427
+
428
+
429
+ class FlaxViTPreTrainedModel(FlaxPreTrainedModel):
430
+ """
431
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
432
+ models.
433
+ """
434
+
435
+ config_class = ViTConfig
436
+ base_model_prefix = "vit"
437
+ main_input_name = "pixel_values"
438
+ module_class: nn.Module = None
439
+
440
+ def __init__(
441
+ self,
442
+ config: ViTConfig,
443
+ input_shape=None,
444
+ seed: int = 0,
445
+ dtype: jnp.dtype = jnp.float32,
446
+ _do_init: bool = True,
447
+ **kwargs,
448
+ ):
449
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
450
+ if input_shape is None:
451
+ input_shape = (1, config.image_size, config.image_size, config.num_channels)
452
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
453
+
454
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
455
+ # init input tensors
456
+ pixel_values = jnp.zeros(input_shape, dtype=self.dtype)
457
+
458
+ params_rng, dropout_rng = jax.random.split(rng)
459
+ rngs = {"params": params_rng, "dropout": dropout_rng}
460
+
461
+ random_params = self.module.init(rngs, pixel_values, return_dict=False)["params"]
462
+
463
+ if params is not None:
464
+ random_params = flatten_dict(unfreeze(random_params))
465
+ params = flatten_dict(unfreeze(params))
466
+ for missing_key in self._missing_keys:
467
+ params[missing_key] = random_params[missing_key]
468
+ self._missing_keys = set()
469
+ return freeze(unflatten_dict(params))
470
+ else:
471
+ return random_params
472
+
473
+ @add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
474
+ def __call__(
475
+ self,
476
+ pixel_values,
477
+ params: dict = None,
478
+ dropout_rng: jax.random.PRNGKey = None,
479
+ train: bool = False,
480
+ output_attentions: Optional[bool] = None,
481
+ output_hidden_states: Optional[bool] = None,
482
+ return_dict: Optional[bool] = None,
483
+ ):
484
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
485
+ output_hidden_states = (
486
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
487
+ )
488
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
489
+
490
+ pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1))
491
+ # Handle any PRNG if needed
492
+ rngs = {}
493
+ if dropout_rng is not None:
494
+ rngs["dropout"] = dropout_rng
495
+
496
+ return self.module.apply(
497
+ {"params": params or self.params},
498
+ jnp.array(pixel_values, dtype=jnp.float32),
499
+ not train,
500
+ output_attentions,
501
+ output_hidden_states,
502
+ return_dict,
503
+ rngs=rngs,
504
+ )
505
+
506
+
507
+ class FlaxViTModule(nn.Module):
508
+ config: ViTConfig
509
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
510
+ add_pooling_layer: bool = True
511
+
512
+ def setup(self):
513
+ self.embeddings = FlaxViTEmbeddings(self.config, dtype=self.dtype)
514
+ self.encoder = FlaxViTEncoder(self.config, dtype=self.dtype)
515
+ self.layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
516
+ self.pooler = FlaxViTPooler(self.config, dtype=self.dtype) if self.add_pooling_layer else None
517
+
518
+ def __call__(
519
+ self,
520
+ pixel_values,
521
+ deterministic: bool = True,
522
+ output_attentions: bool = False,
523
+ output_hidden_states: bool = False,
524
+ return_dict: bool = True,
525
+ ):
526
+ hidden_states = self.embeddings(pixel_values, deterministic=deterministic)
527
+
528
+ outputs = self.encoder(
529
+ hidden_states,
530
+ deterministic=deterministic,
531
+ output_attentions=output_attentions,
532
+ output_hidden_states=output_hidden_states,
533
+ return_dict=return_dict,
534
+ )
535
+ hidden_states = outputs[0]
536
+ hidden_states = self.layernorm(hidden_states)
537
+ pooled = self.pooler(hidden_states) if self.add_pooling_layer else None
538
+
539
+ if not return_dict:
540
+ # if pooled is None, don't return it
541
+ if pooled is None:
542
+ return (hidden_states,) + outputs[1:]
543
+ return (hidden_states, pooled) + outputs[1:]
544
+
545
+ return FlaxBaseModelOutputWithPooling(
546
+ last_hidden_state=hidden_states,
547
+ pooler_output=pooled,
548
+ hidden_states=outputs.hidden_states,
549
+ attentions=outputs.attentions,
550
+ )
551
+
552
+
553
+ @add_start_docstrings(
554
+ "The bare ViT Model transformer outputting raw hidden-states without any specific head on top.",
555
+ VIT_START_DOCSTRING,
556
+ )
557
+ class FlaxViTModel(FlaxViTPreTrainedModel):
558
+ module_class = FlaxViTModule
559
+
560
+
561
+ FLAX_VISION_MODEL_DOCSTRING = """
562
+ Returns:
563
+
564
+ Examples:
565
+
566
+ ```python
567
+ >>> from transformers import AutoImageProcessor, FlaxViTModel
568
+ >>> from PIL import Image
569
+ >>> import requests
570
+
571
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
572
+ >>> image = Image.open(requests.get(url, stream=True).raw)
573
+
574
+ >>> image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
575
+ >>> model = FlaxViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
576
+
577
+ >>> inputs = image_processor(images=image, return_tensors="np")
578
+ >>> outputs = model(**inputs)
579
+ >>> last_hidden_states = outputs.last_hidden_state
580
+ ```
581
+ """
582
+
583
+ overwrite_call_docstring(FlaxViTModel, FLAX_VISION_MODEL_DOCSTRING)
584
+ append_replace_return_docstrings(FlaxViTModel, output_type=FlaxBaseModelOutputWithPooling, config_class=ViTConfig)
585
+
586
+
587
+ class FlaxViTForImageClassificationModule(nn.Module):
588
+ config: ViTConfig
589
+ dtype: jnp.dtype = jnp.float32
590
+
591
+ def setup(self):
592
+ self.vit = FlaxViTModule(config=self.config, dtype=self.dtype, add_pooling_layer=False)
593
+ self.classifier = nn.Dense(
594
+ self.config.num_labels,
595
+ dtype=self.dtype,
596
+ kernel_init=jax.nn.initializers.variance_scaling(
597
+ self.config.initializer_range**2, "fan_in", "truncated_normal"
598
+ ),
599
+ )
600
+
601
+ def __call__(
602
+ self,
603
+ pixel_values=None,
604
+ deterministic: bool = True,
605
+ output_attentions=None,
606
+ output_hidden_states=None,
607
+ return_dict=None,
608
+ ):
609
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
610
+
611
+ outputs = self.vit(
612
+ pixel_values,
613
+ deterministic=deterministic,
614
+ output_attentions=output_attentions,
615
+ output_hidden_states=output_hidden_states,
616
+ return_dict=return_dict,
617
+ )
618
+
619
+ hidden_states = outputs[0]
620
+ logits = self.classifier(hidden_states[:, 0, :])
621
+
622
+ if not return_dict:
623
+ output = (logits,) + outputs[2:]
624
+ return output
625
+
626
+ return FlaxSequenceClassifierOutput(
627
+ logits=logits,
628
+ hidden_states=outputs.hidden_states,
629
+ attentions=outputs.attentions,
630
+ )
631
+
632
+
633
+ @add_start_docstrings(
634
+ """
635
+ ViT Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
636
+ the [CLS] token) e.g. for ImageNet.
637
+ """,
638
+ VIT_START_DOCSTRING,
639
+ )
640
+ class FlaxViTForImageClassification(FlaxViTPreTrainedModel):
641
+ module_class = FlaxViTForImageClassificationModule
642
+
643
+
644
+ FLAX_VISION_CLASSIF_DOCSTRING = """
645
+ Returns:
646
+
647
+ Example:
648
+
649
+ ```python
650
+ >>> from transformers import AutoImageProcessor, FlaxViTForImageClassification
651
+ >>> from PIL import Image
652
+ >>> import jax
653
+ >>> import requests
654
+
655
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
656
+ >>> image = Image.open(requests.get(url, stream=True).raw)
657
+
658
+ >>> image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")
659
+ >>> model = FlaxViTForImageClassification.from_pretrained("google/vit-base-patch16-224")
660
+
661
+ >>> inputs = image_processor(images=image, return_tensors="np")
662
+ >>> outputs = model(**inputs)
663
+ >>> logits = outputs.logits
664
+
665
+ >>> # model predicts one of the 1000 ImageNet classes
666
+ >>> predicted_class_idx = jax.numpy.argmax(logits, axis=-1)
667
+ >>> print("Predicted class:", model.config.id2label[predicted_class_idx.item()])
668
+ ```
669
+ """
670
+
671
+ overwrite_call_docstring(FlaxViTForImageClassification, FLAX_VISION_CLASSIF_DOCSTRING)
672
+ append_replace_return_docstrings(
673
+ FlaxViTForImageClassification, output_type=FlaxSequenceClassifierOutput, config_class=ViTConfig
674
+ )
675
+
676
+
677
+ __all__ = ["FlaxViTForImageClassification", "FlaxViTModel", "FlaxViTPreTrainedModel"]
docs/transformers/build/lib/transformers/models/vit/modeling_tf_vit.py ADDED
@@ -0,0 +1,907 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 Google AI, Ross Wightman, The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """TF 2.0 ViT model."""
16
+
17
+ from __future__ import annotations
18
+
19
+ import collections.abc
20
+ import math
21
+ from typing import Optional, Tuple, Union
22
+
23
+ import numpy as np
24
+ import tensorflow as tf
25
+
26
+ from ...activations_tf import get_tf_activation
27
+ from ...modeling_tf_outputs import TFBaseModelOutput, TFBaseModelOutputWithPooling, TFSequenceClassifierOutput
28
+ from ...modeling_tf_utils import (
29
+ TFModelInputType,
30
+ TFPreTrainedModel,
31
+ TFSequenceClassificationLoss,
32
+ get_initializer,
33
+ keras,
34
+ keras_serializable,
35
+ unpack_inputs,
36
+ )
37
+ from ...tf_utils import shape_list, stable_softmax
38
+ from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
39
+ from .configuration_vit import ViTConfig
40
+
41
+
42
+ logger = logging.get_logger(__name__)
43
+
44
+ # General docstring
45
+ _CONFIG_FOR_DOC = "ViTConfig"
46
+
47
+ # Base docstring
48
+ _CHECKPOINT_FOR_DOC = "google/vit-base-patch16-224-in21k"
49
+ _EXPECTED_OUTPUT_SHAPE = [1, 197, 768]
50
+
51
+ # Image classification docstring
52
+ _IMAGE_CLASS_CHECKPOINT = "google/vit-base-patch16-224"
53
+ _IMAGE_CLASS_EXPECTED_OUTPUT = "Egyptian cat"
54
+
55
+
56
+ class TFViTEmbeddings(keras.layers.Layer):
57
+ """
58
+ Construct the CLS token, position and patch embeddings.
59
+
60
+ """
61
+
62
+ def __init__(self, config: ViTConfig, **kwargs):
63
+ super().__init__(**kwargs)
64
+
65
+ self.patch_embeddings = TFViTPatchEmbeddings(config, name="patch_embeddings")
66
+ self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
67
+ self.config = config
68
+
69
+ def build(self, input_shape=None):
70
+ num_patches = self.patch_embeddings.num_patches
71
+ self.cls_token = self.add_weight(
72
+ shape=(1, 1, self.config.hidden_size),
73
+ initializer=get_initializer(self.config.initializer_range),
74
+ trainable=True,
75
+ name="cls_token",
76
+ )
77
+ self.position_embeddings = self.add_weight(
78
+ shape=(1, num_patches + 1, self.config.hidden_size),
79
+ initializer=get_initializer(self.config.initializer_range),
80
+ trainable=True,
81
+ name="position_embeddings",
82
+ )
83
+
84
+ if self.built:
85
+ return
86
+ self.built = True
87
+ if getattr(self, "patch_embeddings", None) is not None:
88
+ with tf.name_scope(self.patch_embeddings.name):
89
+ self.patch_embeddings.build(None)
90
+
91
+ def interpolate_pos_encoding(self, embeddings, height, width) -> tf.Tensor:
92
+ """
93
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
94
+ resolution images.
95
+
96
+ Source:
97
+ https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
98
+ """
99
+
100
+ batch_size, seq_len, dim = shape_list(embeddings)
101
+ num_patches = seq_len - 1
102
+
103
+ _, num_positions, _ = shape_list(self.position_embeddings)
104
+ num_positions -= 1
105
+
106
+ if num_patches == num_positions and height == width:
107
+ return self.position_embeddings
108
+ class_pos_embed = self.position_embeddings[:, :1]
109
+ patch_pos_embed = self.position_embeddings[:, 1:]
110
+ h0 = height // self.config.patch_size
111
+ w0 = width // self.config.patch_size
112
+ patch_pos_embed = tf.image.resize(
113
+ images=tf.reshape(
114
+ patch_pos_embed, shape=(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
115
+ ),
116
+ size=(h0, w0),
117
+ method="bicubic",
118
+ )
119
+
120
+ shape = shape_list(patch_pos_embed)
121
+ assert h0 == shape[-3] and w0 == shape[-2]
122
+ patch_pos_embed = tf.reshape(tensor=patch_pos_embed, shape=(1, -1, dim))
123
+ return tf.concat(values=(class_pos_embed, patch_pos_embed), axis=1)
124
+
125
+ def call(
126
+ self, pixel_values: tf.Tensor, interpolate_pos_encoding: bool = False, training: bool = False
127
+ ) -> tf.Tensor:
128
+ batch_size, num_channels, height, width = shape_list(pixel_values)
129
+ embeddings = self.patch_embeddings(
130
+ pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, training=training
131
+ )
132
+
133
+ # add the [CLS] token to the embedded patch tokens
134
+ cls_tokens = tf.repeat(self.cls_token, repeats=batch_size, axis=0)
135
+ embeddings = tf.concat((cls_tokens, embeddings), axis=1)
136
+
137
+ # add positional encoding to each token
138
+ if interpolate_pos_encoding:
139
+ embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
140
+ else:
141
+ embeddings = embeddings + self.position_embeddings
142
+
143
+ embeddings = self.dropout(embeddings, training=training)
144
+
145
+ return embeddings
146
+
147
+
148
+ # Based on timm implementation, which can be found here:
149
+ # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
150
+ class TFViTPatchEmbeddings(keras.layers.Layer):
151
+ """
152
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
153
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
154
+ Transformer.
155
+ """
156
+
157
+ def __init__(self, config: ViTConfig, **kwargs):
158
+ super().__init__(**kwargs)
159
+ image_size, patch_size = config.image_size, config.patch_size
160
+ num_channels, hidden_size = config.num_channels, config.hidden_size
161
+
162
+ image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
163
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
164
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
165
+ self.image_size = image_size
166
+ self.patch_size = patch_size
167
+ self.num_patches = num_patches
168
+ self.num_channels = num_channels
169
+ self.config = config
170
+
171
+ self.projection = keras.layers.Conv2D(
172
+ filters=hidden_size,
173
+ kernel_size=patch_size,
174
+ strides=patch_size,
175
+ padding="valid",
176
+ data_format="channels_last",
177
+ use_bias=True,
178
+ kernel_initializer=get_initializer(self.config.initializer_range),
179
+ bias_initializer="zeros",
180
+ name="projection",
181
+ )
182
+
183
+ def call(
184
+ self, pixel_values: tf.Tensor, interpolate_pos_encoding: bool = False, training: bool = False
185
+ ) -> tf.Tensor:
186
+ batch_size, num_channels, height, width = shape_list(pixel_values)
187
+ if tf.executing_eagerly() and num_channels != self.num_channels:
188
+ raise ValueError(
189
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
190
+ )
191
+ if not interpolate_pos_encoding:
192
+ if tf.executing_eagerly():
193
+ if height != self.image_size[0] or width != self.image_size[1]:
194
+ raise ValueError(
195
+ f"Input image size ({height}*{width}) doesn't match model"
196
+ f" ({self.image_size[0]}*{self.image_size[1]})."
197
+ )
198
+
199
+ # When running on CPU, `keras.layers.Conv2D` doesn't support `NCHW` format.
200
+ # So change the input format from `NCHW` to `NHWC`.
201
+ # shape = (batch_size, in_height, in_width, in_channels=num_channels)
202
+ pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1))
203
+
204
+ projection = self.projection(pixel_values)
205
+
206
+ # Change the 2D spatial dimensions to a single temporal dimension.
207
+ # shape = (batch_size, num_patches, out_channels=embed_dim)
208
+ num_patches = (width // self.patch_size[1]) * (height // self.patch_size[0])
209
+ embeddings = tf.reshape(tensor=projection, shape=(batch_size, num_patches, -1))
210
+
211
+ return embeddings
212
+
213
+ def build(self, input_shape=None):
214
+ if self.built:
215
+ return
216
+ self.built = True
217
+ if getattr(self, "projection", None) is not None:
218
+ with tf.name_scope(self.projection.name):
219
+ self.projection.build([None, None, None, self.num_channels])
220
+
221
+
222
+ class TFViTSelfAttention(keras.layers.Layer):
223
+ def __init__(self, config: ViTConfig, **kwargs):
224
+ super().__init__(**kwargs)
225
+
226
+ if config.hidden_size % config.num_attention_heads != 0:
227
+ raise ValueError(
228
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number "
229
+ f"of attention heads ({config.num_attention_heads})"
230
+ )
231
+
232
+ self.num_attention_heads = config.num_attention_heads
233
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
234
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
235
+ self.sqrt_att_head_size = math.sqrt(self.attention_head_size)
236
+
237
+ self.query = keras.layers.Dense(
238
+ units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
239
+ )
240
+ self.key = keras.layers.Dense(
241
+ units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
242
+ )
243
+ self.value = keras.layers.Dense(
244
+ units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
245
+ )
246
+ self.dropout = keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
247
+ self.config = config
248
+
249
+ def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
250
+ # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
251
+ tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
252
+
253
+ # Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size]
254
+ return tf.transpose(tensor, perm=[0, 2, 1, 3])
255
+
256
+ def call(
257
+ self,
258
+ hidden_states: tf.Tensor,
259
+ head_mask: tf.Tensor,
260
+ output_attentions: bool,
261
+ training: bool = False,
262
+ ) -> Tuple[tf.Tensor]:
263
+ batch_size = shape_list(hidden_states)[0]
264
+ mixed_query_layer = self.query(inputs=hidden_states)
265
+ mixed_key_layer = self.key(inputs=hidden_states)
266
+ mixed_value_layer = self.value(inputs=hidden_states)
267
+ query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
268
+ key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
269
+ value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
270
+
271
+ # Take the dot product between "query" and "key" to get the raw attention scores.
272
+ # (batch size, num_heads, seq_len_q, seq_len_k)
273
+ attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
274
+ dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype)
275
+ attention_scores = tf.divide(attention_scores, dk)
276
+
277
+ # Normalize the attention scores to probabilities.
278
+ attention_probs = stable_softmax(logits=attention_scores, axis=-1)
279
+
280
+ # This is actually dropping out entire tokens to attend to, which might
281
+ # seem a bit unusual, but is taken from the original Transformer paper.
282
+ attention_probs = self.dropout(inputs=attention_probs, training=training)
283
+
284
+ # Mask heads if we want to
285
+ if head_mask is not None:
286
+ attention_probs = tf.multiply(attention_probs, head_mask)
287
+
288
+ attention_output = tf.matmul(attention_probs, value_layer)
289
+ attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3])
290
+
291
+ # (batch_size, seq_len_q, all_head_size)
292
+ attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size))
293
+ outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
294
+
295
+ return outputs
296
+
297
+ def build(self, input_shape=None):
298
+ if self.built:
299
+ return
300
+ self.built = True
301
+ if getattr(self, "query", None) is not None:
302
+ with tf.name_scope(self.query.name):
303
+ self.query.build([None, None, self.config.hidden_size])
304
+ if getattr(self, "key", None) is not None:
305
+ with tf.name_scope(self.key.name):
306
+ self.key.build([None, None, self.config.hidden_size])
307
+ if getattr(self, "value", None) is not None:
308
+ with tf.name_scope(self.value.name):
309
+ self.value.build([None, None, self.config.hidden_size])
310
+
311
+
312
+ class TFViTSelfOutput(keras.layers.Layer):
313
+ """
314
+ The residual connection is defined in TFViTLayer instead of here (as is the case with other models), due to the
315
+ layernorm applied before each block.
316
+ """
317
+
318
+ def __init__(self, config: ViTConfig, **kwargs):
319
+ super().__init__(**kwargs)
320
+
321
+ self.dense = keras.layers.Dense(
322
+ units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
323
+ )
324
+ self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
325
+ self.config = config
326
+
327
+ def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
328
+ hidden_states = self.dense(inputs=hidden_states)
329
+ hidden_states = self.dropout(inputs=hidden_states, training=training)
330
+
331
+ return hidden_states
332
+
333
+ def build(self, input_shape=None):
334
+ if self.built:
335
+ return
336
+ self.built = True
337
+ if getattr(self, "dense", None) is not None:
338
+ with tf.name_scope(self.dense.name):
339
+ self.dense.build([None, None, self.config.hidden_size])
340
+
341
+
342
+ class TFViTAttention(keras.layers.Layer):
343
+ def __init__(self, config: ViTConfig, **kwargs):
344
+ super().__init__(**kwargs)
345
+
346
+ self.self_attention = TFViTSelfAttention(config, name="attention")
347
+ self.dense_output = TFViTSelfOutput(config, name="output")
348
+
349
+ def prune_heads(self, heads):
350
+ raise NotImplementedError
351
+
352
+ def call(
353
+ self,
354
+ input_tensor: tf.Tensor,
355
+ head_mask: tf.Tensor,
356
+ output_attentions: bool,
357
+ training: bool = False,
358
+ ) -> Tuple[tf.Tensor]:
359
+ self_outputs = self.self_attention(
360
+ hidden_states=input_tensor, head_mask=head_mask, output_attentions=output_attentions, training=training
361
+ )
362
+ attention_output = self.dense_output(
363
+ hidden_states=self_outputs[0], input_tensor=input_tensor, training=training
364
+ )
365
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
366
+
367
+ return outputs
368
+
369
+ def build(self, input_shape=None):
370
+ if self.built:
371
+ return
372
+ self.built = True
373
+ if getattr(self, "self_attention", None) is not None:
374
+ with tf.name_scope(self.self_attention.name):
375
+ self.self_attention.build(None)
376
+ if getattr(self, "dense_output", None) is not None:
377
+ with tf.name_scope(self.dense_output.name):
378
+ self.dense_output.build(None)
379
+
380
+
381
+ class TFViTIntermediate(keras.layers.Layer):
382
+ def __init__(self, config: ViTConfig, **kwargs):
383
+ super().__init__(**kwargs)
384
+
385
+ self.dense = keras.layers.Dense(
386
+ units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
387
+ )
388
+
389
+ if isinstance(config.hidden_act, str):
390
+ self.intermediate_act_fn = get_tf_activation(config.hidden_act)
391
+ else:
392
+ self.intermediate_act_fn = config.hidden_act
393
+ self.config = config
394
+
395
+ def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
396
+ hidden_states = self.dense(inputs=hidden_states)
397
+ hidden_states = self.intermediate_act_fn(hidden_states)
398
+
399
+ return hidden_states
400
+
401
+ def build(self, input_shape=None):
402
+ if self.built:
403
+ return
404
+ self.built = True
405
+ if getattr(self, "dense", None) is not None:
406
+ with tf.name_scope(self.dense.name):
407
+ self.dense.build([None, None, self.config.hidden_size])
408
+
409
+
410
+ class TFViTOutput(keras.layers.Layer):
411
+ def __init__(self, config: ViTConfig, **kwargs):
412
+ super().__init__(**kwargs)
413
+
414
+ self.dense = keras.layers.Dense(
415
+ units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
416
+ )
417
+ self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
418
+ self.config = config
419
+
420
+ def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
421
+ hidden_states = self.dense(inputs=hidden_states)
422
+ hidden_states = self.dropout(inputs=hidden_states, training=training)
423
+ hidden_states = hidden_states + input_tensor
424
+
425
+ return hidden_states
426
+
427
+ def build(self, input_shape=None):
428
+ if self.built:
429
+ return
430
+ self.built = True
431
+ if getattr(self, "dense", None) is not None:
432
+ with tf.name_scope(self.dense.name):
433
+ self.dense.build([None, None, self.config.intermediate_size])
434
+
435
+
436
+ class TFViTLayer(keras.layers.Layer):
437
+ """This corresponds to the Block class in the timm implementation."""
438
+
439
+ def __init__(self, config: ViTConfig, **kwargs):
440
+ super().__init__(**kwargs)
441
+
442
+ self.attention = TFViTAttention(config, name="attention")
443
+ self.intermediate = TFViTIntermediate(config, name="intermediate")
444
+ self.vit_output = TFViTOutput(config, name="output")
445
+
446
+ self.layernorm_before = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_before")
447
+ self.layernorm_after = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_after")
448
+ self.config = config
449
+
450
+ def call(
451
+ self,
452
+ hidden_states: tf.Tensor,
453
+ head_mask: tf.Tensor,
454
+ output_attentions: bool,
455
+ training: bool = False,
456
+ ) -> Tuple[tf.Tensor]:
457
+ attention_outputs = self.attention(
458
+ # in ViT, layernorm is applied before self-attention
459
+ input_tensor=self.layernorm_before(inputs=hidden_states),
460
+ head_mask=head_mask,
461
+ output_attentions=output_attentions,
462
+ training=training,
463
+ )
464
+ attention_output = attention_outputs[0]
465
+
466
+ # first residual connection
467
+ hidden_states = attention_output + hidden_states
468
+
469
+ # in ViT, layernorm is also applied after self-attention
470
+ layer_output = self.layernorm_after(inputs=hidden_states)
471
+
472
+ intermediate_output = self.intermediate(hidden_states=layer_output)
473
+
474
+ # second residual connection is done here
475
+ layer_output = self.vit_output(
476
+ hidden_states=intermediate_output, input_tensor=hidden_states, training=training
477
+ )
478
+ outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
479
+
480
+ return outputs
481
+
482
+ def build(self, input_shape=None):
483
+ if self.built:
484
+ return
485
+ self.built = True
486
+ if getattr(self, "attention", None) is not None:
487
+ with tf.name_scope(self.attention.name):
488
+ self.attention.build(None)
489
+ if getattr(self, "intermediate", None) is not None:
490
+ with tf.name_scope(self.intermediate.name):
491
+ self.intermediate.build(None)
492
+ if getattr(self, "vit_output", None) is not None:
493
+ with tf.name_scope(self.vit_output.name):
494
+ self.vit_output.build(None)
495
+ if getattr(self, "layernorm_before", None) is not None:
496
+ with tf.name_scope(self.layernorm_before.name):
497
+ self.layernorm_before.build([None, None, self.config.hidden_size])
498
+ if getattr(self, "layernorm_after", None) is not None:
499
+ with tf.name_scope(self.layernorm_after.name):
500
+ self.layernorm_after.build([None, None, self.config.hidden_size])
501
+
502
+
503
+ class TFViTEncoder(keras.layers.Layer):
504
+ def __init__(self, config: ViTConfig, **kwargs):
505
+ super().__init__(**kwargs)
506
+
507
+ self.layer = [TFViTLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)]
508
+
509
+ def call(
510
+ self,
511
+ hidden_states: tf.Tensor,
512
+ head_mask: tf.Tensor,
513
+ output_attentions: bool,
514
+ output_hidden_states: bool,
515
+ return_dict: bool,
516
+ training: bool = False,
517
+ ) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
518
+ all_hidden_states = () if output_hidden_states else None
519
+ all_attentions = () if output_attentions else None
520
+
521
+ for i, layer_module in enumerate(self.layer):
522
+ if output_hidden_states:
523
+ all_hidden_states = all_hidden_states + (hidden_states,)
524
+
525
+ layer_outputs = layer_module(
526
+ hidden_states=hidden_states,
527
+ head_mask=head_mask[i],
528
+ output_attentions=output_attentions,
529
+ training=training,
530
+ )
531
+ hidden_states = layer_outputs[0]
532
+
533
+ if output_attentions:
534
+ all_attentions = all_attentions + (layer_outputs[1],)
535
+
536
+ # Add last layer
537
+ if output_hidden_states:
538
+ all_hidden_states = all_hidden_states + (hidden_states,)
539
+
540
+ if not return_dict:
541
+ return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
542
+
543
+ return TFBaseModelOutput(
544
+ last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
545
+ )
546
+
547
+ def build(self, input_shape=None):
548
+ if self.built:
549
+ return
550
+ self.built = True
551
+ if getattr(self, "layer", None) is not None:
552
+ for layer in self.layer:
553
+ with tf.name_scope(layer.name):
554
+ layer.build(None)
555
+
556
+
557
+ @keras_serializable
558
+ class TFViTMainLayer(keras.layers.Layer):
559
+ config_class = ViTConfig
560
+
561
+ def __init__(self, config: ViTConfig, add_pooling_layer: bool = True, **kwargs):
562
+ super().__init__(**kwargs)
563
+
564
+ self.config = config
565
+
566
+ self.embeddings = TFViTEmbeddings(config, name="embeddings")
567
+ self.encoder = TFViTEncoder(config, name="encoder")
568
+ self.layernorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm")
569
+ self.pooler = TFViTPooler(config, name="pooler") if add_pooling_layer else None
570
+
571
+ def get_input_embeddings(self) -> keras.layers.Layer:
572
+ return self.embeddings.patch_embeddings
573
+
574
+ def _prune_heads(self, heads_to_prune):
575
+ """
576
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
577
+ class PreTrainedModel
578
+ """
579
+ raise NotImplementedError
580
+
581
+ @unpack_inputs
582
+ def call(
583
+ self,
584
+ pixel_values: TFModelInputType | None = None,
585
+ head_mask: np.ndarray | tf.Tensor | None = None,
586
+ output_attentions: Optional[bool] = None,
587
+ output_hidden_states: Optional[bool] = None,
588
+ interpolate_pos_encoding: Optional[bool] = None,
589
+ return_dict: Optional[bool] = None,
590
+ training: bool = False,
591
+ ) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
592
+ if pixel_values is None:
593
+ raise ValueError("You have to specify pixel_values")
594
+
595
+ embedding_output = self.embeddings(
596
+ pixel_values=pixel_values,
597
+ interpolate_pos_encoding=interpolate_pos_encoding,
598
+ training=training,
599
+ )
600
+
601
+ # Prepare head mask if needed
602
+ # 1.0 in head_mask indicate we keep the head
603
+ # attention_probs has shape bsz x n_heads x N x N
604
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
605
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
606
+ if head_mask is not None:
607
+ raise NotImplementedError
608
+ else:
609
+ head_mask = [None] * self.config.num_hidden_layers
610
+
611
+ encoder_outputs = self.encoder(
612
+ hidden_states=embedding_output,
613
+ head_mask=head_mask,
614
+ output_attentions=output_attentions,
615
+ output_hidden_states=output_hidden_states,
616
+ return_dict=return_dict,
617
+ training=training,
618
+ )
619
+
620
+ sequence_output = encoder_outputs[0]
621
+ sequence_output = self.layernorm(inputs=sequence_output)
622
+ pooled_output = self.pooler(hidden_states=sequence_output) if self.pooler is not None else None
623
+
624
+ if not return_dict:
625
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
626
+
627
+ return TFBaseModelOutputWithPooling(
628
+ last_hidden_state=sequence_output,
629
+ pooler_output=pooled_output,
630
+ hidden_states=encoder_outputs.hidden_states,
631
+ attentions=encoder_outputs.attentions,
632
+ )
633
+
634
+ def build(self, input_shape=None):
635
+ if self.built:
636
+ return
637
+ self.built = True
638
+ if getattr(self, "embeddings", None) is not None:
639
+ with tf.name_scope(self.embeddings.name):
640
+ self.embeddings.build(None)
641
+ if getattr(self, "encoder", None) is not None:
642
+ with tf.name_scope(self.encoder.name):
643
+ self.encoder.build(None)
644
+ if getattr(self, "layernorm", None) is not None:
645
+ with tf.name_scope(self.layernorm.name):
646
+ self.layernorm.build([None, None, self.config.hidden_size])
647
+ if getattr(self, "pooler", None) is not None:
648
+ with tf.name_scope(self.pooler.name):
649
+ self.pooler.build(None)
650
+
651
+
652
+ class TFViTPreTrainedModel(TFPreTrainedModel):
653
+ """
654
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
655
+ models.
656
+ """
657
+
658
+ config_class = ViTConfig
659
+ base_model_prefix = "vit"
660
+ main_input_name = "pixel_values"
661
+
662
+
663
+ VIT_START_DOCSTRING = r"""
664
+
665
+ This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
666
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
667
+ etc.)
668
+
669
+ This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
670
+ as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
671
+ behavior.
672
+
673
+ <Tip>
674
+
675
+ TensorFlow models and layers in `transformers` accept two formats as input:
676
+
677
+ - having all inputs as keyword arguments (like PyTorch models), or
678
+ - having all inputs as a list, tuple or dict in the first positional argument.
679
+
680
+ The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
681
+ and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
682
+ pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
683
+ format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
684
+ the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
685
+ positional argument:
686
+
687
+ - a single Tensor with `pixel_values` only and nothing else: `model(pixel_values)`
688
+ - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
689
+ `model([pixel_values, attention_mask])` or `model([pixel_values, attention_mask, token_type_ids])`
690
+ - a dictionary with one or several input Tensors associated to the input names given in the docstring:
691
+ `model({"pixel_values": pixel_values, "token_type_ids": token_type_ids})`
692
+
693
+ Note that when creating models and layers with
694
+ [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
695
+ about any of this, as you can just pass inputs like you would to any other Python function!
696
+
697
+ </Tip>
698
+
699
+ Args:
700
+ config ([`ViTConfig`]): Model configuration class with all the parameters of the model.
701
+ Initializing with a config file does not load the weights associated with the model, only the
702
+ configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
703
+ """
704
+
705
+ VIT_INPUTS_DOCSTRING = r"""
706
+ Args:
707
+ pixel_values (`np.ndarray`, `tf.Tensor`, `List[tf.Tensor]` ``Dict[str, tf.Tensor]` or `Dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`):
708
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`]
709
+ for details.
710
+
711
+ head_mask (`np.ndarray` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
712
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
713
+
714
+ - 1 indicates the head is **not masked**,
715
+ - 0 indicates the head is **masked**.
716
+
717
+ output_attentions (`bool`, *optional*):
718
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
719
+ tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
720
+ config will be used instead.
721
+ output_hidden_states (`bool`, *optional*):
722
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
723
+ more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
724
+ used instead.
725
+ interpolate_pos_encoding (`bool`, *optional*):
726
+ Whether to interpolate the pre-trained position encodings.
727
+ return_dict (`bool`, *optional*):
728
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
729
+ eager mode, in graph mode the value will always be set to True.
730
+ training (`bool`, *optional*, defaults to `False``):
731
+ Whether or not to use the model in training mode (some modules like dropout modules have different
732
+ behaviors between training and evaluation).
733
+ """
734
+
735
+
736
+ @add_start_docstrings(
737
+ "The bare ViT Model transformer outputting raw hidden-states without any specific head on top.",
738
+ VIT_START_DOCSTRING,
739
+ )
740
+ class TFViTModel(TFViTPreTrainedModel):
741
+ def __init__(self, config: ViTConfig, *inputs, add_pooling_layer=True, **kwargs):
742
+ super().__init__(config, *inputs, **kwargs)
743
+
744
+ self.vit = TFViTMainLayer(config, add_pooling_layer=add_pooling_layer, name="vit")
745
+
746
+ @unpack_inputs
747
+ @add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
748
+ @add_code_sample_docstrings(
749
+ checkpoint=_CHECKPOINT_FOR_DOC,
750
+ output_type=TFBaseModelOutputWithPooling,
751
+ config_class=_CONFIG_FOR_DOC,
752
+ modality="vision",
753
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
754
+ )
755
+ def call(
756
+ self,
757
+ pixel_values: TFModelInputType | None = None,
758
+ head_mask: np.ndarray | tf.Tensor | None = None,
759
+ output_attentions: Optional[bool] = None,
760
+ output_hidden_states: Optional[bool] = None,
761
+ interpolate_pos_encoding: Optional[bool] = None,
762
+ return_dict: Optional[bool] = None,
763
+ training: bool = False,
764
+ ) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
765
+ outputs = self.vit(
766
+ pixel_values=pixel_values,
767
+ head_mask=head_mask,
768
+ output_attentions=output_attentions,
769
+ output_hidden_states=output_hidden_states,
770
+ interpolate_pos_encoding=interpolate_pos_encoding,
771
+ return_dict=return_dict,
772
+ training=training,
773
+ )
774
+
775
+ return outputs
776
+
777
+ def build(self, input_shape=None):
778
+ if self.built:
779
+ return
780
+ self.built = True
781
+ if getattr(self, "vit", None) is not None:
782
+ with tf.name_scope(self.vit.name):
783
+ self.vit.build(None)
784
+
785
+
786
+ class TFViTPooler(keras.layers.Layer):
787
+ def __init__(self, config: ViTConfig, **kwargs):
788
+ super().__init__(**kwargs)
789
+
790
+ self.dense = keras.layers.Dense(
791
+ units=config.pooler_output_size,
792
+ kernel_initializer=get_initializer(config.initializer_range),
793
+ activation=config.pooler_act,
794
+ name="dense",
795
+ )
796
+ self.config = config
797
+
798
+ def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
799
+ # We "pool" the model by simply taking the hidden state corresponding
800
+ # to the first token.
801
+ first_token_tensor = hidden_states[:, 0]
802
+ pooled_output = self.dense(inputs=first_token_tensor)
803
+
804
+ return pooled_output
805
+
806
+ def build(self, input_shape=None):
807
+ if self.built:
808
+ return
809
+ self.built = True
810
+ if getattr(self, "dense", None) is not None:
811
+ with tf.name_scope(self.dense.name):
812
+ self.dense.build([None, None, self.config.hidden_size])
813
+
814
+
815
+ @add_start_docstrings(
816
+ """
817
+ ViT Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
818
+ the [CLS] token) e.g. for ImageNet.
819
+
820
+ <Tip>
821
+
822
+ Note that it's possible to fine-tune ViT on higher resolution images than the ones it has been trained on, by
823
+ setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
824
+ position embeddings to the higher resolution.
825
+
826
+ </Tip>
827
+ """,
828
+ VIT_START_DOCSTRING,
829
+ )
830
+ class TFViTForImageClassification(TFViTPreTrainedModel, TFSequenceClassificationLoss):
831
+ def __init__(self, config: ViTConfig, *inputs, **kwargs):
832
+ super().__init__(config, *inputs, **kwargs)
833
+
834
+ self.num_labels = config.num_labels
835
+ self.vit = TFViTMainLayer(config, add_pooling_layer=False, name="vit")
836
+
837
+ # Classifier head
838
+ self.classifier = keras.layers.Dense(
839
+ units=config.num_labels,
840
+ kernel_initializer=get_initializer(config.initializer_range),
841
+ name="classifier",
842
+ )
843
+ self.config = config
844
+
845
+ @unpack_inputs
846
+ @add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
847
+ @add_code_sample_docstrings(
848
+ checkpoint=_IMAGE_CLASS_CHECKPOINT,
849
+ output_type=TFSequenceClassifierOutput,
850
+ config_class=_CONFIG_FOR_DOC,
851
+ expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
852
+ )
853
+ def call(
854
+ self,
855
+ pixel_values: TFModelInputType | None = None,
856
+ head_mask: np.ndarray | tf.Tensor | None = None,
857
+ output_attentions: Optional[bool] = None,
858
+ output_hidden_states: Optional[bool] = None,
859
+ interpolate_pos_encoding: Optional[bool] = None,
860
+ return_dict: Optional[bool] = None,
861
+ labels: np.ndarray | tf.Tensor | None = None,
862
+ training: Optional[bool] = False,
863
+ ) -> Union[TFSequenceClassifierOutput, Tuple[tf.Tensor]]:
864
+ r"""
865
+ labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*):
866
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
867
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
868
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
869
+ """
870
+
871
+ outputs = self.vit(
872
+ pixel_values=pixel_values,
873
+ head_mask=head_mask,
874
+ output_attentions=output_attentions,
875
+ output_hidden_states=output_hidden_states,
876
+ interpolate_pos_encoding=interpolate_pos_encoding,
877
+ return_dict=return_dict,
878
+ training=training,
879
+ )
880
+ sequence_output = outputs[0]
881
+ logits = self.classifier(inputs=sequence_output[:, 0, :])
882
+ loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
883
+
884
+ if not return_dict:
885
+ output = (logits,) + outputs[2:]
886
+ return ((loss,) + output) if loss is not None else output
887
+
888
+ return TFSequenceClassifierOutput(
889
+ loss=loss,
890
+ logits=logits,
891
+ hidden_states=outputs.hidden_states,
892
+ attentions=outputs.attentions,
893
+ )
894
+
895
+ def build(self, input_shape=None):
896
+ if self.built:
897
+ return
898
+ self.built = True
899
+ if getattr(self, "vit", None) is not None:
900
+ with tf.name_scope(self.vit.name):
901
+ self.vit.build(None)
902
+ if getattr(self, "classifier", None) is not None:
903
+ with tf.name_scope(self.classifier.name):
904
+ self.classifier.build([None, None, self.config.hidden_size])
905
+
906
+
907
+ __all__ = ["TFViTForImageClassification", "TFViTModel", "TFViTPreTrainedModel"]
docs/transformers/build/lib/transformers/models/vit/modeling_vit.py ADDED
@@ -0,0 +1,883 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 Google AI, Ross Wightman, The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch ViT model."""
16
+
17
+ import collections.abc
18
+ import math
19
+ from typing import Callable, Dict, List, Optional, Set, Tuple, Union
20
+
21
+ import torch
22
+ import torch.utils.checkpoint
23
+ from torch import nn
24
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
25
+
26
+ from ...activations import ACT2FN
27
+ from ...modeling_outputs import (
28
+ BaseModelOutput,
29
+ BaseModelOutputWithPooling,
30
+ ImageClassifierOutput,
31
+ MaskedImageModelingOutput,
32
+ )
33
+ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
34
+ from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
35
+ from ...utils import (
36
+ add_code_sample_docstrings,
37
+ add_start_docstrings,
38
+ add_start_docstrings_to_model_forward,
39
+ logging,
40
+ replace_return_docstrings,
41
+ torch_int,
42
+ )
43
+ from .configuration_vit import ViTConfig
44
+
45
+
46
+ logger = logging.get_logger(__name__)
47
+
48
+ # General docstring
49
+ _CONFIG_FOR_DOC = "ViTConfig"
50
+
51
+ # Base docstring
52
+ _CHECKPOINT_FOR_DOC = "google/vit-base-patch16-224-in21k"
53
+ _EXPECTED_OUTPUT_SHAPE = [1, 197, 768]
54
+
55
+ # Image classification docstring
56
+ _IMAGE_CLASS_CHECKPOINT = "google/vit-base-patch16-224"
57
+ _IMAGE_CLASS_EXPECTED_OUTPUT = "Egyptian cat"
58
+
59
+
60
+ class ViTEmbeddings(nn.Module):
61
+ """
62
+ Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
63
+ """
64
+
65
+ def __init__(self, config: ViTConfig, use_mask_token: bool = False) -> None:
66
+ super().__init__()
67
+
68
+ self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
69
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
70
+ self.patch_embeddings = ViTPatchEmbeddings(config)
71
+ num_patches = self.patch_embeddings.num_patches
72
+ self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size))
73
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
74
+ self.patch_size = config.patch_size
75
+ self.config = config
76
+
77
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
78
+ """
79
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
80
+ images. This method is also adapted to support torch.jit tracing.
81
+
82
+ Adapted from:
83
+ - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
84
+ - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
85
+ """
86
+
87
+ num_patches = embeddings.shape[1] - 1
88
+ num_positions = self.position_embeddings.shape[1] - 1
89
+
90
+ # always interpolate when tracing to ensure the exported model works for dynamic input shapes
91
+ if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
92
+ return self.position_embeddings
93
+
94
+ class_pos_embed = self.position_embeddings[:, :1]
95
+ patch_pos_embed = self.position_embeddings[:, 1:]
96
+
97
+ dim = embeddings.shape[-1]
98
+
99
+ new_height = height // self.patch_size
100
+ new_width = width // self.patch_size
101
+
102
+ sqrt_num_positions = torch_int(num_positions**0.5)
103
+ patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
104
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
105
+
106
+ patch_pos_embed = nn.functional.interpolate(
107
+ patch_pos_embed,
108
+ size=(new_height, new_width),
109
+ mode="bicubic",
110
+ align_corners=False,
111
+ )
112
+
113
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
114
+
115
+ return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
116
+
117
+ def forward(
118
+ self,
119
+ pixel_values: torch.Tensor,
120
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
121
+ interpolate_pos_encoding: bool = False,
122
+ ) -> torch.Tensor:
123
+ batch_size, num_channels, height, width = pixel_values.shape
124
+ embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
125
+
126
+ if bool_masked_pos is not None:
127
+ seq_length = embeddings.shape[1]
128
+ mask_tokens = self.mask_token.expand(batch_size, seq_length, -1)
129
+ # replace the masked visual tokens by mask_tokens
130
+ mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
131
+ embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
132
+
133
+ # add the [CLS] token to the embedded patch tokens
134
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
135
+ embeddings = torch.cat((cls_tokens, embeddings), dim=1)
136
+
137
+ # add positional encoding to each token
138
+ if interpolate_pos_encoding:
139
+ embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
140
+ else:
141
+ embeddings = embeddings + self.position_embeddings
142
+
143
+ embeddings = self.dropout(embeddings)
144
+
145
+ return embeddings
146
+
147
+
148
+ class ViTPatchEmbeddings(nn.Module):
149
+ """
150
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
151
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
152
+ Transformer.
153
+ """
154
+
155
+ def __init__(self, config):
156
+ super().__init__()
157
+ image_size, patch_size = config.image_size, config.patch_size
158
+ num_channels, hidden_size = config.num_channels, config.hidden_size
159
+
160
+ image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
161
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
162
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
163
+ self.image_size = image_size
164
+ self.patch_size = patch_size
165
+ self.num_channels = num_channels
166
+ self.num_patches = num_patches
167
+
168
+ self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
169
+
170
+ def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
171
+ batch_size, num_channels, height, width = pixel_values.shape
172
+ if num_channels != self.num_channels:
173
+ raise ValueError(
174
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
175
+ f" Expected {self.num_channels} but got {num_channels}."
176
+ )
177
+ if not interpolate_pos_encoding:
178
+ if height != self.image_size[0] or width != self.image_size[1]:
179
+ raise ValueError(
180
+ f"Input image size ({height}*{width}) doesn't match model"
181
+ f" ({self.image_size[0]}*{self.image_size[1]})."
182
+ )
183
+ embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
184
+ return embeddings
185
+
186
+
187
+ def eager_attention_forward(
188
+ module: nn.Module,
189
+ query: torch.Tensor,
190
+ key: torch.Tensor,
191
+ value: torch.Tensor,
192
+ attention_mask: Optional[torch.Tensor],
193
+ scaling: float,
194
+ dropout: float = 0.0,
195
+ **kwargs,
196
+ ):
197
+ # Take the dot product between "query" and "key" to get the raw attention scores.
198
+ attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
199
+
200
+ # Normalize the attention scores to probabilities.
201
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
202
+
203
+ # This is actually dropping out entire tokens to attend to, which might
204
+ # seem a bit unusual, but is taken from the original Transformer paper.
205
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
206
+
207
+ # Mask heads if we want to
208
+ if attention_mask is not None:
209
+ attn_weights = attn_weights * attention_mask
210
+
211
+ attn_output = torch.matmul(attn_weights, value)
212
+ attn_output = attn_output.transpose(1, 2).contiguous()
213
+
214
+ return attn_output, attn_weights
215
+
216
+
217
+ class ViTSelfAttention(nn.Module):
218
+ def __init__(self, config: ViTConfig) -> None:
219
+ super().__init__()
220
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
221
+ raise ValueError(
222
+ f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
223
+ f"heads {config.num_attention_heads}."
224
+ )
225
+
226
+ self.config = config
227
+ self.num_attention_heads = config.num_attention_heads
228
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
229
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
230
+ self.dropout_prob = config.attention_probs_dropout_prob
231
+ self.scaling = self.attention_head_size**-0.5
232
+ self.is_causal = False
233
+
234
+ self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
235
+ self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
236
+ self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
237
+
238
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
239
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
240
+ x = x.view(new_x_shape)
241
+ return x.permute(0, 2, 1, 3)
242
+
243
+ def forward(
244
+ self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
245
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
246
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
247
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
248
+ query_layer = self.transpose_for_scores(self.query(hidden_states))
249
+
250
+ attention_interface: Callable = eager_attention_forward
251
+ if self.config._attn_implementation != "eager":
252
+ if self.config._attn_implementation == "sdpa" and output_attentions:
253
+ logger.warning_once(
254
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
255
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
256
+ )
257
+ else:
258
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
259
+
260
+ context_layer, attention_probs = attention_interface(
261
+ self,
262
+ query_layer,
263
+ key_layer,
264
+ value_layer,
265
+ head_mask,
266
+ is_causal=self.is_causal,
267
+ scaling=self.scaling,
268
+ dropout=0.0 if not self.training else self.dropout_prob,
269
+ )
270
+
271
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
272
+ context_layer = context_layer.reshape(new_context_layer_shape)
273
+
274
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
275
+
276
+ return outputs
277
+
278
+
279
+ class ViTSelfOutput(nn.Module):
280
+ """
281
+ The residual connection is defined in ViTLayer instead of here (as is the case with other models), due to the
282
+ layernorm applied before each block.
283
+ """
284
+
285
+ def __init__(self, config: ViTConfig) -> None:
286
+ super().__init__()
287
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
288
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
289
+
290
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
291
+ hidden_states = self.dense(hidden_states)
292
+ hidden_states = self.dropout(hidden_states)
293
+
294
+ return hidden_states
295
+
296
+
297
+ class ViTAttention(nn.Module):
298
+ def __init__(self, config: ViTConfig) -> None:
299
+ super().__init__()
300
+ self.attention = ViTSelfAttention(config)
301
+ self.output = ViTSelfOutput(config)
302
+ self.pruned_heads = set()
303
+
304
+ def prune_heads(self, heads: Set[int]) -> None:
305
+ if len(heads) == 0:
306
+ return
307
+ heads, index = find_pruneable_heads_and_indices(
308
+ heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
309
+ )
310
+
311
+ # Prune linear layers
312
+ self.attention.query = prune_linear_layer(self.attention.query, index)
313
+ self.attention.key = prune_linear_layer(self.attention.key, index)
314
+ self.attention.value = prune_linear_layer(self.attention.value, index)
315
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
316
+
317
+ # Update hyper params and store pruned heads
318
+ self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
319
+ self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
320
+ self.pruned_heads = self.pruned_heads.union(heads)
321
+
322
+ def forward(
323
+ self,
324
+ hidden_states: torch.Tensor,
325
+ head_mask: Optional[torch.Tensor] = None,
326
+ output_attentions: bool = False,
327
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
328
+ self_outputs = self.attention(hidden_states, head_mask, output_attentions)
329
+
330
+ attention_output = self.output(self_outputs[0], hidden_states)
331
+
332
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
333
+ return outputs
334
+
335
+
336
+ class ViTIntermediate(nn.Module):
337
+ def __init__(self, config: ViTConfig) -> None:
338
+ super().__init__()
339
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
340
+ if isinstance(config.hidden_act, str):
341
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
342
+ else:
343
+ self.intermediate_act_fn = config.hidden_act
344
+
345
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
346
+ hidden_states = self.dense(hidden_states)
347
+ hidden_states = self.intermediate_act_fn(hidden_states)
348
+
349
+ return hidden_states
350
+
351
+
352
+ class ViTOutput(nn.Module):
353
+ def __init__(self, config: ViTConfig) -> None:
354
+ super().__init__()
355
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
356
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
357
+
358
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
359
+ hidden_states = self.dense(hidden_states)
360
+ hidden_states = self.dropout(hidden_states)
361
+
362
+ hidden_states = hidden_states + input_tensor
363
+
364
+ return hidden_states
365
+
366
+
367
+ class ViTLayer(nn.Module):
368
+ """This corresponds to the Block class in the timm implementation."""
369
+
370
+ def __init__(self, config: ViTConfig) -> None:
371
+ super().__init__()
372
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
373
+ self.seq_len_dim = 1
374
+ self.attention = ViTAttention(config)
375
+ self.intermediate = ViTIntermediate(config)
376
+ self.output = ViTOutput(config)
377
+ self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
378
+ self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
379
+
380
+ def forward(
381
+ self,
382
+ hidden_states: torch.Tensor,
383
+ head_mask: Optional[torch.Tensor] = None,
384
+ output_attentions: bool = False,
385
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
386
+ self_attention_outputs = self.attention(
387
+ self.layernorm_before(hidden_states), # in ViT, layernorm is applied before self-attention
388
+ head_mask,
389
+ output_attentions=output_attentions,
390
+ )
391
+ attention_output = self_attention_outputs[0]
392
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
393
+
394
+ # first residual connection
395
+ hidden_states = attention_output + hidden_states
396
+
397
+ # in ViT, layernorm is also applied after self-attention
398
+ layer_output = self.layernorm_after(hidden_states)
399
+ layer_output = self.intermediate(layer_output)
400
+
401
+ # second residual connection is done here
402
+ layer_output = self.output(layer_output, hidden_states)
403
+
404
+ outputs = (layer_output,) + outputs
405
+
406
+ return outputs
407
+
408
+
409
+ class ViTEncoder(nn.Module):
410
+ def __init__(self, config: ViTConfig) -> None:
411
+ super().__init__()
412
+ self.config = config
413
+ self.layer = nn.ModuleList([ViTLayer(config) for _ in range(config.num_hidden_layers)])
414
+ self.gradient_checkpointing = False
415
+
416
+ def forward(
417
+ self,
418
+ hidden_states: torch.Tensor,
419
+ head_mask: Optional[torch.Tensor] = None,
420
+ output_attentions: bool = False,
421
+ output_hidden_states: bool = False,
422
+ return_dict: bool = True,
423
+ ) -> Union[tuple, BaseModelOutput]:
424
+ all_hidden_states = () if output_hidden_states else None
425
+ all_self_attentions = () if output_attentions else None
426
+
427
+ for i, layer_module in enumerate(self.layer):
428
+ if output_hidden_states:
429
+ all_hidden_states = all_hidden_states + (hidden_states,)
430
+
431
+ layer_head_mask = head_mask[i] if head_mask is not None else None
432
+
433
+ if self.gradient_checkpointing and self.training:
434
+ layer_outputs = self._gradient_checkpointing_func(
435
+ layer_module.__call__,
436
+ hidden_states,
437
+ layer_head_mask,
438
+ output_attentions,
439
+ )
440
+ else:
441
+ layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
442
+
443
+ hidden_states = layer_outputs[0]
444
+
445
+ if output_attentions:
446
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
447
+
448
+ if output_hidden_states:
449
+ all_hidden_states = all_hidden_states + (hidden_states,)
450
+
451
+ if not return_dict:
452
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
453
+ return BaseModelOutput(
454
+ last_hidden_state=hidden_states,
455
+ hidden_states=all_hidden_states,
456
+ attentions=all_self_attentions,
457
+ )
458
+
459
+
460
+ class ViTPreTrainedModel(PreTrainedModel):
461
+ """
462
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
463
+ models.
464
+ """
465
+
466
+ config_class = ViTConfig
467
+ base_model_prefix = "vit"
468
+ main_input_name = "pixel_values"
469
+ supports_gradient_checkpointing = True
470
+ _no_split_modules = ["ViTEmbeddings", "ViTLayer"]
471
+ _supports_sdpa = True
472
+ _supports_flash_attn_2 = True
473
+
474
+ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
475
+ """Initialize the weights"""
476
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
477
+ # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
478
+ # `trunc_normal_cpu` not implemented in `half` issues
479
+ module.weight.data = nn.init.trunc_normal_(
480
+ module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
481
+ ).to(module.weight.dtype)
482
+ if module.bias is not None:
483
+ module.bias.data.zero_()
484
+ elif isinstance(module, nn.LayerNorm):
485
+ module.bias.data.zero_()
486
+ module.weight.data.fill_(1.0)
487
+ elif isinstance(module, ViTEmbeddings):
488
+ module.position_embeddings.data = nn.init.trunc_normal_(
489
+ module.position_embeddings.data.to(torch.float32),
490
+ mean=0.0,
491
+ std=self.config.initializer_range,
492
+ ).to(module.position_embeddings.dtype)
493
+
494
+ module.cls_token.data = nn.init.trunc_normal_(
495
+ module.cls_token.data.to(torch.float32),
496
+ mean=0.0,
497
+ std=self.config.initializer_range,
498
+ ).to(module.cls_token.dtype)
499
+
500
+ if module.mask_token is not None:
501
+ module.mask_token.data.zero_()
502
+
503
+
504
+ VIT_START_DOCSTRING = r"""
505
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
506
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
507
+ behavior.
508
+
509
+ Parameters:
510
+ config ([`ViTConfig`]): Model configuration class with all the parameters of the model.
511
+ Initializing with a config file does not load the weights associated with the model, only the
512
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
513
+ """
514
+
515
+ VIT_INPUTS_DOCSTRING = r"""
516
+ Args:
517
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
518
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`]
519
+ for details.
520
+
521
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
522
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
523
+
524
+ - 1 indicates the head is **not masked**,
525
+ - 0 indicates the head is **masked**.
526
+
527
+ output_attentions (`bool`, *optional*):
528
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
529
+ tensors for more detail.
530
+ output_hidden_states (`bool`, *optional*):
531
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
532
+ more detail.
533
+ interpolate_pos_encoding (`bool`, *optional*):
534
+ Whether to interpolate the pre-trained position encodings.
535
+ return_dict (`bool`, *optional*):
536
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
537
+ """
538
+
539
+
540
+ @add_start_docstrings(
541
+ "The bare ViT Model transformer outputting raw hidden-states without any specific head on top.",
542
+ VIT_START_DOCSTRING,
543
+ )
544
+ class ViTModel(ViTPreTrainedModel):
545
+ def __init__(self, config: ViTConfig, add_pooling_layer: bool = True, use_mask_token: bool = False):
546
+ super().__init__(config)
547
+ self.config = config
548
+
549
+ self.embeddings = ViTEmbeddings(config, use_mask_token=use_mask_token)
550
+ self.encoder = ViTEncoder(config)
551
+
552
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
553
+ self.pooler = ViTPooler(config) if add_pooling_layer else None
554
+
555
+ # Initialize weights and apply final processing
556
+ self.post_init()
557
+
558
+ def get_input_embeddings(self) -> ViTPatchEmbeddings:
559
+ return self.embeddings.patch_embeddings
560
+
561
+ def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
562
+ """
563
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
564
+ class PreTrainedModel
565
+ """
566
+ for layer, heads in heads_to_prune.items():
567
+ self.encoder.layer[layer].attention.prune_heads(heads)
568
+
569
+ @add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
570
+ @add_code_sample_docstrings(
571
+ checkpoint=_CHECKPOINT_FOR_DOC,
572
+ output_type=BaseModelOutputWithPooling,
573
+ config_class=_CONFIG_FOR_DOC,
574
+ modality="vision",
575
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
576
+ )
577
+ def forward(
578
+ self,
579
+ pixel_values: Optional[torch.Tensor] = None,
580
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
581
+ head_mask: Optional[torch.Tensor] = None,
582
+ output_attentions: Optional[bool] = None,
583
+ output_hidden_states: Optional[bool] = None,
584
+ interpolate_pos_encoding: Optional[bool] = None,
585
+ return_dict: Optional[bool] = None,
586
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
587
+ r"""
588
+ bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
589
+ Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
590
+ """
591
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
592
+ output_hidden_states = (
593
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
594
+ )
595
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
596
+
597
+ if pixel_values is None:
598
+ raise ValueError("You have to specify pixel_values")
599
+
600
+ # Prepare head mask if needed
601
+ # 1.0 in head_mask indicate we keep the head
602
+ # attention_probs has shape bsz x n_heads x N x N
603
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
604
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
605
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
606
+
607
+ # TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?)
608
+ expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype
609
+ if pixel_values.dtype != expected_dtype:
610
+ pixel_values = pixel_values.to(expected_dtype)
611
+
612
+ embedding_output = self.embeddings(
613
+ pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
614
+ )
615
+
616
+ encoder_outputs = self.encoder(
617
+ embedding_output,
618
+ head_mask=head_mask,
619
+ output_attentions=output_attentions,
620
+ output_hidden_states=output_hidden_states,
621
+ return_dict=return_dict,
622
+ )
623
+ sequence_output = encoder_outputs[0]
624
+ sequence_output = self.layernorm(sequence_output)
625
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
626
+
627
+ if not return_dict:
628
+ head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
629
+ return head_outputs + encoder_outputs[1:]
630
+
631
+ return BaseModelOutputWithPooling(
632
+ last_hidden_state=sequence_output,
633
+ pooler_output=pooled_output,
634
+ hidden_states=encoder_outputs.hidden_states,
635
+ attentions=encoder_outputs.attentions,
636
+ )
637
+
638
+
639
+ class ViTPooler(nn.Module):
640
+ def __init__(self, config: ViTConfig):
641
+ super().__init__()
642
+ self.dense = nn.Linear(config.hidden_size, config.pooler_output_size)
643
+ self.activation = ACT2FN[config.pooler_act]
644
+
645
+ def forward(self, hidden_states):
646
+ # We "pool" the model by simply taking the hidden state corresponding
647
+ # to the first token.
648
+ first_token_tensor = hidden_states[:, 0]
649
+ pooled_output = self.dense(first_token_tensor)
650
+ pooled_output = self.activation(pooled_output)
651
+ return pooled_output
652
+
653
+
654
+ @add_start_docstrings(
655
+ """ViT Model with a decoder on top for masked image modeling, as proposed in [SimMIM](https://arxiv.org/abs/2111.09886).
656
+
657
+ <Tip>
658
+
659
+ Note that we provide a script to pre-train this model on custom data in our [examples
660
+ directory](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-pretraining).
661
+
662
+ </Tip>
663
+ """,
664
+ VIT_START_DOCSTRING,
665
+ )
666
+ class ViTForMaskedImageModeling(ViTPreTrainedModel):
667
+ def __init__(self, config: ViTConfig) -> None:
668
+ super().__init__(config)
669
+
670
+ self.vit = ViTModel(config, add_pooling_layer=False, use_mask_token=True)
671
+
672
+ self.decoder = nn.Sequential(
673
+ nn.Conv2d(
674
+ in_channels=config.hidden_size,
675
+ out_channels=config.encoder_stride**2 * config.num_channels,
676
+ kernel_size=1,
677
+ ),
678
+ nn.PixelShuffle(config.encoder_stride),
679
+ )
680
+
681
+ # Initialize weights and apply final processing
682
+ self.post_init()
683
+
684
+ @add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
685
+ @replace_return_docstrings(output_type=MaskedImageModelingOutput, config_class=_CONFIG_FOR_DOC)
686
+ def forward(
687
+ self,
688
+ pixel_values: Optional[torch.Tensor] = None,
689
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
690
+ head_mask: Optional[torch.Tensor] = None,
691
+ output_attentions: Optional[bool] = None,
692
+ output_hidden_states: Optional[bool] = None,
693
+ interpolate_pos_encoding: Optional[bool] = None,
694
+ return_dict: Optional[bool] = None,
695
+ ) -> Union[tuple, MaskedImageModelingOutput]:
696
+ r"""
697
+ bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
698
+ Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
699
+
700
+ Returns:
701
+
702
+ Examples:
703
+ ```python
704
+ >>> from transformers import AutoImageProcessor, ViTForMaskedImageModeling
705
+ >>> import torch
706
+ >>> from PIL import Image
707
+ >>> import requests
708
+
709
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
710
+ >>> image = Image.open(requests.get(url, stream=True).raw)
711
+
712
+ >>> image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
713
+ >>> model = ViTForMaskedImageModeling.from_pretrained("google/vit-base-patch16-224-in21k")
714
+
715
+ >>> num_patches = (model.config.image_size // model.config.patch_size) ** 2
716
+ >>> pixel_values = image_processor(images=image, return_tensors="pt").pixel_values
717
+ >>> # create random boolean mask of shape (batch_size, num_patches)
718
+ >>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()
719
+
720
+ >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
721
+ >>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction
722
+ >>> list(reconstructed_pixel_values.shape)
723
+ [1, 3, 224, 224]
724
+ ```"""
725
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
726
+
727
+ if bool_masked_pos is not None and (self.config.patch_size != self.config.encoder_stride):
728
+ raise ValueError(
729
+ "When `bool_masked_pos` is provided, `patch_size` must be equal to `encoder_stride` to ensure that "
730
+ "the reconstructed image has the same dimensions as the input. "
731
+ f"Got `patch_size` = {self.config.patch_size} and `encoder_stride` = {self.config.encoder_stride}."
732
+ )
733
+
734
+ outputs = self.vit(
735
+ pixel_values,
736
+ bool_masked_pos=bool_masked_pos,
737
+ head_mask=head_mask,
738
+ output_attentions=output_attentions,
739
+ output_hidden_states=output_hidden_states,
740
+ interpolate_pos_encoding=interpolate_pos_encoding,
741
+ return_dict=return_dict,
742
+ )
743
+
744
+ sequence_output = outputs[0]
745
+
746
+ # Reshape to (batch_size, num_channels, height, width)
747
+ sequence_output = sequence_output[:, 1:]
748
+ batch_size, sequence_length, num_channels = sequence_output.shape
749
+ height = width = math.floor(sequence_length**0.5)
750
+ sequence_output = sequence_output.permute(0, 2, 1).reshape(batch_size, num_channels, height, width)
751
+
752
+ # Reconstruct pixel values
753
+ reconstructed_pixel_values = self.decoder(sequence_output)
754
+
755
+ masked_im_loss = None
756
+ if bool_masked_pos is not None:
757
+ size = self.config.image_size // self.config.patch_size
758
+ bool_masked_pos = bool_masked_pos.reshape(-1, size, size)
759
+ mask = (
760
+ bool_masked_pos.repeat_interleave(self.config.patch_size, 1)
761
+ .repeat_interleave(self.config.patch_size, 2)
762
+ .unsqueeze(1)
763
+ .contiguous()
764
+ )
765
+ reconstruction_loss = nn.functional.l1_loss(pixel_values, reconstructed_pixel_values, reduction="none")
766
+ masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels
767
+
768
+ if not return_dict:
769
+ output = (reconstructed_pixel_values,) + outputs[1:]
770
+ return ((masked_im_loss,) + output) if masked_im_loss is not None else output
771
+
772
+ return MaskedImageModelingOutput(
773
+ loss=masked_im_loss,
774
+ reconstruction=reconstructed_pixel_values,
775
+ hidden_states=outputs.hidden_states,
776
+ attentions=outputs.attentions,
777
+ )
778
+
779
+
780
+ @add_start_docstrings(
781
+ """
782
+ ViT Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
783
+ the [CLS] token) e.g. for ImageNet.
784
+
785
+ <Tip>
786
+
787
+ Note that it's possible to fine-tune ViT on higher resolution images than the ones it has been trained on, by
788
+ setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
789
+ position embeddings to the higher resolution.
790
+
791
+ </Tip>
792
+ """,
793
+ VIT_START_DOCSTRING,
794
+ )
795
+ class ViTForImageClassification(ViTPreTrainedModel):
796
+ def __init__(self, config: ViTConfig) -> None:
797
+ super().__init__(config)
798
+
799
+ self.num_labels = config.num_labels
800
+ self.vit = ViTModel(config, add_pooling_layer=False)
801
+
802
+ # Classifier head
803
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
804
+
805
+ # Initialize weights and apply final processing
806
+ self.post_init()
807
+
808
+ @add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
809
+ @add_code_sample_docstrings(
810
+ checkpoint=_IMAGE_CLASS_CHECKPOINT,
811
+ output_type=ImageClassifierOutput,
812
+ config_class=_CONFIG_FOR_DOC,
813
+ expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
814
+ )
815
+ def forward(
816
+ self,
817
+ pixel_values: Optional[torch.Tensor] = None,
818
+ head_mask: Optional[torch.Tensor] = None,
819
+ labels: Optional[torch.Tensor] = None,
820
+ output_attentions: Optional[bool] = None,
821
+ output_hidden_states: Optional[bool] = None,
822
+ interpolate_pos_encoding: Optional[bool] = None,
823
+ return_dict: Optional[bool] = None,
824
+ ) -> Union[tuple, ImageClassifierOutput]:
825
+ r"""
826
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
827
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
828
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
829
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
830
+ """
831
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
832
+
833
+ outputs = self.vit(
834
+ pixel_values,
835
+ head_mask=head_mask,
836
+ output_attentions=output_attentions,
837
+ output_hidden_states=output_hidden_states,
838
+ interpolate_pos_encoding=interpolate_pos_encoding,
839
+ return_dict=return_dict,
840
+ )
841
+
842
+ sequence_output = outputs[0]
843
+
844
+ logits = self.classifier(sequence_output[:, 0, :])
845
+
846
+ loss = None
847
+ if labels is not None:
848
+ # move labels to correct device to enable model parallelism
849
+ labels = labels.to(logits.device)
850
+ if self.config.problem_type is None:
851
+ if self.num_labels == 1:
852
+ self.config.problem_type = "regression"
853
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
854
+ self.config.problem_type = "single_label_classification"
855
+ else:
856
+ self.config.problem_type = "multi_label_classification"
857
+
858
+ if self.config.problem_type == "regression":
859
+ loss_fct = MSELoss()
860
+ if self.num_labels == 1:
861
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
862
+ else:
863
+ loss = loss_fct(logits, labels)
864
+ elif self.config.problem_type == "single_label_classification":
865
+ loss_fct = CrossEntropyLoss()
866
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
867
+ elif self.config.problem_type == "multi_label_classification":
868
+ loss_fct = BCEWithLogitsLoss()
869
+ loss = loss_fct(logits, labels)
870
+
871
+ if not return_dict:
872
+ output = (logits,) + outputs[1:]
873
+ return ((loss,) + output) if loss is not None else output
874
+
875
+ return ImageClassifierOutput(
876
+ loss=loss,
877
+ logits=logits,
878
+ hidden_states=outputs.hidden_states,
879
+ attentions=outputs.attentions,
880
+ )
881
+
882
+
883
+ __all__ = ["ViTForImageClassification", "ViTForMaskedImageModeling", "ViTModel", "ViTPreTrainedModel"]
docs/transformers/build/lib/transformers/models/vit_mae/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_vit_mae import *
22
+ from .modeling_tf_vit_mae import *
23
+ from .modeling_vit_mae import *
24
+ else:
25
+ import sys
26
+
27
+ _file = globals()["__file__"]
28
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/build/lib/transformers/models/vit_mae/configuration_vit_mae.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 Facebook AI and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ViT MAE model configuration"""
16
+
17
+ from ...configuration_utils import PretrainedConfig
18
+ from ...utils import logging
19
+
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+
24
+ class ViTMAEConfig(PretrainedConfig):
25
+ r"""
26
+ This is the configuration class to store the configuration of a [`ViTMAEModel`]. It is used to instantiate an ViT
27
+ MAE model according to the specified arguments, defining the model architecture. Instantiating a configuration with
28
+ the defaults will yield a similar configuration to that of the ViT
29
+ [facebook/vit-mae-base](https://huggingface.co/facebook/vit-mae-base) architecture.
30
+
31
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
32
+ documentation from [`PretrainedConfig`] for more information.
33
+
34
+
35
+ Args:
36
+ hidden_size (`int`, *optional*, defaults to 768):
37
+ Dimensionality of the encoder layers and the pooler layer.
38
+ num_hidden_layers (`int`, *optional*, defaults to 12):
39
+ Number of hidden layers in the Transformer encoder.
40
+ num_attention_heads (`int`, *optional*, defaults to 12):
41
+ Number of attention heads for each attention layer in the Transformer encoder.
42
+ intermediate_size (`int`, *optional*, defaults to 3072):
43
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
44
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
45
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
46
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
47
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
48
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
49
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
50
+ The dropout ratio for the attention probabilities.
51
+ initializer_range (`float`, *optional*, defaults to 0.02):
52
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
53
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
54
+ The epsilon used by the layer normalization layers.
55
+ image_size (`int`, *optional*, defaults to 224):
56
+ The size (resolution) of each image.
57
+ patch_size (`int`, *optional*, defaults to 16):
58
+ The size (resolution) of each patch.
59
+ num_channels (`int`, *optional*, defaults to 3):
60
+ The number of input channels.
61
+ qkv_bias (`bool`, *optional*, defaults to `True`):
62
+ Whether to add a bias to the queries, keys and values.
63
+ decoder_num_attention_heads (`int`, *optional*, defaults to 16):
64
+ Number of attention heads for each attention layer in the decoder.
65
+ decoder_hidden_size (`int`, *optional*, defaults to 512):
66
+ Dimensionality of the decoder.
67
+ decoder_num_hidden_layers (`int`, *optional*, defaults to 8):
68
+ Number of hidden layers in the decoder.
69
+ decoder_intermediate_size (`int`, *optional*, defaults to 2048):
70
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the decoder.
71
+ mask_ratio (`float`, *optional*, defaults to 0.75):
72
+ The ratio of the number of masked tokens in the input sequence.
73
+ norm_pix_loss (`bool`, *optional*, defaults to `False`):
74
+ Whether or not to train with normalized pixels (see Table 3 in the paper). Using normalized pixels improved
75
+ representation quality in the experiments of the authors.
76
+
77
+ Example:
78
+
79
+ ```python
80
+ >>> from transformers import ViTMAEConfig, ViTMAEModel
81
+
82
+ >>> # Initializing a ViT MAE vit-mae-base style configuration
83
+ >>> configuration = ViTMAEConfig()
84
+
85
+ >>> # Initializing a model (with random weights) from the vit-mae-base style configuration
86
+ >>> model = ViTMAEModel(configuration)
87
+
88
+ >>> # Accessing the model configuration
89
+ >>> configuration = model.config
90
+ ```"""
91
+
92
+ model_type = "vit_mae"
93
+
94
+ def __init__(
95
+ self,
96
+ hidden_size=768,
97
+ num_hidden_layers=12,
98
+ num_attention_heads=12,
99
+ intermediate_size=3072,
100
+ hidden_act="gelu",
101
+ hidden_dropout_prob=0.0,
102
+ attention_probs_dropout_prob=0.0,
103
+ initializer_range=0.02,
104
+ layer_norm_eps=1e-12,
105
+ image_size=224,
106
+ patch_size=16,
107
+ num_channels=3,
108
+ qkv_bias=True,
109
+ decoder_num_attention_heads=16,
110
+ decoder_hidden_size=512,
111
+ decoder_num_hidden_layers=8,
112
+ decoder_intermediate_size=2048,
113
+ mask_ratio=0.75,
114
+ norm_pix_loss=False,
115
+ **kwargs,
116
+ ):
117
+ super().__init__(**kwargs)
118
+
119
+ self.hidden_size = hidden_size
120
+ self.num_hidden_layers = num_hidden_layers
121
+ self.num_attention_heads = num_attention_heads
122
+ self.intermediate_size = intermediate_size
123
+ self.hidden_act = hidden_act
124
+ self.hidden_dropout_prob = hidden_dropout_prob
125
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
126
+ self.initializer_range = initializer_range
127
+ self.layer_norm_eps = layer_norm_eps
128
+ self.image_size = image_size
129
+ self.patch_size = patch_size
130
+ self.num_channels = num_channels
131
+ self.qkv_bias = qkv_bias
132
+ self.decoder_num_attention_heads = decoder_num_attention_heads
133
+ self.decoder_hidden_size = decoder_hidden_size
134
+ self.decoder_num_hidden_layers = decoder_num_hidden_layers
135
+ self.decoder_intermediate_size = decoder_intermediate_size
136
+ self.mask_ratio = mask_ratio
137
+ self.norm_pix_loss = norm_pix_loss
138
+
139
+
140
+ __all__ = ["ViTMAEConfig"]
docs/transformers/build/lib/transformers/models/vit_mae/convert_vit_mae_to_pytorch.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Convert ViT MAE checkpoints from the original repository: https://github.com/facebookresearch/mae"""
16
+
17
+ import argparse
18
+
19
+ import requests
20
+ import torch
21
+ from PIL import Image
22
+
23
+ from transformers import ViTMAEConfig, ViTMAEForPreTraining, ViTMAEImageProcessor
24
+
25
+
26
+ def rename_key(name):
27
+ if "cls_token" in name:
28
+ name = name.replace("cls_token", "vit.embeddings.cls_token")
29
+ if "mask_token" in name:
30
+ name = name.replace("mask_token", "decoder.mask_token")
31
+ if "decoder_pos_embed" in name:
32
+ name = name.replace("decoder_pos_embed", "decoder.decoder_pos_embed")
33
+ if "pos_embed" in name and "decoder" not in name:
34
+ name = name.replace("pos_embed", "vit.embeddings.position_embeddings")
35
+ if "patch_embed.proj" in name:
36
+ name = name.replace("patch_embed.proj", "vit.embeddings.patch_embeddings.projection")
37
+ if "patch_embed.norm" in name:
38
+ name = name.replace("patch_embed.norm", "vit.embeddings.norm")
39
+ if "decoder_blocks" in name:
40
+ name = name.replace("decoder_blocks", "decoder.decoder_layers")
41
+ if "blocks" in name:
42
+ name = name.replace("blocks", "vit.encoder.layer")
43
+ if "attn.proj" in name:
44
+ name = name.replace("attn.proj", "attention.output.dense")
45
+ if "attn" in name:
46
+ name = name.replace("attn", "attention.self")
47
+ if "norm1" in name:
48
+ name = name.replace("norm1", "layernorm_before")
49
+ if "norm2" in name:
50
+ name = name.replace("norm2", "layernorm_after")
51
+ if "mlp.fc1" in name:
52
+ name = name.replace("mlp.fc1", "intermediate.dense")
53
+ if "mlp.fc2" in name:
54
+ name = name.replace("mlp.fc2", "output.dense")
55
+ if "decoder_embed" in name:
56
+ name = name.replace("decoder_embed", "decoder.decoder_embed")
57
+ if "decoder_norm" in name:
58
+ name = name.replace("decoder_norm", "decoder.decoder_norm")
59
+ if "decoder_pred" in name:
60
+ name = name.replace("decoder_pred", "decoder.decoder_pred")
61
+ if "norm.weight" in name and "decoder" not in name:
62
+ name = name.replace("norm.weight", "vit.layernorm.weight")
63
+ if "norm.bias" in name and "decoder" not in name:
64
+ name = name.replace("norm.bias", "vit.layernorm.bias")
65
+
66
+ return name
67
+
68
+
69
+ def convert_state_dict(orig_state_dict, config):
70
+ for key in orig_state_dict.copy().keys():
71
+ val = orig_state_dict.pop(key)
72
+
73
+ if "qkv" in key:
74
+ key_split = key.split(".")
75
+ layer_num = int(key_split[1])
76
+ if "decoder_blocks" in key:
77
+ dim = config.decoder_hidden_size
78
+ prefix = "decoder.decoder_layers."
79
+ if "weight" in key:
80
+ orig_state_dict[f"{prefix}{layer_num}.attention.attention.query.weight"] = val[:dim, :]
81
+ orig_state_dict[f"{prefix}{layer_num}.attention.attention.key.weight"] = val[dim : dim * 2, :]
82
+ orig_state_dict[f"{prefix}{layer_num}.attention.attention.value.weight"] = val[-dim:, :]
83
+ elif "bias" in key:
84
+ orig_state_dict[f"{prefix}{layer_num}.attention.attention.query.bias"] = val[:dim]
85
+ orig_state_dict[f"{prefix}{layer_num}.attention.attention.key.bias"] = val[dim : dim * 2]
86
+ orig_state_dict[f"{prefix}{layer_num}.attention.attention.value.bias"] = val[-dim:]
87
+ else:
88
+ dim = config.hidden_size
89
+ prefix = "vit.encoder.layer."
90
+ if "weight" in key:
91
+ orig_state_dict[f"{prefix}{layer_num}.attention.attention.query.weight"] = val[:dim, :]
92
+ orig_state_dict[f"{prefix}{layer_num}.attention.attention.key.weight"] = val[dim : dim * 2, :]
93
+ orig_state_dict[f"{prefix}{layer_num}.attention.attention.value.weight"] = val[-dim:, :]
94
+ elif "bias" in key:
95
+ orig_state_dict[f"{prefix}{layer_num}.attention.attention.query.bias"] = val[:dim]
96
+ orig_state_dict[f"{prefix}{layer_num}.attention.attention.key.bias"] = val[dim : dim * 2]
97
+ orig_state_dict[f"{prefix}{layer_num}.attention.attention.value.bias"] = val[-dim:]
98
+
99
+ else:
100
+ orig_state_dict[rename_key(key)] = val
101
+
102
+ return orig_state_dict
103
+
104
+
105
+ def convert_vit_mae_checkpoint(checkpoint_url, pytorch_dump_folder_path):
106
+ config = ViTMAEConfig()
107
+ if "large" in checkpoint_url:
108
+ config.hidden_size = 1024
109
+ config.intermediate_size = 4096
110
+ config.num_hidden_layers = 24
111
+ config.num_attention_heads = 16
112
+ elif "huge" in checkpoint_url:
113
+ config.patch_size = 14
114
+ config.hidden_size = 1280
115
+ config.intermediate_size = 5120
116
+ config.num_hidden_layers = 32
117
+ config.num_attention_heads = 16
118
+
119
+ model = ViTMAEForPreTraining(config)
120
+
121
+ state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu")["model"]
122
+
123
+ image_processor = ViTMAEImageProcessor(size=config.image_size)
124
+
125
+ new_state_dict = convert_state_dict(state_dict, config)
126
+
127
+ model.load_state_dict(new_state_dict)
128
+ model.eval()
129
+
130
+ url = "https://user-images.githubusercontent.com/11435359/147738734-196fd92f-9260-48d5-ba7e-bf103d29364d.jpg"
131
+
132
+ image = Image.open(requests.get(url, stream=True).raw)
133
+ image_processor = ViTMAEImageProcessor(size=config.image_size)
134
+ inputs = image_processor(images=image, return_tensors="pt")
135
+
136
+ # forward pass
137
+ torch.manual_seed(2)
138
+ outputs = model(**inputs)
139
+ logits = outputs.logits
140
+
141
+ if "large" in checkpoint_url:
142
+ expected_slice = torch.tensor(
143
+ [[-0.7309, -0.7128, -1.0169], [-1.0161, -0.9058, -1.1878], [-1.0478, -0.9411, -1.1911]]
144
+ )
145
+ elif "huge" in checkpoint_url:
146
+ expected_slice = torch.tensor(
147
+ [[-1.1599, -0.9199, -1.2221], [-1.1952, -0.9269, -1.2307], [-1.2143, -0.9337, -1.2262]]
148
+ )
149
+ else:
150
+ expected_slice = torch.tensor(
151
+ [[-0.9192, -0.8481, -1.1259], [-1.1349, -1.0034, -1.2599], [-1.1757, -1.0429, -1.2726]]
152
+ )
153
+
154
+ # verify logits
155
+ assert torch.allclose(logits[0, :3, :3], expected_slice, atol=1e-4)
156
+
157
+ print(f"Saving model to {pytorch_dump_folder_path}")
158
+ model.save_pretrained(pytorch_dump_folder_path)
159
+
160
+ print(f"Saving image processor to {pytorch_dump_folder_path}")
161
+ image_processor.save_pretrained(pytorch_dump_folder_path)
162
+
163
+
164
+ if __name__ == "__main__":
165
+ parser = argparse.ArgumentParser()
166
+ # Required parameters
167
+ parser.add_argument(
168
+ "--checkpoint_url",
169
+ default="https://dl.fbaipublicfiles.com/mae/visualize/mae_visualize_vit_base.pth",
170
+ type=str,
171
+ help="URL of the checkpoint you'd like to convert.",
172
+ )
173
+ parser.add_argument(
174
+ "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
175
+ )
176
+
177
+ args = parser.parse_args()
178
+ convert_vit_mae_checkpoint(args.checkpoint_url, args.pytorch_dump_folder_path)
docs/transformers/build/lib/transformers/models/vit_mae/modeling_tf_vit_mae.py ADDED
@@ -0,0 +1,1375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 Facebook AI and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """TF 2.0 ViT MAE (masked autoencoder) model."""
16
+
17
+ from __future__ import annotations
18
+
19
+ import collections.abc
20
+ import math
21
+ from copy import deepcopy
22
+ from dataclasses import dataclass
23
+ from typing import Optional, Tuple, Union
24
+
25
+ import numpy as np
26
+ import tensorflow as tf
27
+
28
+ from ...activations_tf import get_tf_activation
29
+ from ...file_utils import (
30
+ ModelOutput,
31
+ add_start_docstrings,
32
+ add_start_docstrings_to_model_forward,
33
+ replace_return_docstrings,
34
+ )
35
+ from ...modeling_tf_outputs import TFBaseModelOutput
36
+ from ...modeling_tf_utils import (
37
+ TFModelInputType,
38
+ TFPreTrainedModel,
39
+ get_initializer,
40
+ keras,
41
+ keras_serializable,
42
+ unpack_inputs,
43
+ )
44
+ from ...tf_utils import shape_list, stable_softmax
45
+ from ...utils import logging
46
+ from .configuration_vit_mae import ViTMAEConfig
47
+
48
+
49
+ logger = logging.get_logger(__name__)
50
+
51
+ _CONFIG_FOR_DOC = "ViTMAEConfig"
52
+ _CHECKPOINT_FOR_DOC = "facebook/vit-mae-base"
53
+
54
+
55
+ @dataclass
56
+ class TFViTMAEModelOutput(ModelOutput):
57
+ """
58
+ Class for TFViTMAEModel's outputs, with potential hidden states and attentions.
59
+
60
+ Args:
61
+ last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
62
+ Sequence of hidden-states at the output of the last layer of the model.
63
+ mask (`tf.Tensor` of shape `(batch_size, sequence_length)`):
64
+ Tensor indicating which patches are masked (1) and which are not (0).
65
+ ids_restore (`tf.Tensor` of shape `(batch_size, sequence_length)`):
66
+ Tensor containing the original index of the (shuffled) masked patches.
67
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
68
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
69
+ `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus
70
+ the initial embedding outputs.
71
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
72
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
73
+ sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
74
+ the self-attention heads.
75
+ """
76
+
77
+ last_hidden_state: Optional[tf.Tensor] = None
78
+ mask: Optional[tf.Tensor] = None
79
+ ids_restore: Optional[tf.Tensor] = None
80
+ hidden_states: Tuple[tf.Tensor] | None = None
81
+ attentions: Tuple[tf.Tensor] | None = None
82
+
83
+
84
+ @dataclass
85
+ class TFViTMAEDecoderOutput(ModelOutput):
86
+ """
87
+ Class for TFViTMAEDecoder's outputs, with potential hidden states and attentions.
88
+
89
+ Args:
90
+ logits (`tf.Tensor` of shape `(batch_size, sequence_length, patch_size ** 2 * num_channels)`):
91
+ Pixel reconstruction logits.
92
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
93
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
94
+ `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus
95
+ the initial embedding outputs.
96
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
97
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
98
+ sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
99
+ the self-attention heads.
100
+ """
101
+
102
+ logits: Optional[tf.Tensor] = None
103
+ hidden_states: Tuple[tf.Tensor] | None = None
104
+ attentions: Tuple[tf.Tensor] | None = None
105
+
106
+
107
+ @dataclass
108
+ class TFViTMAEForPreTrainingOutput(ModelOutput):
109
+ """
110
+ Class for TFViTMAEForPreTraining's outputs, with potential hidden states and attentions.
111
+
112
+ Args:
113
+ loss (`tf.Tensor` of shape `(1,)`):
114
+ Pixel reconstruction loss.
115
+ logits (`tf.Tensor` of shape `(batch_size, sequence_length, patch_size ** 2 * num_channels)`):
116
+ Pixel reconstruction logits.
117
+ mask (`tf.Tensor` of shape `(batch_size, sequence_length)`):
118
+ Tensor indicating which patches are masked (1) and which are not (0).
119
+ ids_restore (`tf.Tensor` of shape `(batch_size, sequence_length)`):
120
+ Tensor containing the original index of the (shuffled) masked patches.
121
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
122
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
123
+ `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus
124
+ the initial embedding outputs.
125
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
126
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
127
+ sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
128
+ the self-attention heads.
129
+ """
130
+
131
+ loss: tf.Tensor | None = None
132
+ logits: Optional[tf.Tensor] = None
133
+ mask: Optional[tf.Tensor] = None
134
+ ids_restore: Optional[tf.Tensor] = None
135
+ hidden_states: Tuple[tf.Tensor] | None = None
136
+ attentions: Tuple[tf.Tensor] | None = None
137
+
138
+
139
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False):
140
+ """
141
+ Create 2D sin/cos positional embeddings.
142
+
143
+ Args:
144
+ embed_dim (`int`):
145
+ Embedding dimension.
146
+ grid_size (`int`):
147
+ The grid height and width.
148
+ add_cls_token (`bool`, *optional*, defaults to `False`):
149
+ Whether or not to add a classification (CLS) token.
150
+
151
+ Returns:
152
+ (`tf.Tensor` of shape (grid_size*grid_size, embed_dim) or (1+grid_size*grid_size, embed_dim): the position
153
+ embeddings (with or without classification token)
154
+ """
155
+ grid_h = tf.range(grid_size, dtype=tf.float32)
156
+ grid_w = tf.range(grid_size, dtype=tf.float32)
157
+ grid = tf.meshgrid(grid_w, grid_h) # here w goes first
158
+ grid = tf.stack(grid, axis=0)
159
+
160
+ grid = tf.reshape(grid, [2, 1, grid_size, grid_size])
161
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
162
+ if add_cls_token:
163
+ pos_embed = tf.concat([tf.zeros((1, embed_dim)), pos_embed], axis=0)
164
+ return pos_embed
165
+
166
+
167
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
168
+ if embed_dim % 2 != 0:
169
+ raise ValueError("embed_dim must be even")
170
+
171
+ # use half of dimensions to encode grid_h
172
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
173
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
174
+
175
+ emb = tf.concat([emb_h, emb_w], axis=1) # (H*W, D)
176
+ return emb
177
+
178
+
179
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
180
+ """
181
+ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
182
+ """
183
+ if embed_dim % 2 != 0:
184
+ raise ValueError("embed_dim must be even")
185
+
186
+ omega = tf.range(embed_dim // 2, dtype="float32")
187
+ omega /= embed_dim / 2.0
188
+ omega = 1.0 / 10000**omega # (D/2,)
189
+
190
+ pos = tf.reshape(pos, [-1]) # (M,)
191
+ out = tf.einsum("m,d->md", pos, omega) # (M, D/2), outer product
192
+
193
+ # half of the positions get sinusoidal pattern and the rest gets
194
+ # cosine pattern and then they are concatenated
195
+ emb_sin = tf.sin(out) # (M, D/2)
196
+ emb_cos = tf.cos(out) # (M, D/2)
197
+
198
+ emb = tf.concat([emb_sin, emb_cos], axis=1) # (M, D)
199
+ return emb
200
+
201
+
202
+ class TFViTMAEEmbeddings(keras.layers.Layer):
203
+ """
204
+ Construct the CLS token, position and patch embeddings.
205
+
206
+ """
207
+
208
+ def __init__(self, config: ViTMAEConfig, **kwargs):
209
+ super().__init__(**kwargs)
210
+
211
+ self.patch_embeddings = TFViTMAEPatchEmbeddings(config, name="patch_embeddings")
212
+ self.num_patches = self.patch_embeddings.num_patches
213
+
214
+ self.config = config
215
+
216
+ def build(self, input_shape=None):
217
+ self.cls_token = self.add_weight(
218
+ shape=(1, 1, self.config.hidden_size),
219
+ initializer=tf.random_normal_initializer(stddev=self.config.initializer_range),
220
+ trainable=True,
221
+ name="cls_token",
222
+ )
223
+ self.position_embeddings = self.add_weight(
224
+ shape=(1, self.num_patches + 1, self.config.hidden_size),
225
+ initializer="zeros",
226
+ trainable=False, # fixed sin-cos embedding
227
+ name="position_embeddings",
228
+ )
229
+ pos_embed = get_2d_sincos_pos_embed(
230
+ self.position_embeddings.shape[-1],
231
+ int(self.patch_embeddings.num_patches**0.5),
232
+ add_cls_token=True,
233
+ )[None, ...]
234
+ self.position_embeddings.assign(pos_embed)
235
+
236
+ if self.built:
237
+ return
238
+ self.built = True
239
+ if getattr(self, "patch_embeddings", None) is not None:
240
+ with tf.name_scope(self.patch_embeddings.name):
241
+ self.patch_embeddings.build(None)
242
+
243
+ def interpolate_pos_encoding(self, embeddings, height, width) -> tf.Tensor:
244
+ """
245
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
246
+ resolution images.
247
+
248
+ Source:
249
+ https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
250
+ """
251
+
252
+ batch_size, seq_len, dim = shape_list(embeddings)
253
+ num_patches = seq_len - 1
254
+
255
+ _, num_positions, _ = shape_list(self.position_embeddings)
256
+ num_positions -= 1
257
+
258
+ if num_patches == num_positions and height == width:
259
+ return self.position_embeddings
260
+ class_pos_embed = self.position_embeddings[:, :1]
261
+ patch_pos_embed = self.position_embeddings[:, 1:]
262
+ h0 = height // self.config.patch_size
263
+ w0 = width // self.config.patch_size
264
+ patch_pos_embed = tf.image.resize(
265
+ images=tf.reshape(
266
+ patch_pos_embed, shape=(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
267
+ ),
268
+ size=(h0, w0),
269
+ method="bicubic",
270
+ )
271
+
272
+ patch_pos_embed = tf.reshape(tensor=patch_pos_embed, shape=(1, -1, dim))
273
+ return tf.concat(values=(class_pos_embed, patch_pos_embed), axis=1)
274
+
275
+ def random_masking(self, sequence: tf.Tensor, noise: tf.Tensor | None = None):
276
+ """
277
+ Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random
278
+ noise.
279
+
280
+ Args:
281
+ sequence (`tf.Tensor` of shape `(batch_size, sequence_length, dim)`)
282
+ noise (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*) which is
283
+ mainly used for testing purposes to control randomness and maintain the reproducibility
284
+ """
285
+ batch_size, seq_length, dim = shape_list(sequence)
286
+ len_keep = int(seq_length * (1 - self.config.mask_ratio))
287
+
288
+ if noise is None:
289
+ noise = tf.random.uniform(shape=(batch_size, seq_length), minval=0.0, maxval=1.0) # noise in [0, 1)
290
+
291
+ # sort noise for each sample
292
+ ids_shuffle = tf.argsort(noise, axis=1) # ascend: small is keep, large is remove
293
+ ids_restore = tf.argsort(ids_shuffle, axis=1)
294
+
295
+ # keep the first subset
296
+ ids_keep = ids_shuffle[:, :len_keep]
297
+ sequence_unmasked = tf.gather(
298
+ sequence,
299
+ axis=1,
300
+ batch_dims=1,
301
+ indices=ids_keep,
302
+ )
303
+
304
+ # generate the binary mask: 0 is keep, 1 is remove
305
+ # this hack is needed because TF's EagerTensors don't support
306
+ # assignment
307
+ mask_keep = tf.zeros((batch_size, len_keep))
308
+ mask_remove = tf.ones((batch_size, seq_length - len_keep))
309
+ mask = tf.concat([mask_keep, mask_remove], axis=-1)
310
+
311
+ # unshuffle to get the binary mask
312
+ mask = tf.gather(mask, axis=1, batch_dims=1, indices=ids_restore)
313
+
314
+ return sequence_unmasked, mask, ids_restore
315
+
316
+ def call(
317
+ self, pixel_values: tf.Tensor, noise: Optional[tf.Tensor] = None, interpolate_pos_encoding: bool = False
318
+ ) -> tf.Tensor:
319
+ batch_size, num_channels, height, width = shape_list(pixel_values)
320
+ embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
321
+ if interpolate_pos_encoding:
322
+ position_embeddings = self.interpolate_pos_encoding(embeddings, height, width)
323
+ else:
324
+ position_embeddings = self.position_embeddings
325
+ # add position embeddings w/o cls token
326
+ embeddings = embeddings + position_embeddings[:, 1:, :]
327
+
328
+ # masking: length -> length * config.mask_ratio
329
+ embeddings, mask, ids_restore = self.random_masking(embeddings, noise)
330
+
331
+ # append cls token
332
+ cls_token = self.cls_token + position_embeddings[:, :1, :]
333
+ cls_tokens = tf.tile(cls_token, (shape_list(embeddings)[0], 1, 1))
334
+ embeddings = tf.concat([cls_tokens, embeddings], axis=1)
335
+
336
+ return embeddings, mask, ids_restore
337
+
338
+
339
+ class TFViTMAEPatchEmbeddings(keras.layers.Layer):
340
+ """
341
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
342
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
343
+ Transformer.
344
+ """
345
+
346
+ def __init__(self, config: ViTMAEConfig, **kwargs):
347
+ super().__init__(**kwargs)
348
+ image_size, patch_size = config.image_size, config.patch_size
349
+ num_channels, hidden_size = config.num_channels, config.hidden_size
350
+ image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
351
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
352
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
353
+ self.image_size = image_size
354
+ self.patch_size = patch_size
355
+ self.num_patches = num_patches
356
+ self.num_channels = num_channels
357
+ self.config = config
358
+
359
+ self.projection = keras.layers.Conv2D(
360
+ filters=hidden_size,
361
+ kernel_size=patch_size,
362
+ strides=patch_size,
363
+ padding="valid",
364
+ data_format="channels_last",
365
+ kernel_initializer="glorot_uniform", # following torch.nn.Linear
366
+ bias_initializer="zeros",
367
+ name="projection",
368
+ )
369
+
370
+ def call(
371
+ self, pixel_values: tf.Tensor, training: bool = False, interpolate_pos_encoding: bool = False
372
+ ) -> tf.Tensor:
373
+ batch_size, num_channels, height, width = shape_list(pixel_values)
374
+ if tf.executing_eagerly():
375
+ if num_channels != self.num_channels:
376
+ raise ValueError(
377
+ "Make sure that the channel dimension of the pixel values match with the one set in the"
378
+ " configuration."
379
+ )
380
+ if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]):
381
+ raise ValueError(
382
+ f"Input image size ({height}*{width}) doesn't match model"
383
+ f" ({self.image_size[0]}*{self.image_size[1]})."
384
+ )
385
+
386
+ # When running on CPU, `keras.layers.Conv2D` doesn't support `NCHW` format.
387
+ # So change the input format from `NCHW` to `NHWC`.
388
+ # shape = (batch_size, in_height, in_width, in_channels=num_channels)
389
+ pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1))
390
+
391
+ projection = self.projection(pixel_values)
392
+
393
+ # Change the 2D spatial dimensions to a single temporal dimension.
394
+ # shape = (batch_size, num_patches, out_channels=embed_dim)
395
+ num_patches = (width // self.patch_size[1]) * (height // self.patch_size[0])
396
+ x = tf.reshape(tensor=projection, shape=(batch_size, num_patches, -1))
397
+
398
+ return x
399
+
400
+ def build(self, input_shape=None):
401
+ if self.built:
402
+ return
403
+ self.built = True
404
+ if getattr(self, "projection", None) is not None:
405
+ with tf.name_scope(self.projection.name):
406
+ self.projection.build([None, None, None, self.num_channels])
407
+
408
+
409
+ # Copied from transformers.models.vit.modeling_tf_vit.TFViTSelfAttention with ViT->ViTMAE
410
+ class TFViTMAESelfAttention(keras.layers.Layer):
411
+ def __init__(self, config: ViTMAEConfig, **kwargs):
412
+ super().__init__(**kwargs)
413
+
414
+ if config.hidden_size % config.num_attention_heads != 0:
415
+ raise ValueError(
416
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number "
417
+ f"of attention heads ({config.num_attention_heads})"
418
+ )
419
+
420
+ self.num_attention_heads = config.num_attention_heads
421
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
422
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
423
+ self.sqrt_att_head_size = math.sqrt(self.attention_head_size)
424
+
425
+ self.query = keras.layers.Dense(
426
+ units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
427
+ )
428
+ self.key = keras.layers.Dense(
429
+ units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
430
+ )
431
+ self.value = keras.layers.Dense(
432
+ units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
433
+ )
434
+ self.dropout = keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
435
+ self.config = config
436
+
437
+ def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
438
+ # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
439
+ tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
440
+
441
+ # Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size]
442
+ return tf.transpose(tensor, perm=[0, 2, 1, 3])
443
+
444
+ def call(
445
+ self,
446
+ hidden_states: tf.Tensor,
447
+ head_mask: tf.Tensor,
448
+ output_attentions: bool,
449
+ training: bool = False,
450
+ ) -> Tuple[tf.Tensor]:
451
+ batch_size = shape_list(hidden_states)[0]
452
+ mixed_query_layer = self.query(inputs=hidden_states)
453
+ mixed_key_layer = self.key(inputs=hidden_states)
454
+ mixed_value_layer = self.value(inputs=hidden_states)
455
+ query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
456
+ key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
457
+ value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
458
+
459
+ # Take the dot product between "query" and "key" to get the raw attention scores.
460
+ # (batch size, num_heads, seq_len_q, seq_len_k)
461
+ attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
462
+ dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype)
463
+ attention_scores = tf.divide(attention_scores, dk)
464
+
465
+ # Normalize the attention scores to probabilities.
466
+ attention_probs = stable_softmax(logits=attention_scores, axis=-1)
467
+
468
+ # This is actually dropping out entire tokens to attend to, which might
469
+ # seem a bit unusual, but is taken from the original Transformer paper.
470
+ attention_probs = self.dropout(inputs=attention_probs, training=training)
471
+
472
+ # Mask heads if we want to
473
+ if head_mask is not None:
474
+ attention_probs = tf.multiply(attention_probs, head_mask)
475
+
476
+ attention_output = tf.matmul(attention_probs, value_layer)
477
+ attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3])
478
+
479
+ # (batch_size, seq_len_q, all_head_size)
480
+ attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size))
481
+ outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
482
+
483
+ return outputs
484
+
485
+ def build(self, input_shape=None):
486
+ if self.built:
487
+ return
488
+ self.built = True
489
+ if getattr(self, "query", None) is not None:
490
+ with tf.name_scope(self.query.name):
491
+ self.query.build([None, None, self.config.hidden_size])
492
+ if getattr(self, "key", None) is not None:
493
+ with tf.name_scope(self.key.name):
494
+ self.key.build([None, None, self.config.hidden_size])
495
+ if getattr(self, "value", None) is not None:
496
+ with tf.name_scope(self.value.name):
497
+ self.value.build([None, None, self.config.hidden_size])
498
+
499
+
500
+ # Copied from transformers.models.vit.modeling_tf_vit.TFViTSelfOutput with ViT->ViTMAE
501
+ class TFViTMAESelfOutput(keras.layers.Layer):
502
+ """
503
+ The residual connection is defined in TFViTMAELayer instead of here (as is the case with other models), due to the
504
+ layernorm applied before each block.
505
+ """
506
+
507
+ def __init__(self, config: ViTMAEConfig, **kwargs):
508
+ super().__init__(**kwargs)
509
+
510
+ self.dense = keras.layers.Dense(
511
+ units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
512
+ )
513
+ self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
514
+ self.config = config
515
+
516
+ def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
517
+ hidden_states = self.dense(inputs=hidden_states)
518
+ hidden_states = self.dropout(inputs=hidden_states, training=training)
519
+
520
+ return hidden_states
521
+
522
+ def build(self, input_shape=None):
523
+ if self.built:
524
+ return
525
+ self.built = True
526
+ if getattr(self, "dense", None) is not None:
527
+ with tf.name_scope(self.dense.name):
528
+ self.dense.build([None, None, self.config.hidden_size])
529
+
530
+
531
+ # Copied from transformers.models.vit.modeling_tf_vit.TFViTAttention with ViT->ViTMAE
532
+ class TFViTMAEAttention(keras.layers.Layer):
533
+ def __init__(self, config: ViTMAEConfig, **kwargs):
534
+ super().__init__(**kwargs)
535
+
536
+ self.self_attention = TFViTMAESelfAttention(config, name="attention")
537
+ self.dense_output = TFViTMAESelfOutput(config, name="output")
538
+
539
+ def prune_heads(self, heads):
540
+ raise NotImplementedError
541
+
542
+ def call(
543
+ self,
544
+ input_tensor: tf.Tensor,
545
+ head_mask: tf.Tensor,
546
+ output_attentions: bool,
547
+ training: bool = False,
548
+ ) -> Tuple[tf.Tensor]:
549
+ self_outputs = self.self_attention(
550
+ hidden_states=input_tensor, head_mask=head_mask, output_attentions=output_attentions, training=training
551
+ )
552
+ attention_output = self.dense_output(
553
+ hidden_states=self_outputs[0], input_tensor=input_tensor, training=training
554
+ )
555
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
556
+
557
+ return outputs
558
+
559
+ def build(self, input_shape=None):
560
+ if self.built:
561
+ return
562
+ self.built = True
563
+ if getattr(self, "self_attention", None) is not None:
564
+ with tf.name_scope(self.self_attention.name):
565
+ self.self_attention.build(None)
566
+ if getattr(self, "dense_output", None) is not None:
567
+ with tf.name_scope(self.dense_output.name):
568
+ self.dense_output.build(None)
569
+
570
+
571
+ # Copied from transformers.models.vit.modeling_tf_vit.TFViTIntermediate with ViT->ViTMAE
572
+ class TFViTMAEIntermediate(keras.layers.Layer):
573
+ def __init__(self, config: ViTMAEConfig, **kwargs):
574
+ super().__init__(**kwargs)
575
+
576
+ self.dense = keras.layers.Dense(
577
+ units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
578
+ )
579
+
580
+ if isinstance(config.hidden_act, str):
581
+ self.intermediate_act_fn = get_tf_activation(config.hidden_act)
582
+ else:
583
+ self.intermediate_act_fn = config.hidden_act
584
+ self.config = config
585
+
586
+ def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
587
+ hidden_states = self.dense(inputs=hidden_states)
588
+ hidden_states = self.intermediate_act_fn(hidden_states)
589
+
590
+ return hidden_states
591
+
592
+ def build(self, input_shape=None):
593
+ if self.built:
594
+ return
595
+ self.built = True
596
+ if getattr(self, "dense", None) is not None:
597
+ with tf.name_scope(self.dense.name):
598
+ self.dense.build([None, None, self.config.hidden_size])
599
+
600
+
601
+ # Copied from transformers.models.vit.modeling_tf_vit.TFViTOutput with ViT->ViTMAE
602
+ class TFViTMAEOutput(keras.layers.Layer):
603
+ def __init__(self, config: ViTMAEConfig, **kwargs):
604
+ super().__init__(**kwargs)
605
+
606
+ self.dense = keras.layers.Dense(
607
+ units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
608
+ )
609
+ self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
610
+ self.config = config
611
+
612
+ def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
613
+ hidden_states = self.dense(inputs=hidden_states)
614
+ hidden_states = self.dropout(inputs=hidden_states, training=training)
615
+ hidden_states = hidden_states + input_tensor
616
+
617
+ return hidden_states
618
+
619
+ def build(self, input_shape=None):
620
+ if self.built:
621
+ return
622
+ self.built = True
623
+ if getattr(self, "dense", None) is not None:
624
+ with tf.name_scope(self.dense.name):
625
+ self.dense.build([None, None, self.config.intermediate_size])
626
+
627
+
628
+ # Copied from transformers.models.vit.modeling_tf_vit.TFViTLayer with ViT->ViTMAE
629
+ class TFViTMAELayer(keras.layers.Layer):
630
+ """This corresponds to the Block class in the timm implementation."""
631
+
632
+ def __init__(self, config: ViTMAEConfig, **kwargs):
633
+ super().__init__(**kwargs)
634
+
635
+ self.attention = TFViTMAEAttention(config, name="attention")
636
+ self.intermediate = TFViTMAEIntermediate(config, name="intermediate")
637
+ self.vit_output = TFViTMAEOutput(config, name="output")
638
+
639
+ self.layernorm_before = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_before")
640
+ self.layernorm_after = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_after")
641
+ self.config = config
642
+
643
+ def call(
644
+ self,
645
+ hidden_states: tf.Tensor,
646
+ head_mask: tf.Tensor,
647
+ output_attentions: bool,
648
+ training: bool = False,
649
+ ) -> Tuple[tf.Tensor]:
650
+ attention_outputs = self.attention(
651
+ # in ViTMAE, layernorm is applied before self-attention
652
+ input_tensor=self.layernorm_before(inputs=hidden_states),
653
+ head_mask=head_mask,
654
+ output_attentions=output_attentions,
655
+ training=training,
656
+ )
657
+ attention_output = attention_outputs[0]
658
+
659
+ # first residual connection
660
+ hidden_states = attention_output + hidden_states
661
+
662
+ # in ViTMAE, layernorm is also applied after self-attention
663
+ layer_output = self.layernorm_after(inputs=hidden_states)
664
+
665
+ intermediate_output = self.intermediate(hidden_states=layer_output)
666
+
667
+ # second residual connection is done here
668
+ layer_output = self.vit_output(
669
+ hidden_states=intermediate_output, input_tensor=hidden_states, training=training
670
+ )
671
+ outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
672
+
673
+ return outputs
674
+
675
+ def build(self, input_shape=None):
676
+ if self.built:
677
+ return
678
+ self.built = True
679
+ if getattr(self, "attention", None) is not None:
680
+ with tf.name_scope(self.attention.name):
681
+ self.attention.build(None)
682
+ if getattr(self, "intermediate", None) is not None:
683
+ with tf.name_scope(self.intermediate.name):
684
+ self.intermediate.build(None)
685
+ if getattr(self, "vit_output", None) is not None:
686
+ with tf.name_scope(self.vit_output.name):
687
+ self.vit_output.build(None)
688
+ if getattr(self, "layernorm_before", None) is not None:
689
+ with tf.name_scope(self.layernorm_before.name):
690
+ self.layernorm_before.build([None, None, self.config.hidden_size])
691
+ if getattr(self, "layernorm_after", None) is not None:
692
+ with tf.name_scope(self.layernorm_after.name):
693
+ self.layernorm_after.build([None, None, self.config.hidden_size])
694
+
695
+
696
+ # Copied from transformers.models.vit.modeling_tf_vit.TFViTEncoder with ViT->ViTMAE
697
+ class TFViTMAEEncoder(keras.layers.Layer):
698
+ def __init__(self, config: ViTMAEConfig, **kwargs):
699
+ super().__init__(**kwargs)
700
+
701
+ self.layer = [TFViTMAELayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)]
702
+
703
+ def call(
704
+ self,
705
+ hidden_states: tf.Tensor,
706
+ head_mask: tf.Tensor,
707
+ output_attentions: bool,
708
+ output_hidden_states: bool,
709
+ return_dict: bool,
710
+ training: bool = False,
711
+ ) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
712
+ all_hidden_states = () if output_hidden_states else None
713
+ all_attentions = () if output_attentions else None
714
+
715
+ for i, layer_module in enumerate(self.layer):
716
+ if output_hidden_states:
717
+ all_hidden_states = all_hidden_states + (hidden_states,)
718
+
719
+ layer_outputs = layer_module(
720
+ hidden_states=hidden_states,
721
+ head_mask=head_mask[i],
722
+ output_attentions=output_attentions,
723
+ training=training,
724
+ )
725
+ hidden_states = layer_outputs[0]
726
+
727
+ if output_attentions:
728
+ all_attentions = all_attentions + (layer_outputs[1],)
729
+
730
+ # Add last layer
731
+ if output_hidden_states:
732
+ all_hidden_states = all_hidden_states + (hidden_states,)
733
+
734
+ if not return_dict:
735
+ return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
736
+
737
+ return TFBaseModelOutput(
738
+ last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
739
+ )
740
+
741
+ def build(self, input_shape=None):
742
+ if self.built:
743
+ return
744
+ self.built = True
745
+ if getattr(self, "layer", None) is not None:
746
+ for layer in self.layer:
747
+ with tf.name_scope(layer.name):
748
+ layer.build(None)
749
+
750
+
751
+ @keras_serializable
752
+ class TFViTMAEMainLayer(keras.layers.Layer):
753
+ config_class = ViTMAEConfig
754
+
755
+ def __init__(self, config: ViTMAEConfig, **kwargs):
756
+ super().__init__(**kwargs)
757
+
758
+ self.config = config
759
+
760
+ self.embeddings = TFViTMAEEmbeddings(config, name="embeddings")
761
+ self.encoder = TFViTMAEEncoder(config, name="encoder")
762
+ self.layernorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm")
763
+
764
+ def get_input_embeddings(self) -> keras.layers.Layer:
765
+ return self.embeddings.patch_embeddings
766
+
767
+ def _prune_heads(self, heads_to_prune):
768
+ """
769
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
770
+ class PreTrainedModel
771
+ """
772
+ raise NotImplementedError
773
+
774
+ @unpack_inputs
775
+ def call(
776
+ self,
777
+ pixel_values: TFModelInputType | None = None,
778
+ noise: Optional[tf.Tensor] = None,
779
+ head_mask: np.ndarray | tf.Tensor | None = None,
780
+ output_attentions: Optional[bool] = None,
781
+ output_hidden_states: Optional[bool] = None,
782
+ return_dict: Optional[bool] = None,
783
+ training: bool = False,
784
+ interpolate_pos_encoding: bool = False,
785
+ ) -> Union[TFViTMAEModelOutput, Tuple[tf.Tensor]]:
786
+ embedding_output, mask, ids_restore = self.embeddings(
787
+ pixel_values=pixel_values,
788
+ training=training,
789
+ noise=noise,
790
+ interpolate_pos_encoding=interpolate_pos_encoding,
791
+ )
792
+
793
+ # Prepare head mask if needed
794
+ # 1.0 in head_mask indicate we keep the head
795
+ # attention_probs has shape bsz x n_heads x N x N
796
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
797
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
798
+ if head_mask is not None:
799
+ raise NotImplementedError
800
+ else:
801
+ head_mask = [None] * self.config.num_hidden_layers
802
+
803
+ encoder_outputs = self.encoder(
804
+ embedding_output,
805
+ head_mask=head_mask,
806
+ output_attentions=output_attentions,
807
+ output_hidden_states=output_hidden_states,
808
+ return_dict=return_dict,
809
+ training=training,
810
+ )
811
+
812
+ sequence_output = encoder_outputs[0]
813
+ sequence_output = self.layernorm(inputs=sequence_output)
814
+
815
+ if not return_dict:
816
+ return (sequence_output, mask, ids_restore) + encoder_outputs[1:]
817
+
818
+ return TFViTMAEModelOutput(
819
+ last_hidden_state=sequence_output,
820
+ mask=mask,
821
+ ids_restore=ids_restore,
822
+ hidden_states=encoder_outputs.hidden_states,
823
+ attentions=encoder_outputs.attentions,
824
+ )
825
+
826
+ def build(self, input_shape=None):
827
+ if self.built:
828
+ return
829
+ self.built = True
830
+ if getattr(self, "embeddings", None) is not None:
831
+ with tf.name_scope(self.embeddings.name):
832
+ self.embeddings.build(None)
833
+ if getattr(self, "encoder", None) is not None:
834
+ with tf.name_scope(self.encoder.name):
835
+ self.encoder.build(None)
836
+ if getattr(self, "layernorm", None) is not None:
837
+ with tf.name_scope(self.layernorm.name):
838
+ self.layernorm.build([None, None, self.config.hidden_size])
839
+
840
+
841
+ class TFViTMAEPreTrainedModel(TFPreTrainedModel):
842
+ """
843
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
844
+ models.
845
+ """
846
+
847
+ config_class = ViTMAEConfig
848
+ base_model_prefix = "vit"
849
+ main_input_name = "pixel_values"
850
+
851
+
852
+ VIT_MAE_START_DOCSTRING = r"""
853
+ This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
854
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
855
+ etc.)
856
+
857
+ This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
858
+ as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
859
+ behavior.
860
+
861
+ <Tip>
862
+
863
+ TensorFlow models and layers in `transformers` accept two formats as input:
864
+
865
+ - having all inputs as keyword arguments (like PyTorch models), or
866
+ - having all inputs as a list, tuple or dict in the first positional argument.
867
+
868
+ The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
869
+ and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
870
+ pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
871
+ format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
872
+ the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
873
+ positional argument:
874
+
875
+ - a single Tensor with `pixel_values` only and nothing else: `model(pixel_values)`
876
+ - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
877
+ `model([pixel_values, attention_mask])` or `model([pixel_values, attention_mask, token_type_ids])`
878
+ - a dictionary with one or several input Tensors associated to the input names given in the docstring:
879
+ `model({"pixel_values": pixel_values, "token_type_ids": token_type_ids})`
880
+
881
+ Note that when creating models and layers with
882
+ [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
883
+ about any of this, as you can just pass inputs like you would to any other Python function!
884
+
885
+ </Tip>
886
+
887
+ Args:
888
+ config ([`ViTMAEConfig`]): Model configuration class with all the parameters of the model.
889
+ Initializing with a config file does not load the weights associated with the model, only the
890
+ configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
891
+ """
892
+
893
+ VIT_MAE_INPUTS_DOCSTRING = r"""
894
+ Args:
895
+ pixel_values (`np.ndarray`, `tf.Tensor`, `List[tf.Tensor]` ``Dict[str, tf.Tensor]` or `Dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`):
896
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`]
897
+ for details.
898
+
899
+ head_mask (`np.ndarray` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
900
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
901
+ - 1 indicates the head is **not masked**,
902
+ - 0 indicates the head is **masked**.
903
+
904
+ output_attentions (`bool`, *optional*):
905
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
906
+ tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
907
+ config will be used instead.
908
+
909
+ output_hidden_states (`bool`, *optional*):
910
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
911
+ more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
912
+ used instead.
913
+
914
+ return_dict (`bool`, *optional*):
915
+ Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. This argument can be used
916
+ in eager mode, in graph mode the value will always be set to True.
917
+
918
+ training (`bool`, *optional*, defaults to `False``):
919
+ Whether or not to use the model in training mode (some modules like dropout modules have different
920
+ behaviors between training and evaluation).
921
+
922
+ interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
923
+ Whether to interpolate the position encodings at the encoder and decoder.
924
+ """
925
+
926
+
927
+ @add_start_docstrings(
928
+ "The bare ViTMAE Model transformer outputting raw hidden-states without any specific head on top.",
929
+ VIT_MAE_START_DOCSTRING,
930
+ )
931
+ class TFViTMAEModel(TFViTMAEPreTrainedModel):
932
+ def __init__(self, config: ViTMAEConfig, *inputs, **kwargs):
933
+ super().__init__(config, *inputs, **kwargs)
934
+
935
+ self.vit = TFViTMAEMainLayer(config, name="vit")
936
+
937
+ def get_input_embeddings(self):
938
+ return self.vit.get_input_embeddings()
939
+
940
+ @unpack_inputs
941
+ @add_start_docstrings_to_model_forward(VIT_MAE_INPUTS_DOCSTRING)
942
+ @replace_return_docstrings(output_type=TFViTMAEModelOutput, config_class=_CONFIG_FOR_DOC)
943
+ def call(
944
+ self,
945
+ pixel_values: TFModelInputType | None = None,
946
+ noise: Optional[tf.Tensor] = None,
947
+ head_mask: np.ndarray | tf.Tensor | None = None,
948
+ output_attentions: Optional[bool] = None,
949
+ output_hidden_states: Optional[bool] = None,
950
+ return_dict: Optional[bool] = None,
951
+ training: bool = False,
952
+ interpolate_pos_encoding: bool = False,
953
+ ) -> Union[TFViTMAEModelOutput, Tuple[tf.Tensor]]:
954
+ r"""
955
+ Returns:
956
+
957
+ Examples:
958
+
959
+ ```python
960
+ >>> from transformers import AutoImageProcessor, TFViTMAEModel
961
+ >>> from PIL import Image
962
+ >>> import requests
963
+
964
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
965
+ >>> image = Image.open(requests.get(url, stream=True).raw)
966
+
967
+ >>> image_processor = AutoImageProcessor.from_pretrained("facebook/vit-mae-base")
968
+ >>> model = TFViTMAEModel.from_pretrained("facebook/vit-mae-base")
969
+
970
+ >>> inputs = image_processor(images=image, return_tensors="tf")
971
+ >>> outputs = model(**inputs)
972
+ >>> last_hidden_states = outputs.last_hidden_state
973
+ ```"""
974
+ outputs = self.vit(
975
+ pixel_values=pixel_values,
976
+ noise=noise,
977
+ head_mask=head_mask,
978
+ output_attentions=output_attentions,
979
+ output_hidden_states=output_hidden_states,
980
+ return_dict=return_dict,
981
+ training=training,
982
+ interpolate_pos_encoding=interpolate_pos_encoding,
983
+ )
984
+
985
+ return outputs
986
+
987
+ def build(self, input_shape=None):
988
+ if self.built:
989
+ return
990
+ self.built = True
991
+ if getattr(self, "vit", None) is not None:
992
+ with tf.name_scope(self.vit.name):
993
+ self.vit.build(None)
994
+
995
+
996
+ class TFViTMAEDecoder(keras.layers.Layer):
997
+ def __init__(self, config, num_patches, **kwargs):
998
+ super().__init__(**kwargs)
999
+ self.decoder_embed = keras.layers.Dense(config.decoder_hidden_size, name="decoder_embed")
1000
+
1001
+ decoder_config = deepcopy(config)
1002
+ decoder_config.hidden_size = config.decoder_hidden_size
1003
+ decoder_config.num_hidden_layers = config.decoder_num_hidden_layers
1004
+ decoder_config.num_attention_heads = config.decoder_num_attention_heads
1005
+ decoder_config.intermediate_size = config.decoder_intermediate_size
1006
+ self.decoder_layers = [
1007
+ TFViTMAELayer(decoder_config, name=f"decoder_layers.{j}") for j in range(config.decoder_num_hidden_layers)
1008
+ ]
1009
+
1010
+ self.decoder_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="decoder_norm")
1011
+ self.decoder_pred = keras.layers.Dense(
1012
+ config.patch_size**2 * config.num_channels,
1013
+ kernel_initializer=get_initializer(config.initializer_range),
1014
+ name="decoder_pred",
1015
+ ) # encoder to decoder
1016
+ self.config = config
1017
+ self.num_patches = num_patches
1018
+
1019
+ def build(self, input_shape=None):
1020
+ self.mask_token = self.add_weight(
1021
+ shape=(1, 1, self.config.decoder_hidden_size),
1022
+ initializer=tf.random_normal_initializer(stddev=self.config.initializer_range),
1023
+ trainable=True,
1024
+ name="mask_token",
1025
+ )
1026
+ self.decoder_pos_embed = self.add_weight(
1027
+ shape=(1, self.num_patches + 1, self.config.decoder_hidden_size),
1028
+ initializer="zeros",
1029
+ trainable=False,
1030
+ name="decoder_pos_embed",
1031
+ )
1032
+ decoder_pos_embed = get_2d_sincos_pos_embed(
1033
+ self.decoder_pos_embed.shape[-1],
1034
+ int(self.num_patches**0.5),
1035
+ add_cls_token=True,
1036
+ )[None, ...]
1037
+ self.decoder_pos_embed.assign(decoder_pos_embed)
1038
+
1039
+ if self.built:
1040
+ return
1041
+ self.built = True
1042
+ if getattr(self, "decoder_embed", None) is not None:
1043
+ with tf.name_scope(self.decoder_embed.name):
1044
+ self.decoder_embed.build([None, None, self.config.hidden_size])
1045
+ if getattr(self, "decoder_norm", None) is not None:
1046
+ with tf.name_scope(self.decoder_norm.name):
1047
+ self.decoder_norm.build([None, None, self.config.decoder_hidden_size])
1048
+ if getattr(self, "decoder_pred", None) is not None:
1049
+ with tf.name_scope(self.decoder_pred.name):
1050
+ self.decoder_pred.build([None, None, self.config.decoder_hidden_size])
1051
+ if getattr(self, "decoder_layers", None) is not None:
1052
+ for layer in self.decoder_layers:
1053
+ with tf.name_scope(layer.name):
1054
+ layer.build(None)
1055
+
1056
+ def interpolate_pos_encoding(self, embeddings) -> tf.Tensor:
1057
+ """
1058
+ This method is a modified version of the interpolation function for ViT-mae model at the deocder, that
1059
+ allows to interpolate the pre-trained decoder position encodings, to be able to use the model on higher
1060
+ resolution images.
1061
+
1062
+ Source:
1063
+ https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
1064
+ """
1065
+
1066
+ # [batch_size, num_patches + 1, hidden_size]
1067
+ _, num_positions, dim = shape_list(self.decoder_pos_embed)
1068
+
1069
+ # -1 removes the class dimension since we later append it without interpolation
1070
+ seq_len = shape_list(embeddings)[1] - 1
1071
+ num_positions = num_positions - 1
1072
+
1073
+ # Separation of class token and patch tokens
1074
+ class_pos_embed = self.decoder_pos_embed[:, :1, :]
1075
+ patch_pos_embed = self.decoder_pos_embed[:, 1:, :]
1076
+
1077
+ # interpolate the position embeddings
1078
+ patch_pos_embed = tf.image.resize(
1079
+ images=tf.reshape(patch_pos_embed, shape=(1, 1, -1, dim)),
1080
+ size=(1, seq_len),
1081
+ method="bicubic",
1082
+ )
1083
+
1084
+ # [1, seq_len, hidden_size]
1085
+ patch_pos_embed = tf.reshape(tensor=patch_pos_embed, shape=(1, -1, dim))
1086
+ # Adding the class token back
1087
+ return tf.concat(values=(class_pos_embed, patch_pos_embed), axis=1)
1088
+
1089
+ def call(
1090
+ self,
1091
+ hidden_states,
1092
+ ids_restore,
1093
+ output_attentions=False,
1094
+ output_hidden_states=False,
1095
+ return_dict=True,
1096
+ interpolate_pos_encoding=False,
1097
+ ):
1098
+ # embed tokens
1099
+ x = self.decoder_embed(hidden_states)
1100
+ # append mask tokens to sequence
1101
+ mask_tokens = tf.tile(
1102
+ self.mask_token,
1103
+ (shape_list(x)[0], shape_list(ids_restore)[1] + 1 - shape_list(x)[1], 1),
1104
+ )
1105
+ x_ = tf.concat([x[:, 1:, :], mask_tokens], axis=1) # no cls token
1106
+ x_ = tf.gather(x_, axis=1, batch_dims=1, indices=ids_restore) # unshuffle
1107
+ x = tf.concat([x[:, :1, :], x_], axis=1) # append cls token
1108
+ if interpolate_pos_encoding:
1109
+ decoder_pos_embed = self.interpolate_pos_encoding(x)
1110
+ else:
1111
+ decoder_pos_embed = self.decoder_pos_embed
1112
+ # add pos embed
1113
+ hidden_states = x + decoder_pos_embed
1114
+ # apply Transformer layers (blocks)
1115
+ all_hidden_states = () if output_hidden_states else None
1116
+ all_self_attentions = () if output_attentions else None
1117
+ for i, layer_module in enumerate(self.decoder_layers):
1118
+ if output_hidden_states:
1119
+ all_hidden_states = all_hidden_states + (hidden_states,)
1120
+
1121
+ layer_outputs = layer_module(
1122
+ hidden_states,
1123
+ head_mask=None,
1124
+ output_attentions=output_attentions,
1125
+ )
1126
+
1127
+ hidden_states = layer_outputs[0]
1128
+
1129
+ if output_attentions:
1130
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
1131
+
1132
+ if output_hidden_states:
1133
+ all_hidden_states = all_hidden_states + (hidden_states,)
1134
+
1135
+ hidden_states = self.decoder_norm(hidden_states)
1136
+
1137
+ # predictor projection
1138
+ logits = self.decoder_pred(hidden_states)
1139
+
1140
+ # remove cls token
1141
+ logits = logits[:, 1:, :]
1142
+
1143
+ if not return_dict:
1144
+ return tuple(v for v in [logits, all_hidden_states, all_self_attentions] if v is not None)
1145
+ return TFViTMAEDecoderOutput(logits=logits, hidden_states=all_hidden_states, attentions=all_self_attentions)
1146
+
1147
+
1148
+ @add_start_docstrings(
1149
+ "The ViTMAE Model transformer with the decoder on top for self-supervised pre-training.",
1150
+ VIT_MAE_START_DOCSTRING,
1151
+ )
1152
+ class TFViTMAEForPreTraining(TFViTMAEPreTrainedModel):
1153
+ def __init__(self, config):
1154
+ super().__init__(config)
1155
+ self.config = config
1156
+
1157
+ self.vit = TFViTMAEMainLayer(config, name="vit")
1158
+ self.decoder = TFViTMAEDecoder(
1159
+ config,
1160
+ num_patches=self.vit.embeddings.num_patches,
1161
+ name="decoder",
1162
+ )
1163
+
1164
+ def get_input_embeddings(self):
1165
+ return self.vit.get_input_embeddings()
1166
+
1167
+ def _prune_heads(self, heads_to_prune):
1168
+ raise NotImplementedError
1169
+
1170
+ def patchify(self, pixel_values, interpolate_pos_encoding: bool = False):
1171
+ """
1172
+ Args:
1173
+ pixel_values (`tf.Tensor` of shape `(batch_size, height, width, num_channels)` or `(batch_size, num_channels, height, width)`):
1174
+ Pixel values.
1175
+ interpolate_pos_encoding (`bool`, default `False`):
1176
+ interpolation flag passed during the forward pass.
1177
+
1178
+ Returns:
1179
+ `tf.Tensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
1180
+ Patchified pixel values.
1181
+ """
1182
+ patch_size, num_channels = self.config.patch_size, self.config.num_channels
1183
+ # make sure channels are last
1184
+ if shape_list(pixel_values)[1] == num_channels:
1185
+ pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1))
1186
+
1187
+ # sanity checks
1188
+ if not interpolate_pos_encoding:
1189
+ tf.debugging.assert_equal(
1190
+ shape_list(pixel_values)[1],
1191
+ shape_list(pixel_values)[2],
1192
+ message="Make sure the pixel values have a squared size",
1193
+ )
1194
+ tf.debugging.assert_equal(
1195
+ shape_list(pixel_values)[1] % patch_size,
1196
+ 0,
1197
+ message="Make sure the pixel values have a size that is divisible by the patch size",
1198
+ )
1199
+ tf.debugging.assert_equal(
1200
+ shape_list(pixel_values)[3],
1201
+ num_channels,
1202
+ message=(
1203
+ "Make sure the number of channels of the pixel values is equal to the one set in the configuration"
1204
+ ),
1205
+ )
1206
+
1207
+ # patchify
1208
+ batch_size = shape_list(pixel_values)[0]
1209
+ num_patches_h = shape_list(pixel_values)[1] // patch_size
1210
+ num_patches_w = shape_list(pixel_values)[2] // patch_size
1211
+ patchified_pixel_values = tf.reshape(
1212
+ pixel_values,
1213
+ (batch_size, num_patches_h, patch_size, num_patches_w, patch_size, num_channels),
1214
+ )
1215
+ patchified_pixel_values = tf.einsum("nhpwqc->nhwpqc", patchified_pixel_values)
1216
+ patchified_pixel_values = tf.reshape(
1217
+ patchified_pixel_values,
1218
+ (batch_size, num_patches_h * num_patches_w, patch_size**2 * num_channels),
1219
+ )
1220
+ return patchified_pixel_values
1221
+
1222
+ def unpatchify(self, patchified_pixel_values, original_image_size: Optional[Tuple[int, int]] = None):
1223
+ """
1224
+ Args:
1225
+ patchified_pixel_values (`tf.Tensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
1226
+ Patchified pixel values.
1227
+ original_image_size (`Tuple[int, int]`, *optional*):
1228
+ Original image size.
1229
+
1230
+ Returns:
1231
+ `tf.Tensor` of shape `(batch_size, height, width, num_channels)`:
1232
+ Pixel values.
1233
+ """
1234
+ patch_size, num_channels = self.config.patch_size, self.config.num_channels
1235
+ original_image_size = (
1236
+ original_image_size
1237
+ if original_image_size is not None
1238
+ else (self.config.image_size, self.config.image_size)
1239
+ )
1240
+ original_height, original_width = original_image_size
1241
+ num_patches_h = original_height // patch_size
1242
+ num_patches_w = original_width // patch_size
1243
+ # sanity check
1244
+ tf.debugging.assert_equal(
1245
+ num_patches_h * num_patches_w,
1246
+ shape_list(patchified_pixel_values)[1],
1247
+ message=f"The number of patches in the patchified pixel values is {shape_list(patchified_pixel_values)[1]} does not match the patches of original image {num_patches_w}*{num_patches_h}",
1248
+ )
1249
+
1250
+ # unpatchify
1251
+ batch_size = shape_list(patchified_pixel_values)[0]
1252
+ patchified_pixel_values = tf.reshape(
1253
+ patchified_pixel_values,
1254
+ (batch_size, num_patches_h, num_patches_w, patch_size, patch_size, num_channels),
1255
+ )
1256
+ patchified_pixel_values = tf.einsum("nhwpqc->nhpwqc", patchified_pixel_values)
1257
+ pixel_values = tf.reshape(
1258
+ patchified_pixel_values,
1259
+ (batch_size, num_patches_h * patch_size, num_patches_w * patch_size, num_channels),
1260
+ )
1261
+ return pixel_values
1262
+
1263
+ def forward_loss(self, pixel_values, pred, mask, interpolate_pos_encoding: bool = False):
1264
+ """
1265
+ Args:
1266
+ pixel_values (`tf.Tensor` of shape `(batch_size, height, width, num_channels)`):
1267
+ Pixel values.
1268
+ pred (`tf.Tensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
1269
+ Predicted pixel values.
1270
+ mask (`tf.Tensor` of shape `(batch_size, sequence_length)`):
1271
+ Tensor indicating which patches are masked (1) and which are not (0).
1272
+ interpolate_pos_encoding (`bool`, *optional*, default `False`):
1273
+ interpolation flag passed during the forward pass.
1274
+
1275
+ Returns:
1276
+ `tf.Tensor`: Pixel reconstruction loss.
1277
+ """
1278
+ target = self.patchify(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
1279
+ if self.config.norm_pix_loss:
1280
+ mean = tf.reduce_mean(target, axis=-1, keepdims=True)
1281
+ var = tf.math.reduce_variance(target, axis=-1, keepdims=True)
1282
+ target = (target - mean) / (var + 1.0e-6) ** 0.5
1283
+
1284
+ loss = (pred - target) ** 2
1285
+ loss = tf.reduce_mean(loss, axis=-1) # [batch_size, num_patches], mean loss per patch
1286
+
1287
+ loss = tf.reduce_sum(loss * mask) / tf.reduce_sum(mask) # mean loss on removed patches
1288
+ loss = tf.reshape(loss, (1,))
1289
+ return loss
1290
+
1291
+ @unpack_inputs
1292
+ @add_start_docstrings_to_model_forward(VIT_MAE_INPUTS_DOCSTRING)
1293
+ @replace_return_docstrings(output_type=TFViTMAEForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
1294
+ def call(
1295
+ self,
1296
+ pixel_values: TFModelInputType | None = None,
1297
+ noise: Optional[tf.Tensor] = None,
1298
+ head_mask: np.ndarray | tf.Tensor | None = None,
1299
+ output_attentions: Optional[bool] = None,
1300
+ output_hidden_states: Optional[bool] = None,
1301
+ return_dict: Optional[bool] = None,
1302
+ training: bool = False,
1303
+ interpolate_pos_encoding: bool = False,
1304
+ ) -> Union[TFViTMAEForPreTrainingOutput, Tuple[tf.Tensor]]:
1305
+ r"""
1306
+ Returns:
1307
+
1308
+ Examples:
1309
+
1310
+ ```python
1311
+ >>> from transformers import AutoImageProcessor, TFViTMAEForPreTraining
1312
+ >>> from PIL import Image
1313
+ >>> import requests
1314
+
1315
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1316
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1317
+
1318
+ >>> image_processor = AutoImageProcessor.from_pretrained("facebook/vit-mae-base")
1319
+ >>> model = TFViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base")
1320
+
1321
+ >>> inputs = image_processor(images=image, return_tensors="pt")
1322
+ >>> outputs = model(**inputs)
1323
+ >>> loss = outputs.loss
1324
+ >>> mask = outputs.mask
1325
+ >>> ids_restore = outputs.ids_restore
1326
+ ```"""
1327
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1328
+
1329
+ outputs = self.vit(
1330
+ pixel_values=pixel_values,
1331
+ noise=noise,
1332
+ head_mask=head_mask,
1333
+ output_attentions=output_attentions,
1334
+ output_hidden_states=output_hidden_states,
1335
+ return_dict=return_dict,
1336
+ training=training,
1337
+ interpolate_pos_encoding=interpolate_pos_encoding,
1338
+ )
1339
+
1340
+ latent = outputs.last_hidden_state
1341
+ ids_restore = outputs.ids_restore
1342
+ mask = outputs.mask
1343
+
1344
+ # [batch_size, num_patches, patch_size**2*3]
1345
+ decoder_outputs = self.decoder(latent, ids_restore, interpolate_pos_encoding=interpolate_pos_encoding)
1346
+ logits = decoder_outputs.logits
1347
+
1348
+ loss = self.forward_loss(pixel_values, logits, mask, interpolate_pos_encoding=interpolate_pos_encoding)
1349
+
1350
+ if not return_dict:
1351
+ output = (logits, mask, ids_restore) + outputs[2:]
1352
+ return ((loss,) + output) if loss is not None else output
1353
+
1354
+ return TFViTMAEForPreTrainingOutput(
1355
+ loss=loss,
1356
+ logits=logits,
1357
+ mask=mask,
1358
+ ids_restore=ids_restore,
1359
+ hidden_states=outputs.hidden_states,
1360
+ attentions=outputs.attentions,
1361
+ )
1362
+
1363
+ def build(self, input_shape=None):
1364
+ if self.built:
1365
+ return
1366
+ self.built = True
1367
+ if getattr(self, "vit", None) is not None:
1368
+ with tf.name_scope(self.vit.name):
1369
+ self.vit.build(None)
1370
+ if getattr(self, "decoder", None) is not None:
1371
+ with tf.name_scope(self.decoder.name):
1372
+ self.decoder.build(None)
1373
+
1374
+
1375
+ __all__ = ["TFViTMAEForPreTraining", "TFViTMAEModel", "TFViTMAEPreTrainedModel"]
docs/transformers/build/lib/transformers/models/vit_mae/modeling_vit_mae.py ADDED
@@ -0,0 +1,1163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 Facebook AI and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch ViT MAE (masked autoencoder) model."""
16
+
17
+ import collections.abc
18
+ from copy import deepcopy
19
+ from dataclasses import dataclass
20
+ from typing import Callable, Optional, Set, Tuple, Union
21
+
22
+ import numpy as np
23
+ import torch
24
+ import torch.utils.checkpoint
25
+ from torch import nn
26
+
27
+ from ...activations import ACT2FN
28
+ from ...modeling_outputs import BaseModelOutput
29
+ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
30
+ from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
31
+ from ...utils import (
32
+ ModelOutput,
33
+ add_start_docstrings,
34
+ add_start_docstrings_to_model_forward,
35
+ logging,
36
+ replace_return_docstrings,
37
+ torch_int,
38
+ )
39
+ from .configuration_vit_mae import ViTMAEConfig
40
+
41
+
42
+ logger = logging.get_logger(__name__)
43
+
44
+ _CONFIG_FOR_DOC = "ViTMAEConfig"
45
+ _CHECKPOINT_FOR_DOC = "facebook/vit-mae-base"
46
+
47
+
48
+ @dataclass
49
+ class ViTMAEModelOutput(ModelOutput):
50
+ """
51
+ Class for ViTMAEModel's outputs, with potential hidden states and attentions.
52
+
53
+ Args:
54
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
55
+ Sequence of hidden-states at the output of the last layer of the model.
56
+ mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
57
+ Tensor indicating which patches are masked (1) and which are not (0).
58
+ ids_restore (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
59
+ Tensor containing the original index of the (shuffled) masked patches.
60
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
61
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
62
+ shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
63
+ plus the initial embedding outputs.
64
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
65
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
66
+ sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
67
+ the self-attention heads.
68
+ """
69
+
70
+ last_hidden_state: Optional[torch.FloatTensor] = None
71
+ mask: Optional[torch.LongTensor] = None
72
+ ids_restore: Optional[torch.LongTensor] = None
73
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
74
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
75
+
76
+
77
+ @dataclass
78
+ class ViTMAEDecoderOutput(ModelOutput):
79
+ """
80
+ Class for ViTMAEDecoder's outputs, with potential hidden states and attentions.
81
+
82
+ Args:
83
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, patch_size ** 2 * num_channels)`):
84
+ Pixel reconstruction logits.
85
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
86
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
87
+ shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
88
+ plus the initial embedding outputs.
89
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
90
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
91
+ sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
92
+ the self-attention heads.
93
+ """
94
+
95
+ logits: Optional[torch.FloatTensor] = None
96
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
97
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
98
+
99
+
100
+ @dataclass
101
+ class ViTMAEForPreTrainingOutput(ModelOutput):
102
+ """
103
+ Class for ViTMAEForPreTraining's outputs, with potential hidden states and attentions.
104
+
105
+ Args:
106
+ loss (`torch.FloatTensor` of shape `(1,)`):
107
+ Pixel reconstruction loss.
108
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, patch_size ** 2 * num_channels)`):
109
+ Pixel reconstruction logits.
110
+ mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
111
+ Tensor indicating which patches are masked (1) and which are not (0).
112
+ ids_restore (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
113
+ Tensor containing the original index of the (shuffled) masked patches.
114
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
115
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
116
+ shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
117
+ plus the initial embedding outputs.
118
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
119
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
120
+ sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
121
+ the self-attention heads.
122
+ """
123
+
124
+ loss: Optional[torch.FloatTensor] = None
125
+ logits: Optional[torch.FloatTensor] = None
126
+ mask: Optional[torch.LongTensor] = None
127
+ ids_restore: Optional[torch.LongTensor] = None
128
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
129
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
130
+
131
+
132
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False):
133
+ """
134
+ Create 2D sin/cos positional embeddings.
135
+
136
+ Args:
137
+ embed_dim (`int`):
138
+ Embedding dimension.
139
+ grid_size (`int`):
140
+ The grid height and width.
141
+ add_cls_token (`bool`, *optional*, defaults to `False`):
142
+ Whether or not to add a classification (CLS) token.
143
+
144
+ Returns:
145
+ (`torch.FloatTensor` of shape (grid_size*grid_size, embed_dim) or (1+grid_size*grid_size, embed_dim): the
146
+ position embeddings (with or without classification token)
147
+ """
148
+ grid_h = np.arange(grid_size, dtype=np.float32)
149
+ grid_w = np.arange(grid_size, dtype=np.float32)
150
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
151
+ grid = np.stack(grid, axis=0)
152
+
153
+ grid = grid.reshape([2, 1, grid_size, grid_size])
154
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
155
+ if add_cls_token:
156
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
157
+ return pos_embed
158
+
159
+
160
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
161
+ if embed_dim % 2 != 0:
162
+ raise ValueError("embed_dim must be even")
163
+
164
+ # use half of dimensions to encode grid_h
165
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
166
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
167
+
168
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
169
+ return emb
170
+
171
+
172
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
173
+ """
174
+ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
175
+ """
176
+ if embed_dim % 2 != 0:
177
+ raise ValueError("embed_dim must be even")
178
+
179
+ omega = np.arange(embed_dim // 2, dtype=float)
180
+ omega /= embed_dim / 2.0
181
+ omega = 1.0 / 10000**omega # (D/2,)
182
+
183
+ pos = pos.reshape(-1) # (M,)
184
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
185
+
186
+ emb_sin = np.sin(out) # (M, D/2)
187
+ emb_cos = np.cos(out) # (M, D/2)
188
+
189
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
190
+ return emb
191
+
192
+
193
+ class ViTMAEEmbeddings(nn.Module):
194
+ """
195
+ Construct the CLS token, position and patch embeddings.
196
+
197
+ """
198
+
199
+ def __init__(self, config):
200
+ super().__init__()
201
+
202
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
203
+ self.patch_embeddings = ViTMAEPatchEmbeddings(config)
204
+ self.num_patches = self.patch_embeddings.num_patches
205
+ # fixed sin-cos embedding
206
+ self.position_embeddings = nn.Parameter(
207
+ torch.zeros(1, self.num_patches + 1, config.hidden_size), requires_grad=False
208
+ )
209
+ self.patch_size = config.patch_size
210
+ self.config = config
211
+
212
+ def initialize_weights(self):
213
+ # initialize (and freeze) position embeddings by sin-cos embedding
214
+ pos_embed = get_2d_sincos_pos_embed(
215
+ self.position_embeddings.shape[-1], int(self.patch_embeddings.num_patches**0.5), add_cls_token=True
216
+ )
217
+ self.position_embeddings.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
218
+
219
+ # initialize patch_embeddings like nn.Linear (instead of nn.Conv2d)
220
+ w = self.patch_embeddings.projection.weight.data
221
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
222
+
223
+ # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
224
+ torch.nn.init.normal_(self.cls_token, std=self.config.initializer_range)
225
+
226
+ # Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
227
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
228
+ """
229
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
230
+ images. This method is also adapted to support torch.jit tracing.
231
+
232
+ Adapted from:
233
+ - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
234
+ - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
235
+ """
236
+
237
+ num_patches = embeddings.shape[1] - 1
238
+ num_positions = self.position_embeddings.shape[1] - 1
239
+
240
+ # always interpolate when tracing to ensure the exported model works for dynamic input shapes
241
+ if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
242
+ return self.position_embeddings
243
+
244
+ class_pos_embed = self.position_embeddings[:, :1]
245
+ patch_pos_embed = self.position_embeddings[:, 1:]
246
+
247
+ dim = embeddings.shape[-1]
248
+
249
+ new_height = height // self.patch_size
250
+ new_width = width // self.patch_size
251
+
252
+ sqrt_num_positions = torch_int(num_positions**0.5)
253
+ patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
254
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
255
+
256
+ patch_pos_embed = nn.functional.interpolate(
257
+ patch_pos_embed,
258
+ size=(new_height, new_width),
259
+ mode="bicubic",
260
+ align_corners=False,
261
+ )
262
+
263
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
264
+
265
+ return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
266
+
267
+ def random_masking(self, sequence, noise=None):
268
+ """
269
+ Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random
270
+ noise.
271
+
272
+ Args:
273
+ sequence (`torch.LongTensor` of shape `(batch_size, sequence_length, dim)`)
274
+ noise (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*) which is
275
+ mainly used for testing purposes to control randomness and maintain the reproducibility
276
+ """
277
+ batch_size, seq_length, dim = sequence.shape
278
+ len_keep = int(seq_length * (1 - self.config.mask_ratio))
279
+
280
+ if noise is None:
281
+ noise = torch.rand(batch_size, seq_length, device=sequence.device) # noise in [0, 1]
282
+
283
+ # sort noise for each sample
284
+ ids_shuffle = torch.argsort(noise, dim=1).to(sequence.device) # ascend: small is keep, large is remove
285
+ ids_restore = torch.argsort(ids_shuffle, dim=1).to(sequence.device)
286
+
287
+ # keep the first subset
288
+ ids_keep = ids_shuffle[:, :len_keep]
289
+ sequence_unmasked = torch.gather(sequence, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, dim))
290
+
291
+ # generate the binary mask: 0 is keep, 1 is remove
292
+ mask = torch.ones([batch_size, seq_length], device=sequence.device)
293
+ mask[:, :len_keep] = 0
294
+ # unshuffle to get the binary mask
295
+ mask = torch.gather(mask, dim=1, index=ids_restore)
296
+
297
+ return sequence_unmasked, mask, ids_restore
298
+
299
+ def forward(self, pixel_values, noise=None, interpolate_pos_encoding: bool = False):
300
+ batch_size, num_channels, height, width = pixel_values.shape
301
+ embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
302
+ if interpolate_pos_encoding:
303
+ position_embeddings = self.interpolate_pos_encoding(embeddings, height, width)
304
+ else:
305
+ position_embeddings = self.position_embeddings
306
+
307
+ # add position embeddings w/o cls token
308
+ embeddings = embeddings + position_embeddings[:, 1:, :]
309
+
310
+ # masking: length -> length * config.mask_ratio
311
+ embeddings, mask, ids_restore = self.random_masking(embeddings, noise)
312
+
313
+ # append cls token
314
+ cls_token = self.cls_token + position_embeddings[:, :1, :]
315
+ cls_tokens = cls_token.expand(embeddings.shape[0], -1, -1)
316
+ embeddings = torch.cat((cls_tokens, embeddings), dim=1)
317
+
318
+ return embeddings, mask, ids_restore
319
+
320
+
321
+ class ViTMAEPatchEmbeddings(nn.Module):
322
+ """
323
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
324
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
325
+ Transformer.
326
+ """
327
+
328
+ def __init__(self, config):
329
+ super().__init__()
330
+ image_size, patch_size = config.image_size, config.patch_size
331
+ num_channels, hidden_size = config.num_channels, config.hidden_size
332
+ image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
333
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
334
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
335
+ self.image_size = image_size
336
+ self.patch_size = patch_size
337
+ self.num_channels = num_channels
338
+ self.num_patches = num_patches
339
+
340
+ self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
341
+
342
+ def forward(self, pixel_values, interpolate_pos_encoding: bool = False):
343
+ batch_size, num_channels, height, width = pixel_values.shape
344
+ if num_channels != self.num_channels:
345
+ raise ValueError(
346
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
347
+ )
348
+
349
+ if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]):
350
+ raise ValueError(
351
+ f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
352
+ )
353
+ x = self.projection(pixel_values).flatten(2).transpose(1, 2)
354
+ return x
355
+
356
+
357
+ # Copied from transformers.models.vit.modeling_vit.eager_attention_forward
358
+ def eager_attention_forward(
359
+ module: nn.Module,
360
+ query: torch.Tensor,
361
+ key: torch.Tensor,
362
+ value: torch.Tensor,
363
+ attention_mask: Optional[torch.Tensor],
364
+ scaling: float,
365
+ dropout: float = 0.0,
366
+ **kwargs,
367
+ ):
368
+ # Take the dot product between "query" and "key" to get the raw attention scores.
369
+ attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
370
+
371
+ # Normalize the attention scores to probabilities.
372
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
373
+
374
+ # This is actually dropping out entire tokens to attend to, which might
375
+ # seem a bit unusual, but is taken from the original Transformer paper.
376
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
377
+
378
+ # Mask heads if we want to
379
+ if attention_mask is not None:
380
+ attn_weights = attn_weights * attention_mask
381
+
382
+ attn_output = torch.matmul(attn_weights, value)
383
+ attn_output = attn_output.transpose(1, 2).contiguous()
384
+
385
+ return attn_output, attn_weights
386
+
387
+
388
+ # Copied from transformers.models.vit.modeling_vit.ViTSelfAttention ViT->ViTMAE
389
+ class ViTMAESelfAttention(nn.Module):
390
+ def __init__(self, config: ViTMAEConfig) -> None:
391
+ super().__init__()
392
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
393
+ raise ValueError(
394
+ f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
395
+ f"heads {config.num_attention_heads}."
396
+ )
397
+
398
+ self.config = config
399
+ self.num_attention_heads = config.num_attention_heads
400
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
401
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
402
+ self.dropout_prob = config.attention_probs_dropout_prob
403
+ self.scaling = self.attention_head_size**-0.5
404
+ self.is_causal = False
405
+
406
+ self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
407
+ self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
408
+ self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
409
+
410
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
411
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
412
+ x = x.view(new_x_shape)
413
+ return x.permute(0, 2, 1, 3)
414
+
415
+ def forward(
416
+ self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
417
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
418
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
419
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
420
+ query_layer = self.transpose_for_scores(self.query(hidden_states))
421
+
422
+ attention_interface: Callable = eager_attention_forward
423
+ if self.config._attn_implementation != "eager":
424
+ if self.config._attn_implementation == "sdpa" and output_attentions:
425
+ logger.warning_once(
426
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
427
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
428
+ )
429
+ else:
430
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
431
+
432
+ context_layer, attention_probs = attention_interface(
433
+ self,
434
+ query_layer,
435
+ key_layer,
436
+ value_layer,
437
+ head_mask,
438
+ is_causal=self.is_causal,
439
+ scaling=self.scaling,
440
+ dropout=0.0 if not self.training else self.dropout_prob,
441
+ )
442
+
443
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
444
+ context_layer = context_layer.reshape(new_context_layer_shape)
445
+
446
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
447
+
448
+ return outputs
449
+
450
+
451
+ # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->ViTMAE
452
+ class ViTMAESelfOutput(nn.Module):
453
+ """
454
+ The residual connection is defined in ViTMAELayer instead of here (as is the case with other models), due to the
455
+ layernorm applied before each block.
456
+ """
457
+
458
+ def __init__(self, config: ViTMAEConfig) -> None:
459
+ super().__init__()
460
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
461
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
462
+
463
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
464
+ hidden_states = self.dense(hidden_states)
465
+ hidden_states = self.dropout(hidden_states)
466
+
467
+ return hidden_states
468
+
469
+
470
+ # Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->ViTMAE
471
+ class ViTMAEAttention(nn.Module):
472
+ def __init__(self, config: ViTMAEConfig) -> None:
473
+ super().__init__()
474
+ self.attention = ViTMAESelfAttention(config)
475
+ self.output = ViTMAESelfOutput(config)
476
+ self.pruned_heads = set()
477
+
478
+ def prune_heads(self, heads: Set[int]) -> None:
479
+ if len(heads) == 0:
480
+ return
481
+ heads, index = find_pruneable_heads_and_indices(
482
+ heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
483
+ )
484
+
485
+ # Prune linear layers
486
+ self.attention.query = prune_linear_layer(self.attention.query, index)
487
+ self.attention.key = prune_linear_layer(self.attention.key, index)
488
+ self.attention.value = prune_linear_layer(self.attention.value, index)
489
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
490
+
491
+ # Update hyper params and store pruned heads
492
+ self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
493
+ self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
494
+ self.pruned_heads = self.pruned_heads.union(heads)
495
+
496
+ def forward(
497
+ self,
498
+ hidden_states: torch.Tensor,
499
+ head_mask: Optional[torch.Tensor] = None,
500
+ output_attentions: bool = False,
501
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
502
+ self_outputs = self.attention(hidden_states, head_mask, output_attentions)
503
+
504
+ attention_output = self.output(self_outputs[0], hidden_states)
505
+
506
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
507
+ return outputs
508
+
509
+
510
+ # Copied from transformers.models.vit.modeling_vit.ViTIntermediate ViT->ViTMAE
511
+ class ViTMAEIntermediate(nn.Module):
512
+ def __init__(self, config: ViTMAEConfig) -> None:
513
+ super().__init__()
514
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
515
+ if isinstance(config.hidden_act, str):
516
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
517
+ else:
518
+ self.intermediate_act_fn = config.hidden_act
519
+
520
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
521
+ hidden_states = self.dense(hidden_states)
522
+ hidden_states = self.intermediate_act_fn(hidden_states)
523
+
524
+ return hidden_states
525
+
526
+
527
+ # Copied from transformers.models.vit.modeling_vit.ViTOutput ViT->ViTMAE
528
+ class ViTMAEOutput(nn.Module):
529
+ def __init__(self, config: ViTMAEConfig) -> None:
530
+ super().__init__()
531
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
532
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
533
+
534
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
535
+ hidden_states = self.dense(hidden_states)
536
+ hidden_states = self.dropout(hidden_states)
537
+
538
+ hidden_states = hidden_states + input_tensor
539
+
540
+ return hidden_states
541
+
542
+
543
+ # Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->ViTMAE,VIT->VITMAE
544
+ class ViTMAELayer(nn.Module):
545
+ """This corresponds to the Block class in the timm implementation."""
546
+
547
+ def __init__(self, config: ViTMAEConfig) -> None:
548
+ super().__init__()
549
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
550
+ self.seq_len_dim = 1
551
+ self.attention = ViTMAEAttention(config)
552
+ self.intermediate = ViTMAEIntermediate(config)
553
+ self.output = ViTMAEOutput(config)
554
+ self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
555
+ self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
556
+
557
+ def forward(
558
+ self,
559
+ hidden_states: torch.Tensor,
560
+ head_mask: Optional[torch.Tensor] = None,
561
+ output_attentions: bool = False,
562
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
563
+ self_attention_outputs = self.attention(
564
+ self.layernorm_before(hidden_states), # in ViTMAE, layernorm is applied before self-attention
565
+ head_mask,
566
+ output_attentions=output_attentions,
567
+ )
568
+ attention_output = self_attention_outputs[0]
569
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
570
+
571
+ # first residual connection
572
+ hidden_states = attention_output + hidden_states
573
+
574
+ # in ViTMAE, layernorm is also applied after self-attention
575
+ layer_output = self.layernorm_after(hidden_states)
576
+ layer_output = self.intermediate(layer_output)
577
+
578
+ # second residual connection is done here
579
+ layer_output = self.output(layer_output, hidden_states)
580
+
581
+ outputs = (layer_output,) + outputs
582
+
583
+ return outputs
584
+
585
+
586
+ # Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->ViTMAE
587
+ class ViTMAEEncoder(nn.Module):
588
+ def __init__(self, config: ViTMAEConfig) -> None:
589
+ super().__init__()
590
+ self.config = config
591
+ self.layer = nn.ModuleList([ViTMAELayer(config) for _ in range(config.num_hidden_layers)])
592
+ self.gradient_checkpointing = False
593
+
594
+ def forward(
595
+ self,
596
+ hidden_states: torch.Tensor,
597
+ head_mask: Optional[torch.Tensor] = None,
598
+ output_attentions: bool = False,
599
+ output_hidden_states: bool = False,
600
+ return_dict: bool = True,
601
+ ) -> Union[tuple, BaseModelOutput]:
602
+ all_hidden_states = () if output_hidden_states else None
603
+ all_self_attentions = () if output_attentions else None
604
+
605
+ for i, layer_module in enumerate(self.layer):
606
+ if output_hidden_states:
607
+ all_hidden_states = all_hidden_states + (hidden_states,)
608
+
609
+ layer_head_mask = head_mask[i] if head_mask is not None else None
610
+
611
+ if self.gradient_checkpointing and self.training:
612
+ layer_outputs = self._gradient_checkpointing_func(
613
+ layer_module.__call__,
614
+ hidden_states,
615
+ layer_head_mask,
616
+ output_attentions,
617
+ )
618
+ else:
619
+ layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
620
+
621
+ hidden_states = layer_outputs[0]
622
+
623
+ if output_attentions:
624
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
625
+
626
+ if output_hidden_states:
627
+ all_hidden_states = all_hidden_states + (hidden_states,)
628
+
629
+ if not return_dict:
630
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
631
+ return BaseModelOutput(
632
+ last_hidden_state=hidden_states,
633
+ hidden_states=all_hidden_states,
634
+ attentions=all_self_attentions,
635
+ )
636
+
637
+
638
+ class ViTMAEPreTrainedModel(PreTrainedModel):
639
+ """
640
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
641
+ models.
642
+ """
643
+
644
+ config_class = ViTMAEConfig
645
+ base_model_prefix = "vit"
646
+ main_input_name = "pixel_values"
647
+ supports_gradient_checkpointing = True
648
+ _supports_sdpa = True
649
+ _supports_flash_attn_2 = True
650
+
651
+ def _init_weights(self, module):
652
+ """Initialize the weights"""
653
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
654
+ # Slightly different from the TF version which uses truncated_normal for initialization
655
+ # cf https://github.com/pytorch/pytorch/pull/5617
656
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
657
+ if module.bias is not None:
658
+ module.bias.data.zero_()
659
+ elif isinstance(module, nn.LayerNorm):
660
+ module.bias.data.zero_()
661
+ module.weight.data.fill_(1.0)
662
+ elif isinstance(module, ViTMAEEmbeddings):
663
+ module.initialize_weights()
664
+ elif isinstance(module, ViTMAEDecoder):
665
+ module.mask_token.data.zero_()
666
+ module.decoder_pos_embed.data.zero_()
667
+
668
+
669
+ VIT_MAE_START_DOCSTRING = r"""
670
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
671
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
672
+ behavior.
673
+
674
+ Parameters:
675
+ config ([`ViTMAEConfig`]): Model configuration class with all the parameters of the model.
676
+ Initializing with a config file does not load the weights associated with the model, only the
677
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
678
+ """
679
+
680
+ VIT_MAE_INPUTS_DOCSTRING = r"""
681
+ Args:
682
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
683
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`]
684
+ for details.
685
+
686
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
687
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
688
+
689
+ - 1 indicates the head is **not masked**,
690
+ - 0 indicates the head is **masked**.
691
+
692
+ output_attentions (`bool`, *optional*):
693
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
694
+ tensors for more detail.
695
+ output_hidden_states (`bool`, *optional*):
696
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
697
+ more detail.
698
+ return_dict (`bool`, *optional*):
699
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
700
+ interpolate_pos_encoding (`bool`, *optional*, default `False`):
701
+ Whether to interpolate the pre-trained position encodings. This is mainly used to use the model on higher
702
+ resolution images.
703
+ """
704
+
705
+
706
+ @add_start_docstrings(
707
+ "The bare ViTMAE Model transformer outputting raw hidden-states without any specific head on top.",
708
+ VIT_MAE_START_DOCSTRING,
709
+ )
710
+ class ViTMAEModel(ViTMAEPreTrainedModel):
711
+ def __init__(self, config):
712
+ super().__init__(config)
713
+ self.config = config
714
+
715
+ self.embeddings = ViTMAEEmbeddings(config)
716
+ self.encoder = ViTMAEEncoder(config)
717
+
718
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
719
+
720
+ # Initialize weights and apply final processing
721
+ self.post_init()
722
+
723
+ def get_input_embeddings(self):
724
+ return self.embeddings.patch_embeddings
725
+
726
+ def _prune_heads(self, heads_to_prune):
727
+ """
728
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
729
+ class PreTrainedModel
730
+ """
731
+ for layer, heads in heads_to_prune.items():
732
+ self.encoder.layer[layer].attention.prune_heads(heads)
733
+
734
+ @add_start_docstrings_to_model_forward(VIT_MAE_INPUTS_DOCSTRING)
735
+ @replace_return_docstrings(output_type=ViTMAEModelOutput, config_class=_CONFIG_FOR_DOC)
736
+ def forward(
737
+ self,
738
+ pixel_values: Optional[torch.FloatTensor] = None,
739
+ noise: Optional[torch.FloatTensor] = None,
740
+ head_mask: Optional[torch.FloatTensor] = None,
741
+ output_attentions: Optional[bool] = None,
742
+ output_hidden_states: Optional[bool] = None,
743
+ return_dict: Optional[bool] = None,
744
+ interpolate_pos_encoding: bool = False,
745
+ ) -> Union[Tuple, ViTMAEModelOutput]:
746
+ r"""
747
+ Returns:
748
+
749
+ Examples:
750
+
751
+ ```python
752
+ >>> from transformers import AutoImageProcessor, ViTMAEModel
753
+ >>> from PIL import Image
754
+ >>> import requests
755
+
756
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
757
+ >>> image = Image.open(requests.get(url, stream=True).raw)
758
+
759
+ >>> image_processor = AutoImageProcessor.from_pretrained("facebook/vit-mae-base")
760
+ >>> model = ViTMAEModel.from_pretrained("facebook/vit-mae-base")
761
+
762
+ >>> inputs = image_processor(images=image, return_tensors="pt")
763
+ >>> outputs = model(**inputs)
764
+ >>> last_hidden_states = outputs.last_hidden_state
765
+ ```"""
766
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
767
+ output_hidden_states = (
768
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
769
+ )
770
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
771
+
772
+ if pixel_values is None:
773
+ raise ValueError("You have to specify pixel_values")
774
+
775
+ # Prepare head mask if needed
776
+ # 1.0 in head_mask indicate we keep the head
777
+ # attention_probs has shape bsz x n_heads x N x N
778
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
779
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
780
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
781
+
782
+ embedding_output, mask, ids_restore = self.embeddings(
783
+ pixel_values, noise=noise, interpolate_pos_encoding=interpolate_pos_encoding
784
+ )
785
+
786
+ encoder_outputs = self.encoder(
787
+ embedding_output,
788
+ head_mask=head_mask,
789
+ output_attentions=output_attentions,
790
+ output_hidden_states=output_hidden_states,
791
+ return_dict=return_dict,
792
+ )
793
+ sequence_output = encoder_outputs[0]
794
+ sequence_output = self.layernorm(sequence_output)
795
+
796
+ if not return_dict:
797
+ return (sequence_output, mask, ids_restore) + encoder_outputs[1:]
798
+
799
+ return ViTMAEModelOutput(
800
+ last_hidden_state=sequence_output,
801
+ mask=mask,
802
+ ids_restore=ids_restore,
803
+ hidden_states=encoder_outputs.hidden_states,
804
+ attentions=encoder_outputs.attentions,
805
+ )
806
+
807
+
808
+ class ViTMAEDecoder(nn.Module):
809
+ def __init__(self, config, num_patches):
810
+ super().__init__()
811
+ self.decoder_embed = nn.Linear(config.hidden_size, config.decoder_hidden_size, bias=True)
812
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, config.decoder_hidden_size))
813
+ self.decoder_pos_embed = nn.Parameter(
814
+ torch.zeros(1, num_patches + 1, config.decoder_hidden_size), requires_grad=False
815
+ ) # fixed sin-cos embedding
816
+
817
+ decoder_config = deepcopy(config)
818
+ decoder_config.hidden_size = config.decoder_hidden_size
819
+ decoder_config.num_hidden_layers = config.decoder_num_hidden_layers
820
+ decoder_config.num_attention_heads = config.decoder_num_attention_heads
821
+ decoder_config.intermediate_size = config.decoder_intermediate_size
822
+ self.decoder_layers = nn.ModuleList(
823
+ [ViTMAELayer(decoder_config) for _ in range(config.decoder_num_hidden_layers)]
824
+ )
825
+
826
+ self.decoder_norm = nn.LayerNorm(config.decoder_hidden_size, eps=config.layer_norm_eps)
827
+ self.decoder_pred = nn.Linear(
828
+ config.decoder_hidden_size, config.patch_size**2 * config.num_channels, bias=True
829
+ ) # encoder to decoder
830
+ self.gradient_checkpointing = False
831
+ self.config = config
832
+ self.initialize_weights(num_patches)
833
+
834
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor) -> torch.Tensor:
835
+ """
836
+ This method is a modified version of the interpolation function for ViT-mae model at the decoder, that
837
+ allows to interpolate the pre-trained decoder position encodings, to be able to use the model on higher
838
+ resolution images.
839
+
840
+ Adapted from:
841
+ https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
842
+ """
843
+
844
+ # -1 removes the class dimension since we later append it without interpolation
845
+ embeddings_positions = embeddings.shape[1] - 1
846
+
847
+ # Separation of class token and patch tokens
848
+ class_pos_embed = self.decoder_pos_embed[:, :1]
849
+ patch_pos_embed = self.decoder_pos_embed[:, 1:]
850
+
851
+ # To retain the final 3d tensor with the required dimensions
852
+ dim = self.decoder_pos_embed.shape[-1]
853
+
854
+ # Increasing a dimension to enable bicubic interpolation
855
+ patch_pos_embed = patch_pos_embed.reshape(1, 1, -1, dim)
856
+
857
+ # permute to bring the dimension to be interpolated, to the last
858
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
859
+
860
+ # Interpolating the decoder position embeddings shape wrt embeddings shape i.e (x).
861
+ # we keep the second last dimension constant
862
+ patch_pos_embed = nn.functional.interpolate(
863
+ patch_pos_embed,
864
+ size=(patch_pos_embed.shape[-2], embeddings_positions),
865
+ mode="bicubic",
866
+ align_corners=False,
867
+ )
868
+
869
+ # Converting back to the original shape
870
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
871
+ # Adding the class token back
872
+ return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
873
+
874
+ def initialize_weights(self, num_patches):
875
+ # initialize (and freeze) position embeddings by sin-cos embedding
876
+ decoder_pos_embed = get_2d_sincos_pos_embed(
877
+ self.decoder_pos_embed.shape[-1], int(num_patches**0.5), add_cls_token=True
878
+ )
879
+ self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
880
+
881
+ # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
882
+ torch.nn.init.normal_(self.mask_token, std=self.config.initializer_range)
883
+
884
+ def forward(
885
+ self,
886
+ hidden_states,
887
+ ids_restore,
888
+ output_attentions=False,
889
+ output_hidden_states=False,
890
+ return_dict=True,
891
+ interpolate_pos_encoding: bool = False,
892
+ ):
893
+ # embed tokens
894
+ x = self.decoder_embed(hidden_states)
895
+
896
+ # append mask tokens to sequence
897
+ mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
898
+ x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
899
+ # unshuffle
900
+ x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]).to(x_.device))
901
+ x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
902
+ # add pos embed
903
+ if interpolate_pos_encoding:
904
+ decoder_pos_embed = self.interpolate_pos_encoding(x)
905
+ else:
906
+ decoder_pos_embed = self.decoder_pos_embed
907
+ hidden_states = x + decoder_pos_embed
908
+
909
+ # apply Transformer layers (blocks)
910
+ all_hidden_states = () if output_hidden_states else None
911
+ all_self_attentions = () if output_attentions else None
912
+ for i, layer_module in enumerate(self.decoder_layers):
913
+ if output_hidden_states:
914
+ all_hidden_states = all_hidden_states + (hidden_states,)
915
+
916
+ if self.gradient_checkpointing and self.training:
917
+ layer_outputs = self._gradient_checkpointing_func(
918
+ layer_module.__call__,
919
+ hidden_states,
920
+ None,
921
+ output_attentions,
922
+ )
923
+ else:
924
+ layer_outputs = layer_module(hidden_states, head_mask=None, output_attentions=output_attentions)
925
+
926
+ hidden_states = layer_outputs[0]
927
+
928
+ if output_attentions:
929
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
930
+
931
+ if output_hidden_states:
932
+ all_hidden_states = all_hidden_states + (hidden_states,)
933
+
934
+ hidden_states = self.decoder_norm(hidden_states)
935
+
936
+ # predictor projection
937
+ logits = self.decoder_pred(hidden_states)
938
+
939
+ # remove cls token
940
+ logits = logits[:, 1:, :]
941
+
942
+ if not return_dict:
943
+ return tuple(v for v in [logits, all_hidden_states, all_self_attentions] if v is not None)
944
+ return ViTMAEDecoderOutput(
945
+ logits=logits,
946
+ hidden_states=all_hidden_states,
947
+ attentions=all_self_attentions,
948
+ )
949
+
950
+
951
+ @add_start_docstrings(
952
+ """The ViTMAE Model transformer with the decoder on top for self-supervised pre-training.
953
+
954
+ <Tip>
955
+
956
+ Note that we provide a script to pre-train this model on custom data in our [examples
957
+ directory](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-pretraining).
958
+
959
+ </Tip>
960
+
961
+ """,
962
+ VIT_MAE_START_DOCSTRING,
963
+ )
964
+ class ViTMAEForPreTraining(ViTMAEPreTrainedModel):
965
+ def __init__(self, config):
966
+ super().__init__(config)
967
+ self.config = config
968
+
969
+ self.vit = ViTMAEModel(config)
970
+ self.decoder = ViTMAEDecoder(config, num_patches=self.vit.embeddings.num_patches)
971
+
972
+ # Initialize weights and apply final processing
973
+ self.post_init()
974
+
975
+ def get_input_embeddings(self):
976
+ return self.vit.embeddings.patch_embeddings
977
+
978
+ def _prune_heads(self, heads_to_prune):
979
+ """
980
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
981
+ class PreTrainedModel
982
+ """
983
+ for layer, heads in heads_to_prune.items():
984
+ self.encoder.layer[layer].attention.prune_heads(heads)
985
+
986
+ def patchify(self, pixel_values, interpolate_pos_encoding: bool = False):
987
+ """
988
+ Args:
989
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
990
+ Pixel values.
991
+ interpolate_pos_encoding (`bool`, *optional*, default `False`):
992
+ interpolation flag passed during the forward pass.
993
+
994
+ Returns:
995
+ `torch.FloatTensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
996
+ Patchified pixel values.
997
+ """
998
+ patch_size, num_channels = self.config.patch_size, self.config.num_channels
999
+ # sanity checks
1000
+ if not interpolate_pos_encoding and (
1001
+ pixel_values.shape[2] != pixel_values.shape[3] or pixel_values.shape[2] % patch_size != 0
1002
+ ):
1003
+ raise ValueError("Make sure the pixel values have a squared size that is divisible by the patch size")
1004
+ if pixel_values.shape[1] != num_channels:
1005
+ raise ValueError(
1006
+ "Make sure the number of channels of the pixel values is equal to the one set in the configuration"
1007
+ )
1008
+
1009
+ # patchify
1010
+ batch_size = pixel_values.shape[0]
1011
+ num_patches_h = pixel_values.shape[2] // patch_size
1012
+ num_patches_w = pixel_values.shape[3] // patch_size
1013
+ patchified_pixel_values = pixel_values.reshape(
1014
+ batch_size, num_channels, num_patches_h, patch_size, num_patches_w, patch_size
1015
+ )
1016
+ patchified_pixel_values = torch.einsum("nchpwq->nhwpqc", patchified_pixel_values)
1017
+ patchified_pixel_values = patchified_pixel_values.reshape(
1018
+ batch_size, num_patches_h * num_patches_w, patch_size**2 * num_channels
1019
+ )
1020
+ return patchified_pixel_values
1021
+
1022
+ def unpatchify(self, patchified_pixel_values, original_image_size: Optional[Tuple[int, int]] = None):
1023
+ """
1024
+ Args:
1025
+ patchified_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
1026
+ Patchified pixel values.
1027
+ original_image_size (`Tuple[int, int]`, *optional*):
1028
+ Original image size.
1029
+
1030
+ Returns:
1031
+ `torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`:
1032
+ Pixel values.
1033
+ """
1034
+ patch_size, num_channels = self.config.patch_size, self.config.num_channels
1035
+ original_image_size = (
1036
+ original_image_size
1037
+ if original_image_size is not None
1038
+ else (self.config.image_size, self.config.image_size)
1039
+ )
1040
+ original_height, original_width = original_image_size
1041
+ num_patches_h = original_height // patch_size
1042
+ num_patches_w = original_width // patch_size
1043
+ # sanity check
1044
+ if num_patches_h * num_patches_w != patchified_pixel_values.shape[1]:
1045
+ raise ValueError(
1046
+ f"The number of patches in the patchified pixel values {patchified_pixel_values.shape[1]}, does not match the number of patches on original image {num_patches_h}*{num_patches_w}"
1047
+ )
1048
+
1049
+ # unpatchify
1050
+ batch_size = patchified_pixel_values.shape[0]
1051
+ patchified_pixel_values = patchified_pixel_values.reshape(
1052
+ batch_size,
1053
+ num_patches_h,
1054
+ num_patches_w,
1055
+ patch_size,
1056
+ patch_size,
1057
+ num_channels,
1058
+ )
1059
+ patchified_pixel_values = torch.einsum("nhwpqc->nchpwq", patchified_pixel_values)
1060
+ pixel_values = patchified_pixel_values.reshape(
1061
+ batch_size,
1062
+ num_channels,
1063
+ num_patches_h * patch_size,
1064
+ num_patches_w * patch_size,
1065
+ )
1066
+ return pixel_values
1067
+
1068
+ def forward_loss(self, pixel_values, pred, mask, interpolate_pos_encoding: bool = False):
1069
+ """
1070
+ Args:
1071
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
1072
+ Pixel values.
1073
+ pred (`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
1074
+ Predicted pixel values.
1075
+ mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
1076
+ Tensor indicating which patches are masked (1) and which are not (0).
1077
+ interpolate_pos_encoding (`bool`, *optional*, default `False`):
1078
+ interpolation flag passed during the forward pass.
1079
+
1080
+ Returns:
1081
+ `torch.FloatTensor`: Pixel reconstruction loss.
1082
+ """
1083
+ target = self.patchify(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
1084
+ if self.config.norm_pix_loss:
1085
+ mean = target.mean(dim=-1, keepdim=True)
1086
+ var = target.var(dim=-1, keepdim=True)
1087
+ target = (target - mean) / (var + 1.0e-6) ** 0.5
1088
+
1089
+ loss = (pred - target) ** 2
1090
+ loss = loss.mean(dim=-1) # [N, L], mean loss per patch
1091
+ loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
1092
+ return loss
1093
+
1094
+ @add_start_docstrings_to_model_forward(VIT_MAE_INPUTS_DOCSTRING)
1095
+ @replace_return_docstrings(output_type=ViTMAEForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
1096
+ def forward(
1097
+ self,
1098
+ pixel_values: Optional[torch.FloatTensor] = None,
1099
+ noise: Optional[torch.FloatTensor] = None,
1100
+ head_mask: Optional[torch.FloatTensor] = None,
1101
+ output_attentions: Optional[bool] = None,
1102
+ output_hidden_states: Optional[bool] = None,
1103
+ return_dict: Optional[bool] = None,
1104
+ interpolate_pos_encoding: bool = False,
1105
+ ) -> Union[Tuple, ViTMAEForPreTrainingOutput]:
1106
+ r"""
1107
+ Returns:
1108
+
1109
+ Examples:
1110
+
1111
+ ```python
1112
+ >>> from transformers import AutoImageProcessor, ViTMAEForPreTraining
1113
+ >>> from PIL import Image
1114
+ >>> import requests
1115
+
1116
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1117
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1118
+
1119
+ >>> image_processor = AutoImageProcessor.from_pretrained("facebook/vit-mae-base")
1120
+ >>> model = ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base")
1121
+
1122
+ >>> inputs = image_processor(images=image, return_tensors="pt")
1123
+ >>> outputs = model(**inputs)
1124
+ >>> loss = outputs.loss
1125
+ >>> mask = outputs.mask
1126
+ >>> ids_restore = outputs.ids_restore
1127
+ ```"""
1128
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1129
+
1130
+ outputs = self.vit(
1131
+ pixel_values,
1132
+ noise=noise,
1133
+ head_mask=head_mask,
1134
+ output_attentions=output_attentions,
1135
+ output_hidden_states=output_hidden_states,
1136
+ return_dict=return_dict,
1137
+ interpolate_pos_encoding=interpolate_pos_encoding,
1138
+ )
1139
+
1140
+ latent = outputs.last_hidden_state
1141
+ ids_restore = outputs.ids_restore
1142
+ mask = outputs.mask
1143
+
1144
+ decoder_outputs = self.decoder(latent, ids_restore, interpolate_pos_encoding=interpolate_pos_encoding)
1145
+ logits = decoder_outputs.logits # shape (batch_size, num_patches, patch_size*patch_size*num_channels)
1146
+
1147
+ loss = self.forward_loss(pixel_values, logits, mask, interpolate_pos_encoding=interpolate_pos_encoding)
1148
+
1149
+ if not return_dict:
1150
+ output = (logits, mask, ids_restore) + outputs[2:]
1151
+ return ((loss,) + output) if loss is not None else output
1152
+
1153
+ return ViTMAEForPreTrainingOutput(
1154
+ loss=loss,
1155
+ logits=logits,
1156
+ mask=mask,
1157
+ ids_restore=ids_restore,
1158
+ hidden_states=outputs.hidden_states,
1159
+ attentions=outputs.attentions,
1160
+ )
1161
+
1162
+
1163
+ __all__ = ["ViTMAEForPreTraining", "ViTMAELayer", "ViTMAEModel", "ViTMAEPreTrainedModel"]
docs/transformers/build/lib/transformers/models/vit_msn/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_vit_msn import *
22
+ from .modeling_vit_msn import *
23
+ else:
24
+ import sys
25
+
26
+ _file = globals()["__file__"]
27
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/build/lib/transformers/models/vit_msn/configuration_vit_msn.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 Facebook AI and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ViT MSN model configuration"""
16
+
17
+ from ...configuration_utils import PretrainedConfig
18
+ from ...utils import logging
19
+
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+
24
+ class ViTMSNConfig(PretrainedConfig):
25
+ r"""
26
+ This is the configuration class to store the configuration of a [`ViTMSNModel`]. It is used to instantiate an ViT
27
+ MSN model according to the specified arguments, defining the model architecture. Instantiating a configuration with
28
+ the defaults will yield a similar configuration to that of the ViT
29
+ [facebook/vit_msn_base](https://huggingface.co/facebook/vit_msn_base) architecture.
30
+
31
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
32
+ documentation from [`PretrainedConfig`] for more information.
33
+
34
+
35
+ Args:
36
+ hidden_size (`int`, *optional*, defaults to 768):
37
+ Dimensionality of the encoder layers and the pooler layer.
38
+ num_hidden_layers (`int`, *optional*, defaults to 12):
39
+ Number of hidden layers in the Transformer encoder.
40
+ num_attention_heads (`int`, *optional*, defaults to 12):
41
+ Number of attention heads for each attention layer in the Transformer encoder.
42
+ intermediate_size (`int`, *optional*, defaults to 3072):
43
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
44
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
45
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
46
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
47
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
48
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
49
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
50
+ The dropout ratio for the attention probabilities.
51
+ initializer_range (`float`, *optional*, defaults to 0.02):
52
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
53
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
54
+ The epsilon used by the layer normalization layers.
55
+ image_size (`int`, *optional*, defaults to 224):
56
+ The size (resolution) of each image.
57
+ patch_size (`int`, *optional*, defaults to 16):
58
+ The size (resolution) of each patch.
59
+ num_channels (`int`, *optional*, defaults to 3):
60
+ The number of input channels.
61
+ qkv_bias (`bool`, *optional*, defaults to `True`):
62
+ Whether to add a bias to the queries, keys and values.
63
+
64
+ Example:
65
+
66
+ ```python
67
+ >>> from transformers import ViTMSNModel, ViTMSNConfig
68
+
69
+ >>> # Initializing a ViT MSN vit-msn-base style configuration
70
+ >>> configuration = ViTConfig()
71
+
72
+ >>> # Initializing a model from the vit-msn-base style configuration
73
+ >>> model = ViTMSNModel(configuration)
74
+
75
+ >>> # Accessing the model configuration
76
+ >>> configuration = model.config
77
+ ```"""
78
+
79
+ model_type = "vit_msn"
80
+
81
+ def __init__(
82
+ self,
83
+ hidden_size=768,
84
+ num_hidden_layers=12,
85
+ num_attention_heads=12,
86
+ intermediate_size=3072,
87
+ hidden_act="gelu",
88
+ hidden_dropout_prob=0.0,
89
+ attention_probs_dropout_prob=0.0,
90
+ initializer_range=0.02,
91
+ layer_norm_eps=1e-06,
92
+ image_size=224,
93
+ patch_size=16,
94
+ num_channels=3,
95
+ qkv_bias=True,
96
+ **kwargs,
97
+ ):
98
+ super().__init__(**kwargs)
99
+
100
+ self.hidden_size = hidden_size
101
+ self.num_hidden_layers = num_hidden_layers
102
+ self.num_attention_heads = num_attention_heads
103
+ self.intermediate_size = intermediate_size
104
+ self.hidden_act = hidden_act
105
+ self.hidden_dropout_prob = hidden_dropout_prob
106
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
107
+ self.initializer_range = initializer_range
108
+ self.layer_norm_eps = layer_norm_eps
109
+ self.image_size = image_size
110
+ self.patch_size = patch_size
111
+ self.num_channels = num_channels
112
+ self.qkv_bias = qkv_bias
113
+
114
+
115
+ __all__ = ["ViTMSNConfig"]
docs/transformers/build/lib/transformers/models/vit_msn/convert_msn_to_pytorch.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Convert ViT MSN checkpoints from the original repository: https://github.com/facebookresearch/msn"""
16
+
17
+ import argparse
18
+ import json
19
+
20
+ import requests
21
+ import torch
22
+ from huggingface_hub import hf_hub_download
23
+ from PIL import Image
24
+
25
+ from transformers import ViTImageProcessor, ViTMSNConfig, ViTMSNModel
26
+ from transformers.image_utils import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
27
+
28
+
29
+ torch.set_grad_enabled(False)
30
+
31
+
32
+ # here we list all keys to be renamed (original name on the left, our name on the right)
33
+ def create_rename_keys(config, base_model=False):
34
+ rename_keys = []
35
+ for i in range(config.num_hidden_layers):
36
+ # encoder layers: output projection, 2 feedforward neural networks and 2 layernorms
37
+ rename_keys.append((f"module.blocks.{i}.norm1.weight", f"vit.encoder.layer.{i}.layernorm_before.weight"))
38
+ rename_keys.append((f"module.blocks.{i}.norm1.bias", f"vit.encoder.layer.{i}.layernorm_before.bias"))
39
+ rename_keys.append(
40
+ (f"module.blocks.{i}.attn.proj.weight", f"vit.encoder.layer.{i}.attention.output.dense.weight")
41
+ )
42
+ rename_keys.append((f"module.blocks.{i}.attn.proj.bias", f"vit.encoder.layer.{i}.attention.output.dense.bias"))
43
+ rename_keys.append((f"module.blocks.{i}.norm2.weight", f"vit.encoder.layer.{i}.layernorm_after.weight"))
44
+ rename_keys.append((f"module.blocks.{i}.norm2.bias", f"vit.encoder.layer.{i}.layernorm_after.bias"))
45
+ rename_keys.append((f"module.blocks.{i}.mlp.fc1.weight", f"vit.encoder.layer.{i}.intermediate.dense.weight"))
46
+ rename_keys.append((f"module.blocks.{i}.mlp.fc1.bias", f"vit.encoder.layer.{i}.intermediate.dense.bias"))
47
+ rename_keys.append((f"module.blocks.{i}.mlp.fc2.weight", f"vit.encoder.layer.{i}.output.dense.weight"))
48
+ rename_keys.append((f"module.blocks.{i}.mlp.fc2.bias", f"vit.encoder.layer.{i}.output.dense.bias"))
49
+
50
+ # projection layer + position embeddings
51
+ rename_keys.extend(
52
+ [
53
+ ("module.cls_token", "vit.embeddings.cls_token"),
54
+ ("module.patch_embed.proj.weight", "vit.embeddings.patch_embeddings.projection.weight"),
55
+ ("module.patch_embed.proj.bias", "vit.embeddings.patch_embeddings.projection.bias"),
56
+ ("module.pos_embed", "vit.embeddings.position_embeddings"),
57
+ ]
58
+ )
59
+
60
+ if base_model:
61
+ # layernorm + pooler
62
+ rename_keys.extend(
63
+ [
64
+ ("module.norm.weight", "layernorm.weight"),
65
+ ("module.norm.bias", "layernorm.bias"),
66
+ ]
67
+ )
68
+
69
+ # if just the base model, we should remove "vit" from all keys that start with "vit"
70
+ rename_keys = [(pair[0], pair[1][4:]) if pair[1].startswith("vit") else pair for pair in rename_keys]
71
+ else:
72
+ # layernorm + classification head
73
+ rename_keys.extend(
74
+ [
75
+ ("norm.weight", "vit.layernorm.weight"),
76
+ ("norm.bias", "vit.layernorm.bias"),
77
+ ("head.weight", "classifier.weight"),
78
+ ("head.bias", "classifier.bias"),
79
+ ]
80
+ )
81
+
82
+ return rename_keys
83
+
84
+
85
+ # we split up the matrix of each encoder layer into queries, keys and values
86
+ def read_in_q_k_v(state_dict, config, base_model=False):
87
+ for i in range(config.num_hidden_layers):
88
+ if base_model:
89
+ prefix = ""
90
+ else:
91
+ prefix = "vit."
92
+ # read in weights + bias of input projection layer (in timm, this is a single matrix + bias)
93
+ in_proj_weight = state_dict.pop(f"module.blocks.{i}.attn.qkv.weight")
94
+ in_proj_bias = state_dict.pop(f"module.blocks.{i}.attn.qkv.bias")
95
+ # next, add query, keys and values (in that order) to the state dict
96
+ state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[
97
+ : config.hidden_size, :
98
+ ]
99
+ state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.bias"] = in_proj_bias[: config.hidden_size]
100
+ state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[
101
+ config.hidden_size : config.hidden_size * 2, :
102
+ ]
103
+ state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.bias"] = in_proj_bias[
104
+ config.hidden_size : config.hidden_size * 2
105
+ ]
106
+ state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[
107
+ -config.hidden_size :, :
108
+ ]
109
+ state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.bias"] = in_proj_bias[-config.hidden_size :]
110
+
111
+
112
+ def remove_classification_head_(state_dict):
113
+ ignore_keys = ["head.weight", "head.bias"]
114
+ for k in ignore_keys:
115
+ state_dict.pop(k, None)
116
+
117
+
118
+ def remove_projection_head(state_dict):
119
+ # projection head is used in the self-supervised pre-training in MSN,
120
+ # for downstream task it's not needed.
121
+ ignore_keys = [
122
+ "module.fc.fc1.weight",
123
+ "module.fc.fc1.bias",
124
+ "module.fc.bn1.weight",
125
+ "module.fc.bn1.bias",
126
+ "module.fc.bn1.running_mean",
127
+ "module.fc.bn1.running_var",
128
+ "module.fc.bn1.num_batches_tracked",
129
+ "module.fc.fc2.weight",
130
+ "module.fc.fc2.bias",
131
+ "module.fc.bn2.weight",
132
+ "module.fc.bn2.bias",
133
+ "module.fc.bn2.running_mean",
134
+ "module.fc.bn2.running_var",
135
+ "module.fc.bn2.num_batches_tracked",
136
+ "module.fc.fc3.weight",
137
+ "module.fc.fc3.bias",
138
+ ]
139
+ for k in ignore_keys:
140
+ state_dict.pop(k, None)
141
+
142
+
143
+ def rename_key(dct, old, new):
144
+ val = dct.pop(old)
145
+ dct[new] = val
146
+
147
+
148
+ def convert_vit_msn_checkpoint(checkpoint_url, pytorch_dump_folder_path):
149
+ config = ViTMSNConfig()
150
+ config.num_labels = 1000
151
+
152
+ repo_id = "datasets/huggingface/label-files"
153
+ filename = "imagenet-1k-id2label.json"
154
+ id2label = json.load(open(hf_hub_download(repo_id, filename), "r"))
155
+ id2label = {int(k): v for k, v in id2label.items()}
156
+ config.id2label = id2label
157
+ config.label2id = {v: k for k, v in id2label.items()}
158
+
159
+ if "s16" in checkpoint_url:
160
+ config.hidden_size = 384
161
+ config.intermediate_size = 1536
162
+ config.num_attention_heads = 6
163
+ elif "l16" in checkpoint_url:
164
+ config.hidden_size = 1024
165
+ config.intermediate_size = 4096
166
+ config.num_hidden_layers = 24
167
+ config.num_attention_heads = 16
168
+ config.hidden_dropout_prob = 0.1
169
+ elif "b4" in checkpoint_url:
170
+ config.patch_size = 4
171
+ elif "l7" in checkpoint_url:
172
+ config.patch_size = 7
173
+ config.hidden_size = 1024
174
+ config.intermediate_size = 4096
175
+ config.num_hidden_layers = 24
176
+ config.num_attention_heads = 16
177
+ config.hidden_dropout_prob = 0.1
178
+
179
+ model = ViTMSNModel(config)
180
+
181
+ state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu")["target_encoder"]
182
+
183
+ image_processor = ViTImageProcessor(size=config.image_size)
184
+
185
+ remove_projection_head(state_dict)
186
+ rename_keys = create_rename_keys(config, base_model=True)
187
+
188
+ for src, dest in rename_keys:
189
+ rename_key(state_dict, src, dest)
190
+ read_in_q_k_v(state_dict, config, base_model=True)
191
+
192
+ model.load_state_dict(state_dict)
193
+ model.eval()
194
+
195
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
196
+
197
+ image = Image.open(requests.get(url, stream=True).raw)
198
+ image_processor = ViTImageProcessor(
199
+ size=config.image_size, image_mean=IMAGENET_DEFAULT_MEAN, image_std=IMAGENET_DEFAULT_STD
200
+ )
201
+ inputs = image_processor(images=image, return_tensors="pt")
202
+
203
+ # forward pass
204
+ torch.manual_seed(2)
205
+ outputs = model(**inputs)
206
+ last_hidden_state = outputs.last_hidden_state
207
+
208
+ # The following Colab Notebook was used to generate these outputs:
209
+ # https://colab.research.google.com/gist/sayakpaul/3672419a04f5997827503fd84079bdd1/scratchpad.ipynb
210
+ if "s16" in checkpoint_url:
211
+ expected_slice = torch.tensor([[-1.0915, -1.4876, -1.1809]])
212
+ elif "b16" in checkpoint_url:
213
+ expected_slice = torch.tensor([[14.2889, -18.9045, 11.7281]])
214
+ elif "l16" in checkpoint_url:
215
+ expected_slice = torch.tensor([[41.5028, -22.8681, 45.6475]])
216
+ elif "b4" in checkpoint_url:
217
+ expected_slice = torch.tensor([[-4.3868, 5.2932, -0.4137]])
218
+ else:
219
+ expected_slice = torch.tensor([[-0.1792, -0.6465, 2.4263]])
220
+
221
+ # verify logits
222
+ assert torch.allclose(last_hidden_state[:, 0, :3], expected_slice, atol=1e-4)
223
+
224
+ print(f"Saving model to {pytorch_dump_folder_path}")
225
+ model.save_pretrained(pytorch_dump_folder_path)
226
+
227
+ print(f"Saving image processor to {pytorch_dump_folder_path}")
228
+ image_processor.save_pretrained(pytorch_dump_folder_path)
229
+
230
+
231
+ if __name__ == "__main__":
232
+ parser = argparse.ArgumentParser()
233
+ # Required parameters
234
+ parser.add_argument(
235
+ "--checkpoint_url",
236
+ default="https://dl.fbaipublicfiles.com/msn/vits16_800ep.pth.tar",
237
+ type=str,
238
+ help="URL of the checkpoint you'd like to convert.",
239
+ )
240
+ parser.add_argument(
241
+ "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
242
+ )
243
+
244
+ args = parser.parse_args()
245
+ convert_vit_msn_checkpoint(args.checkpoint_url, args.pytorch_dump_folder_path)
docs/transformers/build/lib/transformers/models/vit_msn/modeling_vit_msn.py ADDED
@@ -0,0 +1,741 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 Facebook AI and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch ViT MSN (masked siamese network) model."""
16
+
17
+ import collections.abc
18
+ from typing import Callable, Dict, List, Optional, Set, Tuple, Union
19
+
20
+ import torch
21
+ import torch.utils.checkpoint
22
+ from torch import nn
23
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
24
+
25
+ from ...activations import ACT2FN
26
+ from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput
27
+ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
28
+ from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
29
+ from ...utils import (
30
+ add_start_docstrings,
31
+ add_start_docstrings_to_model_forward,
32
+ logging,
33
+ replace_return_docstrings,
34
+ torch_int,
35
+ )
36
+ from .configuration_vit_msn import ViTMSNConfig
37
+
38
+
39
+ logger = logging.get_logger(__name__)
40
+
41
+
42
+ _CONFIG_FOR_DOC = "ViTMSNConfig"
43
+ _CHECKPOINT_FOR_DOC = "facebook/vit-msn-small"
44
+
45
+
46
+ class ViTMSNEmbeddings(nn.Module):
47
+ """
48
+ Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
49
+ """
50
+
51
+ def __init__(self, config: ViTMSNConfig, use_mask_token: bool = False) -> None:
52
+ super().__init__()
53
+
54
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
55
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
56
+ self.patch_embeddings = ViTMSNPatchEmbeddings(config)
57
+ num_patches = self.patch_embeddings.num_patches
58
+ self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
59
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
60
+ self.patch_size = config.patch_size
61
+ self.config = config
62
+
63
+ # Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
64
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
65
+ """
66
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
67
+ images. This method is also adapted to support torch.jit tracing.
68
+
69
+ Adapted from:
70
+ - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
71
+ - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
72
+ """
73
+
74
+ num_patches = embeddings.shape[1] - 1
75
+ num_positions = self.position_embeddings.shape[1] - 1
76
+
77
+ # always interpolate when tracing to ensure the exported model works for dynamic input shapes
78
+ if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
79
+ return self.position_embeddings
80
+
81
+ class_pos_embed = self.position_embeddings[:, :1]
82
+ patch_pos_embed = self.position_embeddings[:, 1:]
83
+
84
+ dim = embeddings.shape[-1]
85
+
86
+ new_height = height // self.patch_size
87
+ new_width = width // self.patch_size
88
+
89
+ sqrt_num_positions = torch_int(num_positions**0.5)
90
+ patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
91
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
92
+
93
+ patch_pos_embed = nn.functional.interpolate(
94
+ patch_pos_embed,
95
+ size=(new_height, new_width),
96
+ mode="bicubic",
97
+ align_corners=False,
98
+ )
99
+
100
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
101
+
102
+ return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
103
+
104
+ def forward(
105
+ self,
106
+ pixel_values: torch.Tensor,
107
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
108
+ interpolate_pos_encoding: bool = False,
109
+ ) -> torch.Tensor:
110
+ batch_size, num_channels, height, width = pixel_values.shape
111
+ embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
112
+
113
+ if bool_masked_pos is not None:
114
+ seq_length = embeddings.shape[1]
115
+ mask_tokens = self.mask_token.expand(batch_size, seq_length, -1)
116
+ # replace the masked visual tokens by mask_tokens
117
+ mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
118
+ embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
119
+
120
+ # add the [CLS] token to the embedded patch tokens
121
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
122
+ embeddings = torch.cat((cls_tokens, embeddings), dim=1)
123
+
124
+ # add positional encoding to each token
125
+ if interpolate_pos_encoding:
126
+ embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
127
+ else:
128
+ embeddings = embeddings + self.position_embeddings
129
+
130
+ embeddings = self.dropout(embeddings)
131
+
132
+ return embeddings
133
+
134
+
135
+ # Copied from transformers.models.vit.modeling_vit.ViTPatchEmbeddings with ViT->ViTMSN
136
+ class ViTMSNPatchEmbeddings(nn.Module):
137
+ """
138
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
139
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
140
+ Transformer.
141
+ """
142
+
143
+ def __init__(self, config):
144
+ super().__init__()
145
+ image_size, patch_size = config.image_size, config.patch_size
146
+ num_channels, hidden_size = config.num_channels, config.hidden_size
147
+
148
+ image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
149
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
150
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
151
+ self.image_size = image_size
152
+ self.patch_size = patch_size
153
+ self.num_channels = num_channels
154
+ self.num_patches = num_patches
155
+
156
+ self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
157
+
158
+ def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
159
+ batch_size, num_channels, height, width = pixel_values.shape
160
+ if num_channels != self.num_channels:
161
+ raise ValueError(
162
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
163
+ f" Expected {self.num_channels} but got {num_channels}."
164
+ )
165
+ if not interpolate_pos_encoding:
166
+ if height != self.image_size[0] or width != self.image_size[1]:
167
+ raise ValueError(
168
+ f"Input image size ({height}*{width}) doesn't match model"
169
+ f" ({self.image_size[0]}*{self.image_size[1]})."
170
+ )
171
+ embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
172
+ return embeddings
173
+
174
+
175
+ # Copied from transformers.models.vit.modeling_vit.eager_attention_forward
176
+ def eager_attention_forward(
177
+ module: nn.Module,
178
+ query: torch.Tensor,
179
+ key: torch.Tensor,
180
+ value: torch.Tensor,
181
+ attention_mask: Optional[torch.Tensor],
182
+ scaling: float,
183
+ dropout: float = 0.0,
184
+ **kwargs,
185
+ ):
186
+ # Take the dot product between "query" and "key" to get the raw attention scores.
187
+ attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
188
+
189
+ # Normalize the attention scores to probabilities.
190
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
191
+
192
+ # This is actually dropping out entire tokens to attend to, which might
193
+ # seem a bit unusual, but is taken from the original Transformer paper.
194
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
195
+
196
+ # Mask heads if we want to
197
+ if attention_mask is not None:
198
+ attn_weights = attn_weights * attention_mask
199
+
200
+ attn_output = torch.matmul(attn_weights, value)
201
+ attn_output = attn_output.transpose(1, 2).contiguous()
202
+
203
+ return attn_output, attn_weights
204
+
205
+
206
+ # Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->ViTMSN
207
+ class ViTMSNSelfAttention(nn.Module):
208
+ def __init__(self, config: ViTMSNConfig) -> None:
209
+ super().__init__()
210
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
211
+ raise ValueError(
212
+ f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
213
+ f"heads {config.num_attention_heads}."
214
+ )
215
+
216
+ self.config = config
217
+ self.num_attention_heads = config.num_attention_heads
218
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
219
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
220
+ self.dropout_prob = config.attention_probs_dropout_prob
221
+ self.scaling = self.attention_head_size**-0.5
222
+ self.is_causal = False
223
+
224
+ self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
225
+ self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
226
+ self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
227
+
228
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
229
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
230
+ x = x.view(new_x_shape)
231
+ return x.permute(0, 2, 1, 3)
232
+
233
+ def forward(
234
+ self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
235
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
236
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
237
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
238
+ query_layer = self.transpose_for_scores(self.query(hidden_states))
239
+
240
+ attention_interface: Callable = eager_attention_forward
241
+ if self.config._attn_implementation != "eager":
242
+ if self.config._attn_implementation == "sdpa" and output_attentions:
243
+ logger.warning_once(
244
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
245
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
246
+ )
247
+ else:
248
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
249
+
250
+ context_layer, attention_probs = attention_interface(
251
+ self,
252
+ query_layer,
253
+ key_layer,
254
+ value_layer,
255
+ head_mask,
256
+ is_causal=self.is_causal,
257
+ scaling=self.scaling,
258
+ dropout=0.0 if not self.training else self.dropout_prob,
259
+ )
260
+
261
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
262
+ context_layer = context_layer.reshape(new_context_layer_shape)
263
+
264
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
265
+
266
+ return outputs
267
+
268
+
269
+ # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->ViTMSN
270
+ class ViTMSNSelfOutput(nn.Module):
271
+ """
272
+ The residual connection is defined in ViTMSNLayer instead of here (as is the case with other models), due to the
273
+ layernorm applied before each block.
274
+ """
275
+
276
+ def __init__(self, config: ViTMSNConfig) -> None:
277
+ super().__init__()
278
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
279
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
280
+
281
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
282
+ hidden_states = self.dense(hidden_states)
283
+ hidden_states = self.dropout(hidden_states)
284
+
285
+ return hidden_states
286
+
287
+
288
+ # Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->ViTMSN
289
+ class ViTMSNAttention(nn.Module):
290
+ def __init__(self, config: ViTMSNConfig) -> None:
291
+ super().__init__()
292
+ self.attention = ViTMSNSelfAttention(config)
293
+ self.output = ViTMSNSelfOutput(config)
294
+ self.pruned_heads = set()
295
+
296
+ def prune_heads(self, heads: Set[int]) -> None:
297
+ if len(heads) == 0:
298
+ return
299
+ heads, index = find_pruneable_heads_and_indices(
300
+ heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
301
+ )
302
+
303
+ # Prune linear layers
304
+ self.attention.query = prune_linear_layer(self.attention.query, index)
305
+ self.attention.key = prune_linear_layer(self.attention.key, index)
306
+ self.attention.value = prune_linear_layer(self.attention.value, index)
307
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
308
+
309
+ # Update hyper params and store pruned heads
310
+ self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
311
+ self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
312
+ self.pruned_heads = self.pruned_heads.union(heads)
313
+
314
+ def forward(
315
+ self,
316
+ hidden_states: torch.Tensor,
317
+ head_mask: Optional[torch.Tensor] = None,
318
+ output_attentions: bool = False,
319
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
320
+ self_outputs = self.attention(hidden_states, head_mask, output_attentions)
321
+
322
+ attention_output = self.output(self_outputs[0], hidden_states)
323
+
324
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
325
+ return outputs
326
+
327
+
328
+ # Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->ViTMSN
329
+ class ViTMSNIntermediate(nn.Module):
330
+ def __init__(self, config: ViTMSNConfig) -> None:
331
+ super().__init__()
332
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
333
+ if isinstance(config.hidden_act, str):
334
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
335
+ else:
336
+ self.intermediate_act_fn = config.hidden_act
337
+
338
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
339
+ hidden_states = self.dense(hidden_states)
340
+ hidden_states = self.intermediate_act_fn(hidden_states)
341
+
342
+ return hidden_states
343
+
344
+
345
+ # Copied from transformers.models.vit.modeling_vit.ViTOutput with ViT->ViTMSN
346
+ class ViTMSNOutput(nn.Module):
347
+ def __init__(self, config: ViTMSNConfig) -> None:
348
+ super().__init__()
349
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
350
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
351
+
352
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
353
+ hidden_states = self.dense(hidden_states)
354
+ hidden_states = self.dropout(hidden_states)
355
+
356
+ hidden_states = hidden_states + input_tensor
357
+
358
+ return hidden_states
359
+
360
+
361
+ # Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->ViTMSN, VIT->VITMSN
362
+ class ViTMSNLayer(nn.Module):
363
+ """This corresponds to the Block class in the timm implementation."""
364
+
365
+ def __init__(self, config: ViTMSNConfig) -> None:
366
+ super().__init__()
367
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
368
+ self.seq_len_dim = 1
369
+ self.attention = ViTMSNAttention(config)
370
+ self.intermediate = ViTMSNIntermediate(config)
371
+ self.output = ViTMSNOutput(config)
372
+ self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
373
+ self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
374
+
375
+ def forward(
376
+ self,
377
+ hidden_states: torch.Tensor,
378
+ head_mask: Optional[torch.Tensor] = None,
379
+ output_attentions: bool = False,
380
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
381
+ self_attention_outputs = self.attention(
382
+ self.layernorm_before(hidden_states), # in ViTMSN, layernorm is applied before self-attention
383
+ head_mask,
384
+ output_attentions=output_attentions,
385
+ )
386
+ attention_output = self_attention_outputs[0]
387
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
388
+
389
+ # first residual connection
390
+ hidden_states = attention_output + hidden_states
391
+
392
+ # in ViTMSN, layernorm is also applied after self-attention
393
+ layer_output = self.layernorm_after(hidden_states)
394
+ layer_output = self.intermediate(layer_output)
395
+
396
+ # second residual connection is done here
397
+ layer_output = self.output(layer_output, hidden_states)
398
+
399
+ outputs = (layer_output,) + outputs
400
+
401
+ return outputs
402
+
403
+
404
+ # Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->ViTMSN
405
+ class ViTMSNEncoder(nn.Module):
406
+ def __init__(self, config: ViTMSNConfig) -> None:
407
+ super().__init__()
408
+ self.config = config
409
+ self.layer = nn.ModuleList([ViTMSNLayer(config) for _ in range(config.num_hidden_layers)])
410
+ self.gradient_checkpointing = False
411
+
412
+ def forward(
413
+ self,
414
+ hidden_states: torch.Tensor,
415
+ head_mask: Optional[torch.Tensor] = None,
416
+ output_attentions: bool = False,
417
+ output_hidden_states: bool = False,
418
+ return_dict: bool = True,
419
+ ) -> Union[tuple, BaseModelOutput]:
420
+ all_hidden_states = () if output_hidden_states else None
421
+ all_self_attentions = () if output_attentions else None
422
+
423
+ for i, layer_module in enumerate(self.layer):
424
+ if output_hidden_states:
425
+ all_hidden_states = all_hidden_states + (hidden_states,)
426
+
427
+ layer_head_mask = head_mask[i] if head_mask is not None else None
428
+
429
+ if self.gradient_checkpointing and self.training:
430
+ layer_outputs = self._gradient_checkpointing_func(
431
+ layer_module.__call__,
432
+ hidden_states,
433
+ layer_head_mask,
434
+ output_attentions,
435
+ )
436
+ else:
437
+ layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
438
+
439
+ hidden_states = layer_outputs[0]
440
+
441
+ if output_attentions:
442
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
443
+
444
+ if output_hidden_states:
445
+ all_hidden_states = all_hidden_states + (hidden_states,)
446
+
447
+ if not return_dict:
448
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
449
+ return BaseModelOutput(
450
+ last_hidden_state=hidden_states,
451
+ hidden_states=all_hidden_states,
452
+ attentions=all_self_attentions,
453
+ )
454
+
455
+
456
+ class ViTMSNPreTrainedModel(PreTrainedModel):
457
+ """
458
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
459
+ models.
460
+ """
461
+
462
+ config_class = ViTMSNConfig
463
+ base_model_prefix = "vit"
464
+ main_input_name = "pixel_values"
465
+ supports_gradient_checkpointing = True
466
+ _no_split_modules = ["ViTMSNAttention", "ViTMSNSdpaAttention"]
467
+ _supports_sdpa = True
468
+ _supports_flash_attn_2 = True
469
+
470
+ # todo: Resort to https://github.com/facebookresearch/msn/blob/main/src/deit.py#L200-#L211
471
+ # when creating pre-training scripts.
472
+ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
473
+ """Initialize the weights"""
474
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
475
+ # Slightly different from the TF version which uses truncated_normal for initialization
476
+ # cf https://github.com/pytorch/pytorch/pull/5617
477
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
478
+ if module.bias is not None:
479
+ module.bias.data.zero_()
480
+ elif isinstance(module, nn.LayerNorm):
481
+ module.bias.data.zero_()
482
+ module.weight.data.fill_(1.0)
483
+ elif isinstance(module, ViTMSNEmbeddings):
484
+ module.cls_token.data.zero_()
485
+ module.position_embeddings.data.zero_()
486
+ if module.mask_token is not None:
487
+ module.mask_token.data.zero_()
488
+
489
+
490
+ VIT_MSN_START_DOCSTRING = r"""
491
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
492
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
493
+ behavior.
494
+
495
+ Parameters:
496
+ config ([`ViTMSNConfig`]): Model configuration class with all the parameters of the model.
497
+ Initializing with a config file does not load the weights associated with the model, only the
498
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
499
+ """
500
+
501
+ VIT_MSN_INPUTS_DOCSTRING = r"""
502
+ Args:
503
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
504
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`]
505
+ for details.
506
+
507
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
508
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
509
+
510
+ - 1 indicates the head is **not masked**,
511
+ - 0 indicates the head is **masked**.
512
+
513
+ output_attentions (`bool`, *optional*):
514
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
515
+ tensors for more detail.
516
+ output_hidden_states (`bool`, *optional*):
517
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
518
+ more detail.
519
+ interpolate_pos_encoding (`bool`, *optional*):
520
+ Whether to interpolate the pre-trained position encodings.
521
+ return_dict (`bool`, *optional*):
522
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
523
+ """
524
+
525
+
526
+ @add_start_docstrings(
527
+ "The bare ViTMSN Model outputting raw hidden-states without any specific head on top.",
528
+ VIT_MSN_START_DOCSTRING,
529
+ )
530
+ class ViTMSNModel(ViTMSNPreTrainedModel):
531
+ def __init__(self, config: ViTMSNConfig, use_mask_token: bool = False):
532
+ super().__init__(config)
533
+ self.config = config
534
+
535
+ self.embeddings = ViTMSNEmbeddings(config, use_mask_token=use_mask_token)
536
+ self.encoder = ViTMSNEncoder(config)
537
+
538
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
539
+
540
+ # Initialize weights and apply final processing
541
+ self.post_init()
542
+
543
+ def get_input_embeddings(self) -> ViTMSNPatchEmbeddings:
544
+ return self.embeddings.patch_embeddings
545
+
546
+ def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
547
+ """
548
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
549
+ class PreTrainedModel
550
+ """
551
+ for layer, heads in heads_to_prune.items():
552
+ self.encoder.layer[layer].attention.prune_heads(heads)
553
+
554
+ @add_start_docstrings_to_model_forward(VIT_MSN_INPUTS_DOCSTRING)
555
+ @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
556
+ def forward(
557
+ self,
558
+ pixel_values: Optional[torch.Tensor] = None,
559
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
560
+ head_mask: Optional[torch.Tensor] = None,
561
+ output_attentions: Optional[bool] = None,
562
+ output_hidden_states: Optional[bool] = None,
563
+ interpolate_pos_encoding: Optional[bool] = None,
564
+ return_dict: Optional[bool] = None,
565
+ ) -> Union[tuple, BaseModelOutput]:
566
+ r"""
567
+ bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
568
+ Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
569
+
570
+ Returns:
571
+
572
+ Examples:
573
+
574
+ ```python
575
+ >>> from transformers import AutoImageProcessor, ViTMSNModel
576
+ >>> import torch
577
+ >>> from PIL import Image
578
+ >>> import requests
579
+
580
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
581
+ >>> image = Image.open(requests.get(url, stream=True).raw)
582
+
583
+ >>> image_processor = AutoImageProcessor.from_pretrained("facebook/vit-msn-small")
584
+ >>> model = ViTMSNModel.from_pretrained("facebook/vit-msn-small")
585
+ >>> inputs = image_processor(images=image, return_tensors="pt")
586
+ >>> with torch.no_grad():
587
+ ... outputs = model(**inputs)
588
+ >>> last_hidden_states = outputs.last_hidden_state
589
+ ```"""
590
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
591
+ output_hidden_states = (
592
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
593
+ )
594
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
595
+
596
+ if pixel_values is None:
597
+ raise ValueError("You have to specify pixel_values")
598
+
599
+ # Prepare head mask if needed
600
+ # 1.0 in head_mask indicate we keep the head
601
+ # attention_probs has shape bsz x n_heads x N x N
602
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
603
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
604
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
605
+
606
+ embedding_output = self.embeddings(
607
+ pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
608
+ )
609
+
610
+ encoder_outputs = self.encoder(
611
+ embedding_output,
612
+ head_mask=head_mask,
613
+ output_attentions=output_attentions,
614
+ output_hidden_states=output_hidden_states,
615
+ return_dict=return_dict,
616
+ )
617
+ sequence_output = encoder_outputs[0]
618
+ sequence_output = self.layernorm(sequence_output)
619
+
620
+ if not return_dict:
621
+ head_outputs = (sequence_output,)
622
+ return head_outputs + encoder_outputs[1:]
623
+
624
+ return BaseModelOutput(
625
+ last_hidden_state=sequence_output,
626
+ hidden_states=encoder_outputs.hidden_states,
627
+ attentions=encoder_outputs.attentions,
628
+ )
629
+
630
+
631
+ # Caution: We don't have the weights for the classification head yet. This class
632
+ # is here for the users that are interested to fine-tune the base model (ViTMSNModel).
633
+ @add_start_docstrings(
634
+ """
635
+ ViTMSN Model with an image classification head on top e.g. for ImageNet.
636
+ """,
637
+ VIT_MSN_START_DOCSTRING,
638
+ )
639
+ class ViTMSNForImageClassification(ViTMSNPreTrainedModel):
640
+ def __init__(self, config: ViTMSNConfig) -> None:
641
+ super().__init__(config)
642
+
643
+ self.num_labels = config.num_labels
644
+ self.vit = ViTMSNModel(config)
645
+
646
+ # Classifier head
647
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
648
+
649
+ # Initialize weights and apply final processing
650
+ self.post_init()
651
+
652
+ @add_start_docstrings_to_model_forward(VIT_MSN_INPUTS_DOCSTRING)
653
+ @replace_return_docstrings(output_type=ImageClassifierOutput, config_class=_CONFIG_FOR_DOC)
654
+ def forward(
655
+ self,
656
+ pixel_values: Optional[torch.Tensor] = None,
657
+ head_mask: Optional[torch.Tensor] = None,
658
+ labels: Optional[torch.Tensor] = None,
659
+ output_attentions: Optional[bool] = None,
660
+ output_hidden_states: Optional[bool] = None,
661
+ interpolate_pos_encoding: Optional[bool] = None,
662
+ return_dict: Optional[bool] = None,
663
+ ) -> Union[tuple, ImageClassifierOutput]:
664
+ r"""
665
+ Returns:
666
+
667
+ Examples:
668
+
669
+ ```python
670
+ >>> from transformers import AutoImageProcessor, ViTMSNForImageClassification
671
+ >>> import torch
672
+ >>> from PIL import Image
673
+ >>> import requests
674
+
675
+ >>> torch.manual_seed(2) # doctest: +IGNORE_RESULT
676
+
677
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
678
+ >>> image = Image.open(requests.get(url, stream=True).raw)
679
+
680
+ >>> image_processor = AutoImageProcessor.from_pretrained("facebook/vit-msn-small")
681
+ >>> model = ViTMSNForImageClassification.from_pretrained("facebook/vit-msn-small")
682
+
683
+ >>> inputs = image_processor(images=image, return_tensors="pt")
684
+ >>> with torch.no_grad():
685
+ ... logits = model(**inputs).logits
686
+ >>> # model predicts one of the 1000 ImageNet classes
687
+ >>> predicted_label = logits.argmax(-1).item()
688
+ >>> print(model.config.id2label[predicted_label])
689
+ tusker
690
+ ```"""
691
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
692
+
693
+ outputs = self.vit(
694
+ pixel_values,
695
+ head_mask=head_mask,
696
+ output_attentions=output_attentions,
697
+ output_hidden_states=output_hidden_states,
698
+ interpolate_pos_encoding=interpolate_pos_encoding,
699
+ return_dict=return_dict,
700
+ )
701
+
702
+ sequence_output = outputs[0]
703
+
704
+ logits = self.classifier(sequence_output[:, 0, :])
705
+
706
+ loss = None
707
+ if labels is not None:
708
+ if self.config.problem_type is None:
709
+ if self.num_labels == 1:
710
+ self.config.problem_type = "regression"
711
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
712
+ self.config.problem_type = "single_label_classification"
713
+ else:
714
+ self.config.problem_type = "multi_label_classification"
715
+
716
+ if self.config.problem_type == "regression":
717
+ loss_fct = MSELoss()
718
+ if self.num_labels == 1:
719
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
720
+ else:
721
+ loss = loss_fct(logits, labels)
722
+ elif self.config.problem_type == "single_label_classification":
723
+ loss_fct = CrossEntropyLoss()
724
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
725
+ elif self.config.problem_type == "multi_label_classification":
726
+ loss_fct = BCEWithLogitsLoss()
727
+ loss = loss_fct(logits, labels)
728
+
729
+ if not return_dict:
730
+ output = (logits,) + outputs[1:]
731
+ return ((loss,) + output) if loss is not None else output
732
+
733
+ return ImageClassifierOutput(
734
+ loss=loss,
735
+ logits=logits,
736
+ hidden_states=outputs.hidden_states,
737
+ attentions=outputs.attentions,
738
+ )
739
+
740
+
741
+ __all__ = ["ViTMSNModel", "ViTMSNForImageClassification", "ViTMSNPreTrainedModel"]
docs/transformers/build/lib/transformers/models/vitdet/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_vitdet import *
22
+ from .modeling_vitdet import *
23
+ else:
24
+ import sys
25
+
26
+ _file = globals()["__file__"]
27
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/build/lib/transformers/models/vitdet/configuration_vitdet.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """VitDet model configuration"""
16
+
17
+ from ...configuration_utils import PretrainedConfig
18
+ from ...utils import logging
19
+ from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
20
+
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+
25
+ class VitDetConfig(BackboneConfigMixin, PretrainedConfig):
26
+ r"""
27
+ This is the configuration class to store the configuration of a [`VitDetModel`]. It is used to instantiate an
28
+ VitDet model according to the specified arguments, defining the model architecture. Instantiating a configuration
29
+ with the defaults will yield a similar configuration to that of the VitDet
30
+ [google/vitdet-base-patch16-224](https://huggingface.co/google/vitdet-base-patch16-224) architecture.
31
+
32
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
33
+ documentation from [`PretrainedConfig`] for more information.
34
+
35
+ Args:
36
+ hidden_size (`int`, *optional*, defaults to 768):
37
+ Dimensionality of the encoder layers and the pooler layer.
38
+ num_hidden_layers (`int`, *optional*, defaults to 12):
39
+ Number of hidden layers in the Transformer encoder.
40
+ num_attention_heads (`int`, *optional*, defaults to 12):
41
+ Number of attention heads for each attention layer in the Transformer encoder.
42
+ mlp_ratio (`int`, *optional*, defaults to 4):
43
+ Ratio of mlp hidden dim to embedding dim.
44
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
45
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
46
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
47
+ dropout_prob (`float`, *optional*, defaults to 0.0):
48
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
49
+ initializer_range (`float`, *optional*, defaults to 0.02):
50
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
51
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
52
+ The epsilon used by the layer normalization layers.
53
+ image_size (`int`, *optional*, defaults to 224):
54
+ The size (resolution) of each image.
55
+ pretrain_image_size (`int`, *optional*, defaults to 224):
56
+ The size (resolution) of each image during pretraining.
57
+ patch_size (`int`, *optional*, defaults to 16):
58
+ The size (resolution) of each patch.
59
+ num_channels (`int`, *optional*, defaults to 3):
60
+ The number of input channels.
61
+ qkv_bias (`bool`, *optional*, defaults to `True`):
62
+ Whether to add a bias to the queries, keys and values.
63
+ drop_path_rate (`float`, *optional*, defaults to 0.0):
64
+ Stochastic depth rate.
65
+ window_block_indices (`List[int]`, *optional*, defaults to `[]`):
66
+ List of indices of blocks that should have window attention instead of regular global self-attention.
67
+ residual_block_indices (`List[int]`, *optional*, defaults to `[]`):
68
+ List of indices of blocks that should have an extra residual block after the MLP.
69
+ use_absolute_position_embeddings (`bool`, *optional*, defaults to `True`):
70
+ Whether to add absolute position embeddings to the patch embeddings.
71
+ use_relative_position_embeddings (`bool`, *optional*, defaults to `False`):
72
+ Whether to add relative position embeddings to the attention maps.
73
+ window_size (`int`, *optional*, defaults to 0):
74
+ The size of the attention window.
75
+ out_features (`List[str]`, *optional*):
76
+ If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
77
+ (depending on how many stages the model has). If unset and `out_indices` is set, will default to the
78
+ corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the
79
+ same order as defined in the `stage_names` attribute.
80
+ out_indices (`List[int]`, *optional*):
81
+ If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
82
+ many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
83
+ If unset and `out_features` is unset, will default to the last stage. Must be in the
84
+ same order as defined in the `stage_names` attribute.
85
+
86
+ Example:
87
+
88
+ ```python
89
+ >>> from transformers import VitDetConfig, VitDetModel
90
+
91
+ >>> # Initializing a VitDet configuration
92
+ >>> configuration = VitDetConfig()
93
+
94
+ >>> # Initializing a model (with random weights) from the configuration
95
+ >>> model = VitDetModel(configuration)
96
+
97
+ >>> # Accessing the model configuration
98
+ >>> configuration = model.config
99
+ ```"""
100
+
101
+ model_type = "vitdet"
102
+
103
+ def __init__(
104
+ self,
105
+ hidden_size=768,
106
+ num_hidden_layers=12,
107
+ num_attention_heads=12,
108
+ mlp_ratio=4,
109
+ hidden_act="gelu",
110
+ dropout_prob=0.0,
111
+ initializer_range=0.02,
112
+ layer_norm_eps=1e-6,
113
+ image_size=224,
114
+ pretrain_image_size=224,
115
+ patch_size=16,
116
+ num_channels=3,
117
+ qkv_bias=True,
118
+ drop_path_rate=0.0,
119
+ window_block_indices=[],
120
+ residual_block_indices=[],
121
+ use_absolute_position_embeddings=True,
122
+ use_relative_position_embeddings=False,
123
+ window_size=0,
124
+ out_features=None,
125
+ out_indices=None,
126
+ **kwargs,
127
+ ):
128
+ super().__init__(**kwargs)
129
+
130
+ self.hidden_size = hidden_size
131
+ self.num_hidden_layers = num_hidden_layers
132
+ self.num_attention_heads = num_attention_heads
133
+ self.mlp_ratio = mlp_ratio
134
+ self.hidden_act = hidden_act
135
+ self.dropout_prob = dropout_prob
136
+ self.initializer_range = initializer_range
137
+ self.layer_norm_eps = layer_norm_eps
138
+ self.image_size = image_size
139
+ self.pretrain_image_size = pretrain_image_size
140
+ self.patch_size = patch_size
141
+ self.num_channels = num_channels
142
+ self.qkv_bias = qkv_bias
143
+ self.drop_path_rate = drop_path_rate
144
+ self.window_block_indices = window_block_indices
145
+ self.residual_block_indices = residual_block_indices
146
+ self.use_absolute_position_embeddings = use_absolute_position_embeddings
147
+ self.use_relative_position_embeddings = use_relative_position_embeddings
148
+ self.window_size = window_size
149
+
150
+ self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, self.num_hidden_layers + 1)]
151
+ self._out_features, self._out_indices = get_aligned_output_features_output_indices(
152
+ out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
153
+ )
154
+
155
+
156
+ __all__ = ["VitDetConfig"]
docs/transformers/build/lib/transformers/models/vitdet/modeling_vitdet.py ADDED
@@ -0,0 +1,883 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch ViTDet backbone."""
16
+
17
+ import collections.abc
18
+ import math
19
+ from typing import Dict, List, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.utils.checkpoint
23
+ from torch import nn
24
+
25
+ from ...activations import ACT2FN
26
+ from ...modeling_outputs import BackboneOutput, BaseModelOutput
27
+ from ...modeling_utils import PreTrainedModel
28
+ from ...utils import (
29
+ add_start_docstrings,
30
+ add_start_docstrings_to_model_forward,
31
+ logging,
32
+ replace_return_docstrings,
33
+ )
34
+ from ...utils.backbone_utils import BackboneMixin
35
+ from .configuration_vitdet import VitDetConfig
36
+
37
+
38
+ logger = logging.get_logger(__name__)
39
+
40
+ # General docstring
41
+ _CONFIG_FOR_DOC = "VitDetConfig"
42
+
43
+
44
+ class VitDetEmbeddings(nn.Module):
45
+ """
46
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
47
+ `hidden_states` (patch embeddings) to be consumed by a Transformer.
48
+ """
49
+
50
+ def __init__(self, config):
51
+ super().__init__()
52
+ image_size, patch_size = config.pretrain_image_size, config.patch_size
53
+ num_channels, hidden_size = config.num_channels, config.hidden_size
54
+
55
+ image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
56
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
57
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
58
+ self.image_size = image_size
59
+ self.patch_size = patch_size
60
+ self.num_channels = num_channels
61
+ self.num_patches = num_patches
62
+
63
+ if config.use_absolute_position_embeddings:
64
+ # Initialize absolute positional embedding with pretrain image size.
65
+ num_positions = num_patches + 1
66
+ self.position_embeddings = nn.Parameter(torch.zeros(1, num_positions, config.hidden_size))
67
+ else:
68
+ self.position_embeddings = None
69
+
70
+ self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
71
+
72
+ def get_absolute_positions(self, abs_pos_embeddings, has_cls_token, height, width):
73
+ """
74
+ Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token dimension for the
75
+ original embeddings.
76
+
77
+ Args:
78
+ abs_pos_embeddings (`torch.Tensor`):
79
+ Absolute positional embeddings with (1, num_position, num_channels).
80
+ has_cls_token (`bool`):
81
+ If true, has 1 embedding in abs_pos_embeddings for cls token.
82
+ height (`int`):
83
+ Height of input image tokens.
84
+ width (`int`):
85
+ Width of input image tokens.
86
+
87
+ Returns:
88
+ Absolute positional embeddings after processing with shape (1, height, width, num_channels)
89
+ """
90
+ if has_cls_token:
91
+ abs_pos_embeddings = abs_pos_embeddings[:, 1:]
92
+ num_position = abs_pos_embeddings.shape[1]
93
+ size = int(math.sqrt(num_position)) # This is a constant and can be recorded as such in the ONNX export.
94
+ if size * size != num_position:
95
+ raise ValueError("Absolute position embeddings must be a square number.")
96
+
97
+ if torch.jit.is_tracing() or (size != height or size != width):
98
+ # nn.functional.interpolate is a noop in case size == height and size == width - we need to always capture this path with jit.trace.
99
+ new_abs_pos_embeddings = nn.functional.interpolate(
100
+ abs_pos_embeddings.reshape(1, size, size, -1).permute(0, 3, 1, 2),
101
+ size=(height, width),
102
+ mode="bicubic",
103
+ align_corners=False,
104
+ )
105
+
106
+ return new_abs_pos_embeddings.permute(0, 2, 3, 1)
107
+ else:
108
+ return abs_pos_embeddings.reshape(1, height, width, -1)
109
+
110
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
111
+ num_channels = pixel_values.shape[1]
112
+ if num_channels != self.num_channels:
113
+ raise ValueError(
114
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
115
+ f" Expected {self.num_channels} but got {num_channels}."
116
+ )
117
+ embeddings = self.projection(pixel_values)
118
+
119
+ if self.position_embeddings is not None:
120
+ # (batch_size, num_channels, height, width) -> (batch_size, height, width, num_channels)
121
+ embeddings = embeddings.permute(0, 2, 3, 1)
122
+ # add position embeddings
123
+ embeddings = embeddings + self.get_absolute_positions(
124
+ self.position_embeddings, True, embeddings.shape[1], embeddings.shape[2]
125
+ )
126
+ # (batch_size, height, width, num_channels) -> (batch_size, num_channels, height, width)
127
+ embeddings = embeddings.permute(0, 3, 1, 2)
128
+
129
+ return embeddings
130
+
131
+
132
+ @torch.jit.script_if_tracing # nn.functional.interpolate's `size` needs to be dynamic.
133
+ def get_rel_pos(q_size, k_size, rel_pos):
134
+ """
135
+ Get relative positional embeddings according to the relative positions of query and key sizes.
136
+
137
+ Args:
138
+ q_size (`int`):
139
+ Size of query q.
140
+ k_size (`int`):
141
+ Size of key k.
142
+ rel_pos (`torch.Tensor`):
143
+ Relative position embeddings (num_embeddings, num_channels).
144
+
145
+ Returns:
146
+ Extracted positional embeddings according to relative positions.
147
+ """
148
+ max_rel_dist = int(2 * max(q_size, k_size) - 1)
149
+ # Interpolate rel pos if needed.
150
+ if rel_pos.shape[0] != max_rel_dist:
151
+ # Interpolate rel position embeddings.
152
+ rel_pos_resized = nn.functional.interpolate(
153
+ rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
154
+ size=max_rel_dist,
155
+ mode="linear",
156
+ )
157
+ rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
158
+ else:
159
+ rel_pos_resized = rel_pos
160
+
161
+ # Scale the coords with short length if shapes for q and k are different.
162
+ q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
163
+ k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
164
+ relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
165
+
166
+ return rel_pos_resized[relative_coords.long()]
167
+
168
+
169
+ def add_decomposed_relative_positions(attn, queries, rel_pos_h, rel_pos_w, q_size, k_size):
170
+ """
171
+ Calculate decomposed Relative Positional Embeddings as introduced in
172
+ [MViT2](https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py).
173
+
174
+ Args:
175
+ attn (`torch.Tensor`):
176
+ Attention map.
177
+ queries (`torch.Tensor`):
178
+ Query q in the attention layer with shape (batch_size, queries_height * queries_width, num_channels).
179
+ rel_pos_h (`torch.Tensor`):
180
+ Relative position embeddings (Lh, num_channels) for height axis.
181
+ rel_pos_w (`torch.Tensor`):
182
+ Relative position embeddings (Lw, num_channels) for width axis.
183
+ q_size (`Tuple[int]`):
184
+ Spatial sequence size of query q with (queries_height, queries_width).
185
+ k_size (`Tuple[int]`):
186
+ Spatial sequence size of key k with (keys_height, keys_width).
187
+
188
+ Returns:
189
+ attn (Tensor): attention map with added relative positional embeddings.
190
+ """
191
+ queries_height, queries_width = q_size
192
+ keys_height, keys_width = k_size
193
+ relative_height = get_rel_pos(queries_height, keys_height, rel_pos_h)
194
+ relative_width = get_rel_pos(queries_width, keys_width, rel_pos_w)
195
+
196
+ batch_size, _, dim = queries.shape
197
+ r_q = queries.reshape(batch_size, queries_height, queries_width, dim)
198
+ relative_height = torch.einsum("bhwc,hkc->bhwk", r_q, relative_height)
199
+ relative_weight = torch.einsum("bhwc,wkc->bhwk", r_q, relative_width)
200
+
201
+ attn = (
202
+ attn.view(batch_size, queries_height, queries_width, keys_height, keys_width)
203
+ + relative_height[:, :, :, :, None]
204
+ + relative_weight[:, :, :, None, :]
205
+ ).view(batch_size, queries_height * queries_width, keys_height * keys_width)
206
+
207
+ return attn
208
+
209
+
210
+ class VitDetAttention(nn.Module):
211
+ """Multi-head Attention block with relative position embeddings."""
212
+
213
+ def __init__(self, config, input_size=None):
214
+ """
215
+ Args:
216
+ config (`VitDetConfig`):
217
+ Model configuration.
218
+ input_size (`Tuple[int]`, *optional*):
219
+ Input resolution, only required in case relative position embeddings are added.
220
+ """
221
+ super().__init__()
222
+
223
+ dim = config.hidden_size
224
+ num_heads = config.num_attention_heads
225
+
226
+ self.num_heads = num_heads
227
+ head_dim = dim // num_heads
228
+ self.scale = head_dim**-0.5
229
+
230
+ self.qkv = nn.Linear(dim, dim * 3, bias=config.qkv_bias)
231
+ self.proj = nn.Linear(dim, dim)
232
+
233
+ self.use_relative_position_embeddings = config.use_relative_position_embeddings
234
+ if self.use_relative_position_embeddings:
235
+ # initialize relative positional embeddings
236
+ self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
237
+ self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
238
+
239
+ def forward(self, hidden_state, output_attentions=False):
240
+ batch_size, height, width, _ = hidden_state.shape
241
+ # qkv with shape (3, batch_size, num_heads, height * width, num_channels)
242
+ qkv = self.qkv(hidden_state).reshape(batch_size, height * width, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
243
+ # queries, keys and values have shape (batch_size * num_heads, height * width, num_channels)
244
+ queries, keys, values = qkv.reshape(3, batch_size * self.num_heads, height * width, -1).unbind(0)
245
+
246
+ attention_scores = (queries * self.scale) @ keys.transpose(-2, -1)
247
+
248
+ if self.use_relative_position_embeddings:
249
+ attention_scores = add_decomposed_relative_positions(
250
+ attention_scores, queries, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
251
+ )
252
+
253
+ attention_probs = attention_scores.softmax(dim=-1)
254
+
255
+ hidden_state = attention_probs @ values
256
+ hidden_state = hidden_state.view(batch_size, self.num_heads, height, width, -1)
257
+ hidden_state = hidden_state.permute(0, 2, 3, 1, 4)
258
+ hidden_state = hidden_state.reshape(batch_size, height, width, -1)
259
+ hidden_state = self.proj(hidden_state)
260
+
261
+ if output_attentions:
262
+ attention_probs = attention_probs.reshape(
263
+ batch_size, self.num_heads, attention_probs.shape[-2], attention_probs.shape[-1]
264
+ )
265
+ outputs = (hidden_state, attention_probs)
266
+ else:
267
+ outputs = (hidden_state,)
268
+
269
+ return outputs
270
+
271
+
272
+ # Copied from transformers.models.beit.modeling_beit.drop_path
273
+ def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
274
+ """
275
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
276
+
277
+ Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
278
+ however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
279
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
280
+ layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
281
+ argument.
282
+ """
283
+ if drop_prob == 0.0 or not training:
284
+ return input
285
+ keep_prob = 1 - drop_prob
286
+ shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
287
+ random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
288
+ random_tensor.floor_() # binarize
289
+ output = input.div(keep_prob) * random_tensor
290
+ return output
291
+
292
+
293
+ # Copied from transformers.models.beit.modeling_beit.BeitDropPath
294
+ class VitDetDropPath(nn.Module):
295
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
296
+
297
+ def __init__(self, drop_prob: Optional[float] = None) -> None:
298
+ super().__init__()
299
+ self.drop_prob = drop_prob
300
+
301
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
302
+ return drop_path(hidden_states, self.drop_prob, self.training)
303
+
304
+ def extra_repr(self) -> str:
305
+ return "p={}".format(self.drop_prob)
306
+
307
+
308
+ class VitDetLayerNorm(nn.Module):
309
+ """
310
+ A LayerNorm variant, popularized by Transformers, that performs point-wise mean and variance normalization over the
311
+ channel dimension for inputs that have shape (batch_size, channels, height, width).
312
+ https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119
313
+ """
314
+
315
+ def __init__(self, normalized_shape, eps=1e-6):
316
+ super().__init__()
317
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
318
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
319
+ self.eps = eps
320
+ self.normalized_shape = (normalized_shape,)
321
+
322
+ def forward(self, x):
323
+ u = x.mean(1, keepdim=True)
324
+ s = (x - u).pow(2).mean(1, keepdim=True)
325
+ x = (x - u) / torch.sqrt(s + self.eps)
326
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
327
+ return x
328
+
329
+
330
+ class VitDetResBottleneckBlock(nn.Module):
331
+ """
332
+ The standard bottleneck residual block without the last activation layer. It contains 3 conv layers with kernels
333
+ 1x1, 3x3, 1x1.
334
+ """
335
+
336
+ def __init__(self, config, in_channels, out_channels, bottleneck_channels):
337
+ """
338
+ Args:
339
+ config (`VitDetConfig`):
340
+ Model configuration.
341
+ in_channels (`int`):
342
+ Number of input channels.
343
+ out_channels (`int`):
344
+ Number of output channels.
345
+ bottleneck_channels (`int`):
346
+ Number of output channels for the 3x3 "bottleneck" conv layers.
347
+ """
348
+ super().__init__()
349
+ self.conv1 = nn.Conv2d(in_channels, bottleneck_channels, 1, bias=False)
350
+ self.norm1 = VitDetLayerNorm(bottleneck_channels)
351
+ self.act1 = ACT2FN[config.hidden_act]
352
+
353
+ self.conv2 = nn.Conv2d(bottleneck_channels, bottleneck_channels, 3, padding=1, bias=False)
354
+ self.norm2 = VitDetLayerNorm(bottleneck_channels)
355
+ self.act2 = ACT2FN[config.hidden_act]
356
+
357
+ self.conv3 = nn.Conv2d(bottleneck_channels, out_channels, 1, bias=False)
358
+ self.norm3 = VitDetLayerNorm(out_channels)
359
+
360
+ def forward(self, x):
361
+ out = x
362
+ for layer in self.children():
363
+ out = layer(out)
364
+
365
+ out = x + out
366
+ return out
367
+
368
+
369
+ class VitDetMlp(nn.Module):
370
+ def __init__(self, config, in_features: int, hidden_features: int) -> None:
371
+ super().__init__()
372
+ self.fc1 = nn.Linear(in_features, hidden_features)
373
+ self.act = ACT2FN[config.hidden_act]
374
+ self.fc2 = nn.Linear(hidden_features, in_features)
375
+ self.drop = nn.Dropout(config.dropout_prob)
376
+
377
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
378
+ x = self.fc1(x)
379
+ x = self.act(x)
380
+ x = self.drop(x)
381
+ x = self.fc2(x)
382
+ x = self.drop(x)
383
+
384
+ return x
385
+
386
+
387
+ def window_partition(hidden_state, window_size):
388
+ """
389
+ Partition into non-overlapping windows with padding if needed.
390
+
391
+ Args:
392
+ hidden_state (`torch.Tensor`):
393
+ Input tokens with [batch_size, height, width, num_channels].
394
+ window_size (`int`):
395
+ Window size.
396
+
397
+ Returns:
398
+ `tuple(torch.FloatTensor)` comprising various elements:
399
+ - windows: windows after partition with [batch_size * num_windows, window_size, window_size, num_channels].
400
+ - (padded_height, padded_width): padded height and width before partition
401
+ """
402
+ batch_size, height, width, num_channels = hidden_state.shape
403
+
404
+ pad_height = (window_size - height % window_size) % window_size
405
+ pad_width = (window_size - width % window_size) % window_size
406
+
407
+ # Noop in case pad_width == 0 and pad_height == 0.
408
+ hidden_state = nn.functional.pad(hidden_state, (0, 0, 0, pad_width, 0, pad_height))
409
+
410
+ padded_height, padded_width = height + pad_height, width + pad_width
411
+
412
+ hidden_state = hidden_state.view(
413
+ batch_size, padded_height // window_size, window_size, padded_width // window_size, window_size, num_channels
414
+ )
415
+ windows = hidden_state.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels)
416
+ return windows, (padded_height, padded_width)
417
+
418
+
419
+ def window_unpartition(windows, window_size, pad_height_width, height_width):
420
+ """
421
+ Window unpartition into original sequences and removing padding.
422
+
423
+ Args:
424
+ windows (`torch.Tensor`):
425
+ Input tokens with [batch_size * num_windows, window_size, window_size, num_channels].
426
+ window_size (`int`):
427
+ Window size.
428
+ pad_height_width (`Tuple[int]`):
429
+ Padded height and width (padded_height, padded_width).
430
+ height_width (`Tuple[int]`):
431
+ Original height and width before padding.
432
+
433
+ Returns:
434
+ hidden_state: unpartitioned sequences with [batch_size, height, width, num_channels].
435
+ """
436
+ padded_height, padded_width = pad_height_width
437
+ height, width = height_width
438
+ batch_size = windows.shape[0] // (padded_height * padded_width // window_size // window_size)
439
+ hidden_state = windows.view(
440
+ batch_size, padded_height // window_size, padded_width // window_size, window_size, window_size, -1
441
+ )
442
+ hidden_state = hidden_state.permute(0, 1, 3, 2, 4, 5).contiguous()
443
+ hidden_state = hidden_state.view(batch_size, padded_height, padded_width, -1)
444
+
445
+ # We always have height <= padded_height and width <= padded_width
446
+ hidden_state = hidden_state[:, :height, :width, :].contiguous()
447
+ return hidden_state
448
+
449
+
450
+ class VitDetLayer(nn.Module):
451
+ """This corresponds to the Block class in the original implementation."""
452
+
453
+ def __init__(
454
+ self, config: VitDetConfig, drop_path_rate: float = 0, window_size: int = 0, use_residual_block: bool = False
455
+ ) -> None:
456
+ super().__init__()
457
+
458
+ dim = config.hidden_size
459
+
460
+ image_size = config.image_size
461
+ image_size = image_size if isinstance(image_size, (list, tuple)) else (image_size, image_size)
462
+
463
+ patch_size = config.patch_size
464
+ patch_size = patch_size if isinstance(patch_size, (list, tuple)) else (patch_size, patch_size)
465
+
466
+ input_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
467
+ self.norm1 = nn.LayerNorm(dim, eps=config.layer_norm_eps)
468
+ self.attention = VitDetAttention(
469
+ config, input_size=input_size if window_size == 0 else (window_size, window_size)
470
+ )
471
+
472
+ self.drop_path = VitDetDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
473
+ self.norm2 = nn.LayerNorm(dim, eps=config.layer_norm_eps)
474
+ self.mlp = VitDetMlp(config=config, in_features=dim, hidden_features=int(dim * config.mlp_ratio))
475
+
476
+ self.window_size = window_size
477
+
478
+ self.use_residual_block = use_residual_block
479
+ if self.use_residual_block:
480
+ # Use a residual block with bottleneck channel as dim // 2
481
+ self.residual = VitDetResBottleneckBlock(
482
+ config=config,
483
+ in_channels=dim,
484
+ out_channels=dim,
485
+ bottleneck_channels=dim // 2,
486
+ )
487
+
488
+ def forward(
489
+ self,
490
+ hidden_states: torch.Tensor,
491
+ head_mask: Optional[torch.Tensor] = None,
492
+ output_attentions: bool = False,
493
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
494
+ hidden_states = hidden_states.permute(0, 2, 3, 1)
495
+
496
+ shortcut = hidden_states
497
+
498
+ hidden_states = self.norm1(hidden_states)
499
+
500
+ # Window partition
501
+ if self.window_size > 0:
502
+ height, width = hidden_states.shape[1], hidden_states.shape[2]
503
+ hidden_states, pad_height_width = window_partition(hidden_states, self.window_size)
504
+
505
+ self_attention_outputs = self.attention(
506
+ hidden_states,
507
+ output_attentions=output_attentions,
508
+ )
509
+ hidden_states = self_attention_outputs[0]
510
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
511
+
512
+ # Reverse window partition
513
+ if self.window_size > 0:
514
+ hidden_states = window_unpartition(hidden_states, self.window_size, pad_height_width, (height, width))
515
+
516
+ # first residual connection
517
+ hidden_states = shortcut + self.drop_path(hidden_states)
518
+
519
+ hidden_states = hidden_states + self.drop_path(self.mlp(self.norm2(hidden_states)))
520
+
521
+ hidden_states = hidden_states.permute(0, 3, 1, 2)
522
+
523
+ if self.use_residual_block:
524
+ hidden_states = self.residual(hidden_states)
525
+
526
+ outputs = (hidden_states,) + outputs
527
+
528
+ return outputs
529
+
530
+
531
+ class VitDetEncoder(nn.Module):
532
+ def __init__(self, config: VitDetConfig) -> None:
533
+ super().__init__()
534
+ self.config = config
535
+ depth = config.num_hidden_layers
536
+
537
+ # stochastic depth decay rule
538
+ drop_path_rate = [x.item() for x in torch.linspace(0, config.drop_path_rate, depth, device="cpu")]
539
+
540
+ layers = []
541
+ for i in range(depth):
542
+ layers.append(
543
+ VitDetLayer(
544
+ config,
545
+ drop_path_rate=drop_path_rate[i],
546
+ window_size=config.window_size if i in config.window_block_indices else 0,
547
+ use_residual_block=i in config.residual_block_indices,
548
+ )
549
+ )
550
+
551
+ self.layer = nn.ModuleList(layers)
552
+ self.gradient_checkpointing = False
553
+
554
+ def forward(
555
+ self,
556
+ hidden_states: torch.Tensor,
557
+ head_mask: Optional[torch.Tensor] = None,
558
+ output_attentions: bool = False,
559
+ output_hidden_states: bool = False,
560
+ return_dict: bool = True,
561
+ ) -> Union[tuple, BaseModelOutput]:
562
+ all_hidden_states = () if output_hidden_states else None
563
+ all_self_attentions = () if output_attentions else None
564
+
565
+ for i, layer_module in enumerate(self.layer):
566
+ if output_hidden_states:
567
+ all_hidden_states = all_hidden_states + (hidden_states,)
568
+
569
+ layer_head_mask = head_mask[i] if head_mask is not None else None
570
+
571
+ if self.gradient_checkpointing and self.training:
572
+ layer_outputs = self._gradient_checkpointing_func(
573
+ layer_module.__call__,
574
+ hidden_states,
575
+ layer_head_mask,
576
+ output_attentions,
577
+ )
578
+ else:
579
+ layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
580
+
581
+ hidden_states = layer_outputs[0]
582
+
583
+ if output_attentions:
584
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
585
+
586
+ if output_hidden_states:
587
+ all_hidden_states = all_hidden_states + (hidden_states,)
588
+
589
+ if not return_dict:
590
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
591
+ return BaseModelOutput(
592
+ last_hidden_state=hidden_states,
593
+ hidden_states=all_hidden_states,
594
+ attentions=all_self_attentions,
595
+ )
596
+
597
+
598
+ def caffe2_msra_fill(module: nn.Module) -> None:
599
+ """
600
+ Initialize `module.weight` using the "MSRAFill" implemented in Caffe2. Also initializes `module.bias` to 0.
601
+
602
+ Source: https://detectron2.readthedocs.io/en/latest/_modules/fvcore/nn/weight_init.html.
603
+
604
+ Args:
605
+ module (torch.nn.Module): module to initialize.
606
+ """
607
+ nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
608
+ if module.bias is not None:
609
+ nn.init.constant_(module.bias, 0)
610
+
611
+
612
+ class VitDetPreTrainedModel(PreTrainedModel):
613
+ """
614
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
615
+ models.
616
+ """
617
+
618
+ config_class = VitDetConfig
619
+ base_model_prefix = "vitdet"
620
+ main_input_name = "pixel_values"
621
+ supports_gradient_checkpointing = True
622
+ _no_split_modules = []
623
+
624
+ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
625
+ """Initialize the weights"""
626
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
627
+ # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
628
+ # `trunc_normal_cpu` not implemented in `half` issues
629
+ module.weight.data = nn.init.trunc_normal_(
630
+ module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
631
+ ).to(module.weight.dtype)
632
+ if module.bias is not None:
633
+ module.bias.data.zero_()
634
+ elif isinstance(module, nn.LayerNorm):
635
+ module.bias.data.zero_()
636
+ module.weight.data.fill_(1.0)
637
+
638
+ elif isinstance(module, VitDetEmbeddings):
639
+ module.position_embeddings.data = nn.init.trunc_normal_(
640
+ module.position_embeddings.data.to(torch.float32),
641
+ mean=0.0,
642
+ std=self.config.initializer_range,
643
+ ).to(module.position_embeddings.dtype)
644
+
645
+ elif isinstance(module, VitDetAttention) and self.config.use_relative_position_embeddings:
646
+ module.rel_pos_h.data = nn.init.trunc_normal_(
647
+ module.rel_pos_h.data.to(torch.float32),
648
+ mean=0.0,
649
+ std=self.config.initializer_range,
650
+ )
651
+ module.rel_pos_w.data = nn.init.trunc_normal_(
652
+ module.rel_pos_w.data.to(torch.float32),
653
+ mean=0.0,
654
+ std=self.config.initializer_range,
655
+ )
656
+
657
+ elif isinstance(module, VitDetResBottleneckBlock):
658
+ for layer in [module.conv1, module.conv2, module.conv3]:
659
+ caffe2_msra_fill(layer)
660
+ for layer in [module.norm1, module.norm2]:
661
+ layer.weight.data.fill_(1.0)
662
+ layer.bias.data.zero_()
663
+ # zero init last norm layer.
664
+ module.norm3.weight.data.zero_()
665
+ module.norm3.bias.data.zero_()
666
+
667
+
668
+ VITDET_START_DOCSTRING = r"""
669
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
670
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
671
+ behavior.
672
+
673
+ Parameters:
674
+ config ([`VitDetConfig`]): Model configuration class with all the parameters of the model.
675
+ Initializing with a config file does not load the weights associated with the model, only the
676
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
677
+ """
678
+
679
+ VITDET_INPUTS_DOCSTRING = r"""
680
+ Args:
681
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
682
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`]
683
+ for details.
684
+
685
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
686
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
687
+
688
+ - 1 indicates the head is **not masked**,
689
+ - 0 indicates the head is **masked**.
690
+
691
+ output_attentions (`bool`, *optional*):
692
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
693
+ tensors for more detail.
694
+ output_hidden_states (`bool`, *optional*):
695
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
696
+ more detail.
697
+ return_dict (`bool`, *optional*):
698
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
699
+ """
700
+
701
+
702
+ @add_start_docstrings(
703
+ "The bare VitDet Transformer model outputting raw hidden-states without any specific head on top.",
704
+ VITDET_START_DOCSTRING,
705
+ )
706
+ class VitDetModel(VitDetPreTrainedModel):
707
+ def __init__(self, config: VitDetConfig):
708
+ super().__init__(config)
709
+ self.config = config
710
+
711
+ self.embeddings = VitDetEmbeddings(config)
712
+ self.encoder = VitDetEncoder(config)
713
+
714
+ # Initialize weights and apply final processing
715
+ self.post_init()
716
+
717
+ def get_input_embeddings(self) -> VitDetEmbeddings:
718
+ return self.embeddings.projection
719
+
720
+ def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
721
+ """
722
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
723
+ class PreTrainedModel
724
+ """
725
+ for layer, heads in heads_to_prune.items():
726
+ self.encoder.layer[layer].attention.prune_heads(heads)
727
+
728
+ @add_start_docstrings_to_model_forward(VITDET_INPUTS_DOCSTRING)
729
+ @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
730
+ def forward(
731
+ self,
732
+ pixel_values: Optional[torch.Tensor] = None,
733
+ head_mask: Optional[torch.Tensor] = None,
734
+ output_attentions: Optional[bool] = None,
735
+ output_hidden_states: Optional[bool] = None,
736
+ return_dict: Optional[bool] = None,
737
+ ) -> Union[Tuple, BaseModelOutput]:
738
+ """
739
+ Returns:
740
+
741
+ Examples:
742
+
743
+ ```python
744
+ >>> from transformers import VitDetConfig, VitDetModel
745
+ >>> import torch
746
+
747
+ >>> config = VitDetConfig()
748
+ >>> model = VitDetModel(config)
749
+
750
+ >>> pixel_values = torch.randn(1, 3, 224, 224)
751
+
752
+ >>> with torch.no_grad():
753
+ ... outputs = model(pixel_values)
754
+
755
+ >>> last_hidden_states = outputs.last_hidden_state
756
+ >>> list(last_hidden_states.shape)
757
+ [1, 768, 14, 14]
758
+ ```"""
759
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
760
+ output_hidden_states = (
761
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
762
+ )
763
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
764
+
765
+ if pixel_values is None:
766
+ raise ValueError("You have to specify pixel_values")
767
+
768
+ # Prepare head mask if needed
769
+ # 1.0 in head_mask indicate we keep the head
770
+ # attention_probs has shape bsz x n_heads x N x N
771
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
772
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
773
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
774
+
775
+ embedding_output = self.embeddings(pixel_values)
776
+
777
+ encoder_outputs = self.encoder(
778
+ embedding_output,
779
+ head_mask=head_mask,
780
+ output_attentions=output_attentions,
781
+ output_hidden_states=output_hidden_states,
782
+ return_dict=return_dict,
783
+ )
784
+ sequence_output = encoder_outputs[0]
785
+
786
+ if not return_dict:
787
+ return (sequence_output,) + encoder_outputs[1:]
788
+
789
+ return BaseModelOutput(
790
+ last_hidden_state=sequence_output,
791
+ hidden_states=encoder_outputs.hidden_states,
792
+ attentions=encoder_outputs.attentions,
793
+ )
794
+
795
+
796
+ @add_start_docstrings(
797
+ """
798
+ ViTDet backbone, to be used with frameworks like Mask R-CNN.
799
+ """,
800
+ VITDET_START_DOCSTRING,
801
+ )
802
+ class VitDetBackbone(VitDetPreTrainedModel, BackboneMixin):
803
+ def __init__(self, config):
804
+ super().__init__(config)
805
+ super()._init_backbone(config)
806
+
807
+ self.embeddings = VitDetEmbeddings(config)
808
+ self.encoder = VitDetEncoder(config)
809
+ self.num_features = [config.hidden_size for _ in range(config.num_hidden_layers + 1)]
810
+
811
+ # initialize weights and apply final processing
812
+ self.post_init()
813
+
814
+ def get_input_embeddings(self) -> VitDetEmbeddings:
815
+ return self.embeddings.projection
816
+
817
+ @add_start_docstrings_to_model_forward(VITDET_INPUTS_DOCSTRING)
818
+ @replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
819
+ def forward(
820
+ self,
821
+ pixel_values: torch.Tensor,
822
+ output_hidden_states: Optional[bool] = None,
823
+ output_attentions: Optional[bool] = None,
824
+ return_dict: Optional[bool] = None,
825
+ ) -> BackboneOutput:
826
+ """
827
+ Returns:
828
+
829
+ Examples:
830
+
831
+ ```python
832
+ >>> from transformers import VitDetConfig, VitDetBackbone
833
+ >>> import torch
834
+
835
+ >>> config = VitDetConfig()
836
+ >>> model = VitDetBackbone(config)
837
+
838
+ >>> pixel_values = torch.randn(1, 3, 224, 224)
839
+
840
+ >>> with torch.no_grad():
841
+ ... outputs = model(pixel_values)
842
+
843
+ >>> feature_maps = outputs.feature_maps
844
+ >>> list(feature_maps[-1].shape)
845
+ [1, 768, 14, 14]
846
+ ```"""
847
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
848
+ output_hidden_states = (
849
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
850
+ )
851
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
852
+
853
+ embedding_output = self.embeddings(pixel_values)
854
+
855
+ outputs = self.encoder(
856
+ embedding_output,
857
+ output_hidden_states=True,
858
+ output_attentions=output_attentions,
859
+ return_dict=return_dict,
860
+ )
861
+
862
+ hidden_states = outputs.hidden_states if return_dict else outputs[1]
863
+
864
+ feature_maps = ()
865
+ for stage, hidden_state in zip(self.stage_names, hidden_states):
866
+ if stage in self.out_features:
867
+ feature_maps += (hidden_state,)
868
+
869
+ if not return_dict:
870
+ if output_hidden_states:
871
+ output = (feature_maps,) + outputs[1:]
872
+ else:
873
+ output = (feature_maps,) + outputs[2:]
874
+ return output
875
+
876
+ return BackboneOutput(
877
+ feature_maps=feature_maps,
878
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
879
+ attentions=outputs.attentions,
880
+ )
881
+
882
+
883
+ __all__ = ["VitDetModel", "VitDetPreTrainedModel", "VitDetBackbone"]
docs/transformers/build/lib/transformers/models/vitmatte/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_vitmatte import *
22
+ from .image_processing_vitmatte import *
23
+ from .modeling_vitmatte import *
24
+ else:
25
+ import sys
26
+
27
+ _file = globals()["__file__"]
28
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/build/lib/transformers/models/vitmatte/configuration_vitmatte.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """VitMatte model configuration"""
16
+
17
+ import copy
18
+ from typing import List
19
+
20
+ from ...configuration_utils import PretrainedConfig
21
+ from ...utils import logging
22
+ from ...utils.backbone_utils import verify_backbone_config_arguments
23
+ from ..auto.configuration_auto import CONFIG_MAPPING
24
+
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+
29
+ class VitMatteConfig(PretrainedConfig):
30
+ r"""
31
+ This is the configuration class to store the configuration of [`VitMatteForImageMatting`]. It is used to
32
+ instantiate a ViTMatte model according to the specified arguments, defining the model architecture. Instantiating a
33
+ configuration with the defaults will yield a similar configuration to that of the ViTMatte
34
+ [hustvl/vitmatte-small-composition-1k](https://huggingface.co/hustvl/vitmatte-small-composition-1k) architecture.
35
+
36
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
37
+ documentation from [`PretrainedConfig`] for more information.
38
+
39
+ Args:
40
+ backbone_config (`PretrainedConfig` or `dict`, *optional*, defaults to `VitDetConfig()`):
41
+ The configuration of the backbone model.
42
+ backbone (`str`, *optional*):
43
+ Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
44
+ will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
45
+ is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
46
+ use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
47
+ Whether to use pretrained weights for the backbone.
48
+ use_timm_backbone (`bool`, *optional*, defaults to `False`):
49
+ Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
50
+ library.
51
+ backbone_kwargs (`dict`, *optional*):
52
+ Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
53
+ e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
54
+ hidden_size (`int`, *optional*, defaults to 384):
55
+ The number of input channels of the decoder.
56
+ batch_norm_eps (`float`, *optional*, defaults to 1e-05):
57
+ The epsilon used by the batch norm layers.
58
+ initializer_range (`float`, *optional*, defaults to 0.02):
59
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
60
+ convstream_hidden_sizes (`List[int]`, *optional*, defaults to `[48, 96, 192]`):
61
+ The output channels of the ConvStream module.
62
+ fusion_hidden_sizes (`List[int]`, *optional*, defaults to `[256, 128, 64, 32]`):
63
+ The output channels of the Fusion blocks.
64
+
65
+ Example:
66
+
67
+ ```python
68
+ >>> from transformers import VitMatteConfig, VitMatteForImageMatting
69
+
70
+ >>> # Initializing a ViTMatte hustvl/vitmatte-small-composition-1k style configuration
71
+ >>> configuration = VitMatteConfig()
72
+
73
+ >>> # Initializing a model (with random weights) from the hustvl/vitmatte-small-composition-1k style configuration
74
+ >>> model = VitMatteForImageMatting(configuration)
75
+
76
+ >>> # Accessing the model configuration
77
+ >>> configuration = model.config
78
+ ```"""
79
+
80
+ model_type = "vitmatte"
81
+
82
+ def __init__(
83
+ self,
84
+ backbone_config: PretrainedConfig = None,
85
+ backbone=None,
86
+ use_pretrained_backbone=False,
87
+ use_timm_backbone=False,
88
+ backbone_kwargs=None,
89
+ hidden_size: int = 384,
90
+ batch_norm_eps: float = 1e-5,
91
+ initializer_range: float = 0.02,
92
+ convstream_hidden_sizes: List[int] = [48, 96, 192],
93
+ fusion_hidden_sizes: List[int] = [256, 128, 64, 32],
94
+ **kwargs,
95
+ ):
96
+ super().__init__(**kwargs)
97
+
98
+ if backbone_config is None and backbone is None:
99
+ logger.info("`backbone_config` is `None`. Initializing the config with the default `VitDet` backbone.")
100
+ backbone_config = CONFIG_MAPPING["vitdet"](out_features=["stage4"])
101
+ elif isinstance(backbone_config, dict):
102
+ backbone_model_type = backbone_config.get("model_type")
103
+ config_class = CONFIG_MAPPING[backbone_model_type]
104
+ backbone_config = config_class.from_dict(backbone_config)
105
+
106
+ verify_backbone_config_arguments(
107
+ use_timm_backbone=use_timm_backbone,
108
+ use_pretrained_backbone=use_pretrained_backbone,
109
+ backbone=backbone,
110
+ backbone_config=backbone_config,
111
+ backbone_kwargs=backbone_kwargs,
112
+ )
113
+
114
+ self.backbone_config = backbone_config
115
+ self.backbone = backbone
116
+ self.use_pretrained_backbone = use_pretrained_backbone
117
+ self.use_timm_backbone = use_timm_backbone
118
+ self.backbone_kwargs = backbone_kwargs
119
+ self.batch_norm_eps = batch_norm_eps
120
+ self.hidden_size = hidden_size
121
+ self.initializer_range = initializer_range
122
+ self.convstream_hidden_sizes = convstream_hidden_sizes
123
+ self.fusion_hidden_sizes = fusion_hidden_sizes
124
+
125
+ def to_dict(self):
126
+ """
127
+ Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. Returns:
128
+ `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
129
+ """
130
+ output = copy.deepcopy(self.__dict__)
131
+ output["backbone_config"] = self.backbone_config.to_dict()
132
+ output["model_type"] = self.__class__.model_type
133
+ return output
134
+
135
+
136
+ __all__ = ["VitMatteConfig"]
docs/transformers/build/lib/transformers/models/vitmatte/convert_vitmatte_to_hf.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Convert VitMatte checkpoints from the original repository.
16
+
17
+ URL: https://github.com/hustvl/ViTMatte
18
+ """
19
+
20
+ import argparse
21
+
22
+ import requests
23
+ import torch
24
+ from huggingface_hub import hf_hub_download
25
+ from PIL import Image
26
+
27
+ from transformers import VitDetConfig, VitMatteConfig, VitMatteForImageMatting, VitMatteImageProcessor
28
+
29
+
30
+ def get_config(model_name):
31
+ hidden_size = 384 if "small" in model_name else 768
32
+ num_attention_heads = 6 if "small" in model_name else 12
33
+
34
+ backbone_config = VitDetConfig(
35
+ num_channels=4,
36
+ image_size=512,
37
+ pretrain_image_size=224,
38
+ patch_size=16,
39
+ hidden_size=hidden_size,
40
+ num_attention_heads=num_attention_heads,
41
+ use_absolute_position_embeddings=True,
42
+ use_relative_position_embeddings=True,
43
+ window_size=14,
44
+ # 2, 5, 8, 11 for global attention
45
+ window_block_indices=[0, 1, 3, 4, 6, 7, 9, 10],
46
+ residual_block_indices=[2, 5, 8, 11],
47
+ out_features=["stage12"],
48
+ )
49
+
50
+ return VitMatteConfig(backbone_config=backbone_config, hidden_size=hidden_size)
51
+
52
+
53
+ # here we list all keys to be renamed (original name on the left, our name on the right)
54
+ def create_rename_keys(config):
55
+ rename_keys = []
56
+
57
+ # fmt: off
58
+ # stem
59
+ rename_keys.append(("backbone.pos_embed", "backbone.embeddings.position_embeddings"))
60
+ rename_keys.append(("backbone.patch_embed.proj.weight", "backbone.embeddings.projection.weight"))
61
+ rename_keys.append(("backbone.patch_embed.proj.bias", "backbone.embeddings.projection.bias"))
62
+ # fmt: on
63
+
64
+ return rename_keys
65
+
66
+
67
+ def rename_key(dct, old, new):
68
+ val = dct.pop(old)
69
+ dct[new] = val
70
+
71
+
72
+ def convert_vitmatte_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub):
73
+ config = get_config(model_name)
74
+
75
+ # load original state dict
76
+ model_name_to_filename = {
77
+ "vitmatte-small-composition-1k": "ViTMatte_S_Com.pth",
78
+ "vitmatte-base-composition-1k": "ViTMatte_B_Com.pth",
79
+ "vitmatte-small-distinctions-646": "ViTMatte_S_DIS.pth",
80
+ "vitmatte-base-distinctions-646": "ViTMatte_B_DIS.pth",
81
+ }
82
+
83
+ filename = model_name_to_filename[model_name]
84
+ filepath = hf_hub_download(repo_id="nielsr/vitmatte-checkpoints", filename=filename, repo_type="model")
85
+ state_dict = torch.load(filepath, map_location="cpu", weights_only=True)
86
+
87
+ # rename keys
88
+ for key in state_dict.copy().keys():
89
+ val = state_dict.pop(key)
90
+ if "backbone.blocks" in key:
91
+ key = key.replace("backbone.blocks", "backbone.encoder.layer")
92
+ if "attn" in key:
93
+ key = key.replace("attn", "attention")
94
+ if "fusion_blks" in key:
95
+ key = key.replace("fusion_blks", "fusion_blocks")
96
+ if "bn" in key:
97
+ key = key.replace("bn", "batch_norm")
98
+ state_dict[key] = val
99
+
100
+ # rename keys
101
+ rename_keys = create_rename_keys(config)
102
+ for src, dest in rename_keys:
103
+ rename_key(state_dict, src, dest)
104
+
105
+ # create model
106
+ processor = VitMatteImageProcessor()
107
+ model = VitMatteForImageMatting(config)
108
+ model.eval()
109
+
110
+ # load state dict
111
+ model.load_state_dict(state_dict)
112
+
113
+ # verify on dummy image + trimap
114
+ url = "https://github.com/hustvl/ViTMatte/blob/main/demo/bulb_rgb.png?raw=true"
115
+ image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
116
+ url = "https://github.com/hustvl/ViTMatte/blob/main/demo/bulb_trimap.png?raw=true"
117
+ trimap = Image.open(requests.get(url, stream=True).raw)
118
+
119
+ pixel_values = processor(images=image, trimaps=trimap.convert("L"), return_tensors="pt").pixel_values
120
+
121
+ with torch.no_grad():
122
+ alphas = model(pixel_values).alphas
123
+
124
+ if model_name == "vitmatte-small-composition-1k":
125
+ expected_slice = torch.tensor([[0.9977, 0.9987, 0.9990], [0.9980, 0.9998, 0.9998], [0.9983, 0.9998, 0.9998]])
126
+ elif model_name == "vitmatte-base-composition-1k":
127
+ expected_slice = torch.tensor([[0.9972, 0.9971, 0.9981], [0.9948, 0.9987, 0.9994], [0.9963, 0.9992, 0.9995]])
128
+ elif model_name == "vitmatte-small-distinctions-646":
129
+ expected_slice = torch.tensor([[0.9880, 0.9970, 0.9972], [0.9960, 0.9996, 0.9997], [0.9963, 0.9996, 0.9997]])
130
+ elif model_name == "vitmatte-base-distinctions-646":
131
+ expected_slice = torch.tensor([[0.9963, 0.9998, 0.9999], [0.9995, 1.0000, 1.0000], [0.9992, 0.9999, 1.0000]])
132
+
133
+ assert torch.allclose(alphas[0, 0, :3, :3], expected_slice, atol=1e-4)
134
+ print("Looks ok!")
135
+
136
+ if pytorch_dump_folder_path is not None:
137
+ print(f"Saving model and processor of {model_name} to {pytorch_dump_folder_path}")
138
+ model.save_pretrained(pytorch_dump_folder_path)
139
+ processor.save_pretrained(pytorch_dump_folder_path)
140
+
141
+ if push_to_hub:
142
+ print(f"Pushing model and processor for {model_name} to hub")
143
+ model.push_to_hub(f"hustvl/{model_name}")
144
+ processor.push_to_hub(f"hustvl/{model_name}")
145
+
146
+
147
+ if __name__ == "__main__":
148
+ parser = argparse.ArgumentParser()
149
+ # Required parameters
150
+ parser.add_argument(
151
+ "--model_name",
152
+ default="vitmatte-small-composition-1k",
153
+ type=str,
154
+ choices=[
155
+ "vitmatte-small-composition-1k",
156
+ "vitmatte-base-composition-1k",
157
+ "vitmatte-small-distinctions-646",
158
+ "vitmatte-base-distinctions-646",
159
+ ],
160
+ help="Name of the VitMatte model you'd like to convert.",
161
+ )
162
+ parser.add_argument(
163
+ "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
164
+ )
165
+ parser.add_argument(
166
+ "--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
167
+ )
168
+
169
+ args = parser.parse_args()
170
+ convert_vitmatte_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub)
docs/transformers/build/lib/transformers/models/vitmatte/image_processing_vitmatte.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Image processor class for ViTMatte."""
16
+
17
+ from typing import List, Optional, Union
18
+
19
+ import numpy as np
20
+
21
+ from ...image_processing_utils import BaseImageProcessor, BatchFeature
22
+ from ...image_transforms import pad, to_channel_dimension_format
23
+ from ...image_utils import (
24
+ IMAGENET_STANDARD_MEAN,
25
+ IMAGENET_STANDARD_STD,
26
+ ChannelDimension,
27
+ ImageInput,
28
+ get_image_size,
29
+ infer_channel_dimension_format,
30
+ is_scaled_image,
31
+ make_list_of_images,
32
+ to_numpy_array,
33
+ valid_images,
34
+ validate_preprocess_arguments,
35
+ )
36
+ from ...utils import TensorType, filter_out_non_signature_kwargs, logging
37
+
38
+
39
+ logger = logging.get_logger(__name__)
40
+
41
+
42
+ class VitMatteImageProcessor(BaseImageProcessor):
43
+ r"""
44
+ Constructs a ViTMatte image processor.
45
+
46
+ Args:
47
+ do_rescale (`bool`, *optional*, defaults to `True`):
48
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
49
+ parameter in the `preprocess` method.
50
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
51
+ Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
52
+ `preprocess` method.
53
+ do_normalize (`bool`, *optional*, defaults to `True`):
54
+ Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
55
+ method.
56
+ image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
57
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
58
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
59
+ image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
60
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
61
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
62
+ do_pad (`bool`, *optional*, defaults to `True`):
63
+ Whether to pad the image to make the width and height divisible by `size_divisibility`. Can be overridden
64
+ by the `do_pad` parameter in the `preprocess` method.
65
+ size_divisibility (`int`, *optional*, defaults to 32):
66
+ The width and height of the image will be padded to be divisible by this number.
67
+ """
68
+
69
+ model_input_names = ["pixel_values"]
70
+
71
+ def __init__(
72
+ self,
73
+ do_rescale: bool = True,
74
+ rescale_factor: Union[int, float] = 1 / 255,
75
+ do_normalize: bool = True,
76
+ image_mean: Optional[Union[float, List[float]]] = None,
77
+ image_std: Optional[Union[float, List[float]]] = None,
78
+ do_pad: bool = True,
79
+ size_divisibility: int = 32,
80
+ **kwargs,
81
+ ) -> None:
82
+ super().__init__(**kwargs)
83
+ self.do_rescale = do_rescale
84
+ self.do_normalize = do_normalize
85
+ self.do_pad = do_pad
86
+ self.rescale_factor = rescale_factor
87
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
88
+ self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
89
+ self.size_divisibility = size_divisibility
90
+
91
+ def pad_image(
92
+ self,
93
+ image: np.ndarray,
94
+ size_divisibility: int = 32,
95
+ data_format: Optional[Union[str, ChannelDimension]] = None,
96
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
97
+ ) -> np.ndarray:
98
+ """
99
+ Args:
100
+ image (`np.ndarray`):
101
+ Image to pad.
102
+ size_divisibility (`int`, *optional*, defaults to 32):
103
+ The width and height of the image will be padded to be divisible by this number.
104
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
105
+ The channel dimension format for the output image. Can be one of:
106
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
107
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
108
+ - Unset: Use the channel dimension format of the input image.
109
+ input_data_format (`ChannelDimension` or `str`, *optional*):
110
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
111
+ from the input image. Can be one of:
112
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
113
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
114
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
115
+ """
116
+ if input_data_format is None:
117
+ input_data_format = infer_channel_dimension_format(image)
118
+
119
+ height, width = get_image_size(image, input_data_format)
120
+
121
+ pad_height = 0 if height % size_divisibility == 0 else size_divisibility - height % size_divisibility
122
+ pad_width = 0 if width % size_divisibility == 0 else size_divisibility - width % size_divisibility
123
+ if pad_width + pad_height > 0:
124
+ padding = ((0, pad_height), (0, pad_width))
125
+ image = pad(image, padding=padding, data_format=data_format, input_data_format=input_data_format)
126
+
127
+ if data_format is not None:
128
+ image = to_channel_dimension_format(image, data_format, input_data_format)
129
+
130
+ return image
131
+
132
+ @filter_out_non_signature_kwargs()
133
+ def preprocess(
134
+ self,
135
+ images: ImageInput,
136
+ trimaps: ImageInput,
137
+ do_rescale: Optional[bool] = None,
138
+ rescale_factor: Optional[float] = None,
139
+ do_normalize: Optional[bool] = None,
140
+ image_mean: Optional[Union[float, List[float]]] = None,
141
+ image_std: Optional[Union[float, List[float]]] = None,
142
+ do_pad: Optional[bool] = None,
143
+ size_divisibility: Optional[int] = None,
144
+ return_tensors: Optional[Union[str, TensorType]] = None,
145
+ data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
146
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
147
+ ):
148
+ """
149
+ Preprocess an image or batch of images.
150
+
151
+ Args:
152
+ images (`ImageInput`):
153
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
154
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
155
+ trimaps (`ImageInput`):
156
+ Trimap to preprocess.
157
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
158
+ Whether to rescale the image values between [0 - 1].
159
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
160
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
161
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
162
+ Whether to normalize the image.
163
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
164
+ Image mean to use if `do_normalize` is set to `True`.
165
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
166
+ Image standard deviation to use if `do_normalize` is set to `True`.
167
+ do_pad (`bool`, *optional*, defaults to `self.do_pad`):
168
+ Whether to pad the image.
169
+ size_divisibility (`int`, *optional*, defaults to `self.size_divisibility`):
170
+ The size divisibility to pad the image to if `do_pad` is set to `True`.
171
+ return_tensors (`str` or `TensorType`, *optional*):
172
+ The type of tensors to return. Can be one of:
173
+ - Unset: Return a list of `np.ndarray`.
174
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
175
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
176
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
177
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
178
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
179
+ The channel dimension format for the output image. Can be one of:
180
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
181
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
182
+ - Unset: Use the channel dimension format of the input image.
183
+ input_data_format (`ChannelDimension` or `str`, *optional*):
184
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
185
+ from the input image. Can be one of:
186
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
187
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
188
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
189
+ """
190
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
191
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
192
+ do_pad = do_pad if do_pad is not None else self.do_pad
193
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
194
+ image_mean = image_mean if image_mean is not None else self.image_mean
195
+ image_std = image_std if image_std is not None else self.image_std
196
+ size_divisibility = size_divisibility if size_divisibility is not None else self.size_divisibility
197
+
198
+ images = make_list_of_images(images)
199
+ trimaps = make_list_of_images(trimaps, expected_ndims=2)
200
+
201
+ if not valid_images(trimaps):
202
+ raise ValueError(
203
+ "Invalid trimap type. Must be of type PIL.Image.Image, numpy.ndarray, "
204
+ "torch.Tensor, tf.Tensor or jax.ndarray."
205
+ )
206
+
207
+ if not valid_images(images):
208
+ raise ValueError(
209
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
210
+ "torch.Tensor, tf.Tensor or jax.ndarray."
211
+ )
212
+ validate_preprocess_arguments(
213
+ do_rescale=do_rescale,
214
+ rescale_factor=rescale_factor,
215
+ do_normalize=do_normalize,
216
+ image_mean=image_mean,
217
+ image_std=image_std,
218
+ do_pad=do_pad,
219
+ size_divisibility=size_divisibility,
220
+ )
221
+
222
+ # All transformations expect numpy arrays.
223
+ images = [to_numpy_array(image) for image in images]
224
+ trimaps = [to_numpy_array(trimap) for trimap in trimaps]
225
+
226
+ if do_rescale and is_scaled_image(images[0]):
227
+ logger.warning_once(
228
+ "It looks like you are trying to rescale already rescaled images. If the input"
229
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
230
+ )
231
+
232
+ if input_data_format is None:
233
+ # We assume that all images have the same channel dimension format.
234
+ input_data_format = infer_channel_dimension_format(images[0])
235
+
236
+ if do_rescale:
237
+ images = [
238
+ self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
239
+ for image in images
240
+ ]
241
+ trimaps = [
242
+ self.rescale(image=trimap, scale=rescale_factor, input_data_format=input_data_format)
243
+ for trimap in trimaps
244
+ ]
245
+
246
+ if do_normalize:
247
+ images = [
248
+ self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
249
+ for image in images
250
+ ]
251
+
252
+ # concatenate images and trimaps
253
+ images = [
254
+ np.concatenate([image, np.expand_dims(trimap, axis=-1)], axis=-1) for image, trimap in zip(images, trimaps)
255
+ ]
256
+
257
+ if do_pad:
258
+ images = [
259
+ self.pad_image(image, size_divisibility=size_divisibility, input_data_format=input_data_format)
260
+ for image in images
261
+ ]
262
+
263
+ images = [
264
+ to_channel_dimension_format(image=image, channel_dim=data_format, input_channel_dim=input_data_format)
265
+ for image in images
266
+ ]
267
+
268
+ data = {"pixel_values": images}
269
+ return BatchFeature(data=data, tensor_type=return_tensors)
270
+
271
+
272
+ __all__ = ["VitMatteImageProcessor"]
docs/transformers/build/lib/transformers/models/vitmatte/modeling_vitmatte.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 HUST-VL and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch ViTMatte model."""
16
+
17
+ from dataclasses import dataclass
18
+ from typing import Optional, Tuple
19
+
20
+ import torch
21
+ from torch import nn
22
+
23
+ from ...modeling_utils import PreTrainedModel
24
+ from ...utils import (
25
+ ModelOutput,
26
+ add_start_docstrings,
27
+ add_start_docstrings_to_model_forward,
28
+ replace_return_docstrings,
29
+ )
30
+ from ...utils.backbone_utils import load_backbone
31
+ from .configuration_vitmatte import VitMatteConfig
32
+
33
+
34
+ # General docstring
35
+ _CONFIG_FOR_DOC = "VitMatteConfig"
36
+
37
+
38
+ @dataclass
39
+ class ImageMattingOutput(ModelOutput):
40
+ """
41
+ Class for outputs of image matting models.
42
+
43
+ Args:
44
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
45
+ Loss.
46
+ alphas (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
47
+ Estimated alpha values.
48
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
49
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
50
+ one for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states
51
+ (also called feature maps) of the model at the output of each stage.
52
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
53
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size,
54
+ sequence_length)`.
55
+
56
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
57
+ heads.
58
+ """
59
+
60
+ loss: Optional[torch.FloatTensor] = None
61
+ alphas: Optional[torch.FloatTensor] = None
62
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
63
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
64
+
65
+
66
+ class VitMattePreTrainedModel(PreTrainedModel):
67
+ """
68
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
69
+ models.
70
+ """
71
+
72
+ config_class = VitMatteConfig
73
+ main_input_name = "pixel_values"
74
+ supports_gradient_checkpointing = True
75
+ _no_split_modules = []
76
+
77
+ def _init_weights(self, module):
78
+ if isinstance(module, nn.Conv2d):
79
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
80
+ if module.bias is not None:
81
+ module.bias.data.zero_()
82
+
83
+
84
+ class VitMatteBasicConv3x3(nn.Module):
85
+ """
86
+ Basic convolution layers including: Conv3x3, BatchNorm2d, ReLU layers.
87
+ """
88
+
89
+ def __init__(self, config, in_channels, out_channels, stride=2, padding=1):
90
+ super().__init__()
91
+ self.conv = nn.Conv2d(
92
+ in_channels=in_channels,
93
+ out_channels=out_channels,
94
+ kernel_size=3,
95
+ stride=stride,
96
+ padding=padding,
97
+ bias=False,
98
+ )
99
+ self.batch_norm = nn.BatchNorm2d(out_channels, eps=config.batch_norm_eps)
100
+ self.relu = nn.ReLU()
101
+
102
+ def forward(self, hidden_state):
103
+ hidden_state = self.conv(hidden_state)
104
+ hidden_state = self.batch_norm(hidden_state)
105
+ hidden_state = self.relu(hidden_state)
106
+
107
+ return hidden_state
108
+
109
+
110
+ class VitMatteConvStream(nn.Module):
111
+ """
112
+ Simple ConvStream containing a series of basic conv3x3 layers to extract detail features.
113
+ """
114
+
115
+ def __init__(self, config):
116
+ super().__init__()
117
+
118
+ # We use a default in-case there isn't a backbone config set. This is for backwards compatibility and
119
+ # to enable loading HF backbone models.
120
+ in_channels = 4
121
+ if config.backbone_config is not None:
122
+ in_channels = config.backbone_config.num_channels
123
+
124
+ out_channels = config.convstream_hidden_sizes
125
+
126
+ self.convs = nn.ModuleList()
127
+ self.conv_chans = [in_channels] + out_channels
128
+
129
+ for i in range(len(self.conv_chans) - 1):
130
+ in_chan_ = self.conv_chans[i]
131
+ out_chan_ = self.conv_chans[i + 1]
132
+ self.convs.append(VitMatteBasicConv3x3(config, in_chan_, out_chan_))
133
+
134
+ def forward(self, pixel_values):
135
+ out_dict = {"detailed_feature_map_0": pixel_values}
136
+ embeddings = pixel_values
137
+ for i in range(len(self.convs)):
138
+ embeddings = self.convs[i](embeddings)
139
+ name_ = "detailed_feature_map_" + str(i + 1)
140
+ out_dict[name_] = embeddings
141
+
142
+ return out_dict
143
+
144
+
145
+ class VitMatteFusionBlock(nn.Module):
146
+ """
147
+ Simple fusion block to fuse features from ConvStream and Plain Vision Transformer.
148
+ """
149
+
150
+ def __init__(self, config, in_channels, out_channels):
151
+ super().__init__()
152
+ self.conv = VitMatteBasicConv3x3(config, in_channels, out_channels, stride=1, padding=1)
153
+
154
+ def forward(self, features, detailed_feature_map):
155
+ upscaled_features = nn.functional.interpolate(features, scale_factor=2, mode="bilinear", align_corners=False)
156
+ out = torch.cat([detailed_feature_map, upscaled_features], dim=1)
157
+ out = self.conv(out)
158
+
159
+ return out
160
+
161
+
162
+ class VitMatteHead(nn.Module):
163
+ """
164
+ Simple Matting Head, containing only conv3x3 and conv1x1 layers.
165
+ """
166
+
167
+ def __init__(self, config):
168
+ super().__init__()
169
+
170
+ in_channels = config.fusion_hidden_sizes[-1]
171
+ mid_channels = 16
172
+
173
+ self.matting_convs = nn.Sequential(
174
+ nn.Conv2d(in_channels, mid_channels, kernel_size=3, stride=1, padding=1),
175
+ nn.BatchNorm2d(mid_channels),
176
+ nn.ReLU(True),
177
+ nn.Conv2d(mid_channels, 1, kernel_size=1, stride=1, padding=0),
178
+ )
179
+
180
+ def forward(self, hidden_state):
181
+ hidden_state = self.matting_convs(hidden_state)
182
+
183
+ return hidden_state
184
+
185
+
186
+ class VitMatteDetailCaptureModule(nn.Module):
187
+ """
188
+ Simple and lightweight Detail Capture Module for ViT Matting.
189
+ """
190
+
191
+ def __init__(self, config):
192
+ super().__init__()
193
+ if len(config.fusion_hidden_sizes) != len(config.convstream_hidden_sizes) + 1:
194
+ raise ValueError(
195
+ "The length of fusion_hidden_sizes should be equal to the length of convstream_hidden_sizes + 1."
196
+ )
197
+
198
+ self.config = config
199
+ self.convstream = VitMatteConvStream(config)
200
+ self.conv_chans = self.convstream.conv_chans
201
+
202
+ self.fusion_blocks = nn.ModuleList()
203
+ self.fusion_channels = [config.hidden_size] + config.fusion_hidden_sizes
204
+
205
+ for i in range(len(self.fusion_channels) - 1):
206
+ self.fusion_blocks.append(
207
+ VitMatteFusionBlock(
208
+ config=config,
209
+ in_channels=self.fusion_channels[i] + self.conv_chans[-(i + 1)],
210
+ out_channels=self.fusion_channels[i + 1],
211
+ )
212
+ )
213
+
214
+ self.matting_head = VitMatteHead(config)
215
+
216
+ def forward(self, features, pixel_values):
217
+ detail_features = self.convstream(pixel_values)
218
+ for i in range(len(self.fusion_blocks)):
219
+ detailed_feature_map_name = "detailed_feature_map_" + str(len(self.fusion_blocks) - i - 1)
220
+ features = self.fusion_blocks[i](features, detail_features[detailed_feature_map_name])
221
+
222
+ alphas = torch.sigmoid(self.matting_head(features))
223
+
224
+ return alphas
225
+
226
+
227
+ VITMATTE_START_DOCSTRING = r"""
228
+ Parameters:
229
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
230
+ it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
231
+ behavior.
232
+ config ([`UperNetConfig`]): Model configuration class with all the parameters of the model.
233
+ Initializing with a config file does not load the weights associated with the model, only the
234
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
235
+ """
236
+
237
+ VITMATTE_INPUTS_DOCSTRING = r"""
238
+ Args:
239
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
240
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
241
+ [`AutoImageProcessor`]. See [`VitMatteImageProcessor.__call__`] for details.
242
+ output_attentions (`bool`, *optional*):
243
+ Whether or not to return the attentions tensors of all attention layers in case the backbone has them. See
244
+ `attentions` under returned tensors for more detail.
245
+ output_hidden_states (`bool`, *optional*):
246
+ Whether or not to return the hidden states of all layers of the backbone. See `hidden_states` under
247
+ returned tensors for more detail.
248
+ return_dict (`bool`, *optional*):
249
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
250
+ """
251
+
252
+
253
+ @add_start_docstrings(
254
+ """ViTMatte framework leveraging any vision backbone e.g. for ADE20k, CityScapes.""",
255
+ VITMATTE_START_DOCSTRING,
256
+ )
257
+ class VitMatteForImageMatting(VitMattePreTrainedModel):
258
+ def __init__(self, config):
259
+ super().__init__(config)
260
+ self.config = config
261
+
262
+ self.backbone = load_backbone(config)
263
+ self.decoder = VitMatteDetailCaptureModule(config)
264
+
265
+ # Initialize weights and apply final processing
266
+ self.post_init()
267
+
268
+ @add_start_docstrings_to_model_forward(VITMATTE_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
269
+ @replace_return_docstrings(output_type=ImageMattingOutput, config_class=_CONFIG_FOR_DOC)
270
+ def forward(
271
+ self,
272
+ pixel_values: Optional[torch.Tensor] = None,
273
+ output_attentions: Optional[bool] = None,
274
+ output_hidden_states: Optional[bool] = None,
275
+ labels: Optional[torch.Tensor] = None,
276
+ return_dict: Optional[bool] = None,
277
+ ):
278
+ """
279
+ labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
280
+ Ground truth image matting for computing the loss.
281
+
282
+ Returns:
283
+
284
+ Examples:
285
+
286
+ ```python
287
+ >>> from transformers import VitMatteImageProcessor, VitMatteForImageMatting
288
+ >>> import torch
289
+ >>> from PIL import Image
290
+ >>> from huggingface_hub import hf_hub_download
291
+
292
+ >>> processor = VitMatteImageProcessor.from_pretrained("hustvl/vitmatte-small-composition-1k")
293
+ >>> model = VitMatteForImageMatting.from_pretrained("hustvl/vitmatte-small-composition-1k")
294
+
295
+ >>> filepath = hf_hub_download(
296
+ ... repo_id="hf-internal-testing/image-matting-fixtures", filename="image.png", repo_type="dataset"
297
+ ... )
298
+ >>> image = Image.open(filepath).convert("RGB")
299
+ >>> filepath = hf_hub_download(
300
+ ... repo_id="hf-internal-testing/image-matting-fixtures", filename="trimap.png", repo_type="dataset"
301
+ ... )
302
+ >>> trimap = Image.open(filepath).convert("L")
303
+
304
+ >>> # prepare image + trimap for the model
305
+ >>> inputs = processor(images=image, trimaps=trimap, return_tensors="pt")
306
+
307
+ >>> with torch.no_grad():
308
+ ... alphas = model(**inputs).alphas
309
+ >>> print(alphas.shape)
310
+ torch.Size([1, 1, 640, 960])
311
+ ```"""
312
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
313
+ output_hidden_states = (
314
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
315
+ )
316
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
317
+
318
+ loss = None
319
+ if labels is not None:
320
+ raise NotImplementedError("Training is not yet supported")
321
+
322
+ outputs = self.backbone.forward_with_filtered_kwargs(
323
+ pixel_values, output_hidden_states=output_hidden_states, output_attentions=output_attentions
324
+ )
325
+
326
+ features = outputs.feature_maps[-1]
327
+ alphas = self.decoder(features, pixel_values)
328
+
329
+ if not return_dict:
330
+ output = (alphas,) + outputs[1:]
331
+ return ((loss,) + output) if loss is not None else output
332
+
333
+ return ImageMattingOutput(
334
+ loss=loss,
335
+ alphas=alphas,
336
+ hidden_states=outputs.hidden_states,
337
+ attentions=outputs.attentions,
338
+ )
339
+
340
+
341
+ __all__ = ["VitMattePreTrainedModel", "VitMatteForImageMatting"]
docs/transformers/build/lib/transformers/models/vitpose/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_vitpose import *
22
+ from .image_processing_vitpose import *
23
+ from .modeling_vitpose import *
24
+ else:
25
+ import sys
26
+
27
+ _file = globals()["__file__"]
28
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/build/lib/transformers/models/vitpose/configuration_vitpose.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """VitPose model configuration"""
16
+
17
+ from typing import Optional
18
+
19
+ from ...configuration_utils import PretrainedConfig
20
+ from ...utils import logging
21
+ from ...utils.backbone_utils import verify_backbone_config_arguments
22
+ from ..auto.configuration_auto import CONFIG_MAPPING
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+
28
+ class VitPoseConfig(PretrainedConfig):
29
+ r"""
30
+ This is the configuration class to store the configuration of a [`VitPoseForPoseEstimation`]. It is used to instantiate a
31
+ VitPose model according to the specified arguments, defining the model architecture. Instantiating a configuration
32
+ with the defaults will yield a similar configuration to that of the VitPose
33
+ [usyd-community/vitpose-base-simple](https://huggingface.co/usyd-community/vitpose-base-simple) architecture.
34
+
35
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
36
+ documentation from [`PretrainedConfig`] for more information.
37
+
38
+ Args:
39
+ backbone_config (`PretrainedConfig` or `dict`, *optional*, defaults to `VitPoseBackboneConfig()`):
40
+ The configuration of the backbone model. Currently, only `backbone_config` with `vitpose_backbone` as `model_type` is supported.
41
+ backbone (`str`, *optional*):
42
+ Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
43
+ will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
44
+ is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
45
+ use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
46
+ Whether to use pretrained weights for the backbone.
47
+ use_timm_backbone (`bool`, *optional*, defaults to `False`):
48
+ Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
49
+ library.
50
+ backbone_kwargs (`dict`, *optional*):
51
+ Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
52
+ e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
53
+ initializer_range (`float`, *optional*, defaults to 0.02):
54
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
55
+ scale_factor (`int`, *optional*, defaults to 4):
56
+ Factor to upscale the feature maps coming from the ViT backbone.
57
+ use_simple_decoder (`bool`, *optional*, defaults to `True`):
58
+ Whether to use a `VitPoseSimpleDecoder` to decode the feature maps from the backbone into heatmaps. Otherwise it uses `VitPoseClassicDecoder`.
59
+
60
+
61
+ Example:
62
+
63
+ ```python
64
+ >>> from transformers import VitPoseConfig, VitPoseForPoseEstimation
65
+
66
+ >>> # Initializing a VitPose configuration
67
+ >>> configuration = VitPoseConfig()
68
+
69
+ >>> # Initializing a model (with random weights) from the configuration
70
+ >>> model = VitPoseForPoseEstimation(configuration)
71
+
72
+ >>> # Accessing the model configuration
73
+ >>> configuration = model.config
74
+ ```"""
75
+
76
+ model_type = "vitpose"
77
+
78
+ def __init__(
79
+ self,
80
+ backbone_config: Optional[PretrainedConfig] = None,
81
+ backbone: Optional[str] = None,
82
+ use_pretrained_backbone: bool = False,
83
+ use_timm_backbone: bool = False,
84
+ backbone_kwargs: Optional[dict] = None,
85
+ initializer_range: float = 0.02,
86
+ scale_factor: int = 4,
87
+ use_simple_decoder: bool = True,
88
+ **kwargs,
89
+ ):
90
+ super().__init__(**kwargs)
91
+
92
+ if use_pretrained_backbone:
93
+ logger.info(
94
+ "`use_pretrained_backbone` is `True`. For the pure inference purpose of VitPose weight do not set this value."
95
+ )
96
+ if use_timm_backbone:
97
+ raise ValueError("use_timm_backbone set `True` is not supported at the moment.")
98
+
99
+ if backbone_config is None and backbone is None:
100
+ logger.info("`backbone_config` is `None`. Initializing the config with the default `VitPose` backbone.")
101
+ backbone_config = CONFIG_MAPPING["vitpose_backbone"](out_indices=[4])
102
+ elif isinstance(backbone_config, dict):
103
+ backbone_model_type = backbone_config.get("model_type")
104
+ config_class = CONFIG_MAPPING[backbone_model_type]
105
+ backbone_config = config_class.from_dict(backbone_config)
106
+
107
+ verify_backbone_config_arguments(
108
+ use_timm_backbone=use_timm_backbone,
109
+ use_pretrained_backbone=use_pretrained_backbone,
110
+ backbone=backbone,
111
+ backbone_config=backbone_config,
112
+ backbone_kwargs=backbone_kwargs,
113
+ )
114
+
115
+ self.backbone_config = backbone_config
116
+ self.backbone = backbone
117
+ self.use_pretrained_backbone = use_pretrained_backbone
118
+ self.use_timm_backbone = use_timm_backbone
119
+ self.backbone_kwargs = backbone_kwargs
120
+
121
+ self.initializer_range = initializer_range
122
+ self.scale_factor = scale_factor
123
+ self.use_simple_decoder = use_simple_decoder
124
+
125
+
126
+ __all__ = ["VitPoseConfig"]
docs/transformers/build/lib/transformers/models/vitpose/convert_vitpose_to_hf.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Convert VitPose checkpoints from the original repository.
16
+
17
+ URL: https://github.com/vitae-transformer/vitpose
18
+
19
+ Notebook to get the original logits: https://colab.research.google.com/drive/1QDX_2POTpl6JaZAV2WIFjuiqDsDwiqMZ?usp=sharing.
20
+ """
21
+
22
+ import argparse
23
+ import os
24
+ import re
25
+
26
+ import requests
27
+ import torch
28
+ from huggingface_hub import hf_hub_download
29
+ from PIL import Image
30
+
31
+ from transformers import VitPoseBackboneConfig, VitPoseConfig, VitPoseForPoseEstimation, VitPoseImageProcessor
32
+
33
+
34
+ ORIGINAL_TO_CONVERTED_KEY_MAPPING = {
35
+ r"patch_embed.proj": "embeddings.patch_embeddings.projection",
36
+ r"pos_embed": "embeddings.position_embeddings",
37
+ r"blocks": "encoder.layer",
38
+ r"attn.proj": "attention.output.dense",
39
+ r"attn": "attention.self",
40
+ r"norm1": "layernorm_before",
41
+ r"norm2": "layernorm_after",
42
+ r"last_norm": "layernorm",
43
+ r"keypoint_head": "head",
44
+ r"final_layer": "conv",
45
+ }
46
+
47
+ MODEL_TO_FILE_NAME_MAPPING = {
48
+ # VitPose models, simple decoder
49
+ "vitpose-base-simple": "vitpose-b-simple.pth",
50
+ # VitPose models, classic decoder
51
+ "vitpose-base": "vitpose-b.pth",
52
+ # VitPose models, COCO-AIC-MPII
53
+ "vitpose-base-coco-aic-mpii": "vitpose_base_coco_aic_mpii.pth",
54
+ # VitPose+ models
55
+ "vitpose-plus-small": "vitpose+_small.pth",
56
+ "vitpose-plus-base": "vitpose+_base.pth",
57
+ "vitpose-plus-large": "vitpose+_large.pth",
58
+ "vitpose-plus-huge": "vitpose+_huge.pth",
59
+ }
60
+
61
+
62
+ def get_config(model_name):
63
+ if "plus" in model_name:
64
+ num_experts = 6
65
+ if "small" in model_name:
66
+ part_features = 96
67
+ out_indices = [12]
68
+ elif "base" in model_name:
69
+ part_features = 192
70
+ out_indices = [12]
71
+ elif "large" in model_name:
72
+ part_features = 256
73
+ out_indices = [24]
74
+ elif "huge" in model_name:
75
+ part_features = 320
76
+ out_indices = [32]
77
+ else:
78
+ raise ValueError(f"Model {model_name} not supported")
79
+ else:
80
+ num_experts = 1
81
+ part_features = 0
82
+
83
+ # size of the architecture
84
+ if "small" in model_name:
85
+ hidden_size = 384
86
+ num_hidden_layers = 12
87
+ num_attention_heads = 12
88
+ elif "large" in model_name:
89
+ hidden_size = 1024
90
+ num_hidden_layers = 24
91
+ num_attention_heads = 16
92
+ elif "huge" in model_name:
93
+ hidden_size = 1280
94
+ num_hidden_layers = 32
95
+ num_attention_heads = 16
96
+
97
+ backbone_config = VitPoseBackboneConfig(
98
+ out_indices=out_indices,
99
+ hidden_size=hidden_size,
100
+ num_hidden_layers=num_hidden_layers,
101
+ num_attention_heads=num_attention_heads,
102
+ num_experts=num_experts,
103
+ part_features=part_features,
104
+ )
105
+
106
+ use_simple_decoder = "simple" in model_name
107
+
108
+ edges = [
109
+ [15, 13],
110
+ [13, 11],
111
+ [16, 14],
112
+ [14, 12],
113
+ [11, 12],
114
+ [5, 11],
115
+ [6, 12],
116
+ [5, 6],
117
+ [5, 7],
118
+ [6, 8],
119
+ [7, 9],
120
+ [8, 10],
121
+ [1, 2],
122
+ [0, 1],
123
+ [0, 2],
124
+ [1, 3],
125
+ [2, 4],
126
+ [3, 5],
127
+ [4, 6],
128
+ ]
129
+ id2label = {
130
+ 0: "Nose",
131
+ 1: "L_Eye",
132
+ 2: "R_Eye",
133
+ 3: "L_Ear",
134
+ 4: "R_Ear",
135
+ 5: "L_Shoulder",
136
+ 6: "R_Shoulder",
137
+ 7: "L_Elbow",
138
+ 8: "R_Elbow",
139
+ 9: "L_Wrist",
140
+ 10: "R_Wrist",
141
+ 11: "L_Hip",
142
+ 12: "R_Hip",
143
+ 13: "L_Knee",
144
+ 14: "R_Knee",
145
+ 15: "L_Ankle",
146
+ 16: "R_Ankle",
147
+ }
148
+
149
+ label2id = {v: k for k, v in id2label.items()}
150
+
151
+ config = VitPoseConfig(
152
+ backbone_config=backbone_config,
153
+ num_labels=17,
154
+ use_simple_decoder=use_simple_decoder,
155
+ edges=edges,
156
+ id2label=id2label,
157
+ label2id=label2id,
158
+ )
159
+
160
+ return config
161
+
162
+
163
+ def convert_old_keys_to_new_keys(state_dict_keys: dict = None):
164
+ """
165
+ This function should be applied only once, on the concatenated keys to efficiently rename using
166
+ the key mappings.
167
+ """
168
+ output_dict = {}
169
+ if state_dict_keys is not None:
170
+ old_text = "\n".join(state_dict_keys)
171
+ new_text = old_text
172
+ for pattern, replacement in ORIGINAL_TO_CONVERTED_KEY_MAPPING.items():
173
+ if replacement is None:
174
+ new_text = re.sub(pattern, "", new_text) # an empty line
175
+ continue
176
+ new_text = re.sub(pattern, replacement, new_text)
177
+ output_dict = dict(zip(old_text.split("\n"), new_text.split("\n")))
178
+ return output_dict
179
+
180
+
181
+ # We will verify our results on a COCO image
182
+ def prepare_img():
183
+ url = "http://images.cocodataset.org/val2017/000000000139.jpg"
184
+ image = Image.open(requests.get(url, stream=True).raw)
185
+ return image
186
+
187
+
188
+ @torch.no_grad()
189
+ def write_model(model_name, model_path, push_to_hub, check_logits=True):
190
+ # ------------------------------------------------------------
191
+ # Vision model params and config
192
+ # ------------------------------------------------------------
193
+
194
+ # params from config
195
+ config = get_config(model_name)
196
+
197
+ # ------------------------------------------------------------
198
+ # Convert weights
199
+ # ------------------------------------------------------------
200
+
201
+ # load original state_dict
202
+ filename = MODEL_TO_FILE_NAME_MAPPING[model_name]
203
+ print(f"Fetching all parameters from the checkpoint at {filename}...")
204
+
205
+ checkpoint_path = hf_hub_download(
206
+ repo_id="nielsr/vitpose-original-checkpoints", filename=filename, repo_type="model"
207
+ )
208
+
209
+ print("Converting model...")
210
+ original_state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True)["state_dict"]
211
+ all_keys = list(original_state_dict.keys())
212
+ new_keys = convert_old_keys_to_new_keys(all_keys)
213
+
214
+ dim = config.backbone_config.hidden_size
215
+
216
+ state_dict = {}
217
+ for key in all_keys:
218
+ new_key = new_keys[key]
219
+ value = original_state_dict[key]
220
+
221
+ if re.search("associate_heads", new_key) or re.search("backbone.cls_token", new_key):
222
+ # This associated_heads is concept of auxiliary head so does not require in inference stage.
223
+ # backbone.cls_token is optional forward function for dynamically change of size, see detail in https://github.com/ViTAE-Transformer/ViTPose/issues/34
224
+ pass
225
+ elif re.search("qkv", new_key):
226
+ state_dict[new_key.replace("self.qkv", "attention.query")] = value[:dim]
227
+ state_dict[new_key.replace("self.qkv", "attention.key")] = value[dim : dim * 2]
228
+ state_dict[new_key.replace("self.qkv", "attention.value")] = value[-dim:]
229
+ elif re.search("head", new_key) and not config.use_simple_decoder:
230
+ # Pattern for deconvolution layers
231
+ deconv_pattern = r"deconv_layers\.(0|3)\.weight"
232
+ new_key = re.sub(deconv_pattern, lambda m: f"deconv{int(m.group(1)) // 3 + 1}.weight", new_key)
233
+ # Pattern for batch normalization layers
234
+ bn_patterns = [
235
+ (r"deconv_layers\.(\d+)\.weight", r"batchnorm\1.weight"),
236
+ (r"deconv_layers\.(\d+)\.bias", r"batchnorm\1.bias"),
237
+ (r"deconv_layers\.(\d+)\.running_mean", r"batchnorm\1.running_mean"),
238
+ (r"deconv_layers\.(\d+)\.running_var", r"batchnorm\1.running_var"),
239
+ (r"deconv_layers\.(\d+)\.num_batches_tracked", r"batchnorm\1.num_batches_tracked"),
240
+ ]
241
+
242
+ for pattern, replacement in bn_patterns:
243
+ if re.search(pattern, new_key):
244
+ # Convert the layer number to the correct batch norm index
245
+ layer_num = int(re.search(pattern, key).group(1))
246
+ bn_num = layer_num // 3 + 1
247
+ new_key = re.sub(pattern, replacement.replace(r"\1", str(bn_num)), new_key)
248
+ state_dict[new_key] = value
249
+ else:
250
+ state_dict[new_key] = value
251
+
252
+ print("Loading the checkpoint in a Vitpose model.")
253
+ model = VitPoseForPoseEstimation(config)
254
+ model.eval()
255
+ model.load_state_dict(state_dict)
256
+ print("Checkpoint loaded successfully.")
257
+
258
+ # create image processor
259
+ image_processor = VitPoseImageProcessor()
260
+
261
+ # verify image processor
262
+ image = prepare_img()
263
+ boxes = [[[412.8, 157.61, 53.05, 138.01], [384.43, 172.21, 15.12, 35.74]]]
264
+ pixel_values = image_processor(images=image, boxes=boxes, return_tensors="pt").pixel_values
265
+
266
+ filepath = hf_hub_download(repo_id="nielsr/test-image", filename="vitpose_batch_data.pt", repo_type="dataset")
267
+ original_pixel_values = torch.load(filepath, map_location="cpu", weights_only=True)["img"]
268
+ # we allow for a small difference in the pixel values due to the original repository using cv2
269
+ assert torch.allclose(pixel_values, original_pixel_values, atol=1e-1)
270
+
271
+ dataset_index = torch.tensor([0])
272
+
273
+ with torch.no_grad():
274
+ print("Shape of original_pixel_values: ", original_pixel_values.shape)
275
+ print("First values of original_pixel_values: ", original_pixel_values[0, 0, :3, :3])
276
+
277
+ # first forward pass
278
+ outputs = model(original_pixel_values, dataset_index=dataset_index)
279
+ output_heatmap = outputs.heatmaps
280
+
281
+ print("Shape of output_heatmap: ", output_heatmap.shape)
282
+ print("First values: ", output_heatmap[0, 0, :3, :3])
283
+
284
+ # second forward pass (flipped)
285
+ # this is done since the model uses `flip_test=True` in its test config
286
+ original_pixel_values_flipped = torch.flip(original_pixel_values, [3])
287
+ outputs_flipped = model(
288
+ original_pixel_values_flipped,
289
+ dataset_index=dataset_index,
290
+ flip_pairs=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]]),
291
+ )
292
+ output_flipped_heatmap = outputs_flipped.heatmaps
293
+
294
+ outputs.heatmaps = (output_heatmap + output_flipped_heatmap) * 0.5
295
+
296
+ # Verify pose_results
297
+ pose_results = image_processor.post_process_pose_estimation(outputs, boxes=boxes)[0]
298
+
299
+ if check_logits:
300
+ # Simple decoder checkpoints
301
+ if model_name == "vitpose-base-simple":
302
+ assert torch.allclose(
303
+ pose_results[1]["keypoints"][0],
304
+ torch.tensor([3.98180511e02, 1.81808380e02]),
305
+ atol=5e-2,
306
+ )
307
+ assert torch.allclose(
308
+ pose_results[1]["scores"][0],
309
+ torch.tensor([8.66642594e-01]),
310
+ atol=5e-2,
311
+ )
312
+ # Classic decoder checkpoints
313
+ elif model_name == "vitpose-base":
314
+ assert torch.allclose(
315
+ pose_results[1]["keypoints"][0],
316
+ torch.tensor([3.9807913e02, 1.8182812e02]),
317
+ atol=5e-2,
318
+ )
319
+ assert torch.allclose(
320
+ pose_results[1]["scores"][0],
321
+ torch.tensor([8.8235235e-01]),
322
+ atol=5e-2,
323
+ )
324
+ # COCO-AIC-MPII checkpoints
325
+ elif model_name == "vitpose-base-coco-aic-mpii":
326
+ assert torch.allclose(
327
+ pose_results[1]["keypoints"][0],
328
+ torch.tensor([3.98305542e02, 1.81741592e02]),
329
+ atol=5e-2,
330
+ )
331
+ assert torch.allclose(
332
+ pose_results[1]["scores"][0],
333
+ torch.tensor([8.69966745e-01]),
334
+ atol=5e-2,
335
+ )
336
+ # VitPose+ models
337
+ elif model_name == "vitpose-plus-small":
338
+ assert torch.allclose(
339
+ pose_results[1]["keypoints"][0],
340
+ torch.tensor([398.1597, 181.6902]),
341
+ atol=5e-2,
342
+ )
343
+ assert torch.allclose(
344
+ pose_results[1]["scores"][0],
345
+ torch.tensor(0.9051),
346
+ atol=5e-2,
347
+ )
348
+ elif model_name == "vitpose-plus-base":
349
+ assert torch.allclose(
350
+ pose_results[1]["keypoints"][0],
351
+ torch.tensor([3.98201294e02, 1.81728302e02]),
352
+ atol=5e-2,
353
+ )
354
+ assert torch.allclose(
355
+ pose_results[1]["scores"][0],
356
+ torch.tensor([8.75046968e-01]),
357
+ atol=5e-2,
358
+ )
359
+ elif model_name == "vitpose-plus-large":
360
+ assert torch.allclose(
361
+ pose_results[1]["keypoints"][0],
362
+ torch.tensor([398.1409, 181.7412]),
363
+ atol=5e-2,
364
+ )
365
+ assert torch.allclose(
366
+ pose_results[1]["scores"][0],
367
+ torch.tensor(0.8746),
368
+ atol=5e-2,
369
+ )
370
+ elif model_name == "vitpose-plus-huge":
371
+ assert torch.allclose(
372
+ pose_results[1]["keypoints"][0],
373
+ torch.tensor([398.2079, 181.8026]),
374
+ atol=5e-2,
375
+ )
376
+ assert torch.allclose(
377
+ pose_results[1]["scores"][0],
378
+ torch.tensor(0.8693),
379
+ atol=5e-2,
380
+ )
381
+ else:
382
+ raise ValueError("Model not supported")
383
+ print("Conversion successfully done.")
384
+
385
+ if model_path is not None:
386
+ os.makedirs(model_path, exist_ok=True)
387
+ model.save_pretrained(model_path)
388
+ image_processor.save_pretrained(model_path)
389
+
390
+ if push_to_hub:
391
+ print(f"Pushing model and image processor for {model_name} to hub")
392
+ # we created a community organization on the hub for this model
393
+ # maintained by the Transformers team
394
+ model.push_to_hub(f"usyd-community/{model_name}")
395
+ image_processor.push_to_hub(f"usyd-community/{model_name}")
396
+
397
+
398
+ def main():
399
+ parser = argparse.ArgumentParser()
400
+ # Required parameters
401
+ parser.add_argument(
402
+ "--model_name",
403
+ default="vitpose-base-simple",
404
+ choices=MODEL_TO_FILE_NAME_MAPPING.keys(),
405
+ type=str,
406
+ help="Name of the VitPose model you'd like to convert.",
407
+ )
408
+ parser.add_argument(
409
+ "--pytorch_dump_folder_path", default=None, type=str, help="Path to store the converted model."
410
+ )
411
+ parser.add_argument(
412
+ "--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
413
+ )
414
+ parser.add_argument(
415
+ "--check_logits", action="store_false", help="Whether or not to verify the logits of the converted model."
416
+ )
417
+
418
+ args = parser.parse_args()
419
+ write_model(
420
+ model_path=args.pytorch_dump_folder_path,
421
+ model_name=args.model_name,
422
+ push_to_hub=args.push_to_hub,
423
+ check_logits=args.check_logits,
424
+ )
425
+
426
+
427
+ if __name__ == "__main__":
428
+ main()
docs/transformers/build/lib/transformers/models/vitpose/image_processing_vitpose.py ADDED
@@ -0,0 +1,684 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Image processor class for VitPose."""
16
+
17
+ import itertools
18
+ import math
19
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
20
+
21
+ import numpy as np
22
+
23
+ from ...image_processing_utils import BaseImageProcessor, BatchFeature
24
+ from ...image_transforms import to_channel_dimension_format
25
+ from ...image_utils import (
26
+ IMAGENET_DEFAULT_MEAN,
27
+ IMAGENET_DEFAULT_STD,
28
+ ChannelDimension,
29
+ ImageInput,
30
+ infer_channel_dimension_format,
31
+ is_scaled_image,
32
+ make_list_of_images,
33
+ to_numpy_array,
34
+ valid_images,
35
+ )
36
+ from ...utils import TensorType, is_scipy_available, is_torch_available, is_vision_available, logging
37
+
38
+
39
+ if is_torch_available():
40
+ import torch
41
+
42
+ if is_vision_available():
43
+ import PIL
44
+
45
+ if is_scipy_available():
46
+ from scipy.linalg import inv
47
+ from scipy.ndimage import affine_transform, gaussian_filter
48
+
49
+ if TYPE_CHECKING:
50
+ from .modeling_vitpose import VitPoseEstimatorOutput
51
+
52
+ logger = logging.get_logger(__name__)
53
+
54
+
55
+ # inspired by https://github.com/ViTAE-Transformer/ViTPose/blob/d5216452796c90c6bc29f5c5ec0bdba94366768a/mmpose/datasets/datasets/base/kpt_2d_sview_rgb_img_top_down_dataset.py#L132
56
+ def box_to_center_and_scale(
57
+ box: Union[Tuple, List, np.ndarray],
58
+ image_width: int,
59
+ image_height: int,
60
+ normalize_factor: float = 200.0,
61
+ padding_factor: float = 1.25,
62
+ ):
63
+ """
64
+ Encodes a bounding box in COCO format into (center, scale).
65
+
66
+ Args:
67
+ box (`Tuple`, `List`, or `np.ndarray`):
68
+ Bounding box in COCO format (top_left_x, top_left_y, width, height).
69
+ image_width (`int`):
70
+ Image width.
71
+ image_height (`int`):
72
+ Image height.
73
+ normalize_factor (`float`):
74
+ Width and height scale factor.
75
+ padding_factor (`float`):
76
+ Bounding box padding factor.
77
+
78
+ Returns:
79
+ tuple: A tuple containing center and scale.
80
+
81
+ - `np.ndarray` [float32](2,): Center of the bbox (x, y).
82
+ - `np.ndarray` [float32](2,): Scale of the bbox width & height.
83
+ """
84
+
85
+ top_left_x, top_left_y, width, height = box[:4]
86
+ aspect_ratio = image_width / image_height
87
+ center = np.array([top_left_x + width * 0.5, top_left_y + height * 0.5], dtype=np.float32)
88
+
89
+ if width > aspect_ratio * height:
90
+ height = width * 1.0 / aspect_ratio
91
+ elif width < aspect_ratio * height:
92
+ width = height * aspect_ratio
93
+
94
+ scale = np.array([width / normalize_factor, height / normalize_factor], dtype=np.float32)
95
+ scale = scale * padding_factor
96
+
97
+ return center, scale
98
+
99
+
100
+ def coco_to_pascal_voc(bboxes: np.ndarray) -> np.ndarray:
101
+ """
102
+ Converts bounding boxes from the COCO format to the Pascal VOC format.
103
+
104
+ In other words, converts from (top_left_x, top_left_y, width, height) format
105
+ to (top_left_x, top_left_y, bottom_right_x, bottom_right_y).
106
+
107
+ Args:
108
+ bboxes (`np.ndarray` of shape `(batch_size, 4)):
109
+ Bounding boxes in COCO format.
110
+
111
+ Returns:
112
+ `np.ndarray` of shape `(batch_size, 4) in Pascal VOC format.
113
+ """
114
+ bboxes[:, 2] = bboxes[:, 2] + bboxes[:, 0] - 1
115
+ bboxes[:, 3] = bboxes[:, 3] + bboxes[:, 1] - 1
116
+
117
+ return bboxes
118
+
119
+
120
+ def get_keypoint_predictions(heatmaps: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
121
+ """Get keypoint predictions from score maps.
122
+
123
+ Args:
124
+ heatmaps (`np.ndarray` of shape `(batch_size, num_keypoints, height, width)`):
125
+ Model predicted heatmaps.
126
+
127
+ Returns:
128
+ tuple: A tuple containing aggregated results.
129
+
130
+ - coords (`np.ndarray` of shape `(batch_size, num_keypoints, 2)`):
131
+ Predicted keypoint location.
132
+ - scores (`np.ndarray` of shape `(batch_size, num_keypoints, 1)`):
133
+ Scores (confidence) of the keypoints.
134
+ """
135
+ if not isinstance(heatmaps, np.ndarray):
136
+ raise ValueError("Heatmaps should be np.ndarray")
137
+ if heatmaps.ndim != 4:
138
+ raise ValueError("Heatmaps should be 4-dimensional")
139
+
140
+ batch_size, num_keypoints, _, width = heatmaps.shape
141
+ heatmaps_reshaped = heatmaps.reshape((batch_size, num_keypoints, -1))
142
+ idx = np.argmax(heatmaps_reshaped, 2).reshape((batch_size, num_keypoints, 1))
143
+ scores = np.amax(heatmaps_reshaped, 2).reshape((batch_size, num_keypoints, 1))
144
+
145
+ preds = np.tile(idx, (1, 1, 2)).astype(np.float32)
146
+ preds[:, :, 0] = preds[:, :, 0] % width
147
+ preds[:, :, 1] = preds[:, :, 1] // width
148
+
149
+ preds = np.where(np.tile(scores, (1, 1, 2)) > 0.0, preds, -1)
150
+ return preds, scores
151
+
152
+
153
+ def post_dark_unbiased_data_processing(coords: np.ndarray, batch_heatmaps: np.ndarray, kernel: int = 3) -> np.ndarray:
154
+ """DARK post-pocessing. Implemented by unbiased_data_processing.
155
+
156
+ Paper references:
157
+ - Huang et al. The Devil is in the Details: Delving into Unbiased Data Processing for Human Pose Estimation (CVPR 2020).
158
+ - Zhang et al. Distribution-Aware Coordinate Representation for Human Pose Estimation (CVPR 2020).
159
+
160
+ Args:
161
+ coords (`np.ndarray` of shape `(num_persons, num_keypoints, 2)`):
162
+ Initial coordinates of human pose.
163
+ batch_heatmaps (`np.ndarray` of shape `(batch_size, num_keypoints, height, width)`):
164
+ Batched heatmaps as predicted by the model.
165
+ A batch_size of 1 is used for the bottom up paradigm where all persons share the same heatmap.
166
+ A batch_size of `num_persons` is used for the top down paradigm where each person has its own heatmaps.
167
+ kernel (`int`, *optional*, defaults to 3):
168
+ Gaussian kernel size (K) for modulation.
169
+
170
+ Returns:
171
+ `np.ndarray` of shape `(num_persons, num_keypoints, 2)` ):
172
+ Refined coordinates.
173
+ """
174
+ batch_size, num_keypoints, height, width = batch_heatmaps.shape
175
+ num_coords = coords.shape[0]
176
+ if not (batch_size == 1 or batch_size == num_coords):
177
+ raise ValueError("The batch size of heatmaps should be 1 or equal to the batch size of coordinates.")
178
+ radius = int((kernel - 1) // 2)
179
+ batch_heatmaps = np.array(
180
+ [
181
+ [gaussian_filter(heatmap, sigma=0.8, radius=(radius, radius), axes=(0, 1)) for heatmap in heatmaps]
182
+ for heatmaps in batch_heatmaps
183
+ ]
184
+ )
185
+ batch_heatmaps = np.clip(batch_heatmaps, 0.001, 50)
186
+ batch_heatmaps = np.log(batch_heatmaps)
187
+
188
+ batch_heatmaps_pad = np.pad(batch_heatmaps, ((0, 0), (0, 0), (1, 1), (1, 1)), mode="edge").flatten()
189
+
190
+ # calculate indices for coordinates
191
+ index = coords[..., 0] + 1 + (coords[..., 1] + 1) * (width + 2)
192
+ index += (width + 2) * (height + 2) * np.arange(0, batch_size * num_keypoints).reshape(-1, num_keypoints)
193
+ index = index.astype(int).reshape(-1, 1)
194
+ i_ = batch_heatmaps_pad[index]
195
+ ix1 = batch_heatmaps_pad[index + 1]
196
+ iy1 = batch_heatmaps_pad[index + width + 2]
197
+ ix1y1 = batch_heatmaps_pad[index + width + 3]
198
+ ix1_y1_ = batch_heatmaps_pad[index - width - 3]
199
+ ix1_ = batch_heatmaps_pad[index - 1]
200
+ iy1_ = batch_heatmaps_pad[index - 2 - width]
201
+
202
+ # calculate refined coordinates using Newton's method
203
+ dx = 0.5 * (ix1 - ix1_)
204
+ dy = 0.5 * (iy1 - iy1_)
205
+ derivative = np.concatenate([dx, dy], axis=1)
206
+ derivative = derivative.reshape(num_coords, num_keypoints, 2, 1)
207
+ dxx = ix1 - 2 * i_ + ix1_
208
+ dyy = iy1 - 2 * i_ + iy1_
209
+ dxy = 0.5 * (ix1y1 - ix1 - iy1 + i_ + i_ - ix1_ - iy1_ + ix1_y1_)
210
+ hessian = np.concatenate([dxx, dxy, dxy, dyy], axis=1)
211
+ hessian = hessian.reshape(num_coords, num_keypoints, 2, 2)
212
+ hessian = np.linalg.inv(hessian + np.finfo(np.float32).eps * np.eye(2))
213
+ coords -= np.einsum("ijmn,ijnk->ijmk", hessian, derivative).squeeze()
214
+ return coords
215
+
216
+
217
+ def transform_preds(coords: np.ndarray, center: np.ndarray, scale: np.ndarray, output_size: np.ndarray) -> np.ndarray:
218
+ """Get final keypoint predictions from heatmaps and apply scaling and
219
+ translation to map them back to the image.
220
+
221
+ Note:
222
+ num_keypoints: K
223
+
224
+ Args:
225
+ coords (`np.ndarray` of shape `(num_keypoints, ndims)`):
226
+
227
+ * If ndims=2, corrds are predicted keypoint location.
228
+ * If ndims=4, corrds are composed of (x, y, scores, tags)
229
+ * If ndims=5, corrds are composed of (x, y, scores, tags,
230
+ flipped_tags)
231
+
232
+ center (`np.ndarray` of shape `(2,)`):
233
+ Center of the bounding box (x, y).
234
+ scale (`np.ndarray` of shape `(2,)`):
235
+ Scale of the bounding box wrt original image of width and height.
236
+ output_size (`np.ndarray` of shape `(2,)`):
237
+ Size of the destination heatmaps in (height, width) format.
238
+
239
+ Returns:
240
+ np.ndarray: Predicted coordinates in the images.
241
+ """
242
+ if coords.shape[1] not in (2, 4, 5):
243
+ raise ValueError("Coordinates need to have either 2, 4 or 5 dimensions.")
244
+ if len(center) != 2:
245
+ raise ValueError("Center needs to have 2 elements, one for x and one for y.")
246
+ if len(scale) != 2:
247
+ raise ValueError("Scale needs to consist of a width and height")
248
+ if len(output_size) != 2:
249
+ raise ValueError("Output size needs to consist of a height and width")
250
+
251
+ # Recover the scale which is normalized by a factor of 200.
252
+ scale = scale * 200.0
253
+
254
+ # We use unbiased data processing
255
+ scale_y = scale[1] / (output_size[0] - 1.0)
256
+ scale_x = scale[0] / (output_size[1] - 1.0)
257
+
258
+ target_coords = np.ones_like(coords)
259
+ target_coords[:, 0] = coords[:, 0] * scale_x + center[0] - scale[0] * 0.5
260
+ target_coords[:, 1] = coords[:, 1] * scale_y + center[1] - scale[1] * 0.5
261
+
262
+ return target_coords
263
+
264
+
265
+ def get_warp_matrix(theta: float, size_input: np.ndarray, size_dst: np.ndarray, size_target: np.ndarray):
266
+ """
267
+ Calculate the transformation matrix under the constraint of unbiased. Paper ref: Huang et al. The Devil is in the
268
+ Details: Delving into Unbiased Data Processing for Human Pose Estimation (CVPR 2020).
269
+
270
+ Source: https://github.com/open-mmlab/mmpose/blob/master/mmpose/core/post_processing/post_transforms.py
271
+
272
+ Args:
273
+ theta (`float`):
274
+ Rotation angle in degrees.
275
+ size_input (`np.ndarray`):
276
+ Size of input image [width, height].
277
+ size_dst (`np.ndarray`):
278
+ Size of output image [width, height].
279
+ size_target (`np.ndarray`):
280
+ Size of ROI in input plane [w, h].
281
+
282
+ Returns:
283
+ `np.ndarray`: A matrix for transformation.
284
+ """
285
+ theta = np.deg2rad(theta)
286
+ matrix = np.zeros((2, 3), dtype=np.float32)
287
+ scale_x = size_dst[0] / size_target[0]
288
+ scale_y = size_dst[1] / size_target[1]
289
+ matrix[0, 0] = math.cos(theta) * scale_x
290
+ matrix[0, 1] = -math.sin(theta) * scale_x
291
+ matrix[0, 2] = scale_x * (
292
+ -0.5 * size_input[0] * math.cos(theta) + 0.5 * size_input[1] * math.sin(theta) + 0.5 * size_target[0]
293
+ )
294
+ matrix[1, 0] = math.sin(theta) * scale_y
295
+ matrix[1, 1] = math.cos(theta) * scale_y
296
+ matrix[1, 2] = scale_y * (
297
+ -0.5 * size_input[0] * math.sin(theta) - 0.5 * size_input[1] * math.cos(theta) + 0.5 * size_target[1]
298
+ )
299
+ return matrix
300
+
301
+
302
+ def scipy_warp_affine(src, M, size):
303
+ """
304
+ This function implements cv2.warpAffine function using affine_transform in scipy. See https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.affine_transform.html and https://docs.opencv.org/4.x/d4/d61/tutorial_warp_affine.html for more details.
305
+
306
+ Note: the original implementation of cv2.warpAffine uses cv2.INTER_LINEAR.
307
+ """
308
+ channels = [src[..., i] for i in range(src.shape[-1])]
309
+
310
+ # Convert to a 3x3 matrix used by SciPy
311
+ M_scipy = np.vstack([M, [0, 0, 1]])
312
+ # If you have a matrix for the ‘push’ transformation, use its inverse (numpy.linalg.inv) in this function.
313
+ M_inv = inv(M_scipy)
314
+ M_inv[0, 0], M_inv[0, 1], M_inv[1, 0], M_inv[1, 1], M_inv[0, 2], M_inv[1, 2] = (
315
+ M_inv[1, 1],
316
+ M_inv[1, 0],
317
+ M_inv[0, 1],
318
+ M_inv[0, 0],
319
+ M_inv[1, 2],
320
+ M_inv[0, 2],
321
+ )
322
+
323
+ new_src = [affine_transform(channel, M_inv, output_shape=size, order=1) for channel in channels]
324
+ new_src = np.stack(new_src, axis=-1)
325
+ return new_src
326
+
327
+
328
+ class VitPoseImageProcessor(BaseImageProcessor):
329
+ r"""
330
+ Constructs a VitPose image processor.
331
+
332
+ Args:
333
+ do_affine_transform (`bool`, *optional*, defaults to `True`):
334
+ Whether to apply an affine transformation to the input images.
335
+ size (`Dict[str, int]` *optional*, defaults to `{"height": 256, "width": 192}`):
336
+ Resolution of the image after `affine_transform` is applied. Only has an effect if `do_affine_transform` is set to `True`. Can
337
+ be overriden by `size` in the `preprocess` method.
338
+ do_rescale (`bool`, *optional*, defaults to `True`):
339
+ Whether or not to apply the scaling factor (to make pixel values floats between 0. and 1.).
340
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
341
+ Scale factor to use if rescaling the image. Can be overriden by `rescale_factor` in the `preprocess`
342
+ method.
343
+ do_normalize (`bool`, *optional*, defaults to `True`):
344
+ Whether or not to normalize the input with mean and standard deviation.
345
+ image_mean (`List[int]`, defaults to `[0.485, 0.456, 0.406]`, *optional*):
346
+ The sequence of means for each channel, to be used when normalizing images.
347
+ image_std (`List[int]`, defaults to `[0.229, 0.224, 0.225]`, *optional*):
348
+ The sequence of standard deviations for each channel, to be used when normalizing images.
349
+ """
350
+
351
+ model_input_names = ["pixel_values"]
352
+
353
+ def __init__(
354
+ self,
355
+ do_affine_transform: bool = True,
356
+ size: Dict[str, int] = None,
357
+ do_rescale: bool = True,
358
+ rescale_factor: Union[int, float] = 1 / 255,
359
+ do_normalize: bool = True,
360
+ image_mean: Optional[Union[float, List[float]]] = None,
361
+ image_std: Optional[Union[float, List[float]]] = None,
362
+ **kwargs,
363
+ ):
364
+ super().__init__(**kwargs)
365
+ self.do_affine_transform = do_affine_transform
366
+ self.size = size if size is not None else {"height": 256, "width": 192}
367
+ self.do_rescale = do_rescale
368
+ self.rescale_factor = rescale_factor
369
+ self.do_normalize = do_normalize
370
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
371
+ self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
372
+ self.normalize_factor = 200.0
373
+
374
+ def affine_transform(
375
+ self,
376
+ image: np.array,
377
+ center: Tuple[float],
378
+ scale: Tuple[float],
379
+ rotation: float,
380
+ size: Dict[str, int],
381
+ data_format: Optional[ChannelDimension] = None,
382
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
383
+ ) -> np.array:
384
+ """
385
+ Apply an affine transformation to an image.
386
+
387
+ Args:
388
+ image (`np.array`):
389
+ Image to transform.
390
+ center (`Tuple[float]`):
391
+ Center of the bounding box (x, y).
392
+ scale (`Tuple[float]`):
393
+ Scale of the bounding box with respect to height/width.
394
+ rotation (`float`):
395
+ Rotation angle in degrees.
396
+ size (`Dict[str, int]`):
397
+ Size of the destination image.
398
+ data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`):
399
+ The channel dimension format of the output image.
400
+ input_data_format (`str` or `ChannelDimension`, *optional*):
401
+ The channel dimension format of the input image.
402
+ """
403
+
404
+ data_format = input_data_format if data_format is None else data_format
405
+
406
+ size = (size["width"], size["height"])
407
+
408
+ # one uses a pixel standard deviation of 200 pixels
409
+ transformation = get_warp_matrix(rotation, center * 2.0, np.array(size) - 1.0, scale * 200.0)
410
+
411
+ # input image requires channels last format
412
+ image = (
413
+ image
414
+ if input_data_format == ChannelDimension.LAST
415
+ else to_channel_dimension_format(image, ChannelDimension.LAST, input_data_format)
416
+ )
417
+ image = scipy_warp_affine(src=image, M=transformation, size=(size[1], size[0]))
418
+
419
+ image = to_channel_dimension_format(image, data_format, ChannelDimension.LAST)
420
+
421
+ return image
422
+
423
+ def preprocess(
424
+ self,
425
+ images: ImageInput,
426
+ boxes: Union[List[List[float]], np.ndarray],
427
+ do_affine_transform: Optional[bool] = None,
428
+ size: Dict[str, int] = None,
429
+ do_rescale: Optional[bool] = None,
430
+ rescale_factor: Optional[float] = None,
431
+ do_normalize: Optional[bool] = None,
432
+ image_mean: Optional[Union[float, List[float]]] = None,
433
+ image_std: Optional[Union[float, List[float]]] = None,
434
+ return_tensors: Optional[Union[str, TensorType]] = None,
435
+ data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
436
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
437
+ ) -> PIL.Image.Image:
438
+ """
439
+ Preprocess an image or batch of images.
440
+
441
+ Args:
442
+ images (`ImageInput`):
443
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
444
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
445
+
446
+ boxes (`List[List[List[float]]]` or `np.ndarray`):
447
+ List or array of bounding boxes for each image. Each box should be a list of 4 floats representing the bounding
448
+ box coordinates in COCO format (top_left_x, top_left_y, width, height).
449
+
450
+ do_affine_transform (`bool`, *optional*, defaults to `self.do_affine_transform`):
451
+ Whether to apply an affine transformation to the input images.
452
+ size (`Dict[str, int]` *optional*, defaults to `self.size`):
453
+ Dictionary in the format `{"height": h, "width": w}` specifying the size of the output image after
454
+ resizing.
455
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
456
+ Whether to rescale the image values between [0 - 1].
457
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
458
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
459
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
460
+ Whether to normalize the image.
461
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
462
+ Image mean to use if `do_normalize` is set to `True`.
463
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
464
+ Image standard deviation to use if `do_normalize` is set to `True`.
465
+ return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`):
466
+ If set, will return tensors of a particular framework. Acceptable values are:
467
+
468
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
469
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
470
+ - `'np'`: Return NumPy `np.ndarray` objects.
471
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
472
+
473
+ Returns:
474
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
475
+
476
+ - **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height,
477
+ width).
478
+ """
479
+ do_affine_transform = do_affine_transform if do_affine_transform is not None else self.do_affine_transform
480
+ size = size if size is not None else self.size
481
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
482
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
483
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
484
+ image_mean = image_mean if image_mean is not None else self.image_mean
485
+ image_std = image_std if image_std is not None else self.image_std
486
+
487
+ images = make_list_of_images(images)
488
+
489
+ if not valid_images(images):
490
+ raise ValueError(
491
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
492
+ "torch.Tensor, tf.Tensor or jax.ndarray."
493
+ )
494
+
495
+ if isinstance(boxes, list) and len(images) != len(boxes):
496
+ raise ValueError(f"Batch of images and boxes mismatch : {len(images)} != {len(boxes)}")
497
+ elif isinstance(boxes, np.ndarray) and len(images) != boxes.shape[0]:
498
+ raise ValueError(f"Batch of images and boxes mismatch : {len(images)} != {boxes.shape[0]}")
499
+
500
+ # All transformations expect numpy arrays.
501
+ images = [to_numpy_array(image) for image in images]
502
+
503
+ if is_scaled_image(images[0]) and do_rescale:
504
+ logger.warning_once(
505
+ "It looks like you are trying to rescale already rescaled images. If the input"
506
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
507
+ )
508
+
509
+ if input_data_format is None:
510
+ # We assume that all images have the same channel dimension format.
511
+ input_data_format = infer_channel_dimension_format(images[0])
512
+
513
+ # transformations (affine transformation + rescaling + normalization)
514
+ if self.do_affine_transform:
515
+ new_images = []
516
+ for image, image_boxes in zip(images, boxes):
517
+ for box in image_boxes:
518
+ center, scale = box_to_center_and_scale(
519
+ box,
520
+ image_width=size["width"],
521
+ image_height=size["height"],
522
+ normalize_factor=self.normalize_factor,
523
+ )
524
+ transformed_image = self.affine_transform(
525
+ image, center, scale, rotation=0, size=size, input_data_format=input_data_format
526
+ )
527
+ new_images.append(transformed_image)
528
+ images = new_images
529
+
530
+ # For batch processing, the number of boxes must be consistent across all images in the batch.
531
+ # When using a list input, the number of boxes can vary dynamically per image.
532
+ # The image processor creates pixel_values of shape (batch_size*num_persons, num_channels, height, width)
533
+
534
+ all_images = []
535
+ for image in images:
536
+ if do_rescale:
537
+ image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
538
+
539
+ if do_normalize:
540
+ image = self.normalize(
541
+ image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
542
+ )
543
+
544
+ all_images.append(image)
545
+ images = [
546
+ to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
547
+ for image in all_images
548
+ ]
549
+
550
+ data = {"pixel_values": images}
551
+ encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
552
+
553
+ return encoded_inputs
554
+
555
+ def keypoints_from_heatmaps(
556
+ self,
557
+ heatmaps: np.ndarray,
558
+ center: np.ndarray,
559
+ scale: np.ndarray,
560
+ kernel: int = 11,
561
+ ):
562
+ """
563
+ Get final keypoint predictions from heatmaps and transform them back to
564
+ the image.
565
+
566
+ Args:
567
+ heatmaps (`np.ndarray` of shape `(batch_size, num_keypoints, height, width])`):
568
+ Model predicted heatmaps.
569
+ center (`np.ndarray` of shape `(batch_size, 2)`):
570
+ Center of the bounding box (x, y).
571
+ scale (`np.ndarray` of shape `(batch_size, 2)`):
572
+ Scale of the bounding box wrt original images of width and height.
573
+ kernel (int, *optional*, defaults to 11):
574
+ Gaussian kernel size (K) for modulation, which should match the heatmap gaussian sigma when training.
575
+ K=17 for sigma=3 and k=11 for sigma=2.
576
+
577
+ Returns:
578
+ tuple: A tuple containing keypoint predictions and scores.
579
+
580
+ - preds (`np.ndarray` of shape `(batch_size, num_keypoints, 2)`):
581
+ Predicted keypoint location in images.
582
+ - scores (`np.ndarray` of shape `(batch_size, num_keypoints, 1)`):
583
+ Scores (confidence) of the keypoints.
584
+ """
585
+ batch_size, _, height, width = heatmaps.shape
586
+
587
+ coords, scores = get_keypoint_predictions(heatmaps)
588
+
589
+ preds = post_dark_unbiased_data_processing(coords, heatmaps, kernel=kernel)
590
+
591
+ # Transform back to the image
592
+ for i in range(batch_size):
593
+ preds[i] = transform_preds(preds[i], center=center[i], scale=scale[i], output_size=[height, width])
594
+
595
+ return preds, scores
596
+
597
+ def post_process_pose_estimation(
598
+ self,
599
+ outputs: "VitPoseEstimatorOutput",
600
+ boxes: Union[List[List[List[float]]], np.ndarray],
601
+ kernel_size: int = 11,
602
+ threshold: Optional[float] = None,
603
+ target_sizes: Union[TensorType, List[Tuple]] = None,
604
+ ):
605
+ """
606
+ Transform the heatmaps into keypoint predictions and transform them back to the image.
607
+
608
+ Args:
609
+ outputs (`VitPoseEstimatorOutput`):
610
+ VitPoseForPoseEstimation model outputs.
611
+ boxes (`List[List[List[float]]]` or `np.ndarray`):
612
+ List or array of bounding boxes for each image. Each box should be a list of 4 floats representing the bounding
613
+ box coordinates in COCO format (top_left_x, top_left_y, width, height).
614
+ kernel_size (`int`, *optional*, defaults to 11):
615
+ Gaussian kernel size (K) for modulation.
616
+ threshold (`float`, *optional*, defaults to None):
617
+ Score threshold to keep object detection predictions.
618
+ target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*):
619
+ Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
620
+ `(height, width)` of each image in the batch. If unset, predictions will be resize with the default value.
621
+ Returns:
622
+ `List[List[Dict]]`: A list of dictionaries, each dictionary containing the keypoints and boxes for an image
623
+ in the batch as predicted by the model.
624
+ """
625
+
626
+ # First compute centers and scales for each bounding box
627
+ batch_size, num_keypoints, _, _ = outputs.heatmaps.shape
628
+
629
+ if target_sizes is not None:
630
+ if batch_size != len(target_sizes):
631
+ raise ValueError(
632
+ "Make sure that you pass in as many target sizes as the batch dimension of the logits"
633
+ )
634
+
635
+ centers = np.zeros((batch_size, 2), dtype=np.float32)
636
+ scales = np.zeros((batch_size, 2), dtype=np.float32)
637
+ flattened_boxes = list(itertools.chain(*boxes))
638
+ for i in range(batch_size):
639
+ if target_sizes is not None:
640
+ image_width, image_height = target_sizes[i][0], target_sizes[i][1]
641
+ scale_factor = np.array([image_width, image_height, image_width, image_height])
642
+ flattened_boxes[i] = flattened_boxes[i] * scale_factor
643
+ width, height = self.size["width"], self.size["height"]
644
+ center, scale = box_to_center_and_scale(flattened_boxes[i], image_width=width, image_height=height)
645
+ centers[i, :] = center
646
+ scales[i, :] = scale
647
+
648
+ preds, scores = self.keypoints_from_heatmaps(
649
+ outputs.heatmaps.cpu().numpy(), centers, scales, kernel=kernel_size
650
+ )
651
+
652
+ all_boxes = np.zeros((batch_size, 4), dtype=np.float32)
653
+ all_boxes[:, 0:2] = centers[:, 0:2]
654
+ all_boxes[:, 2:4] = scales[:, 0:2]
655
+
656
+ poses = torch.tensor(preds)
657
+ scores = torch.tensor(scores)
658
+ labels = torch.arange(0, num_keypoints)
659
+ bboxes_xyxy = torch.tensor(coco_to_pascal_voc(all_boxes))
660
+
661
+ results: List[List[Dict[str, torch.Tensor]]] = []
662
+
663
+ pose_bbox_pairs = zip(poses, scores, bboxes_xyxy)
664
+
665
+ for image_bboxes in boxes:
666
+ image_results: List[Dict[str, torch.Tensor]] = []
667
+ for _ in image_bboxes:
668
+ # Unpack the next pose and bbox_xyxy from the iterator
669
+ pose, score, bbox_xyxy = next(pose_bbox_pairs)
670
+ score = score.squeeze()
671
+ keypoints_labels = labels
672
+ if threshold is not None:
673
+ keep = score > threshold
674
+ pose = pose[keep]
675
+ score = score[keep]
676
+ keypoints_labels = keypoints_labels[keep]
677
+ pose_result = {"keypoints": pose, "scores": score, "labels": keypoints_labels, "bbox": bbox_xyxy}
678
+ image_results.append(pose_result)
679
+ results.append(image_results)
680
+
681
+ return results
682
+
683
+
684
+ __all__ = ["VitPoseImageProcessor"]
docs/transformers/build/lib/transformers/models/vitpose/modeling_vitpose.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 University of Sydney and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch VitPose model."""
16
+
17
+ from dataclasses import dataclass
18
+ from typing import Optional, Tuple, Union
19
+
20
+ import torch
21
+ import torch.utils.checkpoint
22
+ from torch import nn
23
+
24
+ from ...modeling_utils import PreTrainedModel
25
+ from ...utils import (
26
+ ModelOutput,
27
+ add_start_docstrings,
28
+ add_start_docstrings_to_model_forward,
29
+ logging,
30
+ replace_return_docstrings,
31
+ )
32
+ from ...utils.backbone_utils import load_backbone
33
+ from .configuration_vitpose import VitPoseConfig
34
+
35
+
36
+ logger = logging.get_logger(__name__)
37
+
38
+ # General docstring
39
+ _CONFIG_FOR_DOC = "VitPoseConfig"
40
+
41
+
42
+ @dataclass
43
+ class VitPoseEstimatorOutput(ModelOutput):
44
+ """
45
+ Class for outputs of pose estimation models.
46
+
47
+ Args:
48
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
49
+ Loss is not supported at this moment. See https://github.com/ViTAE-Transformer/ViTPose/tree/main/mmpose/models/losses for further detail.
50
+ heatmaps (`torch.FloatTensor` of shape `(batch_size, num_keypoints, height, width)`):
51
+ Heatmaps as predicted by the model.
52
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
53
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
54
+ one for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states
55
+ (also called feature maps) of the model at the output of each stage.
56
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
57
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size,
58
+ sequence_length)`.
59
+
60
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
61
+ heads.
62
+ """
63
+
64
+ loss: Optional[torch.FloatTensor] = None
65
+ heatmaps: Optional[torch.FloatTensor] = None
66
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
67
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
68
+
69
+
70
+ class VitPosePreTrainedModel(PreTrainedModel):
71
+ """
72
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
73
+ models.
74
+ """
75
+
76
+ config_class = VitPoseConfig
77
+ base_model_prefix = "vit"
78
+ main_input_name = "pixel_values"
79
+ supports_gradient_checkpointing = True
80
+
81
+ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
82
+ """Initialize the weights"""
83
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
84
+ # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
85
+ # `trunc_normal_cpu` not implemented in `half` issues
86
+ module.weight.data = nn.init.trunc_normal_(
87
+ module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
88
+ ).to(module.weight.dtype)
89
+ if module.bias is not None:
90
+ module.bias.data.zero_()
91
+ elif isinstance(module, nn.LayerNorm):
92
+ module.bias.data.zero_()
93
+ module.weight.data.fill_(1.0)
94
+
95
+
96
+ VITPOSE_START_DOCSTRING = r"""
97
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
98
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
99
+ behavior.
100
+
101
+ Parameters:
102
+ config ([`VitPoseConfig`]): Model configuration class with all the parameters of the model.
103
+ Initializing with a config file does not load the weights associated with the model, only the
104
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
105
+ """
106
+
107
+ VITPOSE_INPUTS_DOCSTRING = r"""
108
+ Args:
109
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
110
+ Pixel values. Pixel values can be obtained using [`VitPoseImageProcessor`]. See
111
+ [`VitPoseImageProcessor.__call__`] for details.
112
+
113
+ dataset_index (`torch.Tensor` of shape `(batch_size,)`):
114
+ Index to use in the Mixture-of-Experts (MoE) blocks of the backbone.
115
+
116
+ This corresponds to the dataset index used during training, e.g. For the single dataset index 0 refers to the corresponding dataset. For the multiple datasets index 0 refers to dataset A (e.g. MPII) and index 1 refers to dataset B (e.g. CrowdPose).
117
+
118
+ flip_pairs (`torch.tensor`, *optional*):
119
+ Whether to mirror pairs of keypoints (for example, left ear -- right ear).
120
+
121
+ output_attentions (`bool`, *optional*):
122
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
123
+ tensors for more detail.
124
+ output_hidden_states (`bool`, *optional*):
125
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
126
+ more detail.
127
+ return_dict (`bool`, *optional*):
128
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
129
+ """
130
+
131
+
132
+ def flip_back(output_flipped, flip_pairs, target_type="gaussian-heatmap"):
133
+ """Flip the flipped heatmaps back to the original form.
134
+
135
+ Args:
136
+ output_flipped (`torch.tensor` of shape `(batch_size, num_keypoints, height, width)`):
137
+ The output heatmaps obtained from the flipped images.
138
+ flip_pairs (`torch.Tensor` of shape `(num_keypoints, 2)`):
139
+ Pairs of keypoints which are mirrored (for example, left ear -- right ear).
140
+ target_type (`str`, *optional*, defaults to `"gaussian-heatmap"`):
141
+ Target type to use. Can be gaussian-heatmap or combined-target.
142
+ gaussian-heatmap: Classification target with gaussian distribution.
143
+ combined-target: The combination of classification target (response map) and regression target (offset map).
144
+ Paper ref: Huang et al. The Devil is in the Details: Delving into Unbiased Data Processing for Human Pose Estimation (CVPR 2020).
145
+
146
+ Returns:
147
+ torch.Tensor: heatmaps that flipped back to the original image
148
+ """
149
+ if target_type not in ["gaussian-heatmap", "combined-target"]:
150
+ raise ValueError("target_type should be gaussian-heatmap or combined-target")
151
+
152
+ if output_flipped.ndim != 4:
153
+ raise ValueError("output_flipped should be [batch_size, num_keypoints, height, width]")
154
+ batch_size, num_keypoints, height, width = output_flipped.shape
155
+ channels = 1
156
+ if target_type == "combined-target":
157
+ channels = 3
158
+ output_flipped[:, 1::3, ...] = -output_flipped[:, 1::3, ...]
159
+ output_flipped = output_flipped.reshape(batch_size, -1, channels, height, width)
160
+ output_flipped_back = output_flipped.clone()
161
+
162
+ # Swap left-right parts
163
+ for left, right in flip_pairs.tolist():
164
+ output_flipped_back[:, left, ...] = output_flipped[:, right, ...]
165
+ output_flipped_back[:, right, ...] = output_flipped[:, left, ...]
166
+ output_flipped_back = output_flipped_back.reshape((batch_size, num_keypoints, height, width))
167
+ # Flip horizontally
168
+ output_flipped_back = output_flipped_back.flip(-1)
169
+ return output_flipped_back
170
+
171
+
172
+ class VitPoseSimpleDecoder(nn.Module):
173
+ """
174
+ Simple decoding head consisting of a ReLU activation, 4x upsampling and a 3x3 convolution, turning the
175
+ feature maps into heatmaps.
176
+ """
177
+
178
+ def __init__(self, config) -> None:
179
+ super().__init__()
180
+
181
+ self.activation = nn.ReLU()
182
+ self.upsampling = nn.Upsample(scale_factor=config.scale_factor, mode="bilinear", align_corners=False)
183
+ self.conv = nn.Conv2d(
184
+ config.backbone_config.hidden_size, config.num_labels, kernel_size=3, stride=1, padding=1
185
+ )
186
+
187
+ def forward(self, hidden_state: torch.Tensor, flip_pairs: Optional[torch.Tensor] = None) -> torch.Tensor:
188
+ # Transform input: ReLU + upsample
189
+ hidden_state = self.activation(hidden_state)
190
+ hidden_state = self.upsampling(hidden_state)
191
+ heatmaps = self.conv(hidden_state)
192
+
193
+ if flip_pairs is not None:
194
+ heatmaps = flip_back(heatmaps, flip_pairs)
195
+
196
+ return heatmaps
197
+
198
+
199
+ class VitPoseClassicDecoder(nn.Module):
200
+ """
201
+ Classic decoding head consisting of a 2 deconvolutional blocks, followed by a 1x1 convolution layer,
202
+ turning the feature maps into heatmaps.
203
+ """
204
+
205
+ def __init__(self, config: VitPoseConfig):
206
+ super().__init__()
207
+
208
+ self.deconv1 = nn.ConvTranspose2d(
209
+ config.backbone_config.hidden_size, 256, kernel_size=4, stride=2, padding=1, bias=False
210
+ )
211
+ self.batchnorm1 = nn.BatchNorm2d(256)
212
+ self.relu1 = nn.ReLU()
213
+
214
+ self.deconv2 = nn.ConvTranspose2d(256, 256, kernel_size=4, stride=2, padding=1, bias=False)
215
+ self.batchnorm2 = nn.BatchNorm2d(256)
216
+ self.relu2 = nn.ReLU()
217
+
218
+ self.conv = nn.Conv2d(256, config.num_labels, kernel_size=1, stride=1, padding=0)
219
+
220
+ def forward(self, hidden_state: torch.Tensor, flip_pairs: Optional[torch.Tensor] = None):
221
+ hidden_state = self.deconv1(hidden_state)
222
+ hidden_state = self.batchnorm1(hidden_state)
223
+ hidden_state = self.relu1(hidden_state)
224
+
225
+ hidden_state = self.deconv2(hidden_state)
226
+ hidden_state = self.batchnorm2(hidden_state)
227
+ hidden_state = self.relu2(hidden_state)
228
+
229
+ heatmaps = self.conv(hidden_state)
230
+
231
+ if flip_pairs is not None:
232
+ heatmaps = flip_back(heatmaps, flip_pairs)
233
+
234
+ return heatmaps
235
+
236
+
237
+ @add_start_docstrings(
238
+ "The VitPose model with a pose estimation head on top.",
239
+ VITPOSE_START_DOCSTRING,
240
+ )
241
+ class VitPoseForPoseEstimation(VitPosePreTrainedModel):
242
+ def __init__(self, config: VitPoseConfig) -> None:
243
+ super().__init__(config)
244
+
245
+ self.backbone = load_backbone(config)
246
+
247
+ # add backbone attributes
248
+ if not hasattr(self.backbone.config, "hidden_size"):
249
+ raise ValueError("The backbone should have a hidden_size attribute")
250
+ if not hasattr(self.backbone.config, "image_size"):
251
+ raise ValueError("The backbone should have an image_size attribute")
252
+ if not hasattr(self.backbone.config, "patch_size"):
253
+ raise ValueError("The backbone should have a patch_size attribute")
254
+
255
+ self.head = VitPoseSimpleDecoder(config) if config.use_simple_decoder else VitPoseClassicDecoder(config)
256
+
257
+ # Initialize weights and apply final processing
258
+ self.post_init()
259
+
260
+ @add_start_docstrings_to_model_forward(VITPOSE_INPUTS_DOCSTRING)
261
+ @replace_return_docstrings(output_type=VitPoseEstimatorOutput, config_class=_CONFIG_FOR_DOC)
262
+ def forward(
263
+ self,
264
+ pixel_values: torch.Tensor,
265
+ dataset_index: Optional[torch.Tensor] = None,
266
+ flip_pairs: Optional[torch.Tensor] = None,
267
+ labels: Optional[torch.Tensor] = None,
268
+ output_attentions: Optional[bool] = None,
269
+ output_hidden_states: Optional[bool] = None,
270
+ return_dict: Optional[bool] = None,
271
+ ) -> Union[tuple, VitPoseEstimatorOutput]:
272
+ """
273
+ Returns:
274
+
275
+ Examples:
276
+
277
+ ```python
278
+ >>> from transformers import AutoImageProcessor, VitPoseForPoseEstimation
279
+ >>> import torch
280
+ >>> from PIL import Image
281
+ >>> import requests
282
+
283
+ >>> processor = AutoImageProcessor.from_pretrained("usyd-community/vitpose-base-simple")
284
+ >>> model = VitPoseForPoseEstimation.from_pretrained("usyd-community/vitpose-base-simple")
285
+
286
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
287
+ >>> image = Image.open(requests.get(url, stream=True).raw)
288
+ >>> boxes = [[[412.8, 157.61, 53.05, 138.01], [384.43, 172.21, 15.12, 35.74]]]
289
+ >>> inputs = processor(image, boxes=boxes, return_tensors="pt")
290
+
291
+ >>> with torch.no_grad():
292
+ ... outputs = model(**inputs)
293
+ >>> heatmaps = outputs.heatmaps
294
+ ```"""
295
+
296
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
297
+ output_hidden_states = (
298
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
299
+ )
300
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
301
+
302
+ loss = None
303
+ if labels is not None:
304
+ raise NotImplementedError("Training is not yet supported")
305
+
306
+ outputs = self.backbone.forward_with_filtered_kwargs(
307
+ pixel_values,
308
+ dataset_index=dataset_index,
309
+ output_hidden_states=output_hidden_states,
310
+ output_attentions=output_attentions,
311
+ return_dict=return_dict,
312
+ )
313
+
314
+ # Turn output hidden states in tensor of shape (batch_size, num_channels, height, width)
315
+ sequence_output = outputs.feature_maps[-1] if return_dict else outputs[0][-1]
316
+ batch_size = sequence_output.shape[0]
317
+ patch_height = self.config.backbone_config.image_size[0] // self.config.backbone_config.patch_size[0]
318
+ patch_width = self.config.backbone_config.image_size[1] // self.config.backbone_config.patch_size[1]
319
+ sequence_output = (
320
+ sequence_output.permute(0, 2, 1).reshape(batch_size, -1, patch_height, patch_width).contiguous()
321
+ )
322
+
323
+ heatmaps = self.head(sequence_output, flip_pairs=flip_pairs)
324
+
325
+ if not return_dict:
326
+ if output_hidden_states:
327
+ output = (heatmaps,) + outputs[1:]
328
+ else:
329
+ output = (heatmaps,) + outputs[2:]
330
+ return ((loss,) + output) if loss is not None else output
331
+
332
+ return VitPoseEstimatorOutput(
333
+ loss=loss,
334
+ heatmaps=heatmaps,
335
+ hidden_states=outputs.hidden_states,
336
+ attentions=outputs.attentions,
337
+ )
338
+
339
+
340
+ __all__ = ["VitPosePreTrainedModel", "VitPoseForPoseEstimation"]
docs/transformers/build/lib/transformers/models/vitpose_backbone/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa
2
+ # There's no way to ignore "F401 '...' imported but unused" warnings in this
3
+ # module, but to preserve other warnings. So, don't check this module at all.
4
+ from typing import TYPE_CHECKING
5
+
6
+ from ...utils import _LazyModule
7
+ from ...utils.import_utils import define_import_structure
8
+
9
+
10
+ if TYPE_CHECKING:
11
+ from .configuration_vitpose_backbone import *
12
+ from .modeling_vitpose_backbone import *
13
+ else:
14
+ import sys
15
+
16
+ _file = globals()["__file__"]
17
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/build/lib/transformers/models/vitpose_backbone/configuration_vitpose_backbone.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """VitPose backbone configuration"""
16
+
17
+ from ...configuration_utils import PretrainedConfig
18
+ from ...utils import logging
19
+ from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
20
+
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+
25
+ class VitPoseBackboneConfig(BackboneConfigMixin, PretrainedConfig):
26
+ r"""
27
+ This is the configuration class to store the configuration of a [`VitPoseBackbone`]. It is used to instantiate a
28
+ VitPose model according to the specified arguments, defining the model architecture. Instantiating a configuration
29
+ with the defaults will yield a similar configuration to that of the VitPose
30
+ [usyd-community/vitpose-base-simple](https://huggingface.co/usyd-community/vitpose-base-simple) architecture.
31
+
32
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
33
+ documentation from [`PretrainedConfig`] for more information.
34
+
35
+ Args:
36
+ image_size (`int`, *optional*, defaults to `[256, 192]`):
37
+ The size (resolution) of each image.
38
+ patch_size (`List[int]`, *optional*, defaults to `[16, 16]`):
39
+ The size (resolution) of each patch.
40
+ num_channels (`int`, *optional*, defaults to 3):
41
+ The number of input channels.
42
+ hidden_size (`int`, *optional*, defaults to 768):
43
+ Dimensionality of the encoder layers and the pooler layer.
44
+ num_hidden_layers (`int`, *optional*, defaults to 12):
45
+ Number of hidden layers in the Transformer encoder.
46
+ num_attention_heads (`int`, *optional*, defaults to 12):
47
+ Number of attention heads for each attention layer in the Transformer encoder.
48
+ mlp_ratio (`int`, *optional*, defaults to 4):
49
+ The ratio of the hidden size in the feedforward network to the hidden size in the attention layers.
50
+ num_experts (`int`, *optional*, defaults to 1):
51
+ The number of experts in the MoE layer.
52
+ part_features (`int`, *optional*):
53
+ The number of part features to output. Only used in case `num_experts` is greater than 1.
54
+ hidden_act (`str`, *optional*, defaults to `"gelu"`):
55
+ The non-linear activation function in the encoder and pooler. If string, `"gelu"`,
56
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
57
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
58
+ The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
59
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
60
+ The dropout ratio for the attention probabilities.
61
+ initializer_range (`float`, *optional*, defaults to 0.02):
62
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
63
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
64
+ The epsilon used by the layer normalization layers.
65
+ qkv_bias (`bool`, *optional*, defaults to `True`):
66
+ Whether to add a bias to the queries, keys and values.
67
+ out_features (`List[str]`, *optional*):
68
+ If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
69
+ (depending on how many stages the model has). If unset and `out_indices` is set, will default to the
70
+ corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the
71
+ same order as defined in the `stage_names` attribute.
72
+ out_indices (`List[int]`, *optional*):
73
+ If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
74
+ many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
75
+ If unset and `out_features` is unset, will default to the last stage. Must be in the
76
+ same order as defined in the `stage_names` attribute.
77
+
78
+ Example:
79
+
80
+ ```python
81
+ >>> from transformers import VitPoseBackboneConfig, VitPoseBackbone
82
+
83
+ >>> # Initializing a VitPose configuration
84
+ >>> configuration = VitPoseBackboneConfig()
85
+
86
+ >>> # Initializing a model (with random weights) from the configuration
87
+ >>> model = VitPoseBackbone(configuration)
88
+
89
+ >>> # Accessing the model configuration
90
+ >>> configuration = model.config
91
+ ```"""
92
+
93
+ model_type = "vitpose_backbone"
94
+
95
+ def __init__(
96
+ self,
97
+ image_size=[256, 192],
98
+ patch_size=[16, 16],
99
+ num_channels=3,
100
+ hidden_size=768,
101
+ num_hidden_layers=12,
102
+ num_attention_heads=12,
103
+ mlp_ratio=4,
104
+ num_experts=1,
105
+ part_features=256,
106
+ hidden_act="gelu",
107
+ hidden_dropout_prob=0.0,
108
+ attention_probs_dropout_prob=0.0,
109
+ initializer_range=0.02,
110
+ layer_norm_eps=1e-12,
111
+ qkv_bias=True,
112
+ out_features=None,
113
+ out_indices=None,
114
+ **kwargs,
115
+ ):
116
+ super().__init__(**kwargs)
117
+
118
+ self.hidden_size = hidden_size
119
+ self.num_hidden_layers = num_hidden_layers
120
+ self.num_attention_heads = num_attention_heads
121
+ self.mlp_ratio = mlp_ratio
122
+ self.num_experts = num_experts
123
+ self.part_features = part_features
124
+ self.hidden_act = hidden_act
125
+ self.hidden_dropout_prob = hidden_dropout_prob
126
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
127
+ self.initializer_range = initializer_range
128
+ self.layer_norm_eps = layer_norm_eps
129
+ self.image_size = image_size
130
+ self.patch_size = patch_size
131
+ self.num_channels = num_channels
132
+ self.qkv_bias = qkv_bias
133
+ self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, num_hidden_layers + 1)]
134
+ self._out_features, self._out_indices = get_aligned_output_features_output_indices(
135
+ out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
136
+ )
137
+
138
+
139
+ __all__ = ["VitPoseBackboneConfig"]
docs/transformers/build/lib/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py ADDED
@@ -0,0 +1,579 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 University of Sydney and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch VitPose backbone model.
16
+
17
+ This code is the same as the original Vision Transformer (ViT) with 2 modifications:
18
+ - use of padding=2 in the patch embedding layer
19
+ - addition of a mixture-of-experts MLP layer
20
+ """
21
+
22
+ import collections.abc
23
+ from typing import Callable, Optional, Set, Tuple, Union
24
+
25
+ import torch
26
+ import torch.utils.checkpoint
27
+ from torch import nn
28
+
29
+ from ...activations import ACT2FN
30
+ from ...modeling_outputs import BackboneOutput, BaseModelOutput
31
+ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
32
+ from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
33
+ from ...utils import (
34
+ add_start_docstrings,
35
+ add_start_docstrings_to_model_forward,
36
+ logging,
37
+ replace_return_docstrings,
38
+ )
39
+ from ...utils.backbone_utils import BackboneMixin
40
+ from .configuration_vitpose_backbone import VitPoseBackboneConfig
41
+
42
+
43
+ logger = logging.get_logger(__name__)
44
+
45
+ # General docstring
46
+ _CONFIG_FOR_DOC = "VitPoseBackboneConfig"
47
+
48
+
49
+ class VitPoseBackbonePatchEmbeddings(nn.Module):
50
+ """Image to Patch Embedding."""
51
+
52
+ def __init__(self, config):
53
+ super().__init__()
54
+
55
+ image_size = config.image_size
56
+ patch_size = config.patch_size
57
+ num_channels = config.num_channels
58
+ embed_dim = config.hidden_size
59
+
60
+ image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
61
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
62
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
63
+ self.image_size = image_size
64
+ self.patch_size = patch_size
65
+ self.num_patches = num_patches
66
+
67
+ self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size, padding=2)
68
+
69
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
70
+ height, width = pixel_values.shape[-2:]
71
+ if height != self.image_size[0] or width != self.image_size[1]:
72
+ raise ValueError(
73
+ f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
74
+ )
75
+ embeddings = self.projection(pixel_values)
76
+
77
+ embeddings = embeddings.flatten(2).transpose(1, 2)
78
+ return embeddings
79
+
80
+
81
+ class VitPoseBackboneEmbeddings(nn.Module):
82
+ """
83
+ Construct the position and patch embeddings.
84
+ """
85
+
86
+ def __init__(self, config: VitPoseBackboneConfig) -> None:
87
+ super().__init__()
88
+
89
+ self.patch_embeddings = VitPoseBackbonePatchEmbeddings(config)
90
+ num_patches = self.patch_embeddings.num_patches
91
+ self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
92
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
93
+
94
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
95
+ embeddings = self.patch_embeddings(pixel_values)
96
+
97
+ # add positional encoding to each token
98
+ embeddings = embeddings + self.position_embeddings[:, 1:] + self.position_embeddings[:, :1]
99
+
100
+ embeddings = self.dropout(embeddings)
101
+
102
+ return embeddings
103
+
104
+
105
+ # Copied from transformers.models.vit.modeling_vit.eager_attention_forward
106
+ def eager_attention_forward(
107
+ module: nn.Module,
108
+ query: torch.Tensor,
109
+ key: torch.Tensor,
110
+ value: torch.Tensor,
111
+ attention_mask: Optional[torch.Tensor],
112
+ scaling: float,
113
+ dropout: float = 0.0,
114
+ **kwargs,
115
+ ):
116
+ # Take the dot product between "query" and "key" to get the raw attention scores.
117
+ attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
118
+
119
+ # Normalize the attention scores to probabilities.
120
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
121
+
122
+ # This is actually dropping out entire tokens to attend to, which might
123
+ # seem a bit unusual, but is taken from the original Transformer paper.
124
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
125
+
126
+ # Mask heads if we want to
127
+ if attention_mask is not None:
128
+ attn_weights = attn_weights * attention_mask
129
+
130
+ attn_output = torch.matmul(attn_weights, value)
131
+ attn_output = attn_output.transpose(1, 2).contiguous()
132
+
133
+ return attn_output, attn_weights
134
+
135
+
136
+ # Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->VitPoseBackbone
137
+ class VitPoseBackboneSelfAttention(nn.Module):
138
+ def __init__(self, config: VitPoseBackboneConfig) -> None:
139
+ super().__init__()
140
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
141
+ raise ValueError(
142
+ f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
143
+ f"heads {config.num_attention_heads}."
144
+ )
145
+
146
+ self.config = config
147
+ self.num_attention_heads = config.num_attention_heads
148
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
149
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
150
+ self.dropout_prob = config.attention_probs_dropout_prob
151
+ self.scaling = self.attention_head_size**-0.5
152
+ self.is_causal = False
153
+
154
+ self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
155
+ self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
156
+ self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
157
+
158
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
159
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
160
+ x = x.view(new_x_shape)
161
+ return x.permute(0, 2, 1, 3)
162
+
163
+ def forward(
164
+ self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
165
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
166
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
167
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
168
+ query_layer = self.transpose_for_scores(self.query(hidden_states))
169
+
170
+ attention_interface: Callable = eager_attention_forward
171
+ if self.config._attn_implementation != "eager":
172
+ if self.config._attn_implementation == "sdpa" and output_attentions:
173
+ logger.warning_once(
174
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
175
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
176
+ )
177
+ else:
178
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
179
+
180
+ context_layer, attention_probs = attention_interface(
181
+ self,
182
+ query_layer,
183
+ key_layer,
184
+ value_layer,
185
+ head_mask,
186
+ is_causal=self.is_causal,
187
+ scaling=self.scaling,
188
+ dropout=0.0 if not self.training else self.dropout_prob,
189
+ )
190
+
191
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
192
+ context_layer = context_layer.reshape(new_context_layer_shape)
193
+
194
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
195
+
196
+ return outputs
197
+
198
+
199
+ # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->VitPoseBackbone
200
+ class VitPoseBackboneSelfOutput(nn.Module):
201
+ """
202
+ The residual connection is defined in VitPoseBackboneLayer instead of here (as is the case with other models), due to the
203
+ layernorm applied before each block.
204
+ """
205
+
206
+ def __init__(self, config: VitPoseBackboneConfig) -> None:
207
+ super().__init__()
208
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
209
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
210
+
211
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
212
+ hidden_states = self.dense(hidden_states)
213
+ hidden_states = self.dropout(hidden_states)
214
+
215
+ return hidden_states
216
+
217
+
218
+ # Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->VitPoseBackbone
219
+ class VitPoseBackboneAttention(nn.Module):
220
+ def __init__(self, config: VitPoseBackboneConfig) -> None:
221
+ super().__init__()
222
+ self.attention = VitPoseBackboneSelfAttention(config)
223
+ self.output = VitPoseBackboneSelfOutput(config)
224
+ self.pruned_heads = set()
225
+
226
+ def prune_heads(self, heads: Set[int]) -> None:
227
+ if len(heads) == 0:
228
+ return
229
+ heads, index = find_pruneable_heads_and_indices(
230
+ heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
231
+ )
232
+
233
+ # Prune linear layers
234
+ self.attention.query = prune_linear_layer(self.attention.query, index)
235
+ self.attention.key = prune_linear_layer(self.attention.key, index)
236
+ self.attention.value = prune_linear_layer(self.attention.value, index)
237
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
238
+
239
+ # Update hyper params and store pruned heads
240
+ self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
241
+ self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
242
+ self.pruned_heads = self.pruned_heads.union(heads)
243
+
244
+ def forward(
245
+ self,
246
+ hidden_states: torch.Tensor,
247
+ head_mask: Optional[torch.Tensor] = None,
248
+ output_attentions: bool = False,
249
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
250
+ self_outputs = self.attention(hidden_states, head_mask, output_attentions)
251
+
252
+ attention_output = self.output(self_outputs[0], hidden_states)
253
+
254
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
255
+ return outputs
256
+
257
+
258
+ class VitPoseBackboneMoeMLP(nn.Module):
259
+ def __init__(self, config: VitPoseBackboneConfig):
260
+ super().__init__()
261
+
262
+ in_features = out_features = config.hidden_size
263
+ hidden_features = int(config.hidden_size * config.mlp_ratio)
264
+
265
+ num_experts = config.num_experts
266
+ part_features = config.part_features
267
+
268
+ self.part_features = part_features
269
+ self.fc1 = nn.Linear(in_features, hidden_features)
270
+ self.act = ACT2FN[config.hidden_act]
271
+ self.fc2 = nn.Linear(hidden_features, out_features - part_features)
272
+ self.drop = nn.Dropout(config.hidden_dropout_prob)
273
+
274
+ self.num_experts = num_experts
275
+ experts = [nn.Linear(hidden_features, part_features) for _ in range(num_experts)]
276
+ self.experts = nn.ModuleList(experts)
277
+
278
+ def forward(self, hidden_state: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
279
+ expert_hidden_state = torch.zeros_like(hidden_state[:, :, -self.part_features :])
280
+
281
+ hidden_state = self.fc1(hidden_state)
282
+ hidden_state = self.act(hidden_state)
283
+ shared_hidden_state = self.fc2(hidden_state)
284
+ indices = indices.view(-1, 1, 1)
285
+
286
+ # to support ddp training
287
+ for i in range(self.num_experts):
288
+ selected_index = indices == i
289
+ current_hidden_state = self.experts[i](hidden_state) * selected_index
290
+ expert_hidden_state = expert_hidden_state + current_hidden_state
291
+
292
+ hidden_state = torch.cat([shared_hidden_state, expert_hidden_state], dim=-1)
293
+
294
+ return hidden_state
295
+
296
+
297
+ class VitPoseBackboneMLP(nn.Module):
298
+ def __init__(self, config: VitPoseBackboneConfig) -> None:
299
+ super().__init__()
300
+ in_features = out_features = config.hidden_size
301
+ hidden_features = int(config.hidden_size * config.mlp_ratio)
302
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=True)
303
+ self.activation = ACT2FN[config.hidden_act]
304
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=True)
305
+
306
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
307
+ hidden_state = self.fc1(hidden_state)
308
+ hidden_state = self.activation(hidden_state)
309
+ hidden_state = self.fc2(hidden_state)
310
+ return hidden_state
311
+
312
+
313
+ class VitPoseBackboneLayer(nn.Module):
314
+ def __init__(self, config: VitPoseBackboneConfig) -> None:
315
+ super().__init__()
316
+ self.num_experts = config.num_experts
317
+ self.attention = VitPoseBackboneAttention(config)
318
+ self.mlp = VitPoseBackboneMLP(config) if self.num_experts == 1 else VitPoseBackboneMoeMLP(config)
319
+ self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
320
+ self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
321
+
322
+ def forward(
323
+ self,
324
+ hidden_states: torch.Tensor,
325
+ dataset_index: Optional[torch.Tensor] = None,
326
+ head_mask: Optional[torch.Tensor] = None,
327
+ output_attentions: bool = False,
328
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
329
+ # Validate dataset_index when using multiple experts
330
+ if self.num_experts > 1 and dataset_index is None:
331
+ raise ValueError(
332
+ "dataset_index must be provided when using multiple experts "
333
+ f"(num_experts={self.num_experts}). Please provide dataset_index "
334
+ "to the forward pass."
335
+ )
336
+ self_attention_outputs = self.attention(
337
+ self.layernorm_before(hidden_states), # in VitPoseBackbone, layernorm is applied before self-attention
338
+ head_mask,
339
+ output_attentions=output_attentions,
340
+ )
341
+ attention_output = self_attention_outputs[0]
342
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
343
+
344
+ # first residual connection
345
+ hidden_states = attention_output + hidden_states
346
+
347
+ layer_output = self.layernorm_after(hidden_states)
348
+ if self.num_experts == 1:
349
+ layer_output = self.mlp(layer_output)
350
+ else:
351
+ layer_output = self.mlp(layer_output, indices=dataset_index)
352
+
353
+ # second residual connection
354
+ layer_output = layer_output + hidden_states
355
+
356
+ outputs = (layer_output,) + outputs
357
+
358
+ return outputs
359
+
360
+
361
+ # Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->VitPoseBackbone
362
+ class VitPoseBackboneEncoder(nn.Module):
363
+ def __init__(self, config: VitPoseBackboneConfig) -> None:
364
+ super().__init__()
365
+ self.config = config
366
+ self.layer = nn.ModuleList([VitPoseBackboneLayer(config) for _ in range(config.num_hidden_layers)])
367
+ self.gradient_checkpointing = False
368
+
369
+ # Ignore copy
370
+ def forward(
371
+ self,
372
+ hidden_states: torch.Tensor,
373
+ dataset_index: Optional[torch.Tensor] = None,
374
+ head_mask: Optional[torch.Tensor] = None,
375
+ output_attentions: bool = False,
376
+ output_hidden_states: bool = False,
377
+ return_dict: bool = True,
378
+ ) -> Union[tuple, BaseModelOutput]:
379
+ all_hidden_states = () if output_hidden_states else None
380
+ all_self_attentions = () if output_attentions else None
381
+
382
+ for i, layer_module in enumerate(self.layer):
383
+ if output_hidden_states:
384
+ all_hidden_states = all_hidden_states + (hidden_states,)
385
+
386
+ layer_head_mask = head_mask[i] if head_mask is not None else None
387
+
388
+ if self.gradient_checkpointing and self.training:
389
+ layer_outputs = self._gradient_checkpointing_func(
390
+ layer_module.__call__,
391
+ hidden_states,
392
+ dataset_index,
393
+ layer_head_mask,
394
+ output_attentions,
395
+ )
396
+ else:
397
+ layer_outputs = layer_module(hidden_states, dataset_index, layer_head_mask, output_attentions)
398
+
399
+ hidden_states = layer_outputs[0]
400
+
401
+ if output_attentions:
402
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
403
+
404
+ if output_hidden_states:
405
+ all_hidden_states = all_hidden_states + (hidden_states,)
406
+
407
+ if not return_dict:
408
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
409
+ return BaseModelOutput(
410
+ last_hidden_state=hidden_states,
411
+ hidden_states=all_hidden_states,
412
+ attentions=all_self_attentions,
413
+ )
414
+
415
+
416
+ class VitPoseBackbonePreTrainedModel(PreTrainedModel):
417
+ """
418
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
419
+ models.
420
+ """
421
+
422
+ config_class = VitPoseBackboneConfig
423
+ base_model_prefix = "vit"
424
+ main_input_name = "pixel_values"
425
+ supports_gradient_checkpointing = True
426
+ _no_split_modules = ["VitPoseBackboneEmbeddings", "VitPoseBackboneLayer"]
427
+ _supports_sdpa = True
428
+ _supports_flash_attn_2 = True
429
+
430
+ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm, VitPoseBackboneEmbeddings]) -> None:
431
+ """Initialize the weights"""
432
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
433
+ # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
434
+ # `trunc_normal_cpu` not implemented in `half` issues
435
+ module.weight.data = nn.init.trunc_normal_(
436
+ module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
437
+ ).to(module.weight.dtype)
438
+ if module.bias is not None:
439
+ module.bias.data.zero_()
440
+ elif isinstance(module, nn.LayerNorm):
441
+ module.bias.data.zero_()
442
+ module.weight.data.fill_(1.0)
443
+ elif isinstance(module, VitPoseBackboneEmbeddings):
444
+ module.position_embeddings.data = nn.init.trunc_normal_(
445
+ module.position_embeddings.data.to(torch.float32),
446
+ mean=0.0,
447
+ std=self.config.initializer_range,
448
+ ).to(module.position_embeddings.dtype)
449
+
450
+
451
+ VITPOSE_BACKBONE_START_DOCSTRING = r"""
452
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
453
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
454
+ behavior.
455
+
456
+ Parameters:
457
+ config ([`VitPoseBackboneConfig`]): Model configuration class with all the parameters of the model.
458
+ Initializing with a config file does not load the weights associated with the model, only the
459
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
460
+ """
461
+
462
+ VITPOSE_BACKBONE_INPUTS_DOCSTRING = r"""
463
+ Args:
464
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
465
+ Pixel values.
466
+
467
+ dataset_index (`torch.Tensor` of shape `(batch_size,)`):
468
+ Index to use in the Mixture-of-Experts (MoE) blocks of the backbone.
469
+
470
+ This corresponds to the dataset index used during training, e.g. index 0 refers to COCO.
471
+
472
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
473
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
474
+
475
+ - 1 indicates the head is **not masked**,
476
+ - 0 indicates the head is **masked**.
477
+
478
+ output_attentions (`bool`, *optional*):
479
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
480
+ tensors for more detail.
481
+ output_hidden_states (`bool`, *optional*):
482
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
483
+ more detail.
484
+ return_dict (`bool`, *optional*):
485
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
486
+ """
487
+
488
+
489
+ @add_start_docstrings(
490
+ "The VitPose backbone useful for downstream tasks.",
491
+ VITPOSE_BACKBONE_START_DOCSTRING,
492
+ )
493
+ class VitPoseBackbone(VitPoseBackbonePreTrainedModel, BackboneMixin):
494
+ def __init__(self, config: VitPoseBackboneConfig):
495
+ super().__init__(config)
496
+ super()._init_backbone(config)
497
+
498
+ self.num_features = [config.hidden_size for _ in range(config.num_hidden_layers + 1)]
499
+ self.embeddings = VitPoseBackboneEmbeddings(config)
500
+ self.encoder = VitPoseBackboneEncoder(config)
501
+
502
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
503
+
504
+ # Initialize weights and apply final processing
505
+ self.post_init()
506
+
507
+ @add_start_docstrings_to_model_forward(VITPOSE_BACKBONE_INPUTS_DOCSTRING)
508
+ @replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
509
+ def forward(
510
+ self,
511
+ pixel_values: torch.Tensor,
512
+ dataset_index: Optional[torch.Tensor] = None,
513
+ head_mask: Optional[torch.Tensor] = None,
514
+ output_attentions: Optional[bool] = None,
515
+ output_hidden_states: Optional[bool] = None,
516
+ return_dict: Optional[bool] = None,
517
+ ):
518
+ """
519
+ Returns:
520
+
521
+ Examples:
522
+
523
+ ```python
524
+ >>> from transformers import VitPoseBackboneConfig, VitPoseBackbone
525
+ >>> import torch
526
+
527
+ >>> config = VitPoseBackboneConfig(out_indices=[-1])
528
+ >>> model = VitPoseBackbone(config)
529
+
530
+ >>> pixel_values = torch.randn(1, 3, 256, 192)
531
+ >>> dataset_index = torch.tensor([1])
532
+ >>> outputs = model(pixel_values, dataset_index)
533
+ ```"""
534
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
535
+ output_hidden_states = (
536
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
537
+ )
538
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
539
+
540
+ # Prepare head mask if needed
541
+ # 1.0 in head_mask indicate we keep the head
542
+ # attention_probs has shape bsz x n_heads x N x N
543
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
544
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
545
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
546
+
547
+ embedding_output = self.embeddings(pixel_values)
548
+
549
+ outputs = self.encoder(
550
+ embedding_output,
551
+ dataset_index=dataset_index,
552
+ head_mask=head_mask,
553
+ output_attentions=output_attentions,
554
+ output_hidden_states=True,
555
+ return_dict=return_dict,
556
+ )
557
+ hidden_states = outputs.hidden_states if return_dict else outputs[1]
558
+
559
+ feature_maps = ()
560
+ for stage, hidden_state in zip(self.stage_names, hidden_states):
561
+ if stage in self.out_features:
562
+ hidden_state = self.layernorm(hidden_state)
563
+ feature_maps += (hidden_state,)
564
+
565
+ if not return_dict:
566
+ if output_hidden_states:
567
+ output = (feature_maps,) + outputs[1:]
568
+ else:
569
+ output = (feature_maps,) + outputs[2:]
570
+ return output
571
+
572
+ return BackboneOutput(
573
+ feature_maps=feature_maps,
574
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
575
+ attentions=outputs.attentions,
576
+ )
577
+
578
+
579
+ __all__ = ["VitPoseBackbonePreTrainedModel", "VitPoseBackbone"]
docs/transformers/build/lib/transformers/models/vits/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_vits import *
22
+ from .modeling_vits import *
23
+ from .tokenization_vits import *
24
+ else:
25
+ import sys
26
+
27
+ _file = globals()["__file__"]
28
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/build/lib/transformers/models/vits/configuration_vits.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Kakao Enterprise Authors and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """VITS model configuration"""
16
+
17
+ from ...configuration_utils import PretrainedConfig
18
+ from ...utils import logging
19
+
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+
24
+ class VitsConfig(PretrainedConfig):
25
+ r"""
26
+ This is the configuration class to store the configuration of a [`VitsModel`]. It is used to instantiate a VITS
27
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
28
+ defaults will yield a similar configuration to that of the VITS
29
+ [facebook/mms-tts-eng](https://huggingface.co/facebook/mms-tts-eng) architecture.
30
+
31
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
32
+ documentation from [`PretrainedConfig`] for more information.
33
+
34
+ Args:
35
+ vocab_size (`int`, *optional*, defaults to 38):
36
+ Vocabulary size of the VITS model. Defines the number of different tokens that can be represented by the
37
+ `inputs_ids` passed to the forward method of [`VitsModel`].
38
+ hidden_size (`int`, *optional*, defaults to 192):
39
+ Dimensionality of the text encoder layers.
40
+ num_hidden_layers (`int`, *optional*, defaults to 6):
41
+ Number of hidden layers in the Transformer encoder.
42
+ num_attention_heads (`int`, *optional*, defaults to 2):
43
+ Number of attention heads for each attention layer in the Transformer encoder.
44
+ window_size (`int`, *optional*, defaults to 4):
45
+ Window size for the relative positional embeddings in the attention layers of the Transformer encoder.
46
+ use_bias (`bool`, *optional*, defaults to `True`):
47
+ Whether to use bias in the key, query, value projection layers in the Transformer encoder.
48
+ ffn_dim (`int`, *optional*, defaults to 768):
49
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
50
+ layerdrop (`float`, *optional*, defaults to 0.1):
51
+ The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
52
+ for more details.
53
+ ffn_kernel_size (`int`, *optional*, defaults to 3):
54
+ Kernel size of the 1D convolution layers used by the feed-forward network in the Transformer encoder.
55
+ flow_size (`int`, *optional*, defaults to 192):
56
+ Dimensionality of the flow layers.
57
+ spectrogram_bins (`int`, *optional*, defaults to 513):
58
+ Number of frequency bins in the target spectrogram.
59
+ hidden_act (`str` or `function`, *optional*, defaults to `"relu"`):
60
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
61
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
62
+ hidden_dropout (`float`, *optional*, defaults to 0.1):
63
+ The dropout probability for all fully connected layers in the embeddings and encoder.
64
+ attention_dropout (`float`, *optional*, defaults to 0.1):
65
+ The dropout ratio for the attention probabilities.
66
+ activation_dropout (`float`, *optional*, defaults to 0.1):
67
+ The dropout ratio for activations inside the fully connected layer.
68
+ initializer_range (`float`, *optional*, defaults to 0.02):
69
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
70
+ layer_norm_eps (`float`, *optional*, defaults to 1e-05):
71
+ The epsilon used by the layer normalization layers.
72
+ use_stochastic_duration_prediction (`bool`, *optional*, defaults to `True`):
73
+ Whether to use the stochastic duration prediction module or the regular duration predictor.
74
+ num_speakers (`int`, *optional*, defaults to 1):
75
+ Number of speakers if this is a multi-speaker model.
76
+ speaker_embedding_size (`int`, *optional*, defaults to 0):
77
+ Number of channels used by the speaker embeddings. Is zero for single-speaker models.
78
+ upsample_initial_channel (`int`, *optional*, defaults to 512):
79
+ The number of input channels into the HiFi-GAN upsampling network.
80
+ upsample_rates (`Tuple[int]` or `List[int]`, *optional*, defaults to `[8, 8, 2, 2]`):
81
+ A tuple of integers defining the stride of each 1D convolutional layer in the HiFi-GAN upsampling network.
82
+ The length of `upsample_rates` defines the number of convolutional layers and has to match the length of
83
+ `upsample_kernel_sizes`.
84
+ upsample_kernel_sizes (`Tuple[int]` or `List[int]`, *optional*, defaults to `[16, 16, 4, 4]`):
85
+ A tuple of integers defining the kernel size of each 1D convolutional layer in the HiFi-GAN upsampling
86
+ network. The length of `upsample_kernel_sizes` defines the number of convolutional layers and has to match
87
+ the length of `upsample_rates`.
88
+ resblock_kernel_sizes (`Tuple[int]` or `List[int]`, *optional*, defaults to `[3, 7, 11]`):
89
+ A tuple of integers defining the kernel sizes of the 1D convolutional layers in the HiFi-GAN
90
+ multi-receptive field fusion (MRF) module.
91
+ resblock_dilation_sizes (`Tuple[Tuple[int]]` or `List[List[int]]`, *optional*, defaults to `[[1, 3, 5], [1, 3, 5], [1, 3, 5]]`):
92
+ A nested tuple of integers defining the dilation rates of the dilated 1D convolutional layers in the
93
+ HiFi-GAN multi-receptive field fusion (MRF) module.
94
+ leaky_relu_slope (`float`, *optional*, defaults to 0.1):
95
+ The angle of the negative slope used by the leaky ReLU activation.
96
+ depth_separable_channels (`int`, *optional*, defaults to 2):
97
+ Number of channels to use in each depth-separable block.
98
+ depth_separable_num_layers (`int`, *optional*, defaults to 3):
99
+ Number of convolutional layers to use in each depth-separable block.
100
+ duration_predictor_flow_bins (`int`, *optional*, defaults to 10):
101
+ Number of channels to map using the unonstrained rational spline in the duration predictor model.
102
+ duration_predictor_tail_bound (`float`, *optional*, defaults to 5.0):
103
+ Value of the tail bin boundary when computing the unconstrained rational spline in the duration predictor
104
+ model.
105
+ duration_predictor_kernel_size (`int`, *optional*, defaults to 3):
106
+ Kernel size of the 1D convolution layers used in the duration predictor model.
107
+ duration_predictor_dropout (`float`, *optional*, defaults to 0.5):
108
+ The dropout ratio for the duration predictor model.
109
+ duration_predictor_num_flows (`int`, *optional*, defaults to 4):
110
+ Number of flow stages used by the duration predictor model.
111
+ duration_predictor_filter_channels (`int`, *optional*, defaults to 256):
112
+ Number of channels for the convolution layers used in the duration predictor model.
113
+ prior_encoder_num_flows (`int`, *optional*, defaults to 4):
114
+ Number of flow stages used by the prior encoder flow model.
115
+ prior_encoder_num_wavenet_layers (`int`, *optional*, defaults to 4):
116
+ Number of WaveNet layers used by the prior encoder flow model.
117
+ posterior_encoder_num_wavenet_layers (`int`, *optional*, defaults to 16):
118
+ Number of WaveNet layers used by the posterior encoder model.
119
+ wavenet_kernel_size (`int`, *optional*, defaults to 5):
120
+ Kernel size of the 1D convolution layers used in the WaveNet model.
121
+ wavenet_dilation_rate (`int`, *optional*, defaults to 1):
122
+ Dilation rates of the dilated 1D convolutional layers used in the WaveNet model.
123
+ wavenet_dropout (`float`, *optional*, defaults to 0.0):
124
+ The dropout ratio for the WaveNet layers.
125
+ speaking_rate (`float`, *optional*, defaults to 1.0):
126
+ Speaking rate. Larger values give faster synthesised speech.
127
+ noise_scale (`float`, *optional*, defaults to 0.667):
128
+ How random the speech prediction is. Larger values create more variation in the predicted speech.
129
+ noise_scale_duration (`float`, *optional*, defaults to 0.8):
130
+ How random the duration prediction is. Larger values create more variation in the predicted durations.
131
+ sampling_rate (`int`, *optional*, defaults to 16000):
132
+ The sampling rate at which the output audio waveform is digitalized expressed in hertz (Hz).
133
+
134
+ Example:
135
+
136
+ ```python
137
+ >>> from transformers import VitsModel, VitsConfig
138
+
139
+ >>> # Initializing a "facebook/mms-tts-eng" style configuration
140
+ >>> configuration = VitsConfig()
141
+
142
+ >>> # Initializing a model (with random weights) from the "facebook/mms-tts-eng" style configuration
143
+ >>> model = VitsModel(configuration)
144
+
145
+ >>> # Accessing the model configuration
146
+ >>> configuration = model.config
147
+ ```"""
148
+
149
+ model_type = "vits"
150
+
151
+ def __init__(
152
+ self,
153
+ vocab_size=38,
154
+ hidden_size=192,
155
+ num_hidden_layers=6,
156
+ num_attention_heads=2,
157
+ window_size=4,
158
+ use_bias=True,
159
+ ffn_dim=768,
160
+ layerdrop=0.1,
161
+ ffn_kernel_size=3,
162
+ flow_size=192,
163
+ spectrogram_bins=513,
164
+ hidden_act="relu",
165
+ hidden_dropout=0.1,
166
+ attention_dropout=0.1,
167
+ activation_dropout=0.1,
168
+ initializer_range=0.02,
169
+ layer_norm_eps=1e-5,
170
+ use_stochastic_duration_prediction=True,
171
+ num_speakers=1,
172
+ speaker_embedding_size=0,
173
+ upsample_initial_channel=512,
174
+ upsample_rates=[8, 8, 2, 2],
175
+ upsample_kernel_sizes=[16, 16, 4, 4],
176
+ resblock_kernel_sizes=[3, 7, 11],
177
+ resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
178
+ leaky_relu_slope=0.1,
179
+ depth_separable_channels=2,
180
+ depth_separable_num_layers=3,
181
+ duration_predictor_flow_bins=10,
182
+ duration_predictor_tail_bound=5.0,
183
+ duration_predictor_kernel_size=3,
184
+ duration_predictor_dropout=0.5,
185
+ duration_predictor_num_flows=4,
186
+ duration_predictor_filter_channels=256,
187
+ prior_encoder_num_flows=4,
188
+ prior_encoder_num_wavenet_layers=4,
189
+ posterior_encoder_num_wavenet_layers=16,
190
+ wavenet_kernel_size=5,
191
+ wavenet_dilation_rate=1,
192
+ wavenet_dropout=0.0,
193
+ speaking_rate=1.0,
194
+ noise_scale=0.667,
195
+ noise_scale_duration=0.8,
196
+ sampling_rate=16_000,
197
+ **kwargs,
198
+ ):
199
+ self.vocab_size = vocab_size
200
+ self.hidden_size = hidden_size
201
+ self.num_hidden_layers = num_hidden_layers
202
+ self.num_attention_heads = num_attention_heads
203
+ self.window_size = window_size
204
+ self.use_bias = use_bias
205
+ self.ffn_dim = ffn_dim
206
+ self.layerdrop = layerdrop
207
+ self.ffn_kernel_size = ffn_kernel_size
208
+ self.flow_size = flow_size
209
+ self.spectrogram_bins = spectrogram_bins
210
+ self.hidden_act = hidden_act
211
+ self.hidden_dropout = hidden_dropout
212
+ self.attention_dropout = attention_dropout
213
+ self.activation_dropout = activation_dropout
214
+ self.initializer_range = initializer_range
215
+ self.layer_norm_eps = layer_norm_eps
216
+ self.use_stochastic_duration_prediction = use_stochastic_duration_prediction
217
+ self.num_speakers = num_speakers
218
+ self.speaker_embedding_size = speaker_embedding_size
219
+ self.upsample_initial_channel = upsample_initial_channel
220
+ self.upsample_rates = upsample_rates
221
+ self.upsample_kernel_sizes = upsample_kernel_sizes
222
+ self.resblock_kernel_sizes = resblock_kernel_sizes
223
+ self.resblock_dilation_sizes = resblock_dilation_sizes
224
+ self.leaky_relu_slope = leaky_relu_slope
225
+ self.depth_separable_channels = depth_separable_channels
226
+ self.depth_separable_num_layers = depth_separable_num_layers
227
+ self.duration_predictor_flow_bins = duration_predictor_flow_bins
228
+ self.duration_predictor_tail_bound = duration_predictor_tail_bound
229
+ self.duration_predictor_kernel_size = duration_predictor_kernel_size
230
+ self.duration_predictor_dropout = duration_predictor_dropout
231
+ self.duration_predictor_num_flows = duration_predictor_num_flows
232
+ self.duration_predictor_filter_channels = duration_predictor_filter_channels
233
+ self.prior_encoder_num_flows = prior_encoder_num_flows
234
+ self.prior_encoder_num_wavenet_layers = prior_encoder_num_wavenet_layers
235
+ self.posterior_encoder_num_wavenet_layers = posterior_encoder_num_wavenet_layers
236
+ self.wavenet_kernel_size = wavenet_kernel_size
237
+ self.wavenet_dilation_rate = wavenet_dilation_rate
238
+ self.wavenet_dropout = wavenet_dropout
239
+ self.speaking_rate = speaking_rate
240
+ self.noise_scale = noise_scale
241
+ self.noise_scale_duration = noise_scale_duration
242
+ self.sampling_rate = sampling_rate
243
+
244
+ if len(upsample_kernel_sizes) != len(upsample_rates):
245
+ raise ValueError(
246
+ f"The length of `upsample_kernel_sizes` ({len(upsample_kernel_sizes)}) must match the length of "
247
+ f"`upsample_rates` ({len(upsample_rates)})"
248
+ )
249
+
250
+ super().__init__(**kwargs)
251
+
252
+
253
+ __all__ = ["VitsConfig"]
docs/transformers/build/lib/transformers/models/vits/convert_original_checkpoint.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Convert VITS checkpoint."""
16
+
17
+ import argparse
18
+ import json
19
+ import tempfile
20
+
21
+ import torch
22
+ from huggingface_hub import hf_hub_download
23
+
24
+ from transformers import VitsConfig, VitsModel, VitsTokenizer, logging
25
+
26
+
27
+ logging.set_verbosity_info()
28
+ logger = logging.get_logger("transformers.models.vits")
29
+
30
+ MAPPING_TEXT_ENCODER = {
31
+ "enc_p.emb": "text_encoder.embed_tokens",
32
+ "enc_p.encoder.attn_layers.*.conv_k": "text_encoder.encoder.layers.*.attention.k_proj",
33
+ "enc_p.encoder.attn_layers.*.conv_v": "text_encoder.encoder.layers.*.attention.v_proj",
34
+ "enc_p.encoder.attn_layers.*.conv_q": "text_encoder.encoder.layers.*.attention.q_proj",
35
+ "enc_p.encoder.attn_layers.*.conv_o": "text_encoder.encoder.layers.*.attention.out_proj",
36
+ "enc_p.encoder.attn_layers.*.emb_rel_k": "text_encoder.encoder.layers.*.attention.emb_rel_k",
37
+ "enc_p.encoder.attn_layers.*.emb_rel_v": "text_encoder.encoder.layers.*.attention.emb_rel_v",
38
+ "enc_p.encoder.norm_layers_1.*.gamma": "text_encoder.encoder.layers.*.layer_norm.weight",
39
+ "enc_p.encoder.norm_layers_1.*.beta": "text_encoder.encoder.layers.*.layer_norm.bias",
40
+ "enc_p.encoder.ffn_layers.*.conv_1": "text_encoder.encoder.layers.*.feed_forward.conv_1",
41
+ "enc_p.encoder.ffn_layers.*.conv_2": "text_encoder.encoder.layers.*.feed_forward.conv_2",
42
+ "enc_p.encoder.norm_layers_2.*.gamma": "text_encoder.encoder.layers.*.final_layer_norm.weight",
43
+ "enc_p.encoder.norm_layers_2.*.beta": "text_encoder.encoder.layers.*.final_layer_norm.bias",
44
+ "enc_p.proj": "text_encoder.project",
45
+ }
46
+ MAPPING_STOCHASTIC_DURATION_PREDICTOR = {
47
+ "dp.pre": "duration_predictor.conv_pre",
48
+ "dp.proj": "duration_predictor.conv_proj",
49
+ "dp.convs.convs_sep.*": "duration_predictor.conv_dds.convs_dilated.*",
50
+ "dp.convs.convs_1x1.*": "duration_predictor.conv_dds.convs_pointwise.*",
51
+ "dp.convs.norms_1.*.gamma": "duration_predictor.conv_dds.norms_1.*.weight",
52
+ "dp.convs.norms_1.*.beta": "duration_predictor.conv_dds.norms_1.*.bias",
53
+ "dp.convs.norms_2.*.gamma": "duration_predictor.conv_dds.norms_2.*.weight",
54
+ "dp.convs.norms_2.*.beta": "duration_predictor.conv_dds.norms_2.*.bias",
55
+ "dp.flows.0.logs": "duration_predictor.flows.0.log_scale",
56
+ "dp.flows.0.m": "duration_predictor.flows.0.translate",
57
+ "dp.flows.*.pre": "duration_predictor.flows.*.conv_pre",
58
+ "dp.flows.*.proj": "duration_predictor.flows.*.conv_proj",
59
+ "dp.flows.*.convs.convs_1x1.0": "duration_predictor.flows.*.conv_dds.convs_pointwise.0",
60
+ "dp.flows.*.convs.convs_1x1.1": "duration_predictor.flows.*.conv_dds.convs_pointwise.1",
61
+ "dp.flows.*.convs.convs_1x1.2": "duration_predictor.flows.*.conv_dds.convs_pointwise.2",
62
+ "dp.flows.*.convs.convs_sep.0": "duration_predictor.flows.*.conv_dds.convs_dilated.0",
63
+ "dp.flows.*.convs.convs_sep.1": "duration_predictor.flows.*.conv_dds.convs_dilated.1",
64
+ "dp.flows.*.convs.convs_sep.2": "duration_predictor.flows.*.conv_dds.convs_dilated.2",
65
+ "dp.flows.*.convs.norms_1.0.gamma": "duration_predictor.flows.*.conv_dds.norms_1.0.weight",
66
+ "dp.flows.*.convs.norms_1.0.beta": "duration_predictor.flows.*.conv_dds.norms_1.0.bias",
67
+ "dp.flows.*.convs.norms_1.1.gamma": "duration_predictor.flows.*.conv_dds.norms_1.1.weight",
68
+ "dp.flows.*.convs.norms_1.1.beta": "duration_predictor.flows.*.conv_dds.norms_1.1.bias",
69
+ "dp.flows.*.convs.norms_1.2.gamma": "duration_predictor.flows.*.conv_dds.norms_1.2.weight",
70
+ "dp.flows.*.convs.norms_1.2.beta": "duration_predictor.flows.*.conv_dds.norms_1.2.bias",
71
+ "dp.flows.*.convs.norms_2.0.gamma": "duration_predictor.flows.*.conv_dds.norms_2.0.weight",
72
+ "dp.flows.*.convs.norms_2.0.beta": "duration_predictor.flows.*.conv_dds.norms_2.0.bias",
73
+ "dp.flows.*.convs.norms_2.1.gamma": "duration_predictor.flows.*.conv_dds.norms_2.1.weight",
74
+ "dp.flows.*.convs.norms_2.1.beta": "duration_predictor.flows.*.conv_dds.norms_2.1.bias",
75
+ "dp.flows.*.convs.norms_2.2.gamma": "duration_predictor.flows.*.conv_dds.norms_2.2.weight",
76
+ "dp.flows.*.convs.norms_2.2.beta": "duration_predictor.flows.*.conv_dds.norms_2.2.bias",
77
+ "dp.post_pre": "duration_predictor.post_conv_pre",
78
+ "dp.post_proj": "duration_predictor.post_conv_proj",
79
+ "dp.post_convs.convs_sep.*": "duration_predictor.post_conv_dds.convs_dilated.*",
80
+ "dp.post_convs.convs_1x1.*": "duration_predictor.post_conv_dds.convs_pointwise.*",
81
+ "dp.post_convs.norms_1.*.gamma": "duration_predictor.post_conv_dds.norms_1.*.weight",
82
+ "dp.post_convs.norms_1.*.beta": "duration_predictor.post_conv_dds.norms_1.*.bias",
83
+ "dp.post_convs.norms_2.*.gamma": "duration_predictor.post_conv_dds.norms_2.*.weight",
84
+ "dp.post_convs.norms_2.*.beta": "duration_predictor.post_conv_dds.norms_2.*.bias",
85
+ "dp.post_flows.0.logs": "duration_predictor.post_flows.0.log_scale",
86
+ "dp.post_flows.0.m": "duration_predictor.post_flows.0.translate",
87
+ "dp.post_flows.*.pre": "duration_predictor.post_flows.*.conv_pre",
88
+ "dp.post_flows.*.proj": "duration_predictor.post_flows.*.conv_proj",
89
+ "dp.post_flows.*.convs.convs_1x1.0": "duration_predictor.post_flows.*.conv_dds.convs_pointwise.0",
90
+ "dp.post_flows.*.convs.convs_1x1.1": "duration_predictor.post_flows.*.conv_dds.convs_pointwise.1",
91
+ "dp.post_flows.*.convs.convs_1x1.2": "duration_predictor.post_flows.*.conv_dds.convs_pointwise.2",
92
+ "dp.post_flows.*.convs.convs_sep.0": "duration_predictor.post_flows.*.conv_dds.convs_dilated.0",
93
+ "dp.post_flows.*.convs.convs_sep.1": "duration_predictor.post_flows.*.conv_dds.convs_dilated.1",
94
+ "dp.post_flows.*.convs.convs_sep.2": "duration_predictor.post_flows.*.conv_dds.convs_dilated.2",
95
+ "dp.post_flows.*.convs.norms_1.0.gamma": "duration_predictor.post_flows.*.conv_dds.norms_1.0.weight",
96
+ "dp.post_flows.*.convs.norms_1.0.beta": "duration_predictor.post_flows.*.conv_dds.norms_1.0.bias",
97
+ "dp.post_flows.*.convs.norms_1.1.gamma": "duration_predictor.post_flows.*.conv_dds.norms_1.1.weight",
98
+ "dp.post_flows.*.convs.norms_1.1.beta": "duration_predictor.post_flows.*.conv_dds.norms_1.1.bias",
99
+ "dp.post_flows.*.convs.norms_1.2.gamma": "duration_predictor.post_flows.*.conv_dds.norms_1.2.weight",
100
+ "dp.post_flows.*.convs.norms_1.2.beta": "duration_predictor.post_flows.*.conv_dds.norms_1.2.bias",
101
+ "dp.post_flows.*.convs.norms_2.0.gamma": "duration_predictor.post_flows.*.conv_dds.norms_2.0.weight",
102
+ "dp.post_flows.*.convs.norms_2.0.beta": "duration_predictor.post_flows.*.conv_dds.norms_2.0.bias",
103
+ "dp.post_flows.*.convs.norms_2.1.gamma": "duration_predictor.post_flows.*.conv_dds.norms_2.1.weight",
104
+ "dp.post_flows.*.convs.norms_2.1.beta": "duration_predictor.post_flows.*.conv_dds.norms_2.1.bias",
105
+ "dp.post_flows.*.convs.norms_2.2.gamma": "duration_predictor.post_flows.*.conv_dds.norms_2.2.weight",
106
+ "dp.post_flows.*.convs.norms_2.2.beta": "duration_predictor.post_flows.*.conv_dds.norms_2.2.bias",
107
+ "dp.cond": "duration_predictor.cond", # num_speakers > 1
108
+ }
109
+ MAPPING_FLOW = {
110
+ "flow.flows.*.pre": "flow.flows.*.conv_pre",
111
+ "flow.flows.*.enc.in_layers.0": "flow.flows.*.wavenet.in_layers.0",
112
+ "flow.flows.*.enc.in_layers.1": "flow.flows.*.wavenet.in_layers.1",
113
+ "flow.flows.*.enc.in_layers.2": "flow.flows.*.wavenet.in_layers.2",
114
+ "flow.flows.*.enc.in_layers.3": "flow.flows.*.wavenet.in_layers.3",
115
+ "flow.flows.*.enc.res_skip_layers.0": "flow.flows.*.wavenet.res_skip_layers.0",
116
+ "flow.flows.*.enc.res_skip_layers.1": "flow.flows.*.wavenet.res_skip_layers.1",
117
+ "flow.flows.*.enc.res_skip_layers.2": "flow.flows.*.wavenet.res_skip_layers.2",
118
+ "flow.flows.*.enc.res_skip_layers.3": "flow.flows.*.wavenet.res_skip_layers.3",
119
+ "flow.flows.*.enc.cond_layer": "flow.flows.*.wavenet.cond_layer", # num_speakers > 1
120
+ "flow.flows.*.post": "flow.flows.*.conv_post",
121
+ }
122
+ MAPPING_GENERATOR = {
123
+ "dec.conv_pre": "decoder.conv_pre",
124
+ "dec.ups.0": "decoder.upsampler.0",
125
+ "dec.ups.1": "decoder.upsampler.1",
126
+ "dec.ups.2": "decoder.upsampler.2",
127
+ "dec.ups.3": "decoder.upsampler.3",
128
+ "dec.resblocks.*.convs1.0": "decoder.resblocks.*.convs1.0",
129
+ "dec.resblocks.*.convs1.1": "decoder.resblocks.*.convs1.1",
130
+ "dec.resblocks.*.convs1.2": "decoder.resblocks.*.convs1.2",
131
+ "dec.resblocks.*.convs2.0": "decoder.resblocks.*.convs2.0",
132
+ "dec.resblocks.*.convs2.1": "decoder.resblocks.*.convs2.1",
133
+ "dec.resblocks.*.convs2.2": "decoder.resblocks.*.convs2.2",
134
+ "dec.conv_post": "decoder.conv_post",
135
+ "dec.cond": "decoder.cond", # num_speakers > 1
136
+ }
137
+ MAPPING_POSTERIOR_ENCODER = {
138
+ "enc_q.pre": "posterior_encoder.conv_pre",
139
+ "enc_q.enc.in_layers.*": "posterior_encoder.wavenet.in_layers.*",
140
+ "enc_q.enc.res_skip_layers.*": "posterior_encoder.wavenet.res_skip_layers.*",
141
+ "enc_q.enc.cond_layer": "posterior_encoder.wavenet.cond_layer", # num_speakers > 1
142
+ "enc_q.proj": "posterior_encoder.conv_proj",
143
+ }
144
+ MAPPING = {
145
+ **MAPPING_TEXT_ENCODER,
146
+ **MAPPING_STOCHASTIC_DURATION_PREDICTOR,
147
+ **MAPPING_FLOW,
148
+ **MAPPING_GENERATOR,
149
+ **MAPPING_POSTERIOR_ENCODER,
150
+ "emb_g": "embed_speaker", # num_speakers > 1
151
+ }
152
+ TOP_LEVEL_KEYS = []
153
+ IGNORE_KEYS = []
154
+
155
+
156
+ def set_recursively(hf_pointer, key, value, full_name, weight_type):
157
+ for attribute in key.split("."):
158
+ hf_pointer = getattr(hf_pointer, attribute)
159
+
160
+ if weight_type is not None:
161
+ hf_shape = getattr(hf_pointer, weight_type).shape
162
+ else:
163
+ hf_shape = hf_pointer.shape
164
+
165
+ # strip off the kernel dimension at the end (original weights are Conv1d)
166
+ if key.endswith(".k_proj") or key.endswith(".v_proj") or key.endswith(".q_proj") or key.endswith(".out_proj"):
167
+ value = value.squeeze(-1)
168
+
169
+ if hf_shape != value.shape:
170
+ raise ValueError(
171
+ f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be"
172
+ f" {value.shape} for {full_name}"
173
+ )
174
+
175
+ if weight_type == "weight":
176
+ hf_pointer.weight.data = value
177
+ elif weight_type == "weight_g":
178
+ hf_pointer.weight_g.data = value
179
+ elif weight_type == "weight_v":
180
+ hf_pointer.weight_v.data = value
181
+ elif weight_type == "bias":
182
+ hf_pointer.bias.data = value
183
+ elif weight_type == "running_mean":
184
+ hf_pointer.running_mean.data = value
185
+ elif weight_type == "running_var":
186
+ hf_pointer.running_var.data = value
187
+ elif weight_type == "num_batches_tracked":
188
+ hf_pointer.num_batches_tracked.data = value
189
+ else:
190
+ hf_pointer.data = value
191
+
192
+ logger.info(f"{key + ('.' + weight_type if weight_type is not None else '')} was initialized from {full_name}.")
193
+
194
+
195
+ def should_ignore(name, ignore_keys):
196
+ for key in ignore_keys:
197
+ if key.endswith(".*"):
198
+ if name.startswith(key[:-1]):
199
+ return True
200
+ elif ".*." in key:
201
+ prefix, suffix = key.split(".*.")
202
+ if prefix in name and suffix in name:
203
+ return True
204
+ elif key in name:
205
+ return True
206
+ return False
207
+
208
+
209
+ def recursively_load_weights(fairseq_dict, hf_model):
210
+ unused_weights = []
211
+
212
+ for name, value in fairseq_dict.items():
213
+ if should_ignore(name, IGNORE_KEYS):
214
+ logger.info(f"{name} was ignored")
215
+ continue
216
+
217
+ is_used = False
218
+ for key, mapped_key in MAPPING.items():
219
+ if key.endswith(".*"):
220
+ key = key[:-1]
221
+ elif "*" in key:
222
+ prefix, suffix = key.split(".*.")
223
+ if prefix in name and suffix in name:
224
+ key = suffix
225
+
226
+ if key in name:
227
+ is_used = True
228
+ if mapped_key.endswith(".*"):
229
+ layer_index = name.split(key)[-1].split(".")[0]
230
+ mapped_key = mapped_key.replace("*", layer_index)
231
+ elif "*" in mapped_key:
232
+ layer_index = name.split(key)[0].split(".")[-2]
233
+
234
+ # remap the layer index since we removed the Flip layers
235
+ if "flow.flows" in mapped_key:
236
+ layer_index = str(int(layer_index) // 2)
237
+ if "duration_predictor.flows" in mapped_key or "duration_predictor.post_flows" in mapped_key:
238
+ layer_index = str(int(layer_index) // 2 + 1)
239
+
240
+ mapped_key = mapped_key.replace("*", layer_index)
241
+ if "weight_g" in name:
242
+ weight_type = "weight_g"
243
+ elif "weight_v" in name:
244
+ weight_type = "weight_v"
245
+ elif "bias" in name:
246
+ weight_type = "bias"
247
+ elif "weight" in name:
248
+ weight_type = "weight"
249
+ elif "running_mean" in name:
250
+ weight_type = "running_mean"
251
+ elif "running_var" in name:
252
+ weight_type = "running_var"
253
+ elif "num_batches_tracked" in name:
254
+ weight_type = "num_batches_tracked"
255
+ else:
256
+ weight_type = None
257
+ set_recursively(hf_model, mapped_key, value, name, weight_type)
258
+ continue
259
+ if not is_used:
260
+ unused_weights.append(name)
261
+
262
+ logger.warning(f"Unused weights: {unused_weights}")
263
+
264
+
265
+ @torch.no_grad()
266
+ def convert_checkpoint(
267
+ pytorch_dump_folder_path,
268
+ checkpoint_path=None,
269
+ config_path=None,
270
+ vocab_path=None,
271
+ language=None,
272
+ num_speakers=None,
273
+ sampling_rate=None,
274
+ repo_id=None,
275
+ ):
276
+ """
277
+ Copy/paste/tweak model's weights to transformers design.
278
+ """
279
+ if config_path is not None:
280
+ config = VitsConfig.from_pretrained(config_path)
281
+ else:
282
+ config = VitsConfig()
283
+
284
+ if num_speakers:
285
+ config.num_speakers = num_speakers
286
+ config.speaker_embedding_size = 256
287
+
288
+ if sampling_rate:
289
+ config.sampling_rate = sampling_rate
290
+
291
+ if checkpoint_path is None:
292
+ logger.info(f"***Converting model: facebook/mms-tts {language}***")
293
+
294
+ vocab_path = hf_hub_download(
295
+ repo_id="facebook/mms-tts",
296
+ filename="vocab.txt",
297
+ subfolder=f"models/{language}",
298
+ )
299
+ config_file = hf_hub_download(
300
+ repo_id="facebook/mms-tts",
301
+ filename="config.json",
302
+ subfolder=f"models/{language}",
303
+ )
304
+ checkpoint_path = hf_hub_download(
305
+ repo_id="facebook/mms-tts",
306
+ filename="G_100000.pth",
307
+ subfolder=f"models/{language}",
308
+ )
309
+
310
+ with open(config_file, "r") as f:
311
+ data = f.read()
312
+ hps = json.loads(data)
313
+
314
+ is_uroman = hps["data"]["training_files"].split(".")[-1] == "uroman"
315
+ if is_uroman:
316
+ logger.warning("For this checkpoint, you should use `uroman` to convert input text before tokenizing it!")
317
+ else:
318
+ logger.info(f"***Converting model: {checkpoint_path}***")
319
+ is_uroman = False
320
+
321
+ # original VITS checkpoint
322
+ if vocab_path is None:
323
+ _pad = "_"
324
+ _punctuation = ';:,.!?¡¿—…"«»“” '
325
+ _letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
326
+ _letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
327
+ symbols = _pad + _punctuation + _letters + _letters_ipa
328
+ symbol_to_id = {s: i for i, s in enumerate(symbols)}
329
+ phonemize = True
330
+ else:
331
+ # Save vocab as temporary json file
332
+ symbols = [line.replace("\n", "") for line in open(vocab_path, encoding="utf-8").readlines()]
333
+ symbol_to_id = {s: i for i, s in enumerate(symbols)}
334
+ # MMS-TTS does not use a <pad> token, so we set to the token used to space characters
335
+ _pad = symbols[0]
336
+ phonemize = False
337
+
338
+ with tempfile.NamedTemporaryFile() as tf:
339
+ with open(tf.name, "w", encoding="utf-8") as f:
340
+ f.write(json.dumps(symbol_to_id, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
341
+
342
+ tokenizer = VitsTokenizer(tf.name, language=language, phonemize=phonemize, is_uroman=is_uroman, pad_token=_pad)
343
+
344
+ config.vocab_size = len(symbols)
345
+ model = VitsModel(config)
346
+
347
+ model.decoder.apply_weight_norm()
348
+
349
+ orig_checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu"), weights_only=True)
350
+ recursively_load_weights(orig_checkpoint["model"], model)
351
+
352
+ model.decoder.remove_weight_norm()
353
+
354
+ model.save_pretrained(pytorch_dump_folder_path)
355
+ tokenizer.save_pretrained(pytorch_dump_folder_path)
356
+
357
+ if repo_id:
358
+ print("Pushing to the hub...")
359
+ tokenizer.push_to_hub(repo_id)
360
+ model.push_to_hub(repo_id)
361
+
362
+
363
+ if __name__ == "__main__":
364
+ parser = argparse.ArgumentParser()
365
+ parser.add_argument("--checkpoint_path", default=None, type=str, help="Local path to original checkpoint")
366
+ parser.add_argument("--vocab_path", default=None, type=str, help="Path to vocab.txt")
367
+ parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
368
+ parser.add_argument("--language", default=None, type=str, help="Tokenizer language (three-letter code)")
369
+ parser.add_argument("--num_speakers", default=None, type=int, help="Number of speakers")
370
+ parser.add_argument(
371
+ "--sampling_rate", default=None, type=int, help="Sampling rate on which the model was trained."
372
+ )
373
+ parser.add_argument(
374
+ "--pytorch_dump_folder_path", required=True, default=None, type=str, help="Path to the output PyTorch model."
375
+ )
376
+ parser.add_argument(
377
+ "--push_to_hub", default=None, type=str, help="Where to upload the converted model on the 🤗 hub."
378
+ )
379
+
380
+ args = parser.parse_args()
381
+ convert_checkpoint(
382
+ args.pytorch_dump_folder_path,
383
+ args.checkpoint_path,
384
+ args.config_path,
385
+ args.vocab_path,
386
+ args.language,
387
+ args.num_speakers,
388
+ args.sampling_rate,
389
+ args.push_to_hub,
390
+ )
docs/transformers/build/lib/transformers/models/vits/modeling_vits.py ADDED
@@ -0,0 +1,1493 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Kakao Enterprise Authors and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch VITS model."""
16
+
17
+ import math
18
+ from dataclasses import dataclass
19
+ from typing import Any, Optional, Tuple, Union
20
+
21
+ import numpy as np
22
+ import torch
23
+ import torch.utils.checkpoint
24
+ from torch import nn
25
+
26
+ from ...activations import ACT2FN
27
+ from ...integrations.deepspeed import is_deepspeed_zero3_enabled
28
+ from ...integrations.fsdp import is_fsdp_managed_module
29
+ from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
30
+ from ...modeling_outputs import (
31
+ BaseModelOutput,
32
+ ModelOutput,
33
+ )
34
+ from ...modeling_utils import PreTrainedModel
35
+ from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
36
+ from .configuration_vits import VitsConfig
37
+
38
+
39
+ logger = logging.get_logger(__name__)
40
+
41
+
42
+ # General docstring
43
+ _CONFIG_FOR_DOC = "VitsConfig"
44
+
45
+
46
+ @dataclass
47
+ class VitsModelOutput(ModelOutput):
48
+ """
49
+ Describes the outputs for the VITS model, with potential hidden states and attentions.
50
+
51
+ Args:
52
+ waveform (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
53
+ The final audio waveform predicted by the model.
54
+ sequence_lengths (`torch.FloatTensor` of shape `(batch_size,)`):
55
+ The length in samples of each element in the `waveform` batch.
56
+ spectrogram (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_bins)`):
57
+ The log-mel spectrogram predicted at the output of the flow model. This spectrogram is passed to the Hi-Fi
58
+ GAN decoder model to obtain the final audio waveform.
59
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
60
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
61
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
62
+
63
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
64
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
65
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
66
+ sequence_length)`.
67
+
68
+ Attention weights after the attention softmax, used to compute the weighted average in the self-attention
69
+ heads.
70
+ """
71
+
72
+ waveform: Optional[torch.FloatTensor] = None
73
+ sequence_lengths: Optional[torch.FloatTensor] = None
74
+ spectrogram: Optional[Tuple[torch.FloatTensor]] = None
75
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
76
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
77
+
78
+
79
+ @dataclass
80
+ class VitsTextEncoderOutput(ModelOutput):
81
+ """
82
+ Describes the outputs for the VITS text encoder model, with potential hidden states and attentions.
83
+
84
+ Args:
85
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
86
+ Sequence of hidden-states at the output of the last layer of the model.
87
+ prior_means (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
88
+ The predicted mean values of the prior distribution for the latent text variables.
89
+ prior_log_variances (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
90
+ The predicted log-variance values of the prior distribution for the latent text variables.
91
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
92
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
93
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
94
+
95
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
96
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
97
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
98
+ sequence_length)`.
99
+
100
+ Attention weights after the attention softmax, used to compute the weighted average in the self-attention
101
+ heads.
102
+ """
103
+
104
+ last_hidden_state: Optional[torch.FloatTensor] = None
105
+ prior_means: Optional[torch.FloatTensor] = None
106
+ prior_log_variances: Optional[torch.FloatTensor] = None
107
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
108
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
109
+
110
+
111
+ @torch.jit.script
112
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, num_channels):
113
+ in_act = input_a + input_b
114
+ t_act = torch.tanh(in_act[:, :num_channels, :])
115
+ s_act = torch.sigmoid(in_act[:, num_channels:, :])
116
+ acts = t_act * s_act
117
+ return acts
118
+
119
+
120
+ def _unconstrained_rational_quadratic_spline(
121
+ inputs,
122
+ unnormalized_widths,
123
+ unnormalized_heights,
124
+ unnormalized_derivatives,
125
+ reverse=False,
126
+ tail_bound=5.0,
127
+ min_bin_width=1e-3,
128
+ min_bin_height=1e-3,
129
+ min_derivative=1e-3,
130
+ ):
131
+ """
132
+ This transformation represents a monotonically increasing piecewise rational quadratic function. Outside of the
133
+ `tail_bound`, the transform behaves as an identity function.
134
+
135
+ Args:
136
+ inputs (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`:
137
+ Second half of the hidden-states input to the Vits convolutional flow module.
138
+ unnormalized_widths (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`):
139
+ First `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection
140
+ layer in the convolutional flow module
141
+ unnormalized_heights (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`):
142
+ Second `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection
143
+ layer in the convolutional flow module
144
+ unnormalized_derivatives (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`):
145
+ Third `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection
146
+ layer in the convolutional flow module
147
+ reverse (`bool`, *optional*, defaults to `False`):
148
+ Whether the model is being run in reverse mode.
149
+ tail_bound (`float`, *optional* defaults to 5):
150
+ Upper and lower limit bound for the rational quadratic function. Outside of this `tail_bound`, the
151
+ transform behaves as an identity function.
152
+ min_bin_width (`float`, *optional*, defaults to 1e-3):
153
+ Minimum bin value across the width dimension for the piecewise rational quadratic function.
154
+ min_bin_height (`float`, *optional*, defaults to 1e-3):
155
+ Minimum bin value across the height dimension for the piecewise rational quadratic function.
156
+ min_derivative (`float`, *optional*, defaults to 1e-3):
157
+ Minimum bin value across the derivatives for the piecewise rational quadratic function.
158
+ Returns:
159
+ outputs (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`:
160
+ Hidden-states as transformed by the piecewise rational quadratic function with the `tail_bound` limits
161
+ applied.
162
+ log_abs_det (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`:
163
+ Logarithm of the absolute value of the determinants corresponding to the `outputs` with the `tail_bound`
164
+ limits applied.
165
+ """
166
+ inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
167
+ outside_interval_mask = ~inside_interval_mask
168
+
169
+ outputs = torch.zeros_like(inputs)
170
+ log_abs_det = torch.zeros_like(inputs)
171
+ constant = np.log(np.exp(1 - min_derivative) - 1)
172
+
173
+ unnormalized_derivatives = nn.functional.pad(unnormalized_derivatives, pad=(1, 1))
174
+ unnormalized_derivatives[..., 0] = constant
175
+ unnormalized_derivatives[..., -1] = constant
176
+
177
+ outputs[outside_interval_mask] = inputs[outside_interval_mask]
178
+ log_abs_det[outside_interval_mask] = 0.0
179
+
180
+ outputs[inside_interval_mask], log_abs_det[inside_interval_mask] = _rational_quadratic_spline(
181
+ inputs=inputs[inside_interval_mask],
182
+ unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
183
+ unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
184
+ unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
185
+ reverse=reverse,
186
+ tail_bound=tail_bound,
187
+ min_bin_width=min_bin_width,
188
+ min_bin_height=min_bin_height,
189
+ min_derivative=min_derivative,
190
+ )
191
+ return outputs, log_abs_det
192
+
193
+
194
+ def _rational_quadratic_spline(
195
+ inputs,
196
+ unnormalized_widths,
197
+ unnormalized_heights,
198
+ unnormalized_derivatives,
199
+ reverse,
200
+ tail_bound,
201
+ min_bin_width,
202
+ min_bin_height,
203
+ min_derivative,
204
+ ):
205
+ """
206
+ This transformation represents a monotonically increasing piecewise rational quadratic function. Unlike the
207
+ function `_unconstrained_rational_quadratic_spline`, the function behaves the same across the `tail_bound`.
208
+
209
+ Args:
210
+ inputs (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`:
211
+ Second half of the hidden-states input to the Vits convolutional flow module.
212
+ unnormalized_widths (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`):
213
+ First `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection
214
+ layer in the convolutional flow module
215
+ unnormalized_heights (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`):
216
+ Second `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection
217
+ layer in the convolutional flow module
218
+ unnormalized_derivatives (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`):
219
+ Third `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection
220
+ layer in the convolutional flow module
221
+ reverse (`bool`):
222
+ Whether the model is being run in reverse mode.
223
+ tail_bound (`float`):
224
+ Upper and lower limit bound for the rational quadratic function. Outside of this `tail_bound`, the
225
+ transform behaves as an identity function.
226
+ min_bin_width (`float`):
227
+ Minimum bin value across the width dimension for the piecewise rational quadratic function.
228
+ min_bin_height (`float`):
229
+ Minimum bin value across the height dimension for the piecewise rational quadratic function.
230
+ min_derivative (`float`):
231
+ Minimum bin value across the derivatives for the piecewise rational quadratic function.
232
+ Returns:
233
+ outputs (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`:
234
+ Hidden-states as transformed by the piecewise rational quadratic function.
235
+ log_abs_det (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`:
236
+ Logarithm of the absolute value of the determinants corresponding to the `outputs`.
237
+ """
238
+ upper_bound = tail_bound
239
+ lower_bound = -tail_bound
240
+
241
+ if torch.min(inputs) < lower_bound or torch.max(inputs) > upper_bound:
242
+ raise ValueError("Input to a transform is not within its domain")
243
+
244
+ num_bins = unnormalized_widths.shape[-1]
245
+
246
+ if min_bin_width * num_bins > 1.0:
247
+ raise ValueError(f"Minimal bin width {min_bin_width} too large for the number of bins {num_bins}")
248
+ if min_bin_height * num_bins > 1.0:
249
+ raise ValueError(f"Minimal bin height {min_bin_height} too large for the number of bins {num_bins}")
250
+
251
+ widths = nn.functional.softmax(unnormalized_widths, dim=-1)
252
+ widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
253
+ cumwidths = torch.cumsum(widths, dim=-1)
254
+ cumwidths = nn.functional.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
255
+ cumwidths = (upper_bound - lower_bound) * cumwidths + lower_bound
256
+ cumwidths[..., 0] = lower_bound
257
+ cumwidths[..., -1] = upper_bound
258
+ widths = cumwidths[..., 1:] - cumwidths[..., :-1]
259
+
260
+ derivatives = min_derivative + nn.functional.softplus(unnormalized_derivatives)
261
+
262
+ heights = nn.functional.softmax(unnormalized_heights, dim=-1)
263
+ heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
264
+ cumheights = torch.cumsum(heights, dim=-1)
265
+ cumheights = nn.functional.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
266
+ cumheights = (upper_bound - lower_bound) * cumheights + lower_bound
267
+ cumheights[..., 0] = lower_bound
268
+ cumheights[..., -1] = upper_bound
269
+ heights = cumheights[..., 1:] - cumheights[..., :-1]
270
+
271
+ bin_locations = cumheights if reverse else cumwidths
272
+ bin_locations[..., -1] += 1e-6
273
+ bin_idx = torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
274
+ bin_idx = bin_idx[..., None]
275
+
276
+ input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
277
+ input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
278
+
279
+ input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
280
+ delta = heights / widths
281
+ input_delta = delta.gather(-1, bin_idx)[..., 0]
282
+
283
+ input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
284
+ input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
285
+
286
+ input_heights = heights.gather(-1, bin_idx)[..., 0]
287
+
288
+ intermediate1 = input_derivatives + input_derivatives_plus_one - 2 * input_delta
289
+ if not reverse:
290
+ theta = (inputs - input_cumwidths) / input_bin_widths
291
+ theta_one_minus_theta = theta * (1 - theta)
292
+
293
+ numerator = input_heights * (input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta)
294
+ denominator = input_delta + intermediate1 * theta_one_minus_theta
295
+ outputs = input_cumheights + numerator / denominator
296
+
297
+ derivative_numerator = input_delta.pow(2) * (
298
+ input_derivatives_plus_one * theta.pow(2)
299
+ + 2 * input_delta * theta_one_minus_theta
300
+ + input_derivatives * (1 - theta).pow(2)
301
+ )
302
+ log_abs_det = torch.log(derivative_numerator) - 2 * torch.log(denominator)
303
+ return outputs, log_abs_det
304
+ else:
305
+ # find the roots of a quadratic equation
306
+ intermediate2 = inputs - input_cumheights
307
+ intermediate3 = intermediate2 * intermediate1
308
+ a = input_heights * (input_delta - input_derivatives) + intermediate3
309
+ b = input_heights * input_derivatives - intermediate3
310
+ c = -input_delta * intermediate2
311
+
312
+ discriminant = b.pow(2) - 4 * a * c
313
+ if not (discriminant >= 0).all():
314
+ raise RuntimeError(f"invalid discriminant {discriminant}")
315
+
316
+ root = (2 * c) / (-b - torch.sqrt(discriminant))
317
+ outputs = root * input_bin_widths + input_cumwidths
318
+
319
+ theta_one_minus_theta = root * (1 - root)
320
+ denominator = input_delta + intermediate1 * theta_one_minus_theta
321
+ derivative_numerator = input_delta.pow(2) * (
322
+ input_derivatives_plus_one * root.pow(2)
323
+ + 2 * input_delta * theta_one_minus_theta
324
+ + input_derivatives * (1 - root).pow(2)
325
+ )
326
+ log_abs_det = torch.log(derivative_numerator) - 2 * torch.log(denominator)
327
+ return outputs, -log_abs_det
328
+
329
+
330
+ class VitsWaveNet(torch.nn.Module):
331
+ def __init__(self, config: VitsConfig, num_layers: int):
332
+ super().__init__()
333
+ self.hidden_size = config.hidden_size
334
+ self.num_layers = num_layers
335
+
336
+ self.in_layers = torch.nn.ModuleList()
337
+ self.res_skip_layers = torch.nn.ModuleList()
338
+ self.dropout = nn.Dropout(config.wavenet_dropout)
339
+
340
+ if hasattr(nn.utils.parametrizations, "weight_norm"):
341
+ weight_norm = nn.utils.parametrizations.weight_norm
342
+ else:
343
+ weight_norm = nn.utils.weight_norm
344
+
345
+ if config.speaker_embedding_size != 0:
346
+ cond_layer = torch.nn.Conv1d(config.speaker_embedding_size, 2 * config.hidden_size * num_layers, 1)
347
+ self.cond_layer = weight_norm(cond_layer, name="weight")
348
+
349
+ for i in range(num_layers):
350
+ dilation = config.wavenet_dilation_rate**i
351
+ padding = (config.wavenet_kernel_size * dilation - dilation) // 2
352
+ in_layer = torch.nn.Conv1d(
353
+ in_channels=config.hidden_size,
354
+ out_channels=2 * config.hidden_size,
355
+ kernel_size=config.wavenet_kernel_size,
356
+ dilation=dilation,
357
+ padding=padding,
358
+ )
359
+ in_layer = weight_norm(in_layer, name="weight")
360
+ self.in_layers.append(in_layer)
361
+
362
+ # last one is not necessary
363
+ if i < num_layers - 1:
364
+ res_skip_channels = 2 * config.hidden_size
365
+ else:
366
+ res_skip_channels = config.hidden_size
367
+
368
+ res_skip_layer = torch.nn.Conv1d(config.hidden_size, res_skip_channels, 1)
369
+ res_skip_layer = weight_norm(res_skip_layer, name="weight")
370
+ self.res_skip_layers.append(res_skip_layer)
371
+
372
+ def forward(self, inputs, padding_mask, global_conditioning=None):
373
+ outputs = torch.zeros_like(inputs)
374
+ num_channels_tensor = torch.IntTensor([self.hidden_size])
375
+
376
+ if global_conditioning is not None:
377
+ global_conditioning = self.cond_layer(global_conditioning)
378
+
379
+ for i in range(self.num_layers):
380
+ hidden_states = self.in_layers[i](inputs)
381
+
382
+ if global_conditioning is not None:
383
+ cond_offset = i * 2 * self.hidden_size
384
+ global_states = global_conditioning[:, cond_offset : cond_offset + 2 * self.hidden_size, :]
385
+ else:
386
+ global_states = torch.zeros_like(hidden_states)
387
+
388
+ acts = fused_add_tanh_sigmoid_multiply(hidden_states, global_states, num_channels_tensor[0])
389
+ acts = self.dropout(acts)
390
+
391
+ res_skip_acts = self.res_skip_layers[i](acts)
392
+ if i < self.num_layers - 1:
393
+ res_acts = res_skip_acts[:, : self.hidden_size, :]
394
+ inputs = (inputs + res_acts) * padding_mask
395
+ outputs = outputs + res_skip_acts[:, self.hidden_size :, :]
396
+ else:
397
+ outputs = outputs + res_skip_acts
398
+
399
+ return outputs * padding_mask
400
+
401
+ def remove_weight_norm(self):
402
+ if self.speaker_embedding_size != 0:
403
+ torch.nn.utils.remove_weight_norm(self.cond_layer)
404
+ for layer in self.in_layers:
405
+ torch.nn.utils.remove_weight_norm(layer)
406
+ for layer in self.res_skip_layers:
407
+ torch.nn.utils.remove_weight_norm(layer)
408
+
409
+
410
+ class VitsPosteriorEncoder(nn.Module):
411
+ def __init__(self, config: VitsConfig):
412
+ super().__init__()
413
+ self.out_channels = config.flow_size
414
+
415
+ self.conv_pre = nn.Conv1d(config.spectrogram_bins, config.hidden_size, 1)
416
+ self.wavenet = VitsWaveNet(config, num_layers=config.posterior_encoder_num_wavenet_layers)
417
+ self.conv_proj = nn.Conv1d(config.hidden_size, self.out_channels * 2, 1)
418
+
419
+ def forward(self, inputs, padding_mask, global_conditioning=None):
420
+ inputs = self.conv_pre(inputs) * padding_mask
421
+ inputs = self.wavenet(inputs, padding_mask, global_conditioning)
422
+ stats = self.conv_proj(inputs) * padding_mask
423
+ mean, log_stddev = torch.split(stats, self.out_channels, dim=1)
424
+ sampled = (mean + torch.randn_like(mean) * torch.exp(log_stddev)) * padding_mask
425
+ return sampled, mean, log_stddev
426
+
427
+
428
+ # Copied from transformers.models.speecht5.modeling_speecht5.HifiGanResidualBlock
429
+ class HifiGanResidualBlock(nn.Module):
430
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), leaky_relu_slope=0.1):
431
+ super().__init__()
432
+ self.leaky_relu_slope = leaky_relu_slope
433
+
434
+ self.convs1 = nn.ModuleList(
435
+ [
436
+ nn.Conv1d(
437
+ channels,
438
+ channels,
439
+ kernel_size,
440
+ stride=1,
441
+ dilation=dilation[i],
442
+ padding=self.get_padding(kernel_size, dilation[i]),
443
+ )
444
+ for i in range(len(dilation))
445
+ ]
446
+ )
447
+ self.convs2 = nn.ModuleList(
448
+ [
449
+ nn.Conv1d(
450
+ channels,
451
+ channels,
452
+ kernel_size,
453
+ stride=1,
454
+ dilation=1,
455
+ padding=self.get_padding(kernel_size, 1),
456
+ )
457
+ for _ in range(len(dilation))
458
+ ]
459
+ )
460
+
461
+ def get_padding(self, kernel_size, dilation=1):
462
+ return (kernel_size * dilation - dilation) // 2
463
+
464
+ def apply_weight_norm(self):
465
+ weight_norm = nn.utils.weight_norm
466
+ if hasattr(nn.utils.parametrizations, "weight_norm"):
467
+ weight_norm = nn.utils.parametrizations.weight_norm
468
+
469
+ for layer in self.convs1:
470
+ weight_norm(layer)
471
+ for layer in self.convs2:
472
+ weight_norm(layer)
473
+
474
+ def remove_weight_norm(self):
475
+ for layer in self.convs1:
476
+ nn.utils.remove_weight_norm(layer)
477
+ for layer in self.convs2:
478
+ nn.utils.remove_weight_norm(layer)
479
+
480
+ def forward(self, hidden_states):
481
+ for conv1, conv2 in zip(self.convs1, self.convs2):
482
+ residual = hidden_states
483
+ hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope)
484
+ hidden_states = conv1(hidden_states)
485
+ hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope)
486
+ hidden_states = conv2(hidden_states)
487
+ hidden_states = hidden_states + residual
488
+ return hidden_states
489
+
490
+
491
+ class VitsHifiGan(nn.Module):
492
+ def __init__(self, config: VitsConfig):
493
+ super().__init__()
494
+ self.config = config
495
+ self.num_kernels = len(config.resblock_kernel_sizes)
496
+ self.num_upsamples = len(config.upsample_rates)
497
+ self.conv_pre = nn.Conv1d(
498
+ config.flow_size,
499
+ config.upsample_initial_channel,
500
+ kernel_size=7,
501
+ stride=1,
502
+ padding=3,
503
+ )
504
+
505
+ self.upsampler = nn.ModuleList()
506
+ for i, (upsample_rate, kernel_size) in enumerate(zip(config.upsample_rates, config.upsample_kernel_sizes)):
507
+ self.upsampler.append(
508
+ nn.ConvTranspose1d(
509
+ config.upsample_initial_channel // (2**i),
510
+ config.upsample_initial_channel // (2 ** (i + 1)),
511
+ kernel_size=kernel_size,
512
+ stride=upsample_rate,
513
+ padding=(kernel_size - upsample_rate) // 2,
514
+ )
515
+ )
516
+
517
+ self.resblocks = nn.ModuleList()
518
+ for i in range(len(self.upsampler)):
519
+ channels = config.upsample_initial_channel // (2 ** (i + 1))
520
+ for kernel_size, dilation in zip(config.resblock_kernel_sizes, config.resblock_dilation_sizes):
521
+ self.resblocks.append(HifiGanResidualBlock(channels, kernel_size, dilation, config.leaky_relu_slope))
522
+
523
+ self.conv_post = nn.Conv1d(channels, 1, kernel_size=7, stride=1, padding=3, bias=False)
524
+
525
+ if config.speaker_embedding_size != 0:
526
+ self.cond = nn.Conv1d(config.speaker_embedding_size, config.upsample_initial_channel, 1)
527
+
528
+ def apply_weight_norm(self):
529
+ weight_norm = nn.utils.weight_norm
530
+ if hasattr(nn.utils.parametrizations, "weight_norm"):
531
+ weight_norm = nn.utils.parametrizations.weight_norm
532
+
533
+ for layer in self.upsampler:
534
+ weight_norm(layer)
535
+ for layer in self.resblocks:
536
+ layer.apply_weight_norm()
537
+
538
+ def remove_weight_norm(self):
539
+ for layer in self.upsampler:
540
+ nn.utils.remove_weight_norm(layer)
541
+ for layer in self.resblocks:
542
+ layer.remove_weight_norm()
543
+
544
+ def forward(
545
+ self, spectrogram: torch.FloatTensor, global_conditioning: Optional[torch.FloatTensor] = None
546
+ ) -> torch.FloatTensor:
547
+ r"""
548
+ Converts a spectrogram into a speech waveform.
549
+
550
+ Args:
551
+ spectrogram (`torch.FloatTensor` of shape `(batch_size, config.spectrogram_bins, sequence_length)`):
552
+ Tensor containing the spectrograms.
553
+ global_conditioning (`torch.FloatTensor` of shape `(batch_size, config.speaker_embedding_size, 1)`, *optional*):
554
+ Tensor containing speaker embeddings, for multispeaker models.
555
+
556
+ Returns:
557
+ `torch.FloatTensor`: Tensor of shape shape `(batch_size, 1, num_frames)` containing the speech waveform.
558
+ """
559
+ hidden_states = self.conv_pre(spectrogram)
560
+
561
+ if global_conditioning is not None:
562
+ hidden_states = hidden_states + self.cond(global_conditioning)
563
+
564
+ for i in range(self.num_upsamples):
565
+ hidden_states = nn.functional.leaky_relu(hidden_states, self.config.leaky_relu_slope)
566
+ hidden_states = self.upsampler[i](hidden_states)
567
+
568
+ res_state = self.resblocks[i * self.num_kernels](hidden_states)
569
+ for j in range(1, self.num_kernels):
570
+ res_state += self.resblocks[i * self.num_kernels + j](hidden_states)
571
+ hidden_states = res_state / self.num_kernels
572
+
573
+ hidden_states = nn.functional.leaky_relu(hidden_states)
574
+ hidden_states = self.conv_post(hidden_states)
575
+ waveform = torch.tanh(hidden_states)
576
+ return waveform
577
+
578
+
579
+ class VitsResidualCouplingLayer(nn.Module):
580
+ def __init__(self, config: VitsConfig):
581
+ super().__init__()
582
+ self.half_channels = config.flow_size // 2
583
+
584
+ self.conv_pre = nn.Conv1d(self.half_channels, config.hidden_size, 1)
585
+ self.wavenet = VitsWaveNet(config, num_layers=config.prior_encoder_num_wavenet_layers)
586
+ self.conv_post = nn.Conv1d(config.hidden_size, self.half_channels, 1)
587
+
588
+ def forward(self, inputs, padding_mask, global_conditioning=None, reverse=False):
589
+ first_half, second_half = torch.split(inputs, [self.half_channels] * 2, dim=1)
590
+ hidden_states = self.conv_pre(first_half) * padding_mask
591
+ hidden_states = self.wavenet(hidden_states, padding_mask, global_conditioning)
592
+ mean = self.conv_post(hidden_states) * padding_mask
593
+ log_stddev = torch.zeros_like(mean)
594
+
595
+ if not reverse:
596
+ second_half = mean + second_half * torch.exp(log_stddev) * padding_mask
597
+ outputs = torch.cat([first_half, second_half], dim=1)
598
+ log_determinant = torch.sum(log_stddev, [1, 2])
599
+ return outputs, log_determinant
600
+ else:
601
+ second_half = (second_half - mean) * torch.exp(-log_stddev) * padding_mask
602
+ outputs = torch.cat([first_half, second_half], dim=1)
603
+ return outputs, None
604
+
605
+
606
+ class VitsResidualCouplingBlock(nn.Module):
607
+ def __init__(self, config: VitsConfig):
608
+ super().__init__()
609
+ self.flows = nn.ModuleList()
610
+ for _ in range(config.prior_encoder_num_flows):
611
+ self.flows.append(VitsResidualCouplingLayer(config))
612
+
613
+ def forward(self, inputs, padding_mask, global_conditioning=None, reverse=False):
614
+ if not reverse:
615
+ for flow in self.flows:
616
+ inputs, _ = flow(inputs, padding_mask, global_conditioning)
617
+ inputs = torch.flip(inputs, [1])
618
+ else:
619
+ for flow in reversed(self.flows):
620
+ inputs = torch.flip(inputs, [1])
621
+ inputs, _ = flow(inputs, padding_mask, global_conditioning, reverse=True)
622
+ return inputs
623
+
624
+
625
+ class VitsDilatedDepthSeparableConv(nn.Module):
626
+ def __init__(self, config: VitsConfig, dropout_rate=0.0):
627
+ super().__init__()
628
+ kernel_size = config.duration_predictor_kernel_size
629
+ channels = config.hidden_size
630
+ self.num_layers = config.depth_separable_num_layers
631
+
632
+ self.dropout = nn.Dropout(dropout_rate)
633
+ self.convs_dilated = nn.ModuleList()
634
+ self.convs_pointwise = nn.ModuleList()
635
+ self.norms_1 = nn.ModuleList()
636
+ self.norms_2 = nn.ModuleList()
637
+ for i in range(self.num_layers):
638
+ dilation = kernel_size**i
639
+ padding = (kernel_size * dilation - dilation) // 2
640
+ self.convs_dilated.append(
641
+ nn.Conv1d(
642
+ in_channels=channels,
643
+ out_channels=channels,
644
+ kernel_size=kernel_size,
645
+ groups=channels,
646
+ dilation=dilation,
647
+ padding=padding,
648
+ )
649
+ )
650
+ self.convs_pointwise.append(nn.Conv1d(channels, channels, 1))
651
+ self.norms_1.append(nn.LayerNorm(channels))
652
+ self.norms_2.append(nn.LayerNorm(channels))
653
+
654
+ def forward(self, inputs, padding_mask, global_conditioning=None):
655
+ if global_conditioning is not None:
656
+ inputs = inputs + global_conditioning
657
+
658
+ for i in range(self.num_layers):
659
+ hidden_states = self.convs_dilated[i](inputs * padding_mask)
660
+ hidden_states = self.norms_1[i](hidden_states.transpose(1, -1)).transpose(1, -1)
661
+ hidden_states = nn.functional.gelu(hidden_states)
662
+ hidden_states = self.convs_pointwise[i](hidden_states)
663
+ hidden_states = self.norms_2[i](hidden_states.transpose(1, -1)).transpose(1, -1)
664
+ hidden_states = nn.functional.gelu(hidden_states)
665
+ hidden_states = self.dropout(hidden_states)
666
+ inputs = inputs + hidden_states
667
+
668
+ return inputs * padding_mask
669
+
670
+
671
+ class VitsConvFlow(nn.Module):
672
+ def __init__(self, config: VitsConfig):
673
+ super().__init__()
674
+ self.filter_channels = config.hidden_size
675
+ self.half_channels = config.depth_separable_channels // 2
676
+ self.num_bins = config.duration_predictor_flow_bins
677
+ self.tail_bound = config.duration_predictor_tail_bound
678
+
679
+ self.conv_pre = nn.Conv1d(self.half_channels, self.filter_channels, 1)
680
+ self.conv_dds = VitsDilatedDepthSeparableConv(config)
681
+ self.conv_proj = nn.Conv1d(self.filter_channels, self.half_channels * (self.num_bins * 3 - 1), 1)
682
+
683
+ def forward(self, inputs, padding_mask, global_conditioning=None, reverse=False):
684
+ first_half, second_half = torch.split(inputs, [self.half_channels] * 2, dim=1)
685
+
686
+ hidden_states = self.conv_pre(first_half)
687
+ hidden_states = self.conv_dds(hidden_states, padding_mask, global_conditioning)
688
+ hidden_states = self.conv_proj(hidden_states) * padding_mask
689
+
690
+ batch_size, channels, length = first_half.shape
691
+ hidden_states = hidden_states.reshape(batch_size, channels, -1, length).permute(0, 1, 3, 2)
692
+
693
+ unnormalized_widths = hidden_states[..., : self.num_bins] / math.sqrt(self.filter_channels)
694
+ unnormalized_heights = hidden_states[..., self.num_bins : 2 * self.num_bins] / math.sqrt(self.filter_channels)
695
+ unnormalized_derivatives = hidden_states[..., 2 * self.num_bins :]
696
+
697
+ second_half, log_abs_det = _unconstrained_rational_quadratic_spline(
698
+ second_half,
699
+ unnormalized_widths,
700
+ unnormalized_heights,
701
+ unnormalized_derivatives,
702
+ reverse=reverse,
703
+ tail_bound=self.tail_bound,
704
+ )
705
+
706
+ outputs = torch.cat([first_half, second_half], dim=1) * padding_mask
707
+ if not reverse:
708
+ log_determinant = torch.sum(log_abs_det * padding_mask, [1, 2])
709
+ return outputs, log_determinant
710
+ else:
711
+ return outputs, None
712
+
713
+
714
+ class VitsElementwiseAffine(nn.Module):
715
+ def __init__(self, config: VitsConfig):
716
+ super().__init__()
717
+ self.channels = config.depth_separable_channels
718
+ self.translate = nn.Parameter(torch.zeros(self.channels, 1))
719
+ self.log_scale = nn.Parameter(torch.zeros(self.channels, 1))
720
+
721
+ def forward(self, inputs, padding_mask, global_conditioning=None, reverse=False):
722
+ if not reverse:
723
+ outputs = self.translate + torch.exp(self.log_scale) * inputs
724
+ outputs = outputs * padding_mask
725
+ log_determinant = torch.sum(self.log_scale * padding_mask, [1, 2])
726
+ return outputs, log_determinant
727
+ else:
728
+ outputs = (inputs - self.translate) * torch.exp(-self.log_scale) * padding_mask
729
+ return outputs, None
730
+
731
+
732
+ class VitsStochasticDurationPredictor(nn.Module):
733
+ def __init__(self, config):
734
+ super().__init__()
735
+ embed_dim = config.speaker_embedding_size
736
+ filter_channels = config.hidden_size
737
+
738
+ self.conv_pre = nn.Conv1d(filter_channels, filter_channels, 1)
739
+ self.conv_proj = nn.Conv1d(filter_channels, filter_channels, 1)
740
+ self.conv_dds = VitsDilatedDepthSeparableConv(
741
+ config,
742
+ dropout_rate=config.duration_predictor_dropout,
743
+ )
744
+
745
+ if embed_dim != 0:
746
+ self.cond = nn.Conv1d(embed_dim, filter_channels, 1)
747
+
748
+ self.flows = nn.ModuleList()
749
+ self.flows.append(VitsElementwiseAffine(config))
750
+ for _ in range(config.duration_predictor_num_flows):
751
+ self.flows.append(VitsConvFlow(config))
752
+
753
+ self.post_conv_pre = nn.Conv1d(1, filter_channels, 1)
754
+ self.post_conv_proj = nn.Conv1d(filter_channels, filter_channels, 1)
755
+ self.post_conv_dds = VitsDilatedDepthSeparableConv(
756
+ config,
757
+ dropout_rate=config.duration_predictor_dropout,
758
+ )
759
+
760
+ self.post_flows = nn.ModuleList()
761
+ self.post_flows.append(VitsElementwiseAffine(config))
762
+ for _ in range(config.duration_predictor_num_flows):
763
+ self.post_flows.append(VitsConvFlow(config))
764
+
765
+ def forward(self, inputs, padding_mask, global_conditioning=None, durations=None, reverse=False, noise_scale=1.0):
766
+ inputs = torch.detach(inputs)
767
+ inputs = self.conv_pre(inputs)
768
+
769
+ if global_conditioning is not None:
770
+ global_conditioning = torch.detach(global_conditioning)
771
+ inputs = inputs + self.cond(global_conditioning)
772
+
773
+ inputs = self.conv_dds(inputs, padding_mask)
774
+ inputs = self.conv_proj(inputs) * padding_mask
775
+
776
+ if not reverse:
777
+ hidden_states = self.post_conv_pre(durations)
778
+ hidden_states = self.post_conv_dds(hidden_states, padding_mask)
779
+ hidden_states = self.post_conv_proj(hidden_states) * padding_mask
780
+
781
+ random_posterior = (
782
+ torch.randn(durations.size(0), 2, durations.size(2)).to(device=inputs.device, dtype=inputs.dtype)
783
+ * padding_mask
784
+ )
785
+ log_determinant_posterior_sum = 0
786
+ latents_posterior = random_posterior
787
+ for flow in self.post_flows:
788
+ latents_posterior, log_determinant = flow(
789
+ latents_posterior, padding_mask, global_conditioning=inputs + hidden_states
790
+ )
791
+ latents_posterior = torch.flip(latents_posterior, [1])
792
+ log_determinant_posterior_sum += log_determinant
793
+
794
+ first_half, second_half = torch.split(latents_posterior, [1, 1], dim=1)
795
+
796
+ log_determinant_posterior_sum += torch.sum(
797
+ (nn.functional.logsigmoid(first_half) + nn.functional.logsigmoid(-first_half)) * padding_mask, [1, 2]
798
+ )
799
+ logq = (
800
+ torch.sum(-0.5 * (math.log(2 * math.pi) + (random_posterior**2)) * padding_mask, [1, 2])
801
+ - log_determinant_posterior_sum
802
+ )
803
+
804
+ first_half = (durations - torch.sigmoid(first_half)) * padding_mask
805
+ first_half = torch.log(torch.clamp_min(first_half, 1e-5)) * padding_mask
806
+ log_determinant_sum = torch.sum(-first_half, [1, 2])
807
+
808
+ latents = torch.cat([first_half, second_half], dim=1)
809
+ for flow in self.flows:
810
+ latents, log_determinant = flow(latents, padding_mask, global_conditioning=inputs)
811
+ latents = torch.flip(latents, [1])
812
+ log_determinant_sum += log_determinant
813
+
814
+ nll = torch.sum(0.5 * (math.log(2 * math.pi) + (latents**2)) * padding_mask, [1, 2]) - log_determinant_sum
815
+ return nll + logq
816
+ else:
817
+ flows = list(reversed(self.flows))
818
+ flows = flows[:-2] + [flows[-1]] # remove a useless vflow
819
+
820
+ latents = (
821
+ torch.randn(inputs.size(0), 2, inputs.size(2)).to(device=inputs.device, dtype=inputs.dtype)
822
+ * noise_scale
823
+ )
824
+ for flow in flows:
825
+ latents = torch.flip(latents, [1])
826
+ latents, _ = flow(latents, padding_mask, global_conditioning=inputs, reverse=True)
827
+
828
+ log_duration, _ = torch.split(latents, [1, 1], dim=1)
829
+ return log_duration
830
+
831
+
832
+ class VitsDurationPredictor(nn.Module):
833
+ def __init__(self, config):
834
+ super().__init__()
835
+ kernel_size = config.duration_predictor_kernel_size
836
+ filter_channels = config.duration_predictor_filter_channels
837
+
838
+ self.dropout = nn.Dropout(config.duration_predictor_dropout)
839
+ self.conv_1 = nn.Conv1d(config.hidden_size, filter_channels, kernel_size, padding=kernel_size // 2)
840
+ self.norm_1 = nn.LayerNorm(filter_channels, eps=config.layer_norm_eps)
841
+ self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
842
+ self.norm_2 = nn.LayerNorm(filter_channels, eps=config.layer_norm_eps)
843
+ self.proj = nn.Conv1d(filter_channels, 1, 1)
844
+
845
+ if config.speaker_embedding_size != 0:
846
+ self.cond = nn.Conv1d(config.speaker_embedding_size, config.hidden_size, 1)
847
+
848
+ def forward(self, inputs, padding_mask, global_conditioning=None):
849
+ inputs = torch.detach(inputs)
850
+
851
+ if global_conditioning is not None:
852
+ global_conditioning = torch.detach(global_conditioning)
853
+ inputs = inputs + self.cond(global_conditioning)
854
+
855
+ inputs = self.conv_1(inputs * padding_mask)
856
+ inputs = torch.relu(inputs)
857
+ inputs = self.norm_1(inputs.transpose(1, -1)).transpose(1, -1)
858
+ inputs = self.dropout(inputs)
859
+
860
+ inputs = self.conv_2(inputs * padding_mask)
861
+ inputs = torch.relu(inputs)
862
+ inputs = self.norm_2(inputs.transpose(1, -1)).transpose(1, -1)
863
+ inputs = self.dropout(inputs)
864
+
865
+ inputs = self.proj(inputs * padding_mask)
866
+ return inputs * padding_mask
867
+
868
+
869
+ class VitsAttention(nn.Module):
870
+ """Multi-headed attention with relative positional representation."""
871
+
872
+ def __init__(self, config: VitsConfig):
873
+ super().__init__()
874
+ self.embed_dim = config.hidden_size
875
+ self.num_heads = config.num_attention_heads
876
+ self.dropout = config.attention_dropout
877
+ self.window_size = config.window_size
878
+
879
+ self.head_dim = self.embed_dim // self.num_heads
880
+ self.scaling = self.head_dim**-0.5
881
+
882
+ if (self.head_dim * self.num_heads) != self.embed_dim:
883
+ raise ValueError(
884
+ f"hidden_size must be divisible by num_attention_heads (got `hidden_size`: {self.embed_dim}"
885
+ f" and `num_attention_heads`: {self.num_heads})."
886
+ )
887
+
888
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias)
889
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias)
890
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias)
891
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias)
892
+
893
+ if self.window_size:
894
+ self.emb_rel_k = nn.Parameter(torch.randn(1, self.window_size * 2 + 1, self.head_dim) * self.scaling)
895
+ self.emb_rel_v = nn.Parameter(torch.randn(1, self.window_size * 2 + 1, self.head_dim) * self.scaling)
896
+
897
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
898
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
899
+
900
+ def forward(
901
+ self,
902
+ hidden_states: torch.Tensor,
903
+ key_value_states: Optional[torch.Tensor] = None,
904
+ attention_mask: Optional[torch.Tensor] = None,
905
+ layer_head_mask: Optional[torch.Tensor] = None,
906
+ output_attentions: bool = False,
907
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
908
+ """Input shape: Batch x Time x Channel"""
909
+
910
+ # if key_value_states are provided this layer is used as a cross-attention layer
911
+ # for the decoder
912
+
913
+ bsz, tgt_len, _ = hidden_states.size()
914
+
915
+ # get query proj
916
+ query_states = self.q_proj(hidden_states) * self.scaling
917
+
918
+ # self_attention
919
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
920
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
921
+
922
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
923
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
924
+ key_states = key_states.view(*proj_shape)
925
+ value_states = value_states.view(*proj_shape)
926
+
927
+ src_len = key_states.size(1)
928
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
929
+
930
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
931
+ raise ValueError(
932
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
933
+ f" {attn_weights.size()}"
934
+ )
935
+
936
+ if self.window_size is not None:
937
+ key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, src_len)
938
+ relative_logits = torch.matmul(query_states, key_relative_embeddings.transpose(-2, -1))
939
+ rel_pos_bias = self._relative_position_to_absolute_position(relative_logits)
940
+ attn_weights += rel_pos_bias
941
+
942
+ if attention_mask is not None:
943
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
944
+ raise ValueError(
945
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
946
+ )
947
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
948
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
949
+
950
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
951
+
952
+ if layer_head_mask is not None:
953
+ if layer_head_mask.size() != (self.num_heads,):
954
+ raise ValueError(
955
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
956
+ f" {layer_head_mask.size()}"
957
+ )
958
+ attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
959
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
960
+
961
+ if output_attentions:
962
+ # this operation is a bit awkward, but it's required to
963
+ # make sure that attn_weights keeps its gradient.
964
+ # In order to do so, attn_weights have to be reshaped
965
+ # twice and have to be reused in the following
966
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
967
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
968
+ else:
969
+ attn_weights_reshaped = None
970
+
971
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
972
+
973
+ attn_output = torch.bmm(attn_probs, value_states)
974
+
975
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
976
+ raise ValueError(
977
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
978
+ f" {attn_output.size()}"
979
+ )
980
+
981
+ if self.window_size is not None:
982
+ value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, src_len)
983
+ relative_weights = self._absolute_position_to_relative_position(attn_probs)
984
+ rel_pos_bias = torch.matmul(relative_weights, value_relative_embeddings)
985
+ attn_output += rel_pos_bias
986
+
987
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
988
+ attn_output = attn_output.transpose(1, 2)
989
+
990
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
991
+ # partitioned aross GPUs when using tensor-parallelism.
992
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
993
+
994
+ attn_output = self.out_proj(attn_output)
995
+
996
+ return attn_output, attn_weights_reshaped
997
+
998
+ def _get_relative_embeddings(self, relative_embeddings, length):
999
+ pad_length = max(length - (self.window_size + 1), 0)
1000
+ if pad_length > 0:
1001
+ relative_embeddings = nn.functional.pad(relative_embeddings, [0, 0, pad_length, pad_length, 0, 0])
1002
+
1003
+ slice_start_position = max((self.window_size + 1) - length, 0)
1004
+ slice_end_position = slice_start_position + 2 * length - 1
1005
+ return relative_embeddings[:, slice_start_position:slice_end_position]
1006
+
1007
+ def _relative_position_to_absolute_position(self, x):
1008
+ batch_heads, length, _ = x.size()
1009
+
1010
+ # Concat columns of pad to shift from relative to absolute indexing.
1011
+ x = nn.functional.pad(x, [0, 1, 0, 0, 0, 0])
1012
+
1013
+ # Concat extra elements so to add up to shape (len+1, 2*len-1).
1014
+ x_flat = x.view([batch_heads, length * 2 * length])
1015
+ x_flat = nn.functional.pad(x_flat, [0, length - 1, 0, 0])
1016
+
1017
+ # Reshape and slice out the padded elements.
1018
+ x_final = x_flat.view([batch_heads, length + 1, 2 * length - 1])
1019
+ x_final = x_final[:, :length, length - 1 :]
1020
+ return x_final
1021
+
1022
+ def _absolute_position_to_relative_position(self, x):
1023
+ batch_heads, length, _ = x.size()
1024
+
1025
+ # Pad along column
1026
+ x = nn.functional.pad(x, [0, length - 1, 0, 0, 0, 0])
1027
+ x_flat = x.view([batch_heads, length * (2 * length - 1)])
1028
+
1029
+ # Add 0's in the beginning that will skew the elements after reshape
1030
+ x_flat = nn.functional.pad(x_flat, [length, 0, 0, 0])
1031
+ x_final = x_flat.view([batch_heads, length, 2 * length])[:, :, 1:]
1032
+ return x_final
1033
+
1034
+
1035
+ class VitsFeedForward(nn.Module):
1036
+ def __init__(self, config):
1037
+ super().__init__()
1038
+ self.conv_1 = nn.Conv1d(config.hidden_size, config.ffn_dim, config.ffn_kernel_size)
1039
+ self.conv_2 = nn.Conv1d(config.ffn_dim, config.hidden_size, config.ffn_kernel_size)
1040
+ self.dropout = nn.Dropout(config.activation_dropout)
1041
+
1042
+ if isinstance(config.hidden_act, str):
1043
+ self.act_fn = ACT2FN[config.hidden_act]
1044
+ else:
1045
+ self.act_fn = config.hidden_act
1046
+
1047
+ if config.ffn_kernel_size > 1:
1048
+ pad_left = (config.ffn_kernel_size - 1) // 2
1049
+ pad_right = config.ffn_kernel_size // 2
1050
+ self.padding = [pad_left, pad_right, 0, 0, 0, 0]
1051
+ else:
1052
+ self.padding = None
1053
+
1054
+ def forward(self, hidden_states, padding_mask):
1055
+ hidden_states = hidden_states.permute(0, 2, 1)
1056
+ padding_mask = padding_mask.permute(0, 2, 1)
1057
+
1058
+ hidden_states = hidden_states * padding_mask
1059
+ if self.padding is not None:
1060
+ hidden_states = nn.functional.pad(hidden_states, self.padding)
1061
+
1062
+ hidden_states = self.conv_1(hidden_states)
1063
+ hidden_states = self.act_fn(hidden_states)
1064
+ hidden_states = self.dropout(hidden_states)
1065
+
1066
+ hidden_states = hidden_states * padding_mask
1067
+ if self.padding is not None:
1068
+ hidden_states = nn.functional.pad(hidden_states, self.padding)
1069
+
1070
+ hidden_states = self.conv_2(hidden_states)
1071
+ hidden_states = hidden_states * padding_mask
1072
+
1073
+ hidden_states = hidden_states.permute(0, 2, 1)
1074
+ return hidden_states
1075
+
1076
+
1077
+ class VitsEncoderLayer(nn.Module):
1078
+ def __init__(self, config: VitsConfig):
1079
+ super().__init__()
1080
+ self.attention = VitsAttention(config)
1081
+ self.dropout = nn.Dropout(config.hidden_dropout)
1082
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
1083
+ self.feed_forward = VitsFeedForward(config)
1084
+ self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
1085
+
1086
+ def forward(
1087
+ self,
1088
+ hidden_states: torch.Tensor,
1089
+ padding_mask: torch.FloatTensor,
1090
+ attention_mask: Optional[torch.Tensor] = None,
1091
+ output_attentions: bool = False,
1092
+ ):
1093
+ residual = hidden_states
1094
+ hidden_states, attn_weights = self.attention(
1095
+ hidden_states=hidden_states,
1096
+ attention_mask=attention_mask,
1097
+ output_attentions=output_attentions,
1098
+ )
1099
+
1100
+ hidden_states = self.dropout(hidden_states)
1101
+ hidden_states = self.layer_norm(residual + hidden_states)
1102
+
1103
+ residual = hidden_states
1104
+ hidden_states = self.feed_forward(hidden_states, padding_mask)
1105
+ hidden_states = self.dropout(hidden_states)
1106
+ hidden_states = self.final_layer_norm(residual + hidden_states)
1107
+
1108
+ outputs = (hidden_states,)
1109
+
1110
+ if output_attentions:
1111
+ outputs += (attn_weights,)
1112
+
1113
+ return outputs
1114
+
1115
+
1116
+ class VitsEncoder(nn.Module):
1117
+ def __init__(self, config: VitsConfig):
1118
+ super().__init__()
1119
+ self.config = config
1120
+ self.layers = nn.ModuleList([VitsEncoderLayer(config) for _ in range(config.num_hidden_layers)])
1121
+ self.gradient_checkpointing = False
1122
+ self.layerdrop = config.layerdrop
1123
+
1124
+ def forward(
1125
+ self,
1126
+ hidden_states: torch.FloatTensor,
1127
+ padding_mask: torch.FloatTensor,
1128
+ attention_mask: Optional[torch.Tensor] = None,
1129
+ output_attentions: Optional[bool] = None,
1130
+ output_hidden_states: Optional[bool] = None,
1131
+ return_dict: Optional[bool] = None,
1132
+ ) -> Union[Tuple, BaseModelOutput]:
1133
+ all_hidden_states = () if output_hidden_states else None
1134
+ all_self_attentions = () if output_attentions else None
1135
+
1136
+ # expand attention_mask
1137
+ if attention_mask is not None:
1138
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1139
+ attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
1140
+
1141
+ hidden_states = hidden_states * padding_mask
1142
+
1143
+ synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
1144
+
1145
+ for encoder_layer in self.layers:
1146
+ if output_hidden_states:
1147
+ all_hidden_states = all_hidden_states + (hidden_states,)
1148
+
1149
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
1150
+ dropout_probability = np.random.uniform(0, 1)
1151
+
1152
+ skip_the_layer = self.training and (dropout_probability < self.layerdrop)
1153
+ if not skip_the_layer or synced_gpus:
1154
+ # under fsdp or deepspeed zero3 all gpus must run in sync
1155
+ if self.gradient_checkpointing and self.training:
1156
+ layer_outputs = self._gradient_checkpointing_func(
1157
+ encoder_layer.__call__,
1158
+ hidden_states,
1159
+ padding_mask,
1160
+ attention_mask,
1161
+ output_attentions,
1162
+ )
1163
+ else:
1164
+ layer_outputs = encoder_layer(
1165
+ hidden_states,
1166
+ attention_mask=attention_mask,
1167
+ padding_mask=padding_mask,
1168
+ output_attentions=output_attentions,
1169
+ )
1170
+ hidden_states = layer_outputs[0]
1171
+
1172
+ if skip_the_layer:
1173
+ layer_outputs = (None, None)
1174
+
1175
+ if output_attentions:
1176
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
1177
+
1178
+ hidden_states = hidden_states * padding_mask
1179
+
1180
+ if output_hidden_states:
1181
+ all_hidden_states = all_hidden_states + (hidden_states,)
1182
+
1183
+ if not return_dict:
1184
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
1185
+
1186
+ return BaseModelOutput(
1187
+ last_hidden_state=hidden_states,
1188
+ hidden_states=all_hidden_states,
1189
+ attentions=all_self_attentions,
1190
+ )
1191
+
1192
+
1193
+ class VitsTextEncoder(nn.Module):
1194
+ """
1195
+ Transformer encoder that uses relative positional representation instead of absolute positional encoding.
1196
+ """
1197
+
1198
+ def __init__(self, config: VitsConfig):
1199
+ super().__init__()
1200
+ self.config = config
1201
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
1202
+ self.encoder = VitsEncoder(config)
1203
+ self.project = nn.Conv1d(config.hidden_size, config.flow_size * 2, kernel_size=1)
1204
+
1205
+ def get_input_embeddings(self):
1206
+ return self.embed_tokens
1207
+
1208
+ def set_input_embeddings(self, value):
1209
+ self.embed_tokens = value
1210
+
1211
+ def forward(
1212
+ self,
1213
+ input_ids: torch.Tensor,
1214
+ padding_mask: torch.FloatTensor,
1215
+ attention_mask: Optional[torch.Tensor] = None,
1216
+ output_attentions: Optional[bool] = None,
1217
+ output_hidden_states: Optional[bool] = None,
1218
+ return_dict: Optional[bool] = True,
1219
+ ) -> Union[Tuple[torch.Tensor], VitsTextEncoderOutput]:
1220
+ hidden_states = self.embed_tokens(input_ids) * math.sqrt(self.config.hidden_size)
1221
+
1222
+ encoder_outputs = self.encoder(
1223
+ hidden_states=hidden_states,
1224
+ padding_mask=padding_mask,
1225
+ attention_mask=attention_mask,
1226
+ output_attentions=output_attentions,
1227
+ output_hidden_states=output_hidden_states,
1228
+ return_dict=return_dict,
1229
+ )
1230
+
1231
+ last_hidden_state = encoder_outputs[0] if not return_dict else encoder_outputs.last_hidden_state
1232
+
1233
+ stats = self.project(last_hidden_state.transpose(1, 2)).transpose(1, 2) * padding_mask
1234
+ prior_means, prior_log_variances = torch.split(stats, self.config.flow_size, dim=2)
1235
+
1236
+ if not return_dict:
1237
+ outputs = (last_hidden_state, prior_means, prior_log_variances) + encoder_outputs[1:]
1238
+ return outputs
1239
+
1240
+ return VitsTextEncoderOutput(
1241
+ last_hidden_state=last_hidden_state,
1242
+ prior_means=prior_means,
1243
+ prior_log_variances=prior_log_variances,
1244
+ hidden_states=encoder_outputs.hidden_states,
1245
+ attentions=encoder_outputs.attentions,
1246
+ )
1247
+
1248
+
1249
+ class VitsPreTrainedModel(PreTrainedModel):
1250
+ """
1251
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
1252
+ models.
1253
+ """
1254
+
1255
+ config_class = VitsConfig
1256
+ base_model_prefix = "vits"
1257
+ main_input_name = "input_ids"
1258
+ supports_gradient_checkpointing = True
1259
+
1260
+ def _init_weights(self, module):
1261
+ """Initialize the weights"""
1262
+ if isinstance(module, nn.Linear):
1263
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
1264
+ if module.bias is not None:
1265
+ module.bias.data.zero_()
1266
+ elif isinstance(module, nn.LayerNorm):
1267
+ module.bias.data.zero_()
1268
+ module.weight.data.fill_(1.0)
1269
+ elif isinstance(module, nn.Conv1d):
1270
+ nn.init.kaiming_normal_(module.weight)
1271
+ if module.bias is not None:
1272
+ k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
1273
+ nn.init.uniform_(module.bias, a=-k, b=k)
1274
+ elif isinstance(module, nn.Embedding):
1275
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
1276
+ if module.padding_idx is not None:
1277
+ module.weight.data[module.padding_idx].zero_()
1278
+
1279
+
1280
+ VITS_START_DOCSTRING = r"""
1281
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1282
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1283
+ etc.)
1284
+
1285
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
1286
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
1287
+ and behavior.
1288
+
1289
+ Parameters:
1290
+ config ([`VitsConfig`]):
1291
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
1292
+ load the weights associated with the model, only the configuration. Check out the
1293
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1294
+ """
1295
+
1296
+
1297
+ VITS_INPUTS_DOCSTRING = r"""
1298
+ Args:
1299
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1300
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
1301
+ it.
1302
+
1303
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1304
+ [`PreTrainedTokenizer.__call__`] for details.
1305
+
1306
+ [What are input IDs?](../glossary#input-ids)
1307
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1308
+ Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0,
1309
+ 1]`:
1310
+
1311
+ - 1 for tokens that are **not masked**,
1312
+ - 0 for tokens that are **masked**.
1313
+
1314
+ [What are attention masks?](../glossary#attention-mask)
1315
+ speaker_id (`int`, *optional*):
1316
+ Which speaker embedding to use. Only used for multispeaker models.
1317
+ output_attentions (`bool`, *optional*):
1318
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1319
+ tensors for more detail.
1320
+ output_hidden_states (`bool`, *optional*):
1321
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1322
+ more detail.
1323
+ return_dict (`bool`, *optional*):
1324
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1325
+ """
1326
+
1327
+
1328
+ @add_start_docstrings(
1329
+ "The complete VITS model, for text-to-speech synthesis.",
1330
+ VITS_START_DOCSTRING,
1331
+ )
1332
+ class VitsModel(VitsPreTrainedModel):
1333
+ def __init__(self, config: VitsConfig):
1334
+ super().__init__(config)
1335
+ self.config = config
1336
+ self.text_encoder = VitsTextEncoder(config)
1337
+ self.flow = VitsResidualCouplingBlock(config)
1338
+ self.decoder = VitsHifiGan(config)
1339
+
1340
+ if config.use_stochastic_duration_prediction:
1341
+ self.duration_predictor = VitsStochasticDurationPredictor(config)
1342
+ else:
1343
+ self.duration_predictor = VitsDurationPredictor(config)
1344
+
1345
+ if config.num_speakers > 1:
1346
+ self.embed_speaker = nn.Embedding(config.num_speakers, config.speaker_embedding_size)
1347
+
1348
+ # This is used only for training.
1349
+ self.posterior_encoder = VitsPosteriorEncoder(config)
1350
+
1351
+ # These parameters control the synthesised speech properties
1352
+ self.speaking_rate = config.speaking_rate
1353
+ self.noise_scale = config.noise_scale
1354
+ self.noise_scale_duration = config.noise_scale_duration
1355
+
1356
+ # Initialize weights and apply final processing
1357
+ self.post_init()
1358
+
1359
+ def get_encoder(self):
1360
+ return self.text_encoder
1361
+
1362
+ @add_start_docstrings_to_model_forward(VITS_INPUTS_DOCSTRING)
1363
+ @replace_return_docstrings(output_type=VitsModelOutput, config_class=_CONFIG_FOR_DOC)
1364
+ def forward(
1365
+ self,
1366
+ input_ids: Optional[torch.Tensor] = None,
1367
+ attention_mask: Optional[torch.Tensor] = None,
1368
+ speaker_id: Optional[int] = None,
1369
+ output_attentions: Optional[bool] = None,
1370
+ output_hidden_states: Optional[bool] = None,
1371
+ return_dict: Optional[bool] = None,
1372
+ labels: Optional[torch.FloatTensor] = None,
1373
+ ) -> Union[Tuple[Any], VitsModelOutput]:
1374
+ r"""
1375
+ labels (`torch.FloatTensor` of shape `(batch_size, config.spectrogram_bins, sequence_length)`, *optional*):
1376
+ Float values of target spectrogram. Timesteps set to `-100.0` are ignored (masked) for the loss
1377
+ computation.
1378
+
1379
+ Returns:
1380
+
1381
+ Example:
1382
+
1383
+ ```python
1384
+ >>> from transformers import VitsTokenizer, VitsModel, set_seed
1385
+ >>> import torch
1386
+
1387
+ >>> tokenizer = VitsTokenizer.from_pretrained("facebook/mms-tts-eng")
1388
+ >>> model = VitsModel.from_pretrained("facebook/mms-tts-eng")
1389
+
1390
+ >>> inputs = tokenizer(text="Hello - my dog is cute", return_tensors="pt")
1391
+
1392
+ >>> set_seed(555) # make deterministic
1393
+
1394
+ >>> with torch.no_grad():
1395
+ ... outputs = model(inputs["input_ids"])
1396
+ >>> outputs.waveform.shape
1397
+ torch.Size([1, 45824])
1398
+ ```
1399
+ """
1400
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1401
+ output_hidden_states = (
1402
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1403
+ )
1404
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1405
+
1406
+ if labels is not None:
1407
+ raise NotImplementedError("Training of VITS is not supported yet.")
1408
+
1409
+ mask_dtype = self.text_encoder.embed_tokens.weight.dtype
1410
+ if attention_mask is not None:
1411
+ input_padding_mask = attention_mask.unsqueeze(-1).to(mask_dtype)
1412
+ else:
1413
+ input_padding_mask = torch.ones_like(input_ids).unsqueeze(-1).to(mask_dtype)
1414
+
1415
+ if self.config.num_speakers > 1 and speaker_id is not None:
1416
+ if not 0 <= speaker_id < self.config.num_speakers:
1417
+ raise ValueError(f"Set `speaker_id` in the range 0-{self.config.num_speakers - 1}.")
1418
+ if isinstance(speaker_id, int):
1419
+ speaker_id = torch.full(size=(1,), fill_value=speaker_id, device=self.device)
1420
+ speaker_embeddings = self.embed_speaker(speaker_id).unsqueeze(-1)
1421
+ else:
1422
+ speaker_embeddings = None
1423
+
1424
+ text_encoder_output = self.text_encoder(
1425
+ input_ids=input_ids,
1426
+ padding_mask=input_padding_mask,
1427
+ attention_mask=attention_mask,
1428
+ output_attentions=output_attentions,
1429
+ output_hidden_states=output_hidden_states,
1430
+ return_dict=return_dict,
1431
+ )
1432
+ hidden_states = text_encoder_output[0] if not return_dict else text_encoder_output.last_hidden_state
1433
+ hidden_states = hidden_states.transpose(1, 2)
1434
+ input_padding_mask = input_padding_mask.transpose(1, 2)
1435
+ prior_means = text_encoder_output[1] if not return_dict else text_encoder_output.prior_means
1436
+ prior_log_variances = text_encoder_output[2] if not return_dict else text_encoder_output.prior_log_variances
1437
+
1438
+ if self.config.use_stochastic_duration_prediction:
1439
+ log_duration = self.duration_predictor(
1440
+ hidden_states,
1441
+ input_padding_mask,
1442
+ speaker_embeddings,
1443
+ reverse=True,
1444
+ noise_scale=self.noise_scale_duration,
1445
+ )
1446
+ else:
1447
+ log_duration = self.duration_predictor(hidden_states, input_padding_mask, speaker_embeddings)
1448
+
1449
+ length_scale = 1.0 / self.speaking_rate
1450
+ duration = torch.ceil(torch.exp(log_duration) * input_padding_mask * length_scale)
1451
+ predicted_lengths = torch.clamp_min(torch.sum(duration, [1, 2]), 1).long()
1452
+
1453
+ # Create a padding mask for the output lengths of shape (batch, 1, max_output_length)
1454
+ indices = torch.arange(predicted_lengths.max(), dtype=predicted_lengths.dtype, device=predicted_lengths.device)
1455
+ output_padding_mask = indices.unsqueeze(0) < predicted_lengths.unsqueeze(1)
1456
+ output_padding_mask = output_padding_mask.unsqueeze(1).to(input_padding_mask.dtype)
1457
+
1458
+ # Reconstruct an attention tensor of shape (batch, 1, out_length, in_length)
1459
+ attn_mask = torch.unsqueeze(input_padding_mask, 2) * torch.unsqueeze(output_padding_mask, -1)
1460
+ batch_size, _, output_length, input_length = attn_mask.shape
1461
+ cum_duration = torch.cumsum(duration, -1).view(batch_size * input_length, 1)
1462
+ indices = torch.arange(output_length, dtype=duration.dtype, device=duration.device)
1463
+ valid_indices = indices.unsqueeze(0) < cum_duration
1464
+ valid_indices = valid_indices.to(attn_mask.dtype).view(batch_size, input_length, output_length)
1465
+ padded_indices = valid_indices - nn.functional.pad(valid_indices, [0, 0, 1, 0, 0, 0])[:, :-1]
1466
+ attn = padded_indices.unsqueeze(1).transpose(2, 3) * attn_mask
1467
+
1468
+ # Expand prior distribution
1469
+ prior_means = torch.matmul(attn.squeeze(1), prior_means).transpose(1, 2)
1470
+ prior_log_variances = torch.matmul(attn.squeeze(1), prior_log_variances).transpose(1, 2)
1471
+
1472
+ prior_latents = prior_means + torch.randn_like(prior_means) * torch.exp(prior_log_variances) * self.noise_scale
1473
+ latents = self.flow(prior_latents, output_padding_mask, speaker_embeddings, reverse=True)
1474
+
1475
+ spectrogram = latents * output_padding_mask
1476
+ waveform = self.decoder(spectrogram, speaker_embeddings)
1477
+ waveform = waveform.squeeze(1)
1478
+ sequence_lengths = predicted_lengths * np.prod(self.config.upsample_rates)
1479
+
1480
+ if not return_dict:
1481
+ outputs = (waveform, sequence_lengths, spectrogram) + text_encoder_output[3:]
1482
+ return outputs
1483
+
1484
+ return VitsModelOutput(
1485
+ waveform=waveform,
1486
+ sequence_lengths=sequence_lengths,
1487
+ spectrogram=spectrogram,
1488
+ hidden_states=text_encoder_output.hidden_states,
1489
+ attentions=text_encoder_output.attentions,
1490
+ )
1491
+
1492
+
1493
+ __all__ = ["VitsModel", "VitsPreTrainedModel"]
docs/transformers/build/lib/transformers/models/vits/tokenization_vits.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Kakao Enterprise Authors, the MMS-TTS Authors and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization class for VITS."""
16
+
17
+ import json
18
+ import os
19
+ import re
20
+ from typing import Any, Dict, List, Optional, Tuple, Union
21
+
22
+ from ...tokenization_utils import PreTrainedTokenizer
23
+ from ...utils import is_phonemizer_available, is_uroman_available, logging
24
+
25
+
26
+ if is_phonemizer_available():
27
+ import phonemizer
28
+
29
+ if is_uroman_available():
30
+ import uroman as ur
31
+
32
+ logger = logging.get_logger(__name__)
33
+
34
+ VOCAB_FILES_NAMES = {"vocab_file": "vocab.json"}
35
+
36
+
37
+ def has_non_roman_characters(input_string):
38
+ # Find any character outside the ASCII range
39
+ non_roman_pattern = re.compile(r"[^\x00-\x7F]")
40
+
41
+ # Search the input string for non-Roman characters
42
+ match = non_roman_pattern.search(input_string)
43
+ has_non_roman = match is not None
44
+ return has_non_roman
45
+
46
+
47
+ class VitsTokenizer(PreTrainedTokenizer):
48
+ """
49
+ Construct a VITS tokenizer. Also supports MMS-TTS.
50
+
51
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
52
+ this superclass for more information regarding those methods.
53
+
54
+ Args:
55
+ vocab_file (`str`):
56
+ Path to the vocabulary file.
57
+ language (`str`, *optional*):
58
+ Language identifier.
59
+ add_blank (`bool`, *optional*, defaults to `True`):
60
+ Whether to insert token id 0 in between the other tokens.
61
+ normalize (`bool`, *optional*, defaults to `True`):
62
+ Whether to normalize the input text by removing all casing and punctuation.
63
+ phonemize (`bool`, *optional*, defaults to `True`):
64
+ Whether to convert the input text into phonemes.
65
+ is_uroman (`bool`, *optional*, defaults to `False`):
66
+ Whether the `uroman` Romanizer needs to be applied to the input text prior to tokenizing.
67
+ """
68
+
69
+ vocab_files_names = VOCAB_FILES_NAMES
70
+ model_input_names = ["input_ids", "attention_mask"]
71
+
72
+ def __init__(
73
+ self,
74
+ vocab_file,
75
+ pad_token="<pad>",
76
+ unk_token="<unk>",
77
+ language=None,
78
+ add_blank=True,
79
+ normalize=True,
80
+ phonemize=True,
81
+ is_uroman=False,
82
+ **kwargs,
83
+ ) -> None:
84
+ with open(vocab_file, encoding="utf-8") as vocab_handle:
85
+ self.encoder = json.load(vocab_handle)
86
+
87
+ self.decoder = {v: k for k, v in self.encoder.items()}
88
+ self.language = language
89
+ self.add_blank = add_blank
90
+ self.normalize = normalize
91
+ self.phonemize = phonemize
92
+
93
+ self.is_uroman = is_uroman
94
+
95
+ super().__init__(
96
+ pad_token=pad_token,
97
+ unk_token=unk_token,
98
+ language=language,
99
+ add_blank=add_blank,
100
+ normalize=normalize,
101
+ phonemize=phonemize,
102
+ is_uroman=is_uroman,
103
+ **kwargs,
104
+ )
105
+
106
+ @property
107
+ def vocab_size(self):
108
+ return len(self.encoder)
109
+
110
+ def get_vocab(self):
111
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
112
+ vocab.update(self.added_tokens_encoder)
113
+ return vocab
114
+
115
+ def normalize_text(self, input_string):
116
+ """Lowercase the input string, respecting any special token ids that may be part or entirely upper-cased."""
117
+ all_vocabulary = list(self.encoder.keys()) + list(self.added_tokens_encoder.keys())
118
+ filtered_text = ""
119
+
120
+ i = 0
121
+ while i < len(input_string):
122
+ found_match = False
123
+ for word in all_vocabulary:
124
+ if input_string[i : i + len(word)] == word:
125
+ filtered_text += word
126
+ i += len(word)
127
+ found_match = True
128
+ break
129
+
130
+ if not found_match:
131
+ filtered_text += input_string[i].lower()
132
+ i += 1
133
+
134
+ return filtered_text
135
+
136
+ def _preprocess_char(self, text):
137
+ """Special treatment of characters in certain languages"""
138
+ if self.language == "ron":
139
+ text = text.replace("ț", "ţ")
140
+ return text
141
+
142
+ def prepare_for_tokenization(
143
+ self, text: str, is_split_into_words: bool = False, normalize: Optional[bool] = None, **kwargs
144
+ ) -> Tuple[str, Dict[str, Any]]:
145
+ """
146
+ Performs any necessary transformations before tokenization.
147
+
148
+ This method should pop the arguments from kwargs and return the remaining `kwargs` as well. We test the
149
+ `kwargs` at the end of the encoding process to be sure all the arguments have been used.
150
+
151
+ Args:
152
+ text (`str`):
153
+ The text to prepare.
154
+ is_split_into_words (`bool`, *optional*, defaults to `False`):
155
+ Whether or not the input is already pre-tokenized (e.g., split into words). If set to `True`, the
156
+ tokenizer assumes the input is already split into words (for instance, by splitting it on whitespace)
157
+ which it will tokenize.
158
+ normalize (`bool`, *optional*, defaults to `None`):
159
+ Whether or not to apply punctuation and casing normalization to the text inputs. Typically, VITS is
160
+ trained on lower-cased and un-punctuated text. Hence, normalization is used to ensure that the input
161
+ text consists only of lower-case characters.
162
+ kwargs (`Dict[str, Any]`, *optional*):
163
+ Keyword arguments to use for the tokenization.
164
+
165
+ Returns:
166
+ `Tuple[str, Dict[str, Any]]`: The prepared text and the unused kwargs.
167
+ """
168
+ normalize = normalize if normalize is not None else self.normalize
169
+
170
+ if normalize:
171
+ # normalise for casing
172
+ text = self.normalize_text(text)
173
+
174
+ filtered_text = self._preprocess_char(text)
175
+
176
+ if has_non_roman_characters(filtered_text) and self.is_uroman:
177
+ if not is_uroman_available():
178
+ logger.warning(
179
+ "Text to the tokenizer contains non-Roman characters. To apply the `uroman` pre-processing "
180
+ "step automatically, ensure the `uroman` Romanizer is installed with: `pip install uroman` "
181
+ "Note `uroman` requires python version >= 3.10"
182
+ "Otherwise, apply the Romanizer manually as per the instructions: https://github.com/isi-nlp/uroman"
183
+ )
184
+ else:
185
+ uroman = ur.Uroman()
186
+ filtered_text = uroman.romanize_string(filtered_text)
187
+
188
+ if self.phonemize:
189
+ if not is_phonemizer_available():
190
+ raise ImportError("Please install the `phonemizer` Python package to use this tokenizer.")
191
+
192
+ filtered_text = phonemizer.phonemize(
193
+ filtered_text,
194
+ language="en-us",
195
+ backend="espeak",
196
+ strip=True,
197
+ preserve_punctuation=True,
198
+ with_stress=True,
199
+ )
200
+ filtered_text = re.sub(r"\s+", " ", filtered_text)
201
+ elif normalize:
202
+ # strip any chars outside of the vocab (punctuation)
203
+ filtered_text = "".join(list(filter(lambda char: char in self.encoder, filtered_text))).strip()
204
+
205
+ return filtered_text, kwargs
206
+
207
+ def _tokenize(self, text: str) -> List[str]:
208
+ """Tokenize a string by inserting the `<pad>` token at the boundary between adjacent characters."""
209
+ tokens = list(text)
210
+
211
+ if self.add_blank:
212
+ interspersed = [self._convert_id_to_token(0)] * (len(tokens) * 2 + 1)
213
+ interspersed[1::2] = tokens
214
+ tokens = interspersed
215
+
216
+ return tokens
217
+
218
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
219
+ if self.add_blank and len(tokens) > 1:
220
+ tokens = tokens[1::2]
221
+ return "".join(tokens)
222
+
223
+ def _convert_token_to_id(self, token):
224
+ """Converts a token (str) in an id using the vocab."""
225
+ return self.encoder.get(token, self.encoder.get(self.unk_token))
226
+
227
+ def _convert_id_to_token(self, index):
228
+ """Converts an index (integer) in a token (str) using the vocab."""
229
+ return self.decoder.get(index)
230
+
231
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Union[Tuple[str], None]:
232
+ if not os.path.isdir(save_directory):
233
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
234
+ return
235
+
236
+ vocab_file = os.path.join(
237
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
238
+ )
239
+
240
+ with open(vocab_file, "w", encoding="utf-8") as f:
241
+ f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
242
+
243
+ return (vocab_file,)
244
+
245
+
246
+ __all__ = ["VitsTokenizer"]
docs/transformers/build/lib/transformers/models/vivit/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_vivit import *
22
+ from .image_processing_vivit import *
23
+ from .modeling_vivit import *
24
+ else:
25
+ import sys
26
+
27
+ _file = globals()["__file__"]
28
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/build/lib/transformers/models/vivit/configuration_vivit.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ViViT model configuration"""
16
+
17
+ from ...configuration_utils import PretrainedConfig
18
+ from ...utils import logging
19
+
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+
24
+ class VivitConfig(PretrainedConfig):
25
+ r"""
26
+ This is the configuration class to store the configuration of a [`VivitModel`]. It is used to instantiate a ViViT
27
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
28
+ defaults will yield a similar configuration to that of the ViViT
29
+ [google/vivit-b-16x2-kinetics400](https://huggingface.co/google/vivit-b-16x2-kinetics400) architecture.
30
+
31
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
32
+ documentation from [`PretrainedConfig`] for more information.
33
+
34
+ Args:
35
+ image_size (`int`, *optional*, defaults to 224):
36
+ The size (resolution) of each image.
37
+ num_frames (`int`, *optional*, defaults to 32):
38
+ The number of frames in each video.
39
+ tubelet_size (`List[int]`, *optional*, defaults to `[2, 16, 16]`):
40
+ The size (resolution) of each tubelet.
41
+ num_channels (`int`, *optional*, defaults to 3):
42
+ The number of input channels.
43
+ hidden_size (`int`, *optional*, defaults to 768):
44
+ Dimensionality of the encoder layers and the pooler layer.
45
+ num_hidden_layers (`int`, *optional*, defaults to 12):
46
+ Number of hidden layers in the Transformer encoder.
47
+ num_attention_heads (`int`, *optional*, defaults to 12):
48
+ Number of attention heads for each attention layer in the Transformer encoder.
49
+ intermediate_size (`int`, *optional*, defaults to 3072):
50
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
51
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu_fast"`):
52
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
53
+ `"relu"`, `"selu"`, `"gelu_fast"` and `"gelu_new"` are supported.
54
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
55
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
56
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
57
+ The dropout ratio for the attention probabilities.
58
+ initializer_range (`float`, *optional*, defaults to 0.02):
59
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
60
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
61
+ The epsilon used by the layer normalization layers.
62
+ qkv_bias (`bool`, *optional*, defaults to `True`):
63
+ Whether to add a bias to the queries, keys and values.
64
+
65
+ Example:
66
+
67
+ ```python
68
+ >>> from transformers import VivitConfig, VivitModel
69
+
70
+ >>> # Initializing a ViViT google/vivit-b-16x2-kinetics400 style configuration
71
+ >>> configuration = VivitConfig()
72
+
73
+ >>> # Initializing a model (with random weights) from the google/vivit-b-16x2-kinetics400 style configuration
74
+ >>> model = VivitModel(configuration)
75
+
76
+ >>> # Accessing the model configuration
77
+ >>> configuration = model.config
78
+ ```"""
79
+
80
+ model_type = "vivit"
81
+
82
+ def __init__(
83
+ self,
84
+ image_size=224,
85
+ num_frames=32,
86
+ tubelet_size=[2, 16, 16],
87
+ num_channels=3,
88
+ hidden_size=768,
89
+ num_hidden_layers=12,
90
+ num_attention_heads=12,
91
+ intermediate_size=3072,
92
+ hidden_act="gelu_fast",
93
+ hidden_dropout_prob=0.0,
94
+ attention_probs_dropout_prob=0.0,
95
+ initializer_range=0.02,
96
+ layer_norm_eps=1e-06,
97
+ qkv_bias=True,
98
+ **kwargs,
99
+ ):
100
+ self.hidden_size = hidden_size
101
+ self.num_hidden_layers = num_hidden_layers
102
+ self.num_attention_heads = num_attention_heads
103
+ self.intermediate_size = intermediate_size
104
+ self.hidden_act = hidden_act
105
+ self.hidden_dropout_prob = hidden_dropout_prob
106
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
107
+ self.initializer_range = initializer_range
108
+ self.layer_norm_eps = layer_norm_eps
109
+
110
+ self.image_size = image_size
111
+ self.num_frames = num_frames
112
+ self.tubelet_size = tubelet_size
113
+ self.num_channels = num_channels
114
+ self.qkv_bias = qkv_bias
115
+
116
+ super().__init__(**kwargs)
117
+
118
+
119
+ __all__ = ["VivitConfig"]
docs/transformers/build/lib/transformers/models/vivit/convert_vivit_flax_to_pytorch.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Convert Flax ViViT checkpoints from the original repository to PyTorch. URL:
16
+ https://github.com/google-research/scenic/tree/main/scenic/projects/vivit
17
+ """
18
+
19
+ import argparse
20
+ import json
21
+ import os.path
22
+ from collections import OrderedDict
23
+
24
+ import numpy as np
25
+ import requests
26
+ import torch
27
+ from flax.training.checkpoints import restore_checkpoint
28
+ from huggingface_hub import hf_hub_download
29
+
30
+ from transformers import VivitConfig, VivitForVideoClassification, VivitImageProcessor
31
+ from transformers.image_utils import PILImageResampling
32
+
33
+
34
+ def download_checkpoint(path):
35
+ url = "https://storage.googleapis.com/scenic-bucket/vivit/kinetics_400/vivit_base_16x2_unfactorized/checkpoint"
36
+
37
+ with open(path, "wb") as f:
38
+ with requests.get(url, stream=True) as req:
39
+ for chunk in req.iter_content(chunk_size=2048):
40
+ f.write(chunk)
41
+
42
+
43
+ def get_vivit_config() -> VivitConfig:
44
+ config = VivitConfig()
45
+
46
+ config.num_labels = 400
47
+ repo_id = "huggingface/label-files"
48
+ filename = "kinetics400-id2label.json"
49
+
50
+ id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
51
+ id2label = {int(k): v for k, v in id2label.items()}
52
+ config.id2label = id2label
53
+ config.label2id = {v: k for k, v in id2label.items()}
54
+ return config
55
+
56
+
57
+ # We will verify our results on a video of eating spaghetti
58
+ # Frame indices used: [ 47, 51, 55, 59, 63, 67, 71, 75, 80, 84, 88, 92, 96, 100, 104, 108, 113, 117,
59
+ # 121, 125, 129, 133, 137, 141, 146, 150, 154, 158, 162, 166, 170, 174]
60
+ def prepare_video():
61
+ file = hf_hub_download(
62
+ repo_id="hf-internal-testing/spaghetti-video", filename="eating_spaghetti_32_frames.npy", repo_type="dataset"
63
+ )
64
+ video = np.load(file)
65
+ return list(video)
66
+
67
+
68
+ def transform_attention(current: np.ndarray):
69
+ if np.ndim(current) == 2:
70
+ return transform_attention_bias(current)
71
+
72
+ elif np.ndim(current) == 3:
73
+ return transform_attention_kernel(current)
74
+
75
+ else:
76
+ raise Exception(f"Invalid number of dimensions: {np.ndim(current)}")
77
+
78
+
79
+ def transform_attention_bias(current: np.ndarray):
80
+ return current.flatten()
81
+
82
+
83
+ def transform_attention_kernel(current: np.ndarray):
84
+ return np.reshape(current, (current.shape[0], current.shape[1] * current.shape[2])).T
85
+
86
+
87
+ def transform_attention_output_weight(current: np.ndarray):
88
+ return np.reshape(current, (current.shape[0] * current.shape[1], current.shape[2])).T
89
+
90
+
91
+ def transform_state_encoder_block(state_dict, i):
92
+ state = state_dict["optimizer"]["target"]["Transformer"][f"encoderblock_{i}"]
93
+
94
+ prefix = f"encoder.layer.{i}."
95
+ new_state = {
96
+ prefix + "intermediate.dense.bias": state["MlpBlock_0"]["Dense_0"]["bias"],
97
+ prefix + "intermediate.dense.weight": np.transpose(state["MlpBlock_0"]["Dense_0"]["kernel"]),
98
+ prefix + "output.dense.bias": state["MlpBlock_0"]["Dense_1"]["bias"],
99
+ prefix + "output.dense.weight": np.transpose(state["MlpBlock_0"]["Dense_1"]["kernel"]),
100
+ prefix + "layernorm_before.bias": state["LayerNorm_0"]["bias"],
101
+ prefix + "layernorm_before.weight": state["LayerNorm_0"]["scale"],
102
+ prefix + "layernorm_after.bias": state["LayerNorm_1"]["bias"],
103
+ prefix + "layernorm_after.weight": state["LayerNorm_1"]["scale"],
104
+ prefix + "attention.attention.query.bias": transform_attention(
105
+ state["MultiHeadDotProductAttention_0"]["query"]["bias"]
106
+ ),
107
+ prefix + "attention.attention.query.weight": transform_attention(
108
+ state["MultiHeadDotProductAttention_0"]["query"]["kernel"]
109
+ ),
110
+ prefix + "attention.attention.key.bias": transform_attention(
111
+ state["MultiHeadDotProductAttention_0"]["key"]["bias"]
112
+ ),
113
+ prefix + "attention.attention.key.weight": transform_attention(
114
+ state["MultiHeadDotProductAttention_0"]["key"]["kernel"]
115
+ ),
116
+ prefix + "attention.attention.value.bias": transform_attention(
117
+ state["MultiHeadDotProductAttention_0"]["value"]["bias"]
118
+ ),
119
+ prefix + "attention.attention.value.weight": transform_attention(
120
+ state["MultiHeadDotProductAttention_0"]["value"]["kernel"]
121
+ ),
122
+ prefix + "attention.output.dense.bias": state["MultiHeadDotProductAttention_0"]["out"]["bias"],
123
+ prefix + "attention.output.dense.weight": transform_attention_output_weight(
124
+ state["MultiHeadDotProductAttention_0"]["out"]["kernel"]
125
+ ),
126
+ }
127
+
128
+ return new_state
129
+
130
+
131
+ def get_n_layers(state_dict):
132
+ return sum([1 if "encoderblock_" in k else 0 for k in state_dict["optimizer"]["target"]["Transformer"].keys()])
133
+
134
+
135
+ def transform_state(state_dict, classification_head=False):
136
+ transformer_layers = get_n_layers(state_dict)
137
+
138
+ new_state = OrderedDict()
139
+
140
+ new_state["layernorm.bias"] = state_dict["optimizer"]["target"]["Transformer"]["encoder_norm"]["bias"]
141
+ new_state["layernorm.weight"] = state_dict["optimizer"]["target"]["Transformer"]["encoder_norm"]["scale"]
142
+
143
+ new_state["embeddings.patch_embeddings.projection.weight"] = np.transpose(
144
+ state_dict["optimizer"]["target"]["embedding"]["kernel"], (4, 3, 0, 1, 2)
145
+ )
146
+ new_state["embeddings.patch_embeddings.projection.bias"] = state_dict["optimizer"]["target"]["embedding"]["bias"]
147
+
148
+ new_state["embeddings.cls_token"] = state_dict["optimizer"]["target"]["cls"]
149
+ new_state["embeddings.position_embeddings"] = state_dict["optimizer"]["target"]["Transformer"]["posembed_input"][
150
+ "pos_embedding"
151
+ ]
152
+
153
+ for i in range(transformer_layers):
154
+ new_state.update(transform_state_encoder_block(state_dict, i))
155
+
156
+ if classification_head:
157
+ new_state = {"vivit." + k: v for k, v in new_state.items()}
158
+ new_state["classifier.weight"] = np.transpose(state_dict["optimizer"]["target"]["output_projection"]["kernel"])
159
+ new_state["classifier.bias"] = np.transpose(state_dict["optimizer"]["target"]["output_projection"]["bias"])
160
+
161
+ return {k: torch.tensor(v) for k, v in new_state.items()}
162
+
163
+
164
+ # checks that image processor settings are the same as in the original implementation
165
+ # original: https://github.com/google-research/scenic/blob/main/scenic/projects/vivit/data/video_tfrecord_dataset.py
166
+ # dataset specific config:
167
+ # https://github.com/google-research/scenic/blob/main/scenic/projects/vivit/configs/kinetics400/vivit_base_k400.py
168
+ def get_processor() -> VivitImageProcessor:
169
+ extractor = VivitImageProcessor()
170
+
171
+ assert extractor.do_resize is True
172
+ assert extractor.size == {"shortest_edge": 256}
173
+ assert extractor.do_center_crop is True
174
+ assert extractor.crop_size == {"width": 224, "height": 224}
175
+ assert extractor.resample == PILImageResampling.BILINEAR
176
+
177
+ # here: https://github.com/deepmind/dmvr/blob/master/dmvr/modalities.py
178
+ # one can seen that add_image has default values for normalization_mean and normalization_std set to 0 and 1
179
+ # which effectively means no normalization (and ViViT does not overwrite those when calling this func)
180
+ assert extractor.do_normalize is False
181
+ assert extractor.do_rescale is True
182
+ assert extractor.rescale_factor == 1 / 255
183
+
184
+ # zero-centering = True in original implementation
185
+ assert extractor.do_zero_centering is True
186
+
187
+ return extractor
188
+
189
+
190
+ def convert(output_path: str):
191
+ flax_model_path = "checkpoint"
192
+
193
+ if not os.path.exists(flax_model_path):
194
+ download_checkpoint(flax_model_path)
195
+
196
+ state_dict = restore_checkpoint(flax_model_path, None)
197
+ new_state = transform_state(state_dict, classification_head=True)
198
+
199
+ config = get_vivit_config()
200
+
201
+ assert config.image_size == 224
202
+ assert config.num_frames == 32
203
+
204
+ model = VivitForVideoClassification(config)
205
+ model.load_state_dict(new_state)
206
+ model.eval()
207
+
208
+ extractor = get_processor()
209
+
210
+ video = prepare_video()
211
+ inputs = extractor(video, return_tensors="pt")
212
+
213
+ outputs = model(**inputs)
214
+
215
+ expected_shape = torch.Size([1, 400])
216
+ expected_slice = torch.tensor([-1.0543, 2.0764, -0.2104, 0.4439, -0.9658])
217
+
218
+ assert outputs.logits.shape == expected_shape
219
+ assert torch.allclose(outputs.logits[0, :5], expected_slice, atol=1e-4), outputs.logits[0, :5]
220
+
221
+ model.save_pretrained(output_path)
222
+ extractor.save_pretrained(output_path)
223
+
224
+
225
+ if __name__ == "__main__":
226
+ parser = argparse.ArgumentParser()
227
+
228
+ parser.add_argument("--output_model_name", "-o", type=str, help="Output path for the converted HuggingFace model")
229
+
230
+ args = parser.parse_args()
231
+ convert(args.output_model_name)
docs/transformers/build/lib/transformers/models/vivit/image_processing_vivit.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Image processor class for Vivit."""
16
+
17
+ from typing import Dict, List, Optional, Union
18
+
19
+ import numpy as np
20
+
21
+ from transformers.utils import is_vision_available
22
+ from transformers.utils.generic import TensorType
23
+
24
+ from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
25
+ from ...image_transforms import (
26
+ get_resize_output_image_size,
27
+ rescale,
28
+ resize,
29
+ to_channel_dimension_format,
30
+ )
31
+ from ...image_utils import (
32
+ IMAGENET_STANDARD_MEAN,
33
+ IMAGENET_STANDARD_STD,
34
+ ChannelDimension,
35
+ ImageInput,
36
+ PILImageResampling,
37
+ infer_channel_dimension_format,
38
+ is_scaled_image,
39
+ is_valid_image,
40
+ to_numpy_array,
41
+ valid_images,
42
+ validate_preprocess_arguments,
43
+ )
44
+ from ...utils import filter_out_non_signature_kwargs, logging
45
+
46
+
47
+ if is_vision_available():
48
+ import PIL
49
+
50
+ logger = logging.get_logger(__name__)
51
+
52
+
53
+ def make_batched(videos) -> List[List[ImageInput]]:
54
+ if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]):
55
+ return videos
56
+
57
+ elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]):
58
+ return [videos]
59
+
60
+ elif is_valid_image(videos):
61
+ return [[videos]]
62
+
63
+ raise ValueError(f"Could not make batched video from {videos}")
64
+
65
+
66
+ class VivitImageProcessor(BaseImageProcessor):
67
+ r"""
68
+ Constructs a Vivit image processor.
69
+
70
+ Args:
71
+ do_resize (`bool`, *optional*, defaults to `True`):
72
+ Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
73
+ `do_resize` parameter in the `preprocess` method.
74
+ size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 256}`):
75
+ Size of the output image after resizing. The shortest edge of the image will be resized to
76
+ `size["shortest_edge"]` while maintaining the aspect ratio of the original image. Can be overriden by
77
+ `size` in the `preprocess` method.
78
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
79
+ Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
80
+ `preprocess` method.
81
+ do_center_crop (`bool`, *optional*, defaults to `True`):
82
+ Whether to center crop the image to the specified `crop_size`. Can be overridden by the `do_center_crop`
83
+ parameter in the `preprocess` method.
84
+ crop_size (`Dict[str, int]`, *optional*, defaults to `{"height": 224, "width": 224}`):
85
+ Size of the image after applying the center crop. Can be overridden by the `crop_size` parameter in the
86
+ `preprocess` method.
87
+ do_rescale (`bool`, *optional*, defaults to `True`):
88
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
89
+ parameter in the `preprocess` method.
90
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/127.5`):
91
+ Defines the scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter
92
+ in the `preprocess` method.
93
+ offset (`bool`, *optional*, defaults to `True`):
94
+ Whether to scale the image in both negative and positive directions. Can be overriden by the `offset` in
95
+ the `preprocess` method.
96
+ do_normalize (`bool`, *optional*, defaults to `True`):
97
+ Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
98
+ method.
99
+ image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
100
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
101
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
102
+ image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
103
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
104
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
105
+ """
106
+
107
+ model_input_names = ["pixel_values"]
108
+
109
+ def __init__(
110
+ self,
111
+ do_resize: bool = True,
112
+ size: Dict[str, int] = None,
113
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
114
+ do_center_crop: bool = True,
115
+ crop_size: Dict[str, int] = None,
116
+ do_rescale: bool = True,
117
+ rescale_factor: Union[int, float] = 1 / 127.5,
118
+ offset: bool = True,
119
+ do_normalize: bool = True,
120
+ image_mean: Optional[Union[float, List[float]]] = None,
121
+ image_std: Optional[Union[float, List[float]]] = None,
122
+ **kwargs,
123
+ ) -> None:
124
+ super().__init__(**kwargs)
125
+ size = size if size is not None else {"shortest_edge": 256}
126
+ size = get_size_dict(size, default_to_square=False)
127
+ crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
128
+ crop_size = get_size_dict(crop_size, param_name="crop_size")
129
+
130
+ self.do_resize = do_resize
131
+ self.size = size
132
+ self.do_center_crop = do_center_crop
133
+ self.crop_size = crop_size
134
+ self.resample = resample
135
+ self.do_rescale = do_rescale
136
+ self.rescale_factor = rescale_factor
137
+ self.offset = offset
138
+ self.do_normalize = do_normalize
139
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
140
+ self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
141
+
142
+ def resize(
143
+ self,
144
+ image: np.ndarray,
145
+ size: Dict[str, int],
146
+ resample: PILImageResampling = PILImageResampling.BILINEAR,
147
+ data_format: Optional[Union[str, ChannelDimension]] = None,
148
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
149
+ **kwargs,
150
+ ) -> np.ndarray:
151
+ """
152
+ Resize an image.
153
+
154
+ Args:
155
+ image (`np.ndarray`):
156
+ Image to resize.
157
+ size (`Dict[str, int]`):
158
+ Size of the output image. If `size` is of the form `{"height": h, "width": w}`, the output image will
159
+ have the size `(h, w)`. If `size` is of the form `{"shortest_edge": s}`, the output image will have its
160
+ shortest edge of length `s` while keeping the aspect ratio of the original image.
161
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
162
+ Resampling filter to use when resiizing the image.
163
+ data_format (`str` or `ChannelDimension`, *optional*):
164
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
165
+ input_data_format (`str` or `ChannelDimension`, *optional*):
166
+ The channel dimension format of the input image. If not provided, it will be inferred.
167
+ """
168
+ size = get_size_dict(size, default_to_square=False)
169
+ if "shortest_edge" in size:
170
+ output_size = get_resize_output_image_size(
171
+ image, size["shortest_edge"], default_to_square=False, input_data_format=input_data_format
172
+ )
173
+ elif "height" in size and "width" in size:
174
+ output_size = (size["height"], size["width"])
175
+ else:
176
+ raise ValueError(f"Size must have 'height' and 'width' or 'shortest_edge' as keys. Got {size.keys()}")
177
+ return resize(
178
+ image,
179
+ size=output_size,
180
+ resample=resample,
181
+ data_format=data_format,
182
+ input_data_format=input_data_format,
183
+ **kwargs,
184
+ )
185
+
186
+ # Copied from transformers.models.efficientnet.image_processing_efficientnet.EfficientNetImageProcessor.rescale
187
+ def rescale(
188
+ self,
189
+ image: np.ndarray,
190
+ scale: Union[int, float],
191
+ offset: bool = True,
192
+ data_format: Optional[Union[str, ChannelDimension]] = None,
193
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
194
+ **kwargs,
195
+ ):
196
+ """
197
+ Rescale an image by a scale factor.
198
+
199
+ If `offset` is `True`, the image has its values rescaled by `scale` and then offset by 1. If `scale` is
200
+ 1/127.5, the image is rescaled between [-1, 1].
201
+ image = image * scale - 1
202
+
203
+ If `offset` is `False`, and `scale` is 1/255, the image is rescaled between [0, 1].
204
+ image = image * scale
205
+
206
+ Args:
207
+ image (`np.ndarray`):
208
+ Image to rescale.
209
+ scale (`int` or `float`):
210
+ Scale to apply to the image.
211
+ offset (`bool`, *optional*):
212
+ Whether to scale the image in both negative and positive directions.
213
+ data_format (`str` or `ChannelDimension`, *optional*):
214
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
215
+ input_data_format (`ChannelDimension` or `str`, *optional*):
216
+ The channel dimension format of the input image. If not provided, it will be inferred.
217
+ """
218
+ rescaled_image = rescale(
219
+ image, scale=scale, data_format=data_format, input_data_format=input_data_format, **kwargs
220
+ )
221
+
222
+ if offset:
223
+ rescaled_image = rescaled_image - 1
224
+
225
+ return rescaled_image
226
+
227
+ def _preprocess_image(
228
+ self,
229
+ image: ImageInput,
230
+ do_resize: Optional[bool] = None,
231
+ size: Dict[str, int] = None,
232
+ resample: PILImageResampling = None,
233
+ do_center_crop: Optional[bool] = None,
234
+ crop_size: Dict[str, int] = None,
235
+ do_rescale: Optional[bool] = None,
236
+ rescale_factor: Optional[float] = None,
237
+ offset: Optional[bool] = None,
238
+ do_normalize: Optional[bool] = None,
239
+ image_mean: Optional[Union[float, List[float]]] = None,
240
+ image_std: Optional[Union[float, List[float]]] = None,
241
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
242
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
243
+ ) -> np.ndarray:
244
+ """Preprocesses a single image."""
245
+
246
+ validate_preprocess_arguments(
247
+ do_rescale=do_rescale,
248
+ rescale_factor=rescale_factor,
249
+ do_normalize=do_normalize,
250
+ image_mean=image_mean,
251
+ image_std=image_std,
252
+ do_center_crop=do_center_crop,
253
+ crop_size=crop_size,
254
+ do_resize=do_resize,
255
+ size=size,
256
+ resample=resample,
257
+ )
258
+
259
+ if offset and not do_rescale:
260
+ raise ValueError("For offset, do_rescale must also be set to True.")
261
+
262
+ # All transformations expect numpy arrays.
263
+ image = to_numpy_array(image)
264
+
265
+ if do_rescale and is_scaled_image(image):
266
+ logger.warning_once(
267
+ "It looks like you are trying to rescale already rescaled images. If the input"
268
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
269
+ )
270
+
271
+ if input_data_format is None:
272
+ input_data_format = infer_channel_dimension_format(image)
273
+
274
+ if do_resize:
275
+ image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
276
+
277
+ if do_center_crop:
278
+ image = self.center_crop(image, size=crop_size, input_data_format=input_data_format)
279
+
280
+ if do_rescale:
281
+ image = self.rescale(image=image, scale=rescale_factor, offset=offset, input_data_format=input_data_format)
282
+
283
+ if do_normalize:
284
+ image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
285
+
286
+ image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
287
+ return image
288
+
289
+ @filter_out_non_signature_kwargs()
290
+ def preprocess(
291
+ self,
292
+ videos: ImageInput,
293
+ do_resize: Optional[bool] = None,
294
+ size: Dict[str, int] = None,
295
+ resample: PILImageResampling = None,
296
+ do_center_crop: Optional[bool] = None,
297
+ crop_size: Dict[str, int] = None,
298
+ do_rescale: Optional[bool] = None,
299
+ rescale_factor: Optional[float] = None,
300
+ offset: Optional[bool] = None,
301
+ do_normalize: Optional[bool] = None,
302
+ image_mean: Optional[Union[float, List[float]]] = None,
303
+ image_std: Optional[Union[float, List[float]]] = None,
304
+ return_tensors: Optional[Union[str, TensorType]] = None,
305
+ data_format: ChannelDimension = ChannelDimension.FIRST,
306
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
307
+ ) -> PIL.Image.Image:
308
+ """
309
+ Preprocess an image or batch of images.
310
+
311
+ Args:
312
+ videos (`ImageInput`):
313
+ Video frames to preprocess. Expects a single or batch of video frames with pixel values ranging from 0
314
+ to 255. If passing in frames with pixel values between 0 and 1, set `do_rescale=False`.
315
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
316
+ Whether to resize the image.
317
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
318
+ Size of the image after applying resize.
319
+ resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
320
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only
321
+ has an effect if `do_resize` is set to `True`.
322
+ do_center_crop (`bool`, *optional*, defaults to `self.do_centre_crop`):
323
+ Whether to centre crop the image.
324
+ crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`):
325
+ Size of the image after applying the centre crop.
326
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
327
+ Whether to rescale the image values between `[-1 - 1]` if `offset` is `True`, `[0, 1]` otherwise.
328
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
329
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
330
+ offset (`bool`, *optional*, defaults to `self.offset`):
331
+ Whether to scale the image in both negative and positive directions.
332
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
333
+ Whether to normalize the image.
334
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
335
+ Image mean.
336
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
337
+ Image standard deviation.
338
+ return_tensors (`str` or `TensorType`, *optional*):
339
+ The type of tensors to return. Can be one of:
340
+ - Unset: Return a list of `np.ndarray`.
341
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
342
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
343
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
344
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
345
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
346
+ The channel dimension format for the output image. Can be one of:
347
+ - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
348
+ - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
349
+ - Unset: Use the inferred channel dimension format of the input image.
350
+ input_data_format (`ChannelDimension` or `str`, *optional*):
351
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
352
+ from the input image. Can be one of:
353
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
354
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
355
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
356
+ """
357
+ do_resize = do_resize if do_resize is not None else self.do_resize
358
+ resample = resample if resample is not None else self.resample
359
+ do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
360
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
361
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
362
+ offset = offset if offset is not None else self.offset
363
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
364
+ image_mean = image_mean if image_mean is not None else self.image_mean
365
+ image_std = image_std if image_std is not None else self.image_std
366
+
367
+ size = size if size is not None else self.size
368
+ size = get_size_dict(size, default_to_square=False)
369
+ crop_size = crop_size if crop_size is not None else self.crop_size
370
+ crop_size = get_size_dict(crop_size, param_name="crop_size")
371
+
372
+ if not valid_images(videos):
373
+ raise ValueError(
374
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
375
+ "torch.Tensor, tf.Tensor or jax.ndarray."
376
+ )
377
+
378
+ videos = make_batched(videos)
379
+
380
+ videos = [
381
+ [
382
+ self._preprocess_image(
383
+ image=img,
384
+ do_resize=do_resize,
385
+ size=size,
386
+ resample=resample,
387
+ do_center_crop=do_center_crop,
388
+ crop_size=crop_size,
389
+ do_rescale=do_rescale,
390
+ rescale_factor=rescale_factor,
391
+ offset=offset,
392
+ do_normalize=do_normalize,
393
+ image_mean=image_mean,
394
+ image_std=image_std,
395
+ data_format=data_format,
396
+ input_data_format=input_data_format,
397
+ )
398
+ for img in video
399
+ ]
400
+ for video in videos
401
+ ]
402
+
403
+ data = {"pixel_values": videos}
404
+ return BatchFeature(data=data, tensor_type=return_tensors)
405
+
406
+
407
+ __all__ = ["VivitImageProcessor"]
docs/transformers/build/lib/transformers/models/vivit/modeling_vivit.py ADDED
@@ -0,0 +1,844 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Google AI and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch ViViT model."""
16
+
17
+ from typing import Callable, Optional, Set, Tuple, Union
18
+
19
+ import torch
20
+ import torch.utils.checkpoint
21
+ from torch import nn
22
+ from torch.nn import CrossEntropyLoss, MSELoss
23
+
24
+ from ...activations import ACT2FN
25
+ from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
26
+ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
27
+ from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
28
+ from ...utils import (
29
+ add_start_docstrings,
30
+ add_start_docstrings_to_model_forward,
31
+ logging,
32
+ replace_return_docstrings,
33
+ torch_int,
34
+ )
35
+ from .configuration_vivit import VivitConfig
36
+
37
+
38
+ logger = logging.get_logger(__name__)
39
+
40
+ _CHECKPOINT_FOR_DOC = "google/vivit-b-16x2-kinetics400"
41
+ _CONFIG_FOR_DOC = "VivitConfig"
42
+
43
+
44
+ class VivitTubeletEmbeddings(nn.Module):
45
+ """
46
+ Construct Vivit Tubelet embeddings.
47
+
48
+ This module turns a batch of videos of shape (batch_size, num_frames, num_channels, height, width) into a tensor of
49
+ shape (batch_size, seq_len, hidden_size) to be consumed by a Transformer encoder.
50
+
51
+ The seq_len (the number of patches) equals (number of frames // tubelet_size[0]) * (height // tubelet_size[1]) *
52
+ (width // tubelet_size[2]).
53
+ """
54
+
55
+ def __init__(self, config):
56
+ super().__init__()
57
+ self.num_frames = config.num_frames
58
+ self.image_size = config.image_size
59
+ self.patch_size = config.tubelet_size
60
+ self.num_patches = (
61
+ (self.image_size // self.patch_size[2])
62
+ * (self.image_size // self.patch_size[1])
63
+ * (self.num_frames // self.patch_size[0])
64
+ )
65
+ self.embed_dim = config.hidden_size
66
+
67
+ self.projection = nn.Conv3d(
68
+ config.num_channels, config.hidden_size, kernel_size=config.tubelet_size, stride=config.tubelet_size
69
+ )
70
+
71
+ def forward(self, pixel_values, interpolate_pos_encoding: bool = False):
72
+ batch_size, num_frames, num_channels, height, width = pixel_values.shape
73
+ if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size):
74
+ raise ValueError(
75
+ f"Image image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
76
+ )
77
+
78
+ # permute to (batch_size, num_channels, num_frames, height, width)
79
+ pixel_values = pixel_values.permute(0, 2, 1, 3, 4)
80
+
81
+ x = self.projection(pixel_values)
82
+ # out_batch_size, out_num_channels, out_num_frames, out_height, out_width = x.shape
83
+ # flattens time and space dimensions, transposes to (out_batch_size, flat_tokens, out_num_channels)
84
+ x = x.flatten(2).transpose(1, 2)
85
+ return x
86
+
87
+
88
+ class VivitEmbeddings(nn.Module):
89
+ """
90
+ Vivit Embeddings.
91
+
92
+ Creates embeddings from a video using VivitTubeletEmbeddings, adds CLS token and positional embeddings.
93
+ """
94
+
95
+ def __init__(self, config):
96
+ super().__init__()
97
+
98
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
99
+ self.patch_embeddings = VivitTubeletEmbeddings(config)
100
+
101
+ self.position_embeddings = nn.Parameter(
102
+ torch.zeros(1, self.patch_embeddings.num_patches + 1, config.hidden_size)
103
+ )
104
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
105
+ self.patch_size = config.tubelet_size[1:]
106
+ self.config = config
107
+
108
+ # Adapted from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
109
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
110
+ """
111
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
112
+ images. This method is also adapted to support torch.jit tracing.
113
+
114
+ Adapted from:
115
+ - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
116
+ - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
117
+ """
118
+
119
+ num_patches = embeddings.shape[1] - 1
120
+ num_positions = self.position_embeddings.shape[1] - 1
121
+
122
+ # always interpolate when tracing to ensure the exported model works for dynamic input shapes
123
+ if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
124
+ return self.position_embeddings
125
+
126
+ class_pos_embed = self.position_embeddings[:, :1]
127
+ patch_pos_embed = self.position_embeddings[:, 1:]
128
+
129
+ dim = embeddings.shape[-1]
130
+
131
+ new_height = height // self.patch_size[0]
132
+ new_width = width // self.patch_size[1]
133
+
134
+ sqrt_num_positions = torch_int(num_positions**0.5)
135
+ patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
136
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
137
+
138
+ patch_pos_embed = nn.functional.interpolate(
139
+ patch_pos_embed,
140
+ size=(new_height, new_width),
141
+ mode="bicubic",
142
+ align_corners=False,
143
+ )
144
+
145
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
146
+
147
+ return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
148
+
149
+ def forward(self, pixel_values, interpolate_pos_encoding: bool = False):
150
+ batch_size, num_frames, num_channels, height, width = pixel_values.shape
151
+ embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
152
+
153
+ cls_tokens = self.cls_token.tile([batch_size, 1, 1])
154
+ embeddings = torch.cat((cls_tokens, embeddings), dim=1)
155
+
156
+ # add positional encoding to each token
157
+ if interpolate_pos_encoding:
158
+ embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
159
+ else:
160
+ embeddings = embeddings + self.position_embeddings
161
+
162
+ embeddings = self.dropout(embeddings)
163
+
164
+ return embeddings
165
+
166
+
167
+ # Copied from transformers.models.vit.modeling_vit.eager_attention_forward
168
+ def eager_attention_forward(
169
+ module: nn.Module,
170
+ query: torch.Tensor,
171
+ key: torch.Tensor,
172
+ value: torch.Tensor,
173
+ attention_mask: Optional[torch.Tensor],
174
+ scaling: float,
175
+ dropout: float = 0.0,
176
+ **kwargs,
177
+ ):
178
+ # Take the dot product between "query" and "key" to get the raw attention scores.
179
+ attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
180
+
181
+ # Normalize the attention scores to probabilities.
182
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
183
+
184
+ # This is actually dropping out entire tokens to attend to, which might
185
+ # seem a bit unusual, but is taken from the original Transformer paper.
186
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
187
+
188
+ # Mask heads if we want to
189
+ if attention_mask is not None:
190
+ attn_weights = attn_weights * attention_mask
191
+
192
+ attn_output = torch.matmul(attn_weights, value)
193
+ attn_output = attn_output.transpose(1, 2).contiguous()
194
+
195
+ return attn_output, attn_weights
196
+
197
+
198
+ # Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->Vivit
199
+ class VivitSelfAttention(nn.Module):
200
+ def __init__(self, config: VivitConfig) -> None:
201
+ super().__init__()
202
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
203
+ raise ValueError(
204
+ f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
205
+ f"heads {config.num_attention_heads}."
206
+ )
207
+
208
+ self.config = config
209
+ self.num_attention_heads = config.num_attention_heads
210
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
211
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
212
+ self.dropout_prob = config.attention_probs_dropout_prob
213
+ self.scaling = self.attention_head_size**-0.5
214
+ self.is_causal = False
215
+
216
+ self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
217
+ self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
218
+ self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
219
+
220
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
221
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
222
+ x = x.view(new_x_shape)
223
+ return x.permute(0, 2, 1, 3)
224
+
225
+ def forward(
226
+ self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
227
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
228
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
229
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
230
+ query_layer = self.transpose_for_scores(self.query(hidden_states))
231
+
232
+ attention_interface: Callable = eager_attention_forward
233
+ if self.config._attn_implementation != "eager":
234
+ if self.config._attn_implementation == "sdpa" and output_attentions:
235
+ logger.warning_once(
236
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
237
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
238
+ )
239
+ else:
240
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
241
+
242
+ context_layer, attention_probs = attention_interface(
243
+ self,
244
+ query_layer,
245
+ key_layer,
246
+ value_layer,
247
+ head_mask,
248
+ is_causal=self.is_causal,
249
+ scaling=self.scaling,
250
+ dropout=0.0 if not self.training else self.dropout_prob,
251
+ )
252
+
253
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
254
+ context_layer = context_layer.reshape(new_context_layer_shape)
255
+
256
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
257
+
258
+ return outputs
259
+
260
+
261
+ # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Vivit
262
+ class VivitSelfOutput(nn.Module):
263
+ """
264
+ The residual connection is defined in VivitLayer instead of here (as is the case with other models), due to the
265
+ layernorm applied before each block.
266
+ """
267
+
268
+ def __init__(self, config: VivitConfig) -> None:
269
+ super().__init__()
270
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
271
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
272
+
273
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
274
+ hidden_states = self.dense(hidden_states)
275
+ hidden_states = self.dropout(hidden_states)
276
+
277
+ return hidden_states
278
+
279
+
280
+ # Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->Vivit
281
+ class VivitAttention(nn.Module):
282
+ def __init__(self, config: VivitConfig) -> None:
283
+ super().__init__()
284
+ self.attention = VivitSelfAttention(config)
285
+ self.output = VivitSelfOutput(config)
286
+ self.pruned_heads = set()
287
+
288
+ def prune_heads(self, heads: Set[int]) -> None:
289
+ if len(heads) == 0:
290
+ return
291
+ heads, index = find_pruneable_heads_and_indices(
292
+ heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
293
+ )
294
+
295
+ # Prune linear layers
296
+ self.attention.query = prune_linear_layer(self.attention.query, index)
297
+ self.attention.key = prune_linear_layer(self.attention.key, index)
298
+ self.attention.value = prune_linear_layer(self.attention.value, index)
299
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
300
+
301
+ # Update hyper params and store pruned heads
302
+ self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
303
+ self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
304
+ self.pruned_heads = self.pruned_heads.union(heads)
305
+
306
+ def forward(
307
+ self,
308
+ hidden_states: torch.Tensor,
309
+ head_mask: Optional[torch.Tensor] = None,
310
+ output_attentions: bool = False,
311
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
312
+ self_outputs = self.attention(hidden_states, head_mask, output_attentions)
313
+
314
+ attention_output = self.output(self_outputs[0], hidden_states)
315
+
316
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
317
+ return outputs
318
+
319
+
320
+ class VivitIntermediate(nn.Module):
321
+ def __init__(self, config):
322
+ super().__init__()
323
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
324
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
325
+ if isinstance(config.hidden_act, str):
326
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
327
+ else:
328
+ self.intermediate_act_fn = config.hidden_act
329
+
330
+ def forward(self, hidden_states):
331
+ hidden_states = self.dense(hidden_states)
332
+ hidden_states = self.intermediate_act_fn(hidden_states)
333
+ hidden_states = self.dropout(hidden_states)
334
+
335
+ return hidden_states
336
+
337
+
338
+ class VivitOutput(nn.Module):
339
+ def __init__(self, config):
340
+ super().__init__()
341
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
342
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
343
+
344
+ def forward(self, hidden_states, input_tensor):
345
+ hidden_states = self.dense(hidden_states)
346
+
347
+ hidden_states = self.dropout(hidden_states)
348
+
349
+ hidden_states = hidden_states + input_tensor
350
+
351
+ return hidden_states
352
+
353
+
354
+ class VivitLayer(nn.Module):
355
+ """This corresponds to the EncoderBlock class in the scenic/vivit implementation."""
356
+
357
+ def __init__(self, config):
358
+ super().__init__()
359
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
360
+ self.seq_len_dim = 1
361
+ self.attention = VivitAttention(config)
362
+ self.intermediate = VivitIntermediate(config)
363
+ self.output = VivitOutput(config)
364
+ self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
365
+ self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
366
+
367
+ def forward(self, hidden_states, head_mask=None, output_attentions=False):
368
+ self_attention_outputs = self.attention(
369
+ # in Vivit, layernorm is applied before self-attention
370
+ self.layernorm_before(hidden_states),
371
+ head_mask,
372
+ output_attentions=output_attentions,
373
+ )
374
+ attention_output = self_attention_outputs[0]
375
+ # add self attentions if we output attention weights
376
+ outputs = self_attention_outputs[1:]
377
+
378
+ # first residual connection
379
+ hidden_states = attention_output + hidden_states
380
+
381
+ # in Vivit, layernorm is also applied after self-attention
382
+ layer_output = self.layernorm_after(hidden_states)
383
+ layer_output = self.intermediate(layer_output)
384
+
385
+ # second residual connection is done here
386
+ layer_output = self.output(layer_output, hidden_states)
387
+
388
+ outputs = (layer_output,) + outputs
389
+
390
+ return outputs
391
+
392
+
393
+ class VivitEncoder(nn.Module):
394
+ def __init__(self, config):
395
+ super().__init__()
396
+ self.config = config
397
+ self.layer = nn.ModuleList([VivitLayer(config) for _ in range(config.num_hidden_layers)])
398
+ self.gradient_checkpointing = False
399
+
400
+ def forward(
401
+ self,
402
+ hidden_states,
403
+ head_mask=None,
404
+ output_attentions=False,
405
+ output_hidden_states=False,
406
+ return_dict=True,
407
+ ):
408
+ all_hidden_states = () if output_hidden_states else None
409
+ all_self_attentions = () if output_attentions else None
410
+
411
+ for i, layer_module in enumerate(self.layer):
412
+ if output_hidden_states:
413
+ all_hidden_states = all_hidden_states + (hidden_states,)
414
+
415
+ layer_head_mask = head_mask[i] if head_mask is not None else None
416
+
417
+ if self.gradient_checkpointing and self.training:
418
+ layer_outputs = self._gradient_checkpointing_func(
419
+ layer_module.__call__,
420
+ hidden_states,
421
+ layer_head_mask,
422
+ output_attentions,
423
+ )
424
+ else:
425
+ layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
426
+
427
+ hidden_states = layer_outputs[0]
428
+
429
+ if output_attentions:
430
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
431
+
432
+ if output_hidden_states:
433
+ all_hidden_states = all_hidden_states + (hidden_states,)
434
+
435
+ if not return_dict:
436
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
437
+ return BaseModelOutput(
438
+ last_hidden_state=hidden_states,
439
+ hidden_states=all_hidden_states,
440
+ attentions=all_self_attentions,
441
+ )
442
+
443
+
444
+ class VivitPooler(nn.Module):
445
+ def __init__(self, config):
446
+ super().__init__()
447
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
448
+ self.activation = nn.Tanh()
449
+
450
+ def forward(self, hidden_states):
451
+ # We "pool" the model by simply taking the hidden state corresponding
452
+ # to the first token.
453
+ first_token_tensor = hidden_states[:, 0]
454
+ pooled_output = self.dense(first_token_tensor)
455
+ pooled_output = self.activation(pooled_output)
456
+ return pooled_output
457
+
458
+
459
+ class VivitPreTrainedModel(PreTrainedModel):
460
+ """
461
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
462
+ models.
463
+ """
464
+
465
+ config_class = VivitConfig
466
+ base_model_prefix = "vivit"
467
+ main_input_name = "pixel_values"
468
+ supports_gradient_checkpointing = True
469
+ _no_split_modules = []
470
+ _supports_sdpa = True
471
+ _supports_flash_attn_2 = True
472
+
473
+ def _init_weights(self, module):
474
+ """Initialize the weights"""
475
+ if isinstance(module, (nn.Linear, nn.Conv3d)):
476
+ # Slightly different from the TF version which uses truncated_normal for initialization
477
+ # cf https://github.com/pytorch/pytorch/pull/5617
478
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
479
+ if module.bias is not None:
480
+ module.bias.data.zero_()
481
+ elif isinstance(module, nn.Embedding):
482
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
483
+ if module.padding_idx is not None:
484
+ module.weight.data[module.padding_idx].zero_()
485
+ elif isinstance(module, nn.LayerNorm):
486
+ module.bias.data.zero_()
487
+ module.weight.data.fill_(1.0)
488
+ elif isinstance(module, VivitEmbeddings):
489
+ module.cls_token.data.zero_()
490
+ module.position_embeddings.data.zero_()
491
+
492
+
493
+ VIVIT_START_DOCSTRING = r"""
494
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
495
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
496
+ behavior.
497
+
498
+ Parameters:
499
+ config ([`VivitConfig`]): Model configuration class with all the parameters of the model.
500
+ Initializing with a config file does not load the weights associated with the model, only the
501
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
502
+ """
503
+
504
+ VIVIT_INPUTS_DOCSTRING = r"""
505
+ Args:
506
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
507
+ Pixel values. Pixel values can be obtained using [`VivitImageProcessor`]. See
508
+ [`VivitImageProcessor.preprocess`] for details.
509
+
510
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
511
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
512
+
513
+ - 1 indicates the head is **not masked**,
514
+ - 0 indicates the head is **masked**.
515
+
516
+ output_attentions (`bool`, *optional*):
517
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
518
+ tensors for more detail.
519
+ output_hidden_states (`bool`, *optional*):
520
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
521
+ more detail.
522
+ interpolate_pos_encoding (`bool`, *optional*, `False`):
523
+ Whether to interpolate the pre-trained position encodings.
524
+ return_dict (`bool`, *optional*):
525
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
526
+ """
527
+
528
+
529
+ @add_start_docstrings(
530
+ "The bare ViViT Transformer model outputting raw hidden-states without any specific head on top.",
531
+ VIVIT_START_DOCSTRING,
532
+ )
533
+ class VivitModel(VivitPreTrainedModel):
534
+ def __init__(self, config, add_pooling_layer=True):
535
+ super().__init__(config)
536
+ self.config = config
537
+
538
+ self.embeddings = VivitEmbeddings(config)
539
+ self.encoder = VivitEncoder(config)
540
+
541
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
542
+ self.pooler = VivitPooler(config) if add_pooling_layer else None
543
+
544
+ # Initialize weights and apply final processing
545
+ self.post_init()
546
+
547
+ def get_input_embeddings(self):
548
+ return self.embeddings.patch_embeddings
549
+
550
+ def _prune_heads(self, heads_to_prune):
551
+ """
552
+ Prunes heads of the model.
553
+
554
+ Args:
555
+ heads_to_prune:
556
+ dict of {layer_num: list of heads to prune in this layer}
557
+ """
558
+ for layer, heads in heads_to_prune.items():
559
+ self.encoder.layer[layer].attention.prune_heads(heads)
560
+
561
+ @add_start_docstrings_to_model_forward(VIVIT_INPUTS_DOCSTRING)
562
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
563
+ def forward(
564
+ self,
565
+ pixel_values: Optional[torch.FloatTensor] = None,
566
+ head_mask: Optional[torch.FloatTensor] = None,
567
+ output_attentions: Optional[bool] = None,
568
+ output_hidden_states: Optional[bool] = None,
569
+ interpolate_pos_encoding: bool = False,
570
+ return_dict: Optional[bool] = None,
571
+ ) -> Union[Tuple[torch.FloatTensor], BaseModelOutputWithPooling]:
572
+ r"""
573
+ Returns:
574
+
575
+ Examples:
576
+
577
+ ```python
578
+ >>> import av
579
+ >>> import numpy as np
580
+
581
+ >>> from transformers import VivitImageProcessor, VivitModel
582
+ >>> from huggingface_hub import hf_hub_download
583
+
584
+ >>> np.random.seed(0)
585
+
586
+
587
+ >>> def read_video_pyav(container, indices):
588
+ ... '''
589
+ ... Decode the video with PyAV decoder.
590
+ ... Args:
591
+ ... container (`av.container.input.InputContainer`): PyAV container.
592
+ ... indices (`List[int]`): List of frame indices to decode.
593
+ ... Returns:
594
+ ... result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
595
+ ... '''
596
+ ... frames = []
597
+ ... container.seek(0)
598
+ ... start_index = indices[0]
599
+ ... end_index = indices[-1]
600
+ ... for i, frame in enumerate(container.decode(video=0)):
601
+ ... if i > end_index:
602
+ ... break
603
+ ... if i >= start_index and i in indices:
604
+ ... frames.append(frame)
605
+ ... return np.stack([x.to_ndarray(format="rgb24") for x in frames])
606
+
607
+
608
+ >>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
609
+ ... '''
610
+ ... Sample a given number of frame indices from the video.
611
+ ... Args:
612
+ ... clip_len (`int`): Total number of frames to sample.
613
+ ... frame_sample_rate (`int`): Sample every n-th frame.
614
+ ... seg_len (`int`): Maximum allowed index of sample's last frame.
615
+ ... Returns:
616
+ ... indices (`List[int]`): List of sampled frame indices
617
+ ... '''
618
+ ... converted_len = int(clip_len * frame_sample_rate)
619
+ ... end_idx = np.random.randint(converted_len, seg_len)
620
+ ... start_idx = end_idx - converted_len
621
+ ... indices = np.linspace(start_idx, end_idx, num=clip_len)
622
+ ... indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
623
+ ... return indices
624
+
625
+
626
+ >>> # video clip consists of 300 frames (10 seconds at 30 FPS)
627
+ >>> file_path = hf_hub_download(
628
+ ... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset"
629
+ ... )
630
+ >>> container = av.open(file_path)
631
+
632
+ >>> # sample 32 frames
633
+ >>> indices = sample_frame_indices(clip_len=32, frame_sample_rate=1, seg_len=container.streams.video[0].frames)
634
+ >>> video = read_video_pyav(container=container, indices=indices)
635
+
636
+ >>> image_processor = VivitImageProcessor.from_pretrained("google/vivit-b-16x2-kinetics400")
637
+ >>> model = VivitModel.from_pretrained("google/vivit-b-16x2-kinetics400")
638
+
639
+ >>> # prepare video for the model
640
+ >>> inputs = image_processor(list(video), return_tensors="pt")
641
+
642
+ >>> # forward pass
643
+ >>> outputs = model(**inputs)
644
+ >>> last_hidden_states = outputs.last_hidden_state
645
+ >>> list(last_hidden_states.shape)
646
+ [1, 3137, 768]
647
+ ```"""
648
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
649
+ output_hidden_states = (
650
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
651
+ )
652
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
653
+
654
+ if pixel_values is None:
655
+ raise ValueError("You have to specify pixel_values")
656
+
657
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
658
+
659
+ embedding_output = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
660
+
661
+ encoder_outputs = self.encoder(
662
+ embedding_output,
663
+ head_mask=head_mask,
664
+ output_attentions=output_attentions,
665
+ output_hidden_states=output_hidden_states,
666
+ return_dict=return_dict,
667
+ )
668
+ sequence_output = encoder_outputs[0]
669
+ sequence_output = self.layernorm(sequence_output)
670
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
671
+
672
+ if not return_dict:
673
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
674
+
675
+ return BaseModelOutputWithPooling(
676
+ last_hidden_state=sequence_output,
677
+ pooler_output=pooled_output,
678
+ hidden_states=encoder_outputs.hidden_states,
679
+ attentions=encoder_outputs.attentions,
680
+ )
681
+
682
+
683
+ @add_start_docstrings(
684
+ """
685
+ ViViT Transformer model with a video classification head on top (a linear layer on top of the final hidden state of the
686
+ [CLS] token) e.g. for Kinetics-400.
687
+
688
+ <Tip>
689
+
690
+ Note that it's possible to fine-tune ViT on higher resolution images than the ones it has been trained on, by
691
+ setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
692
+ position embeddings to the higher resolution.
693
+
694
+ </Tip>
695
+ """,
696
+ VIVIT_START_DOCSTRING,
697
+ )
698
+ class VivitForVideoClassification(VivitPreTrainedModel):
699
+ def __init__(self, config):
700
+ super().__init__(config)
701
+
702
+ self.num_labels = config.num_labels
703
+ self.vivit = VivitModel(config, add_pooling_layer=False)
704
+
705
+ # Classifier head
706
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
707
+
708
+ # Initialize weights and apply final processing
709
+ self.post_init()
710
+
711
+ @add_start_docstrings_to_model_forward(VIVIT_INPUTS_DOCSTRING)
712
+ @replace_return_docstrings(output_type=ImageClassifierOutput, config_class=_CONFIG_FOR_DOC)
713
+ def forward(
714
+ self,
715
+ pixel_values: Optional[torch.FloatTensor] = None,
716
+ head_mask: Optional[torch.FloatTensor] = None,
717
+ labels: Optional[torch.LongTensor] = None,
718
+ output_attentions: Optional[bool] = None,
719
+ output_hidden_states: Optional[bool] = None,
720
+ interpolate_pos_encoding: bool = False,
721
+ return_dict: Optional[bool] = None,
722
+ ) -> Union[Tuple[torch.FloatTensor], ImageClassifierOutput]:
723
+ r"""
724
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
725
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
726
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
727
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
728
+
729
+ Returns:
730
+
731
+ Examples:
732
+
733
+ ```python
734
+ >>> import av
735
+ >>> import numpy as np
736
+ >>> import torch
737
+
738
+ >>> from transformers import VivitImageProcessor, VivitForVideoClassification
739
+ >>> from huggingface_hub import hf_hub_download
740
+
741
+ >>> np.random.seed(0)
742
+
743
+
744
+ >>> def read_video_pyav(container, indices):
745
+ ... '''
746
+ ... Decode the video with PyAV decoder.
747
+ ... Args:
748
+ ... container (`av.container.input.InputContainer`): PyAV container.
749
+ ... indices (`List[int]`): List of frame indices to decode.
750
+ ... Returns:
751
+ ... result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
752
+ ... '''
753
+ ... frames = []
754
+ ... container.seek(0)
755
+ ... start_index = indices[0]
756
+ ... end_index = indices[-1]
757
+ ... for i, frame in enumerate(container.decode(video=0)):
758
+ ... if i > end_index:
759
+ ... break
760
+ ... if i >= start_index and i in indices:
761
+ ... frames.append(frame)
762
+ ... return np.stack([x.to_ndarray(format="rgb24") for x in frames])
763
+
764
+
765
+ >>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
766
+ ... '''
767
+ ... Sample a given number of frame indices from the video.
768
+ ... Args:
769
+ ... clip_len (`int`): Total number of frames to sample.
770
+ ... frame_sample_rate (`int`): Sample every n-th frame.
771
+ ... seg_len (`int`): Maximum allowed index of sample's last frame.
772
+ ... Returns:
773
+ ... indices (`List[int]`): List of sampled frame indices
774
+ ... '''
775
+ ... converted_len = int(clip_len * frame_sample_rate)
776
+ ... end_idx = np.random.randint(converted_len, seg_len)
777
+ ... start_idx = end_idx - converted_len
778
+ ... indices = np.linspace(start_idx, end_idx, num=clip_len)
779
+ ... indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
780
+ ... return indices
781
+
782
+
783
+ >>> # video clip consists of 300 frames (10 seconds at 30 FPS)
784
+ >>> file_path = hf_hub_download(
785
+ ... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset"
786
+ ... )
787
+ >>> container = av.open(file_path)
788
+
789
+ >>> # sample 32 frames
790
+ >>> indices = sample_frame_indices(clip_len=32, frame_sample_rate=4, seg_len=container.streams.video[0].frames)
791
+ >>> video = read_video_pyav(container=container, indices=indices)
792
+
793
+ >>> image_processor = VivitImageProcessor.from_pretrained("google/vivit-b-16x2-kinetics400")
794
+ >>> model = VivitForVideoClassification.from_pretrained("google/vivit-b-16x2-kinetics400")
795
+
796
+ >>> inputs = image_processor(list(video), return_tensors="pt")
797
+
798
+ >>> with torch.no_grad():
799
+ ... outputs = model(**inputs)
800
+ ... logits = outputs.logits
801
+
802
+ >>> # model predicts one of the 400 Kinetics-400 classes
803
+ >>> predicted_label = logits.argmax(-1).item()
804
+ >>> print(model.config.id2label[predicted_label])
805
+ LABEL_116
806
+ ```"""
807
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
808
+
809
+ outputs = self.vivit(
810
+ pixel_values,
811
+ head_mask=head_mask,
812
+ output_attentions=output_attentions,
813
+ output_hidden_states=output_hidden_states,
814
+ interpolate_pos_encoding=interpolate_pos_encoding,
815
+ return_dict=return_dict,
816
+ )
817
+
818
+ sequence_output = outputs[0]
819
+
820
+ logits = self.classifier(sequence_output[:, 0, :])
821
+
822
+ loss = None
823
+ if labels is not None:
824
+ if self.num_labels == 1:
825
+ # We are doing regression
826
+ loss_fct = MSELoss()
827
+ loss = loss_fct(logits.view(-1), labels.view(-1))
828
+ else:
829
+ loss_fct = CrossEntropyLoss()
830
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
831
+
832
+ if not return_dict:
833
+ output = (logits,) + outputs[2:]
834
+ return ((loss,) + output) if loss is not None else output
835
+
836
+ return ImageClassifierOutput(
837
+ loss=loss,
838
+ logits=logits,
839
+ hidden_states=outputs.hidden_states,
840
+ attentions=outputs.attentions,
841
+ )
842
+
843
+
844
+ __all__ = ["VivitModel", "VivitPreTrainedModel", "VivitForVideoClassification"]
docs/transformers/build/lib/transformers/models/wav2vec2/__init__.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_wav2vec2 import *
22
+ from .feature_extraction_wav2vec2 import *
23
+ from .modeling_flax_wav2vec2 import *
24
+ from .modeling_tf_wav2vec2 import *
25
+ from .modeling_wav2vec2 import *
26
+ from .processing_wav2vec2 import *
27
+ from .tokenization_wav2vec2 import *
28
+ else:
29
+ import sys
30
+
31
+ _file = globals()["__file__"]
32
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
docs/transformers/build/lib/transformers/models/wav2vec2/configuration_wav2vec2.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Wav2Vec2 model configuration"""
16
+
17
+ import functools
18
+ import operator
19
+
20
+ from ...configuration_utils import PretrainedConfig
21
+ from ...utils import logging
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+
27
+ class Wav2Vec2Config(PretrainedConfig):
28
+ r"""
29
+ This is the configuration class to store the configuration of a [`Wav2Vec2Model`]. It is used to instantiate an
30
+ Wav2Vec2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
31
+ with the defaults will yield a similar configuration to that of the Wav2Vec2
32
+ [facebook/wav2vec2-base-960h](https://huggingface.co/facebook/wav2vec2-base-960h) architecture.
33
+
34
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
35
+ documentation from [`PretrainedConfig`] for more information.
36
+
37
+
38
+ Args:
39
+ vocab_size (`int`, *optional*, defaults to 32):
40
+ Vocabulary size of the Wav2Vec2 model. Defines the number of different tokens that can be represented by
41
+ the `inputs_ids` passed when calling [`Wav2Vec2Model`] or [`TFWav2Vec2Model`]. Vocabulary size of the
42
+ model. Defines the different tokens that can be represented by the *inputs_ids* passed to the forward
43
+ method of [`Wav2Vec2Model`].
44
+ hidden_size (`int`, *optional*, defaults to 768):
45
+ Dimensionality of the encoder layers and the pooler layer.
46
+ num_hidden_layers (`int`, *optional*, defaults to 12):
47
+ Number of hidden layers in the Transformer encoder.
48
+ num_attention_heads (`int`, *optional*, defaults to 12):
49
+ Number of attention heads for each attention layer in the Transformer encoder.
50
+ intermediate_size (`int`, *optional*, defaults to 3072):
51
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
52
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
53
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
54
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
55
+ hidden_dropout (`float`, *optional*, defaults to 0.1):
56
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
57
+ activation_dropout (`float`, *optional*, defaults to 0.1):
58
+ The dropout ratio for activations inside the fully connected layer.
59
+ attention_dropout (`float`, *optional*, defaults to 0.1):
60
+ The dropout ratio for the attention probabilities.
61
+ final_dropout (`float`, *optional*, defaults to 0.1):
62
+ The dropout probability for the final projection layer of [`Wav2Vec2ForCTC`].
63
+ layerdrop (`float`, *optional*, defaults to 0.1):
64
+ The LayerDrop probability. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) for more
65
+ details.
66
+ initializer_range (`float`, *optional*, defaults to 0.02):
67
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
68
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
69
+ The epsilon used by the layer normalization layers.
70
+ feat_extract_norm (`str`, *optional*, defaults to `"group"`):
71
+ The norm to be applied to 1D convolutional layers in feature encoder. One of `"group"` for group
72
+ normalization of only the first 1D convolutional layer or `"layer"` for layer normalization of all 1D
73
+ convolutional layers.
74
+ feat_proj_dropout (`float`, *optional*, defaults to 0.0):
75
+ The dropout probability for output of the feature encoder.
76
+ feat_extract_activation (`str, `optional`, defaults to `"gelu"`):
77
+ The non-linear activation function (function or string) in the 1D convolutional layers of the feature
78
+ extractor. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` are supported.
79
+ feat_quantizer_dropout (`float`, *optional*, defaults to 0.0):
80
+ The dropout probability for quantized feature encoder states.
81
+ conv_dim (`Tuple[int]` or `List[int]`, *optional*, defaults to `(512, 512, 512, 512, 512, 512, 512)`):
82
+ A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the
83
+ feature encoder. The length of *conv_dim* defines the number of 1D convolutional layers.
84
+ conv_stride (`Tuple[int]` or `List[int]`, *optional*, defaults to `(5, 2, 2, 2, 2, 2, 2)`):
85
+ A tuple of integers defining the stride of each 1D convolutional layer in the feature encoder. The length
86
+ of *conv_stride* defines the number of convolutional layers and has to match the length of *conv_dim*.
87
+ conv_kernel (`Tuple[int]` or `List[int]`, *optional*, defaults to `(10, 3, 3, 3, 3, 3, 3)`):
88
+ A tuple of integers defining the kernel size of each 1D convolutional layer in the feature encoder. The
89
+ length of *conv_kernel* defines the number of convolutional layers and has to match the length of
90
+ *conv_dim*.
91
+ conv_bias (`bool`, *optional*, defaults to `False`):
92
+ Whether the 1D convolutional layers have a bias.
93
+ num_conv_pos_embeddings (`int`, *optional*, defaults to 128):
94
+ Number of convolutional positional embeddings. Defines the kernel size of 1D convolutional positional
95
+ embeddings layer.
96
+ num_conv_pos_embedding_groups (`int`, *optional*, defaults to 16):
97
+ Number of groups of 1D convolutional positional embeddings layer.
98
+ do_stable_layer_norm (`bool`, *optional*, defaults to `False`):
99
+ Whether to apply *stable* layer norm architecture of the Transformer encoder. `do_stable_layer_norm is
100
+ True` corresponds to applying layer norm before the attention layer, whereas `do_stable_layer_norm is
101
+ False` corresponds to applying layer norm after the attention layer.
102
+ apply_spec_augment (`bool`, *optional*, defaults to `True`):
103
+ Whether to apply *SpecAugment* data augmentation to the outputs of the feature encoder. For reference see
104
+ [SpecAugment: A Simple Data Augmentation Method for Automatic Speech
105
+ Recognition](https://arxiv.org/abs/1904.08779).
106
+ mask_time_prob (`float`, *optional*, defaults to 0.05):
107
+ Percentage (between 0 and 1) of all feature vectors along the time axis which will be masked. The masking
108
+ procecure generates ''mask_time_prob*len(time_axis)/mask_time_length'' independent masks over the axis. If
109
+ reasoning from the propability of each feature vector to be chosen as the start of the vector span to be
110
+ masked, *mask_time_prob* should be `prob_vector_start*mask_time_length`. Note that overlap may decrease the
111
+ actual percentage of masked vectors. This is only relevant if `apply_spec_augment is True`.
112
+ mask_time_length (`int`, *optional*, defaults to 10):
113
+ Length of vector span along the time axis.
114
+ mask_time_min_masks (`int`, *optional*, defaults to 2),:
115
+ The minimum number of masks of length `mask_feature_length` generated along the time axis, each time step,
116
+ irrespectively of `mask_feature_prob`. Only relevant if ''mask_time_prob*len(time_axis)/mask_time_length <
117
+ mask_time_min_masks''
118
+ mask_feature_prob (`float`, *optional*, defaults to 0.0):
119
+ Percentage (between 0 and 1) of all feature vectors along the feature axis which will be masked. The
120
+ masking procecure generates ''mask_feature_prob*len(feature_axis)/mask_time_length'' independent masks over
121
+ the axis. If reasoning from the propability of each feature vector to be chosen as the start of the vector
122
+ span to be masked, *mask_feature_prob* should be `prob_vector_start*mask_feature_length`. Note that overlap
123
+ may decrease the actual percentage of masked vectors. This is only relevant if `apply_spec_augment is
124
+ True`.
125
+ mask_feature_length (`int`, *optional*, defaults to 10):
126
+ Length of vector span along the feature axis.
127
+ mask_feature_min_masks (`int`, *optional*, defaults to 0),:
128
+ The minimum number of masks of length `mask_feature_length` generated along the feature axis, each time
129
+ step, irrespectively of `mask_feature_prob`. Only relevant if
130
+ ''mask_feature_prob*len(feature_axis)/mask_feature_length < mask_feature_min_masks''
131
+ num_codevectors_per_group (`int`, *optional*, defaults to 320):
132
+ Number of entries in each quantization codebook (group).
133
+ num_codevector_groups (`int`, *optional*, defaults to 2):
134
+ Number of codevector groups for product codevector quantization.
135
+ contrastive_logits_temperature (`float`, *optional*, defaults to 0.1):
136
+ The temperature *kappa* in the contrastive loss.
137
+ feat_quantizer_dropout (`float`, *optional*, defaults to 0.0):
138
+ The dropout probability for the output of the feature encoder that's used by the quantizer.
139
+ num_negatives (`int`, *optional*, defaults to 100):
140
+ Number of negative samples for the contrastive loss.
141
+ codevector_dim (`int`, *optional*, defaults to 256):
142
+ Dimensionality of the quantized feature vectors.
143
+ proj_codevector_dim (`int`, *optional*, defaults to 256):
144
+ Dimensionality of the final projection of both the quantized and the transformer features.
145
+ diversity_loss_weight (`int`, *optional*, defaults to 0.1):
146
+ The weight of the codebook diversity loss component.
147
+ ctc_loss_reduction (`str`, *optional*, defaults to `"sum"`):
148
+ Specifies the reduction to apply to the output of `torch.nn.CTCLoss`. Only relevant when training an
149
+ instance of [`Wav2Vec2ForCTC`].
150
+ ctc_zero_infinity (`bool`, *optional*, defaults to `False`):
151
+ Whether to zero infinite losses and the associated gradients of `torch.nn.CTCLoss`. Infinite losses mainly
152
+ occur when the inputs are too short to be aligned to the targets. Only relevant when training an instance
153
+ of [`Wav2Vec2ForCTC`].
154
+ use_weighted_layer_sum (`bool`, *optional*, defaults to `False`):
155
+ Whether to use a weighted average of layer outputs with learned weights. Only relevant when using an
156
+ instance of [`Wav2Vec2ForSequenceClassification`].
157
+ classifier_proj_size (`int`, *optional*, defaults to 256):
158
+ Dimensionality of the projection before token mean-pooling for classification.
159
+ tdnn_dim (`Tuple[int]` or `List[int]`, *optional*, defaults to `(512, 512, 512, 512, 1500)`):
160
+ A tuple of integers defining the number of output channels of each 1D convolutional layer in the *TDNN*
161
+ module of the *XVector* model. The length of *tdnn_dim* defines the number of *TDNN* layers.
162
+ tdnn_kernel (`Tuple[int]` or `List[int]`, *optional*, defaults to `(5, 3, 3, 1, 1)`):
163
+ A tuple of integers defining the kernel size of each 1D convolutional layer in the *TDNN* module of the
164
+ *XVector* model. The length of *tdnn_kernel* has to match the length of *tdnn_dim*.
165
+ tdnn_dilation (`Tuple[int]` or `List[int]`, *optional*, defaults to `(1, 2, 3, 1, 1)`):
166
+ A tuple of integers defining the dilation factor of each 1D convolutional layer in *TDNN* module of the
167
+ *XVector* model. The length of *tdnn_dilation* has to match the length of *tdnn_dim*.
168
+ xvector_output_dim (`int`, *optional*, defaults to 512):
169
+ Dimensionality of the *XVector* embedding vectors.
170
+ add_adapter (`bool`, *optional*, defaults to `False`):
171
+ Whether a convolutional network should be stacked on top of the Wav2Vec2 Encoder. Can be very useful for
172
+ warm-starting Wav2Vec2 for SpeechEncoderDecoder models.
173
+ adapter_kernel_size (`int`, *optional*, defaults to 3):
174
+ Kernel size of the convolutional layers in the adapter network. Only relevant if `add_adapter is True`.
175
+ adapter_stride (`int`, *optional*, defaults to 2):
176
+ Stride of the convolutional layers in the adapter network. Only relevant if `add_adapter is True`.
177
+ num_adapter_layers (`int`, *optional*, defaults to 3):
178
+ Number of convolutional layers that should be used in the adapter network. Only relevant if `add_adapter is
179
+ True`.
180
+ adapter_attn_dim (`int`, *optional*):
181
+ Dimension of the attention adapter weights to be used in each attention block. An example of a model using
182
+ attention adapters is [facebook/mms-1b-all](https://huggingface.co/facebook/mms-1b-all).
183
+ output_hidden_size (`int`, *optional*):
184
+ Dimensionality of the encoder output layer. If not defined, this defaults to *hidden-size*. Only relevant
185
+ if `add_adapter is True`.
186
+
187
+ Example:
188
+
189
+ ```python
190
+ >>> from transformers import Wav2Vec2Config, Wav2Vec2Model
191
+
192
+ >>> # Initializing a Wav2Vec2 facebook/wav2vec2-base-960h style configuration
193
+ >>> configuration = Wav2Vec2Config()
194
+
195
+ >>> # Initializing a model (with random weights) from the facebook/wav2vec2-base-960h style configuration
196
+ >>> model = Wav2Vec2Model(configuration)
197
+
198
+ >>> # Accessing the model configuration
199
+ >>> configuration = model.config
200
+ ```"""
201
+
202
+ model_type = "wav2vec2"
203
+
204
+ def __init__(
205
+ self,
206
+ vocab_size=32,
207
+ hidden_size=768,
208
+ num_hidden_layers=12,
209
+ num_attention_heads=12,
210
+ intermediate_size=3072,
211
+ hidden_act="gelu",
212
+ hidden_dropout=0.1,
213
+ activation_dropout=0.1,
214
+ attention_dropout=0.1,
215
+ feat_proj_dropout=0.0,
216
+ feat_quantizer_dropout=0.0,
217
+ final_dropout=0.1,
218
+ layerdrop=0.1,
219
+ initializer_range=0.02,
220
+ layer_norm_eps=1e-5,
221
+ feat_extract_norm="group",
222
+ feat_extract_activation="gelu",
223
+ conv_dim=(512, 512, 512, 512, 512, 512, 512),
224
+ conv_stride=(5, 2, 2, 2, 2, 2, 2),
225
+ conv_kernel=(10, 3, 3, 3, 3, 2, 2),
226
+ conv_bias=False,
227
+ num_conv_pos_embeddings=128,
228
+ num_conv_pos_embedding_groups=16,
229
+ do_stable_layer_norm=False,
230
+ apply_spec_augment=True,
231
+ mask_time_prob=0.05,
232
+ mask_time_length=10,
233
+ mask_time_min_masks=2,
234
+ mask_feature_prob=0.0,
235
+ mask_feature_length=10,
236
+ mask_feature_min_masks=0,
237
+ num_codevectors_per_group=320,
238
+ num_codevector_groups=2,
239
+ contrastive_logits_temperature=0.1,
240
+ num_negatives=100,
241
+ codevector_dim=256,
242
+ proj_codevector_dim=256,
243
+ diversity_loss_weight=0.1,
244
+ ctc_loss_reduction="sum",
245
+ ctc_zero_infinity=False,
246
+ use_weighted_layer_sum=False,
247
+ classifier_proj_size=256,
248
+ tdnn_dim=(512, 512, 512, 512, 1500),
249
+ tdnn_kernel=(5, 3, 3, 1, 1),
250
+ tdnn_dilation=(1, 2, 3, 1, 1),
251
+ xvector_output_dim=512,
252
+ pad_token_id=0,
253
+ bos_token_id=1,
254
+ eos_token_id=2,
255
+ add_adapter=False,
256
+ adapter_kernel_size=3,
257
+ adapter_stride=2,
258
+ num_adapter_layers=3,
259
+ output_hidden_size=None,
260
+ adapter_attn_dim=None,
261
+ **kwargs,
262
+ ):
263
+ super().__init__(**kwargs, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id)
264
+ self.hidden_size = hidden_size
265
+ self.feat_extract_norm = feat_extract_norm
266
+ self.feat_extract_activation = feat_extract_activation
267
+ self.conv_dim = list(conv_dim)
268
+ self.conv_stride = list(conv_stride)
269
+ self.conv_kernel = list(conv_kernel)
270
+ self.conv_bias = conv_bias
271
+ self.num_conv_pos_embeddings = num_conv_pos_embeddings
272
+ self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups
273
+ self.num_feat_extract_layers = len(self.conv_dim)
274
+ self.num_hidden_layers = num_hidden_layers
275
+ self.intermediate_size = intermediate_size
276
+ self.hidden_act = hidden_act
277
+ self.num_attention_heads = num_attention_heads
278
+ self.hidden_dropout = hidden_dropout
279
+ self.attention_dropout = attention_dropout
280
+ self.activation_dropout = activation_dropout
281
+ self.feat_proj_dropout = feat_proj_dropout
282
+ self.final_dropout = final_dropout
283
+ self.layerdrop = layerdrop
284
+ self.layer_norm_eps = layer_norm_eps
285
+ self.initializer_range = initializer_range
286
+ self.vocab_size = vocab_size
287
+ self.do_stable_layer_norm = do_stable_layer_norm
288
+ self.use_weighted_layer_sum = use_weighted_layer_sum
289
+
290
+ if (
291
+ (len(self.conv_stride) != self.num_feat_extract_layers)
292
+ or (len(self.conv_kernel) != self.num_feat_extract_layers)
293
+ or (len(self.conv_dim) != self.num_feat_extract_layers)
294
+ ):
295
+ raise ValueError(
296
+ "Configuration for convolutional layers is incorrect. It is required that `len(config.conv_dim)` =="
297
+ " `len(config.conv_stride)` == `len(config.conv_kernel)`, but is `len(config.conv_dim) ="
298
+ f" {len(self.conv_dim)}`, `len(config.conv_stride) = {len(self.conv_stride)}`,"
299
+ f" `len(config.conv_kernel) = {len(self.conv_kernel)}`."
300
+ )
301
+
302
+ # fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779
303
+ self.apply_spec_augment = apply_spec_augment
304
+ self.mask_time_prob = mask_time_prob
305
+ self.mask_time_length = mask_time_length
306
+ self.mask_time_min_masks = mask_time_min_masks
307
+ self.mask_feature_prob = mask_feature_prob
308
+ self.mask_feature_length = mask_feature_length
309
+ self.mask_feature_min_masks = mask_feature_min_masks
310
+
311
+ # parameters for pretraining with codevector quantized representations
312
+ self.num_codevectors_per_group = num_codevectors_per_group
313
+ self.num_codevector_groups = num_codevector_groups
314
+ self.contrastive_logits_temperature = contrastive_logits_temperature
315
+ self.feat_quantizer_dropout = feat_quantizer_dropout
316
+ self.num_negatives = num_negatives
317
+ self.codevector_dim = codevector_dim
318
+ self.proj_codevector_dim = proj_codevector_dim
319
+ self.diversity_loss_weight = diversity_loss_weight
320
+
321
+ # ctc loss
322
+ self.ctc_loss_reduction = ctc_loss_reduction
323
+ self.ctc_zero_infinity = ctc_zero_infinity
324
+
325
+ # adapter
326
+ self.add_adapter = add_adapter
327
+ self.adapter_kernel_size = adapter_kernel_size
328
+ self.adapter_stride = adapter_stride
329
+ self.num_adapter_layers = num_adapter_layers
330
+ self.output_hidden_size = output_hidden_size or hidden_size
331
+ self.adapter_attn_dim = adapter_attn_dim
332
+
333
+ # SequenceClassification-specific parameter. Feel free to ignore for other classes.
334
+ self.classifier_proj_size = classifier_proj_size
335
+
336
+ # XVector-specific parameters. Feel free to ignore for other classes.
337
+ self.tdnn_dim = list(tdnn_dim)
338
+ self.tdnn_kernel = list(tdnn_kernel)
339
+ self.tdnn_dilation = list(tdnn_dilation)
340
+ self.xvector_output_dim = xvector_output_dim
341
+
342
+ @property
343
+ def inputs_to_logits_ratio(self):
344
+ return functools.reduce(operator.mul, self.conv_stride, 1)
345
+
346
+
347
+ __all__ = ["Wav2Vec2Config"]
docs/transformers/build/lib/transformers/models/wav2vec2/convert_wav2vec2_original_pytorch_checkpoint_to_pytorch.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Convert Wav2Vec2 checkpoint."""
16
+
17
+ import argparse
18
+ import json
19
+ import os
20
+
21
+ import fairseq
22
+ import torch
23
+ from fairseq.data import Dictionary
24
+
25
+ from transformers import (
26
+ Wav2Vec2Config,
27
+ Wav2Vec2CTCTokenizer,
28
+ Wav2Vec2FeatureExtractor,
29
+ Wav2Vec2ForCTC,
30
+ Wav2Vec2ForPreTraining,
31
+ Wav2Vec2Processor,
32
+ logging,
33
+ )
34
+ from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2ForSequenceClassification
35
+
36
+
37
+ logging.set_verbosity_info()
38
+ logger = logging.get_logger(__name__)
39
+
40
+ MAPPING = {
41
+ "post_extract_proj": "feature_projection.projection",
42
+ "encoder.pos_conv.0": "encoder.pos_conv_embed.conv",
43
+ "self_attn.k_proj": "encoder.layers.*.attention.k_proj",
44
+ "self_attn.v_proj": "encoder.layers.*.attention.v_proj",
45
+ "self_attn.q_proj": "encoder.layers.*.attention.q_proj",
46
+ "self_attn.out_proj": "encoder.layers.*.attention.out_proj",
47
+ "self_attn_layer_norm": "encoder.layers.*.layer_norm",
48
+ "fc1": "encoder.layers.*.feed_forward.intermediate_dense",
49
+ "fc2": "encoder.layers.*.feed_forward.output_dense",
50
+ "final_layer_norm": "encoder.layers.*.final_layer_norm",
51
+ "encoder.layer_norm": "encoder.layer_norm",
52
+ "adapter_layer": "encoder.layers.*.adapter_layer",
53
+ "w2v_model.layer_norm": "feature_projection.layer_norm",
54
+ "quantizer.weight_proj": "quantizer.weight_proj",
55
+ "quantizer.vars": "quantizer.codevectors",
56
+ "project_q": "project_q",
57
+ "final_proj": "project_hid",
58
+ "w2v_encoder.proj": "lm_head",
59
+ "mask_emb": "masked_spec_embed",
60
+ "pooling_layer.linear": "projector",
61
+ "pooling_layer.projection": "classifier",
62
+ }
63
+ TOP_LEVEL_KEYS = [
64
+ "lm_head",
65
+ "quantizer.weight_proj",
66
+ "quantizer.codevectors",
67
+ "project_q",
68
+ "project_hid",
69
+ "projector",
70
+ "classifier",
71
+ ]
72
+
73
+
74
+ def read_txt_into_dict(filename):
75
+ result = {}
76
+ with open(filename, "r") as file:
77
+ for line_number, line in enumerate(file):
78
+ line = line.strip()
79
+ if line:
80
+ words = line.split()
81
+ key = line_number
82
+ value = words[0]
83
+ result[key] = value
84
+ return result
85
+
86
+
87
+ def set_recursively(key, value, full_name, weight_type, hf_pointer):
88
+ for attribute in key.split("."):
89
+ hf_pointer = getattr(hf_pointer, attribute)
90
+
91
+ hf_param_name = None
92
+ for param_key in PARAM_MAPPING.keys():
93
+ if full_name.endswith(param_key):
94
+ hf_param_name = PARAM_MAPPING[full_name.split(".")[-1]]
95
+ weight_type = "param"
96
+
97
+ # fairseq uses nn.utils.weight_norm() while transformers switches to nn.utils.parametrizations.weight_norm()
98
+ # the mapping between two versions:
99
+ # https://github.com/pytorch/pytorch/blob/56935684c3dfad7841c83c719eeebecb560fe466/torch/nn/utils/parametrizations.py#L389-L395
100
+
101
+ if weight_type is not None and weight_type != "param":
102
+ if weight_type == "weight_g" and not hasattr(hf_pointer, "weight_g"):
103
+ hf_shape = hf_pointer.parametrizations.weight.original0.shape
104
+ elif weight_type == "weight_v" and not hasattr(hf_pointer, "weight_v"):
105
+ hf_shape = hf_pointer.parametrizations.weight.original1.shape
106
+ else:
107
+ hf_shape = getattr(hf_pointer, weight_type).shape
108
+ elif weight_type is not None and weight_type == "param":
109
+ shape_pointer = hf_pointer
110
+ for attribute in hf_param_name.split("."):
111
+ shape_pointer = getattr(shape_pointer, attribute)
112
+ hf_shape = shape_pointer.shape
113
+
114
+ # let's reduce dimension
115
+ value = value[0]
116
+ else:
117
+ hf_shape = hf_pointer.shape
118
+
119
+ if hf_shape != value.shape:
120
+ raise ValueError(
121
+ f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be"
122
+ f" {value.shape} for {full_name}"
123
+ )
124
+
125
+ if weight_type == "weight":
126
+ hf_pointer.weight.data = value
127
+ elif weight_type == "weight_g":
128
+ if hasattr(hf_pointer, "weight_g"):
129
+ hf_pointer.weight_g.data = value
130
+ else:
131
+ hf_pointer.parametrizations.weight.original0.data = value
132
+ elif weight_type == "weight_v":
133
+ if hasattr(hf_pointer, "weight_v"):
134
+ hf_pointer.weight_v.data = value
135
+ else:
136
+ hf_pointer.parametrizations.weight.original1.data = value
137
+ elif weight_type == "bias":
138
+ hf_pointer.bias.data = value
139
+ elif weight_type == "param":
140
+ for attribute in hf_param_name.split("."):
141
+ hf_pointer = getattr(hf_pointer, attribute)
142
+ hf_pointer.data = value
143
+ else:
144
+ hf_pointer.data = value
145
+
146
+ logger.info(f"{key + '.' + weight_type if weight_type is not None else ''} was initialized from {full_name}.")
147
+
148
+
149
+ def rename_dict(key, value, full_name, weight_type, hf_dict):
150
+ hf_param_name = None
151
+ for param_key in PARAM_MAPPING.keys():
152
+ if full_name.endswith(param_key):
153
+ hf_param_name = PARAM_MAPPING[full_name.split(".")[-1]]
154
+ weight_type = "param"
155
+
156
+ if weight_type is not None and weight_type != "param":
157
+ full_key = ".".join([key, weight_type])
158
+ elif weight_type is not None and weight_type == "param":
159
+ full_key = ".".join([key, hf_param_name])
160
+ else:
161
+ full_key = key
162
+
163
+ hf_dict[full_key] = value if "lm_head" in full_key else value[0]
164
+
165
+
166
+ PARAM_MAPPING = {
167
+ "W_a": "linear_1.weight",
168
+ "W_b": "linear_2.weight",
169
+ "b_a": "linear_1.bias",
170
+ "b_b": "linear_2.bias",
171
+ "ln_W": "norm.weight",
172
+ "ln_b": "norm.bias",
173
+ }
174
+
175
+
176
+ def load_wav2vec2_layer(name, value, hf_model=None, hf_dict=None):
177
+ is_used = False
178
+ for key, mapped_key in MAPPING.items():
179
+ mapped_key = "wav2vec2." + mapped_key if mapped_key not in TOP_LEVEL_KEYS else mapped_key
180
+ if key in name or key.split("w2v_model.")[-1] == name.split(".")[0]:
181
+ is_used = True
182
+ if "*" in mapped_key:
183
+ layer_index = name.split(key)[0].split(".")[-2]
184
+ mapped_key = mapped_key.replace("*", layer_index)
185
+ if "weight_g" in name:
186
+ weight_type = "weight_g"
187
+ elif "weight_v" in name:
188
+ weight_type = "weight_v"
189
+ elif "bias" in name:
190
+ weight_type = "bias"
191
+ elif "weight" in name:
192
+ # TODO: don't match quantizer.weight_proj
193
+ weight_type = "weight"
194
+ else:
195
+ weight_type = None
196
+ if hf_dict is not None:
197
+ rename_dict(mapped_key, value, name, weight_type, hf_dict)
198
+ else:
199
+ set_recursively(mapped_key, value, name, weight_type, hf_model)
200
+ return is_used
201
+ return is_used
202
+
203
+
204
+ def recursively_load_weights(fairseq_model, hf_model, is_headless):
205
+ unused_weights = []
206
+ fairseq_dict = fairseq_model.state_dict()
207
+
208
+ feature_extractor = hf_model.wav2vec2.feature_extractor
209
+
210
+ for name, value in fairseq_dict.items():
211
+ is_used = False
212
+ if "conv_layers" in name:
213
+ load_conv_layer(
214
+ name,
215
+ value,
216
+ feature_extractor,
217
+ unused_weights,
218
+ hf_model.config.feat_extract_norm == "group",
219
+ )
220
+ is_used = True
221
+ else:
222
+ is_used = load_wav2vec2_layer(name, value, hf_model)
223
+ if not is_used:
224
+ unused_weights.append(name)
225
+
226
+ logger.warning(f"Unused weights: {unused_weights}")
227
+
228
+
229
+ def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_group_norm):
230
+ name = full_name.split("conv_layers.")[-1]
231
+ items = name.split(".")
232
+ layer_id = int(items[0])
233
+ type_id = int(items[1])
234
+
235
+ if type_id == 0:
236
+ if "bias" in name:
237
+ if value.shape != feature_extractor.conv_layers[layer_id].conv.bias.data.shape:
238
+ raise ValueError(
239
+ f"{full_name} has size {value.shape}, but"
240
+ f" {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found."
241
+ )
242
+ feature_extractor.conv_layers[layer_id].conv.bias.data = value
243
+ logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
244
+ elif "weight" in name:
245
+ if value.shape != feature_extractor.conv_layers[layer_id].conv.weight.data.shape:
246
+ raise ValueError(
247
+ f"{full_name} has size {value.shape}, but"
248
+ f" {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found."
249
+ )
250
+ feature_extractor.conv_layers[layer_id].conv.weight.data = value
251
+ logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
252
+ elif (type_id == 2 and not use_group_norm) or (type_id == 2 and layer_id == 0 and use_group_norm):
253
+ if "bias" in name:
254
+ if value.shape != feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape:
255
+ raise ValueError(
256
+ f"{full_name} has size {value.shape}, but"
257
+ f" {feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape} was found."
258
+ )
259
+ feature_extractor.conv_layers[layer_id].layer_norm.bias.data = value
260
+ logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
261
+ elif "weight" in name:
262
+ if value.shape != feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape:
263
+ raise ValueError(
264
+ f"{full_name} has size {value.shape}, but"
265
+ f" {feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape} was found."
266
+ )
267
+ feature_extractor.conv_layers[layer_id].layer_norm.weight.data = value
268
+ logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
269
+ else:
270
+ unused_weights.append(full_name)
271
+
272
+
273
+ @torch.no_grad()
274
+ def convert_wav2vec2_checkpoint(
275
+ checkpoint_path, pytorch_dump_folder_path, config_path=None, dict_path=None, is_finetuned=True, is_seq_class=False
276
+ ):
277
+ """
278
+ Copy/paste/tweak model's weights to transformers design.
279
+ """
280
+ if config_path is not None:
281
+ config = Wav2Vec2Config.from_pretrained(config_path)
282
+ else:
283
+ config = Wav2Vec2Config()
284
+
285
+ if is_seq_class:
286
+ id2label = read_txt_into_dict(dict_path)
287
+ config.id2label = id2label
288
+ hf_wav2vec = Wav2Vec2ForSequenceClassification(config)
289
+ feature_extractor = Wav2Vec2FeatureExtractor(
290
+ feature_size=1,
291
+ sampling_rate=16000,
292
+ padding_value=0,
293
+ do_normalize=True,
294
+ return_attention_mask=True,
295
+ )
296
+ feature_extractor.save_pretrained(pytorch_dump_folder_path)
297
+
298
+ elif is_finetuned:
299
+ if dict_path:
300
+ target_dict = Dictionary.load(dict_path)
301
+
302
+ # important change bos & pad token id since CTC symbol is <pad> and
303
+ # not <s> as in fairseq
304
+ config.bos_token_id = target_dict.pad_index
305
+ config.pad_token_id = target_dict.bos_index
306
+ config.eos_token_id = target_dict.eos_index
307
+ config.vocab_size = len(target_dict.symbols)
308
+ vocab_path = os.path.join(pytorch_dump_folder_path, "vocab.json")
309
+ if not os.path.isdir(pytorch_dump_folder_path):
310
+ logger.error("--pytorch_dump_folder_path ({}) should be a directory".format(pytorch_dump_folder_path))
311
+ return
312
+ os.makedirs(pytorch_dump_folder_path, exist_ok=True)
313
+ vocab_dict = target_dict.indices
314
+
315
+ # fairseq has the <pad> and <s> switched
316
+ vocab_dict["<pad>"] = 0
317
+ vocab_dict["<s>"] = 1
318
+ with open(vocab_path, "w", encoding="utf-8") as vocab_handle:
319
+ json.dump(vocab_dict, vocab_handle)
320
+ tokenizer = Wav2Vec2CTCTokenizer(
321
+ vocab_path,
322
+ unk_token=target_dict.unk_word,
323
+ pad_token=target_dict.pad_word,
324
+ bos_token=target_dict.bos_word,
325
+ eos_token=target_dict.eos_word,
326
+ word_delimiter_token="|",
327
+ do_lower_case=False,
328
+ )
329
+ return_attention_mask = True if config.feat_extract_norm == "layer" else False
330
+ feature_extractor = Wav2Vec2FeatureExtractor(
331
+ feature_size=1,
332
+ sampling_rate=16000,
333
+ padding_value=0,
334
+ do_normalize=True,
335
+ return_attention_mask=return_attention_mask,
336
+ )
337
+ processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
338
+ processor.save_pretrained(pytorch_dump_folder_path)
339
+
340
+ hf_wav2vec = Wav2Vec2ForCTC(config)
341
+ else:
342
+ hf_wav2vec = Wav2Vec2ForPreTraining(config)
343
+
344
+ if is_finetuned or is_seq_class:
345
+ model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
346
+ [checkpoint_path], arg_overrides={"data": "/".join(dict_path.split("/")[:-1])}
347
+ )
348
+ else:
349
+ task_arg = argparse.Namespace(task="audio_pretraining")
350
+ task = fairseq.tasks.setup_task(task_arg)
351
+
352
+ model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([checkpoint_path], task=task)
353
+
354
+ model = model[0].eval()
355
+
356
+ recursively_load_weights(model, hf_wav2vec, not is_finetuned)
357
+
358
+ hf_wav2vec.save_pretrained(pytorch_dump_folder_path)
359
+
360
+
361
+ if __name__ == "__main__":
362
+ parser = argparse.ArgumentParser()
363
+ parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
364
+ parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to fairseq checkpoint")
365
+ parser.add_argument("--dict_path", default=None, type=str, help="Path to dict of fine-tuned model")
366
+ parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
367
+ parser.add_argument(
368
+ "--not_finetuned", action="store_true", help="Whether the model to convert is a fine-tuned model or not"
369
+ )
370
+ parser.add_argument(
371
+ "--is_seq_class",
372
+ action="store_true",
373
+ help="Whether the model to convert is a fine-tuned sequence classification model or not",
374
+ )
375
+ args = parser.parse_args()
376
+
377
+ is_finetuned = not args.not_finetuned and not args.is_seq_class
378
+ convert_wav2vec2_checkpoint(
379
+ args.checkpoint_path,
380
+ args.pytorch_dump_folder_path,
381
+ args.config_path,
382
+ args.dict_path,
383
+ is_finetuned,
384
+ args.is_seq_class,
385
+ )