kartmannXu commited on
Commit
f3b99c9
·
verified ·
1 Parent(s): 8722c1c

Upload Idefics3blForConditionalGeneration

Browse files
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
config.json ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Idefics3blForConditionalGeneration"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_idefics3bl.Idefics3blConfig",
7
+ "AutoModelForVision2Seq": "modeling_idefics3bl.Idefics3blForConditionalGeneration"
8
+ },
9
+ "bos_token_id": 1,
10
+ "eos_token_id": 2,
11
+ "image_token_id": 49190,
12
+ "midblock_end": -1,
13
+ "midblock_ratio": 1.0,
14
+ "midblock_start": -1,
15
+ "model_type": "idefics3bl",
16
+ "pad_token_id": 0,
17
+ "scale_factor": 4,
18
+ "text_config": {
19
+ "_flash_attn_2_enabled": true,
20
+ "_name_or_path": "None",
21
+ "architectures": [
22
+ "VLlama3ForCausalLMbl"
23
+ ],
24
+ "attention_bias": false,
25
+ "attention_dropout": 0.0,
26
+ "head_dim": 64,
27
+ "hidden_act": "silu",
28
+ "hidden_size": 960,
29
+ "initializer_range": 0.02,
30
+ "intermediate_size": 2560,
31
+ "is_llama_config": true,
32
+ "max_position_embeddings": 8192,
33
+ "midblock_end": 31,
34
+ "midblock_ratio": 0.6,
35
+ "midblock_start": 1,
36
+ "mlp_bias": false,
37
+ "model_type": "llama",
38
+ "neftune_noise_alpha": 0.0,
39
+ "num_attention_heads": 15,
40
+ "num_hidden_layers": 32,
41
+ "num_key_value_heads": 5,
42
+ "pad_token_id": 2,
43
+ "perceiver_config": {
44
+ "_name_or_path": "",
45
+ "add_cross_attention": false,
46
+ "architectures": null,
47
+ "attention_dropout": 0.0,
48
+ "bad_words_ids": null,
49
+ "begin_suppress_tokens": null,
50
+ "bos_token_id": null,
51
+ "chunk_size_feed_forward": 0,
52
+ "cross_attention_hidden_size": null,
53
+ "decoder_start_token_id": null,
54
+ "diversity_penalty": 0.0,
55
+ "do_sample": false,
56
+ "early_stopping": false,
57
+ "encoder_no_repeat_ngram_size": 0,
58
+ "eos_token_id": null,
59
+ "exponential_decay_length_penalty": null,
60
+ "finetuning_task": null,
61
+ "forced_bos_token_id": null,
62
+ "forced_eos_token_id": null,
63
+ "hidden_act": "silu",
64
+ "id2label": {
65
+ "0": "LABEL_0",
66
+ "1": "LABEL_1"
67
+ },
68
+ "is_decoder": false,
69
+ "is_encoder_decoder": false,
70
+ "label2id": {
71
+ "LABEL_0": 0,
72
+ "LABEL_1": 1
73
+ },
74
+ "length_penalty": 1.0,
75
+ "max_length": 20,
76
+ "min_length": 0,
77
+ "model_type": "vllama3bl",
78
+ "no_repeat_ngram_size": 0,
79
+ "num_beam_groups": 1,
80
+ "num_beams": 1,
81
+ "num_key_value_heads": 1,
82
+ "num_return_sequences": 1,
83
+ "output_attentions": false,
84
+ "output_hidden_states": false,
85
+ "output_scores": false,
86
+ "pad_token_id": null,
87
+ "prefix": null,
88
+ "problem_type": null,
89
+ "pruned_heads": {},
90
+ "qk_layer_norms_perceiver": false,
91
+ "remove_invalid_values": false,
92
+ "repetition_penalty": 1.0,
93
+ "resampler_depth": 6,
94
+ "resampler_head_dim": 96,
95
+ "resampler_n_heads": 16,
96
+ "resampler_n_latents": 64,
97
+ "return_dict": true,
98
+ "return_dict_in_generate": false,
99
+ "sep_token_id": null,
100
+ "suppress_tokens": null,
101
+ "task_specific_params": null,
102
+ "temperature": 1.0,
103
+ "tf_legacy_loss": false,
104
+ "tie_encoder_decoder": false,
105
+ "tie_word_embeddings": true,
106
+ "tokenizer_class": null,
107
+ "top_k": 50,
108
+ "top_p": 1.0,
109
+ "torch_dtype": null,
110
+ "torchscript": false,
111
+ "transformers_version": "4.46.0",
112
+ "typical_p": 1.0,
113
+ "use_bfloat16": false
114
+ },
115
+ "pixel_shuffle_factor": 4,
116
+ "pretraining_tp": 1,
117
+ "qk_layer_norms": false,
118
+ "rms_norm_eps": 1e-05,
119
+ "rope_interleaved": false,
120
+ "rope_scaling": null,
121
+ "rope_theta": 100000,
122
+ "torch_dtype": "float32",
123
+ "transformers.js_config": {
124
+ "kv_cache_dtype": {
125
+ "fp16": "float16",
126
+ "q4f16": "float16"
127
+ }
128
+ },
129
+ "use_cache": true,
130
+ "use_resampler": false,
131
+ "vocab_size": 49280
132
+ },
133
+ "tie_word_embeddings": false,
134
+ "torch_dtype": "float32",
135
+ "transformers.js_config": {
136
+ "kv_cache_dtype": {
137
+ "fp16": "float16",
138
+ "q4f16": "float16"
139
+ }
140
+ },
141
+ "transformers_version": "4.53.2",
142
+ "use_cache": true,
143
+ "vision_config": {
144
+ "attention_dropout": 0.0,
145
+ "hidden_act": "gelu_pytorch_tanh",
146
+ "hidden_size": 768,
147
+ "image_size": 512,
148
+ "initializer_range": 0.02,
149
+ "intermediate_size": 3072,
150
+ "layer_norm_eps": 1e-06,
151
+ "max_image_size": {
152
+ "longest_edge": 512
153
+ },
154
+ "midblock_end": 12,
155
+ "midblock_ratio": 0.5,
156
+ "midblock_start": 1,
157
+ "model_type": "idefics3_visionbl",
158
+ "num_attention_heads": 12,
159
+ "num_channels": 3,
160
+ "num_hidden_layers": 12,
161
+ "patch_size": 16,
162
+ "size": {
163
+ "longest_edge": 2048
164
+ },
165
+ "tie_word_embeddings": false,
166
+ "torch_dtype": "float32",
167
+ "use_base_siglip": false
168
+ },
169
+ "vocab_size": 49280
170
+ }
configuration_idefics3bl.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Idefics3 model configuration"""
15
+
16
+ from transformers import PretrainedConfig
17
+ from transformers.utils import logging
18
+ from transformers import CONFIG_MAPPING, AutoConfig
19
+
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+
24
+ class Idefics3blVisionConfig(PretrainedConfig):
25
+ r"""
26
+ This is the configuration class to store the configuration of a [`Idefics3VisionModel`]. It is used to instantiate a
27
+ Idefics3 vision encoder according to the specified arguments, defining the model architecture. Instantiating a
28
+ configuration with the defaults will yield a similar configuration to that of the SigLIP checkpoint
29
+ [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) used in the Idefics3 model
30
+ [HuggingFaceM4/Idefics3-8B-Llama3](https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3).
31
+
32
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
33
+ documentation from [`PretrainedConfig`] for more information.
34
+
35
+ Args:
36
+ hidden_size (`int`, *optional*, defaults to 1152):
37
+ Dimensionality of the encoder layers and the pooler layer.
38
+ intermediate_size (`int`, *optional*, defaults to 3072):
39
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
40
+ num_hidden_layers (`int`, *optional*, defaults to 12):
41
+ Number of hidden layers in the Transformer encoder.
42
+ num_attention_heads (`int`, *optional*, defaults to 16):
43
+ Number of attention heads for each attention layer in the Transformer encoder.
44
+ num_channels (`int`, *optional*, defaults to 3):
45
+ Number of channels in the input images.
46
+ image_size (`int`, *optional*, defaults to 224):
47
+ The size (resolution) of each image.
48
+ patch_size (`int`, *optional*, defaults to 32):
49
+ The size (resolution) of each patch.
50
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
51
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
52
+ `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
53
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
54
+ The epsilon used by the layer normalization layers.
55
+ attention_dropout (`float`, *optional*, defaults to 0.0):
56
+ The dropout ratio for the attention probabilities.
57
+ initializer_range (`float`, *optional*, defaults to 0.02):
58
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
59
+
60
+ Example:
61
+
62
+ ```python
63
+ >>> from transformers.models.idefics3.modeling_idefics3 import Idefics3VisionTransformer
64
+ >>> from transformers.models.idefics3.configuration_idefics3 import Idefics3VisionConfig
65
+
66
+ >>> # Initializing a Idefics3VisionConfig with google/siglip-base-patch16-224 style configuration
67
+ >>> configuration = Idefics3VisionConfig()
68
+
69
+ >>> # Initializing a Idefics3VisionTransformer (with random weights) from the google/siglip-base-patch16-224 style configuration
70
+ >>> model = Idefics3VisionTransformer(configuration)
71
+
72
+ >>> # Accessing the model configuration
73
+ >>> configuration = model.config
74
+ ```"""
75
+
76
+ model_type = "idefics3_visionbl"
77
+ base_config_key = "vision_config"
78
+
79
+ def __init__(
80
+ self,
81
+ hidden_size=1152,
82
+ intermediate_size=3072,
83
+ num_hidden_layers=12,
84
+ num_attention_heads=16,
85
+ num_channels=3,
86
+ image_size=224,
87
+ patch_size=32,
88
+ hidden_act="gelu_pytorch_tanh",
89
+ layer_norm_eps=1e-6,
90
+ attention_dropout=0.0,
91
+ initializer_range=0.02,
92
+ midblock_ratio=1.0,
93
+ midblock_start=-1,
94
+ midblock_end=-1,
95
+ **kwargs,
96
+ ):
97
+ super().__init__(**kwargs)
98
+
99
+ self.hidden_size = hidden_size
100
+ self.intermediate_size = intermediate_size
101
+ self.num_hidden_layers = num_hidden_layers
102
+ self.num_attention_heads = num_attention_heads
103
+ self.num_channels = num_channels
104
+ self.patch_size = patch_size
105
+ self.image_size = image_size
106
+ self.attention_dropout = attention_dropout
107
+ self.layer_norm_eps = layer_norm_eps
108
+ self.hidden_act = hidden_act
109
+ self.initializer_range = initializer_range
110
+ self.midblock_ratio = midblock_ratio
111
+ self.midblock_start = midblock_start
112
+ self.midblock_end = midblock_end
113
+
114
+
115
+
116
+ class Idefics3blConfig(PretrainedConfig):
117
+ r"""
118
+ This is the configuration class to store the configuration of a [`Idefics3Model`]. It is used to instantiate a
119
+ Idefics3 model according to the specified arguments, defining the model architecture. Instantiating a
120
+ configuration with the defaults will yield a similar configuration to that of the model of the Idefics3
121
+ [HuggingFaceM4/Idefics3-8B-Llama3](https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3) architecture.
122
+
123
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
124
+ documentation from [`PretrainedConfig`] for more information.
125
+
126
+ Args:
127
+ use_cache (`bool`, *optional*, defaults to `True`):
128
+ Whether or not the model should cache the key/value pairs of the attention mechanism. Only
129
+ relevant if `config.is_decoder=True`.
130
+ image_token_id (`int`, *optional*, defaults to 128257):
131
+ The id of the "image" token.
132
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
133
+ Whether or not to tie the word embeddings with the token embeddings.
134
+ vision_config (`IdeficsVisionConfig` or `dict`, *optional*, defaults to `IdeficsVisionConfig`):
135
+ Custom vision config or dict for the vision tower
136
+ text_config (`PretrainedConfig` or `dict`, *optional*, defaults to `LlamaConfig`):
137
+ Custom text config or dict for the text model
138
+ scale_factor (`int`, *optional*, defaults to 2):
139
+ The scale factor for the image encoder.
140
+ pad_token_id (`int`, *optional*, defaults to 128002):
141
+ The id of the padding token.
142
+
143
+ Example:
144
+ ```python
145
+ >>> from transformers import Idefics3Model, Idefics3Config
146
+ >>> # Initializing configuration
147
+ >>> configuration = Idefics3Config()
148
+ >>> # Initializing a model from the configuration
149
+ >>> model = Idefics3Model(configuration)
150
+ >>> # Accessing the model configuration
151
+ >>> configuration = model.config
152
+ ```"""
153
+
154
+ model_type = "idefics3bl"
155
+ sub_configs = {"text_config": AutoConfig, "vision_config": Idefics3blVisionConfig}
156
+
157
+ def __init__(
158
+ self,
159
+ use_cache=True,
160
+ image_token_id=128257,
161
+ tie_word_embeddings=False,
162
+ vision_config=None,
163
+ text_config=None,
164
+ scale_factor=2,
165
+ pad_token_id=128_002,
166
+ midblock_ratio=1.0,
167
+ midblock_start=-1,
168
+ midblock_end=-1,
169
+ **kwargs,
170
+ ):
171
+ self.image_token_id = image_token_id
172
+ self.use_cache = use_cache
173
+ self.tie_word_embeddings = tie_word_embeddings
174
+
175
+ if vision_config is None:
176
+ self.vision_config = Idefics3blVisionConfig()
177
+ logger.info("vision_config is None, using default vision config")
178
+ elif isinstance(vision_config, dict):
179
+ self.vision_config = Idefics3blVisionConfig(**vision_config)
180
+ elif isinstance(vision_config, Idefics3blVisionConfig):
181
+ self.vision_config = vision_config
182
+
183
+ if isinstance(text_config, dict):
184
+ text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama"
185
+ text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
186
+ elif text_config is None:
187
+ logger.info("text_config is None, using default text config")
188
+ text_config = CONFIG_MAPPING["llama"](
189
+ rms_norm_eps=1e-5,
190
+ pad_token_id=pad_token_id,
191
+ tie_word_embeddings=False,
192
+ )
193
+
194
+ self.text_config = text_config
195
+ self.scale_factor = scale_factor
196
+
197
+ self.midblock_ratio = midblock_ratio
198
+ self.midblock_start = midblock_start
199
+ self.midblock_end = midblock_end
200
+
201
+ super().__init__(**kwargs, pad_token_id=pad_token_id, tie_word_embeddings=tie_word_embeddings)
202
+
203
+
204
+ __all__ = ["Idefics3Config", "Idefics3VisionConfig"]
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "pad_token_id": 0,
6
+ "transformers_version": "4.53.2"
7
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:394df19eb2b6aa8fc83cd8e60eceb1b72f4d349214b9df66b310ae08174104e0
3
+ size 1313825808
modeling_idefics3bl.py ADDED
@@ -0,0 +1,1165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch Idefics3 model."""
16
+
17
+ from dataclasses import dataclass
18
+ from typing import Callable, Optional, Union
19
+
20
+ import torch
21
+ import torch.utils.checkpoint
22
+ from torch import nn
23
+
24
+ from transformers.activations import ACT2FN
25
+ from transformers.cache_utils import DynamicCache
26
+ from transformers.generation import GenerationMixin
27
+ from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
28
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
29
+ from transformers.modeling_layers import GradientCheckpointingLayer
30
+ from transformers.modeling_outputs import BaseModelOutput, ModelOutput
31
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
32
+ from transformers.processing_utils import Unpack
33
+ from transformers.utils import LossKwargs, auto_docstring, can_return_tuple, logging
34
+ from transformers.models.auto import AutoModel
35
+ from .configuration_idefics3bl import Idefics3blConfig, Idefics3blVisionConfig
36
+ from .modeling_llama import *
37
+
38
+ logger = logging.get_logger(__name__)
39
+
40
+
41
+ @dataclass
42
+ @auto_docstring(
43
+ custom_intro="""
44
+ Base class for Idefics3 model's outputs that may also contain a past key/values (to speed up sequential decoding).
45
+ """
46
+ )
47
+ class Idefics3BaseModelOutputWithPast(ModelOutput):
48
+ r"""
49
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
50
+ Sequence of hidden-states at the output of the last layer of the model.
51
+ If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
52
+ hidden_size)` is output.
53
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
54
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
55
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
56
+ `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
57
+ encoder_sequence_length, embed_size_per_head)`.
58
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
59
+ `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
60
+ input) to speed up sequential decoding.
61
+ image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
62
+ Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
63
+ sequence_length, hidden_size)`.
64
+ image_hidden_states of the model produced by the vision encoder
65
+ """
66
+
67
+ last_hidden_state: Optional[torch.FloatTensor] = None
68
+ past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None
69
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
70
+ attentions: Optional[tuple[torch.FloatTensor]] = None
71
+ image_hidden_states: Optional[tuple[torch.FloatTensor]] = None
72
+
73
+
74
+ @dataclass
75
+ @auto_docstring(
76
+ custom_intro="""
77
+ Base class for Idefics causal language model (or autoregressive) outputs.
78
+ """
79
+ )
80
+ class Idefics3CausalLMOutputWithPast(ModelOutput):
81
+ r"""
82
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
83
+ Language modeling loss (for next-token prediction).
84
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
85
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
86
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
87
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
88
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
89
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
90
+ `past_key_values` input) to speed up sequential decoding.
91
+ image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
92
+ Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
93
+ sequence_length, hidden_size)`.
94
+ image_hidden_states of the model produced by the vision encoder
95
+ """
96
+
97
+ loss: Optional[torch.FloatTensor] = None
98
+ logits: Optional[torch.FloatTensor] = None
99
+ past_key_values: Optional[list[torch.FloatTensor]] = None
100
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
101
+ attentions: Optional[tuple[torch.FloatTensor]] = None
102
+ image_hidden_states: Optional[tuple[torch.FloatTensor]] = None
103
+
104
+
105
+ # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2VisionEmbeddings with Idefics2->Idefics3
106
+ class Idefics3VisionEmbeddings(nn.Module):
107
+ """
108
+ This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings` to enable images of variable
109
+ resolution.
110
+
111
+ The modifications are adapted from [Patch n' Pack: NaViT, a Vision Transformer for any Aspect Ratio and Resolution](https://huggingface.co/papers/2307.06304)
112
+ which allows treating images in their native aspect ratio and without the need to resize them to the same
113
+ fixed size. In particular, we start from the original pre-trained SigLIP model
114
+ (which uses images of fixed-size square images) and adapt it by training on images of variable resolutions.
115
+ """
116
+
117
+ def __init__(self, config: Idefics3blVisionConfig):
118
+ super().__init__()
119
+ self.embed_dim = config.hidden_size
120
+ self.image_size = config.image_size
121
+ self.patch_size = config.patch_size
122
+
123
+ self.patch_embedding = nn.Conv2d(
124
+ in_channels=config.num_channels,
125
+ out_channels=self.embed_dim,
126
+ kernel_size=self.patch_size,
127
+ stride=self.patch_size,
128
+ padding="valid",
129
+ )
130
+
131
+ self.num_patches_per_side = self.image_size // self.patch_size
132
+ self.num_patches = self.num_patches_per_side**2
133
+ self.num_positions = self.num_patches
134
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
135
+
136
+ def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor) -> torch.Tensor:
137
+ batch_size, _, max_im_h, max_im_w = pixel_values.shape
138
+
139
+ patch_embeds = self.patch_embedding(pixel_values)
140
+ embeddings = patch_embeds.flatten(2).transpose(1, 2)
141
+
142
+ max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size
143
+ boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side)
144
+ position_ids = torch.full(size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0)
145
+
146
+ for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
147
+ nb_patches_h = p_attn_mask[:, 0].sum()
148
+ nb_patches_w = p_attn_mask[0].sum()
149
+
150
+ fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
151
+ fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
152
+
153
+ bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
154
+ bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
155
+
156
+ pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten()
157
+ position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
158
+
159
+ position_ids = position_ids.to(self.position_embedding.weight.device)
160
+ embeddings = embeddings + self.position_embedding(position_ids)
161
+ return embeddings
162
+
163
+
164
+ # Copied from transformers.models.siglip.modeling_siglip.eager_attention_forward
165
+ def eager_attention_forward(
166
+ module: nn.Module,
167
+ query: torch.Tensor,
168
+ key: torch.Tensor,
169
+ value: torch.Tensor,
170
+ attention_mask: Optional[torch.Tensor],
171
+ scaling: float,
172
+ dropout: float = 0.0,
173
+ **kwargs,
174
+ ):
175
+ attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
176
+ if attention_mask is not None:
177
+ attn_weights = attn_weights + attention_mask
178
+
179
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
180
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
181
+
182
+ attn_output = torch.matmul(attn_weights, value)
183
+ attn_output = attn_output.transpose(1, 2).contiguous()
184
+
185
+ return attn_output, attn_weights
186
+
187
+
188
+ # Copied from transformers.models.siglip.modeling_siglip.SiglipAttention with Siglip->Idefics3Vision
189
+ class Idefics3blVisionAttention(nn.Module):
190
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
191
+
192
+ # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
193
+ def __init__(self, config, layer_idx: int = 0):
194
+ super().__init__()
195
+ self.config = config
196
+ self.embed_dim = config.hidden_size
197
+
198
+ self.midblock_start = config.midblock_start
199
+ self.midblock_end = config.midblock_end
200
+ self.ratio = config.midblock_ratio if self.midblock_start <= layer_idx < self.midblock_end else 1.0
201
+
202
+ self.num_heads = int(config.num_attention_heads * self.ratio)
203
+ self.head_dim = self.embed_dim // config.num_attention_heads
204
+
205
+ self.q_out_size = self.num_heads * self.head_dim
206
+ # if self.head_dim * self.num_heads != self.embed_dim:
207
+ # raise ValueError(
208
+ # f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
209
+ # f" {self.num_heads})."
210
+ # )
211
+ self.scale = self.head_dim**-0.5
212
+ self.dropout = config.attention_dropout
213
+
214
+ self.k_proj = nn.Linear(self.embed_dim, self.q_out_size)
215
+ self.v_proj = nn.Linear(self.embed_dim, self.q_out_size)
216
+ self.q_proj = nn.Linear(self.embed_dim, self.q_out_size)
217
+ self.out_proj = nn.Linear(self.q_out_size, self.embed_dim)
218
+
219
+ # Ignore copy
220
+ self.is_causal = False
221
+
222
+ def forward(
223
+ self,
224
+ hidden_states: torch.Tensor,
225
+ attention_mask: Optional[torch.Tensor] = None,
226
+ output_attentions: Optional[bool] = False,
227
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
228
+ """Input shape: Batch x Time x Channel"""
229
+
230
+ batch_size, seq_length, embed_dim = hidden_states.shape
231
+
232
+ queries = self.q_proj(hidden_states)
233
+ keys = self.k_proj(hidden_states)
234
+ values = self.v_proj(hidden_states)
235
+
236
+ queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
237
+ keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
238
+ values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
239
+
240
+ attention_interface: Callable = eager_attention_forward
241
+ if self.config._attn_implementation != "eager":
242
+ if self.config._attn_implementation == "sdpa" and output_attentions:
243
+ logger.warning_once(
244
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
245
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
246
+ )
247
+ else:
248
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
249
+
250
+ attn_output, attn_weights = attention_interface(
251
+ self,
252
+ queries,
253
+ keys,
254
+ values,
255
+ attention_mask,
256
+ is_causal=self.is_causal,
257
+ scaling=self.scale,
258
+ dropout=0.0 if not self.training else self.dropout,
259
+ )
260
+
261
+ attn_output = attn_output.reshape(batch_size, seq_length, self.q_out_size).contiguous()
262
+ attn_output = self.out_proj(attn_output)
263
+
264
+ if not output_attentions:
265
+ attn_weights = None
266
+
267
+ return attn_output, attn_weights
268
+
269
+
270
+ # Copied from transformers.models.siglip.modeling_siglip.SiglipMLP with Siglip->Idefics3Vision
271
+ class Idefics3blVisionMLP(nn.Module):
272
+ def __init__(self, config, layer_idx: int = 0):
273
+ super().__init__()
274
+ self.config = config
275
+ self.activation_fn = ACT2FN[config.hidden_act]
276
+
277
+ self.midblock_start = config.midblock_start
278
+ self.midblock_end = config.midblock_end
279
+ self.ratio = config.midblock_ratio if self.midblock_start <= layer_idx < self.midblock_end else 1.0
280
+
281
+ self.intermediate_size = int(config.intermediate_size * self.ratio)
282
+ self.fc1 = nn.Linear(config.hidden_size, self.intermediate_size)
283
+ self.fc2 = nn.Linear(self.intermediate_size, config.hidden_size)
284
+
285
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
286
+ hidden_states = self.fc1(hidden_states)
287
+ hidden_states = self.activation_fn(hidden_states)
288
+ hidden_states = self.fc2(hidden_states)
289
+ return hidden_states
290
+
291
+
292
+ class Idefics3SimpleMLP(nn.Module):
293
+ def __init__(self, config):
294
+ super().__init__()
295
+ input_size = config.vision_config.hidden_size * (config.scale_factor**2)
296
+ output_size = config.text_config.hidden_size
297
+ self.proj = nn.Linear(input_size, output_size, bias=False)
298
+
299
+ def forward(self, x):
300
+ return self.proj(x)
301
+
302
+
303
+ # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2EncoderLayer with Idefics2->Idefics3
304
+ class Idefics3blEncoderLayer(GradientCheckpointingLayer):
305
+ def __init__(self, config: Idefics3blVisionConfig, layer_id: int = 0):
306
+ super().__init__()
307
+ self.embed_dim = config.hidden_size
308
+ self.self_attn = Idefics3blVisionAttention(config, layer_idx=layer_id)
309
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
310
+ self.mlp = Idefics3blVisionMLP(config, layer_idx=layer_id)
311
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
312
+
313
+ # Copied from transformers.models.siglip.modeling_siglip.SiglipEncoderLayer.forward
314
+ def forward(
315
+ self,
316
+ hidden_states: torch.Tensor,
317
+ attention_mask: torch.Tensor,
318
+ output_attentions: Optional[bool] = False,
319
+ ) -> tuple[torch.FloatTensor]:
320
+ """
321
+ Args:
322
+ hidden_states (`torch.FloatTensor`):
323
+ Input to the layer of shape `(batch, seq_len, embed_dim)`.
324
+ attention_mask (`torch.FloatTensor`):
325
+ Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
326
+ output_attentions (`bool`, *optional*, defaults to `False`):
327
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
328
+ returned tensors for more detail.
329
+ """
330
+ residual = hidden_states
331
+
332
+ hidden_states = self.layer_norm1(hidden_states)
333
+ hidden_states, attn_weights = self.self_attn(
334
+ hidden_states=hidden_states,
335
+ attention_mask=attention_mask,
336
+ output_attentions=output_attentions,
337
+ )
338
+ hidden_states = residual + hidden_states
339
+
340
+ residual = hidden_states
341
+ hidden_states = self.layer_norm2(hidden_states)
342
+ hidden_states = self.mlp(hidden_states)
343
+ hidden_states = residual + hidden_states
344
+
345
+ outputs = (hidden_states,)
346
+
347
+ if output_attentions:
348
+ outputs += (attn_weights,)
349
+
350
+ return outputs
351
+
352
+
353
+ # Copied from transformers.models.siglip.modeling_siglip.SiglipEncoder with Siglip->Idefics3
354
+ class Idefics3blEncoder(nn.Module):
355
+ """
356
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
357
+ [`Idefics3EncoderLayer`].
358
+
359
+ Args:
360
+ config: Idefics3Config
361
+ """
362
+
363
+ def __init__(self, config: Idefics3blConfig):
364
+ super().__init__()
365
+ self.config = config
366
+ self.layers = nn.ModuleList([Idefics3blEncoderLayer(config, layer_id) for layer_id in range(config.num_hidden_layers)])
367
+ self.gradient_checkpointing = False
368
+
369
+ # Ignore copy
370
+ def forward(
371
+ self,
372
+ inputs_embeds,
373
+ attention_mask: Optional[torch.Tensor] = None,
374
+ output_attentions: Optional[bool] = None,
375
+ output_hidden_states: Optional[bool] = None,
376
+ return_dict: Optional[bool] = None,
377
+ ) -> Union[tuple, BaseModelOutput]:
378
+ r"""
379
+ Args:
380
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
381
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
382
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
383
+ than the model's internal embedding lookup matrix.
384
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
385
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
386
+
387
+ - 1 for tokens that are **not masked**,
388
+ - 0 for tokens that are **masked**.
389
+
390
+ [What are attention masks?](../glossary#attention-mask)
391
+ output_attentions (`bool`, *optional*):
392
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
393
+ returned tensors for more detail.
394
+ output_hidden_states (`bool`, *optional*):
395
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
396
+ for more detail.
397
+ return_dict (`bool`, *optional*):
398
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
399
+ """
400
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
401
+ output_hidden_states = (
402
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
403
+ )
404
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
405
+
406
+ encoder_states = () if output_hidden_states else None
407
+ all_attentions = () if output_attentions else None
408
+
409
+ hidden_states = inputs_embeds
410
+ for encoder_layer in self.layers:
411
+ if output_hidden_states:
412
+ encoder_states = encoder_states + (hidden_states,)
413
+ layer_outputs = encoder_layer(
414
+ hidden_states,
415
+ attention_mask,
416
+ output_attentions=output_attentions,
417
+ )
418
+
419
+ hidden_states = layer_outputs[0]
420
+
421
+ if output_attentions:
422
+ all_attentions = all_attentions + (layer_outputs[1],)
423
+
424
+ if output_hidden_states:
425
+ encoder_states = encoder_states + (hidden_states,)
426
+
427
+ if not return_dict:
428
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
429
+ return BaseModelOutput(
430
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
431
+ )
432
+
433
+
434
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv
435
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
436
+ """
437
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
438
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
439
+ """
440
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
441
+ if n_rep == 1:
442
+ return hidden_states
443
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
444
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
445
+
446
+
447
+ # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Idefics3
448
+ class Idefics3RMSNorm(nn.Module):
449
+ def __init__(self, hidden_size, eps=1e-6):
450
+ """
451
+ Idefics3RMSNorm is equivalent to T5LayerNorm
452
+ """
453
+ super().__init__()
454
+ self.weight = nn.Parameter(torch.ones(hidden_size))
455
+ self.variance_epsilon = eps
456
+
457
+ def forward(self, hidden_states):
458
+ input_dtype = hidden_states.dtype
459
+ hidden_states = hidden_states.to(torch.float32)
460
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
461
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
462
+ return self.weight * hidden_states.to(input_dtype)
463
+
464
+ def extra_repr(self):
465
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
466
+
467
+
468
+ class Idefics3Connector(nn.Module):
469
+ def __init__(self, config):
470
+ super().__init__()
471
+ self.scale_factor = config.scale_factor
472
+ self.modality_projection = Idefics3SimpleMLP(config)
473
+
474
+ def pixel_shuffle(self, x, scale_factor=2):
475
+ bsz, seq, embed_dim = x.size()
476
+ height = width = int(seq**0.5)
477
+ x = x.view(bsz, height, width, embed_dim)
478
+ x = x.view(bsz, height, int(width / scale_factor), embed_dim * scale_factor)
479
+ x = x.permute(0, 2, 1, 3)
480
+ x = x.reshape(bsz, int(width / scale_factor), int(height / scale_factor), embed_dim * (scale_factor**2))
481
+ x = x.permute(0, 2, 1, 3)
482
+ x = x.reshape(bsz, int(seq / (scale_factor**2)), embed_dim * (scale_factor**2))
483
+ return x
484
+
485
+ def forward(self, image_hidden_states):
486
+ image_hidden_states = self.pixel_shuffle(image_hidden_states, self.scale_factor)
487
+ image_hidden_states = self.modality_projection(image_hidden_states)
488
+ return image_hidden_states
489
+
490
+
491
+ @auto_docstring
492
+ class Idefics3PreTrainedModel(PreTrainedModel):
493
+ config_class = Idefics3blConfig
494
+ base_model_prefix = "model"
495
+ supports_gradient_checkpointing = True
496
+ _no_split_modules = ["Idefics3blVisionAttention", "Idefics3blDecoderLayer"]
497
+ _skip_keys_device_placement = "past_key_values"
498
+ _supports_flash_attn_2 = True
499
+ _supports_sdpa = True
500
+ _supports_flex_attn = True
501
+ _supports_cache_class = True
502
+ _supports_attention_backend = True
503
+
504
+ def _init_weights(self, module):
505
+ std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
506
+
507
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
508
+ module.weight.data.normal_(mean=0.0, std=std)
509
+ if module.bias is not None:
510
+ module.bias.data.zero_()
511
+ elif isinstance(module, nn.Embedding):
512
+ module.weight.data.normal_(mean=0.0, std=std)
513
+ if module.padding_idx is not None:
514
+ module.weight.data[module.padding_idx].zero_()
515
+ elif isinstance(module, nn.LayerNorm):
516
+ module.weight.data.fill_(1.0)
517
+ module.bias.data.zero_()
518
+ elif isinstance(module, Idefics3RMSNorm):
519
+ module.weight.data.fill_(1.0)
520
+
521
+
522
+ @auto_docstring(
523
+ custom_intro="""
524
+ The Idefics3 Vision Transformer Model outputting raw image embedding.
525
+ """
526
+ )
527
+ class Idefics3VisionTransformer(Idefics3PreTrainedModel):
528
+ config_class = Idefics3blVisionConfig
529
+ _supports_sdpa = True
530
+ _supports_flash_attention_2 = True
531
+ _supports_flex_attn = True
532
+
533
+ def __init__(self, config: Idefics3blVisionConfig):
534
+ super().__init__(config)
535
+ embed_dim = config.hidden_size
536
+
537
+ self.embeddings = Idefics3VisionEmbeddings(config)
538
+ self.encoder = Idefics3blEncoder(config)
539
+ self.patch_size = config.patch_size
540
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
541
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
542
+
543
+ # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2VisionTransformer.get_input_embeddings
544
+ def get_input_embeddings(self):
545
+ return self.embeddings
546
+
547
+ # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2VisionTransformer.set_input_embeddings
548
+ def set_input_embeddings(self, value):
549
+ self.embeddings = value
550
+
551
+ def forward(
552
+ self,
553
+ pixel_values,
554
+ patch_attention_mask: Optional[torch.BoolTensor] = None,
555
+ output_attentions: Optional[bool] = None,
556
+ output_hidden_states: Optional[bool] = None,
557
+ return_dict: Optional[bool] = None,
558
+ ) -> Union[tuple, BaseModelOutput]:
559
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
560
+ output_hidden_states = (
561
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
562
+ )
563
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
564
+
565
+ batch_size = pixel_values.size(0)
566
+ if patch_attention_mask is None:
567
+ patch_size = self.patch_size
568
+ patch_attention_mask = torch.ones(
569
+ (
570
+ batch_size,
571
+ pixel_values.size(2) // patch_size,
572
+ pixel_values.size(3) // patch_size,
573
+ )
574
+ )
575
+ patch_attention_mask = patch_attention_mask.to(dtype=torch.bool, device=pixel_values.device)
576
+
577
+ hidden_states = self.embeddings(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask)
578
+
579
+ patch_attention_mask = patch_attention_mask.view(batch_size, -1)
580
+ # The call to `_upad_input` in `_flash_attention_forward` is expensive
581
+ # So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence),
582
+ # avoiding passing the attention_mask, which is equivalent to attending to the full sequence
583
+ if not torch.any(~patch_attention_mask):
584
+ patch_attention_mask = None
585
+ elif not self._use_flash_attention_2:
586
+ patch_attention_mask = _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype)
587
+
588
+ encoder_outputs = self.encoder(
589
+ inputs_embeds=hidden_states,
590
+ attention_mask=patch_attention_mask,
591
+ output_attentions=output_attentions,
592
+ output_hidden_states=output_hidden_states,
593
+ return_dict=return_dict,
594
+ )
595
+
596
+ last_hidden_state = encoder_outputs[0]
597
+ last_hidden_state = self.post_layernorm(last_hidden_state)
598
+
599
+ if not return_dict:
600
+ return (last_hidden_state,) + encoder_outputs[1:]
601
+
602
+ return BaseModelOutput(
603
+ last_hidden_state=last_hidden_state,
604
+ hidden_states=encoder_outputs.hidden_states,
605
+ attentions=encoder_outputs.attentions,
606
+ )
607
+
608
+
609
+ class LlamablAttention(LlamaAttention):
610
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
611
+
612
+ def __init__(self, config: LlamaConfig, layer_idx: int):
613
+ super().__init__(config, layer_idx=layer_idx)
614
+ self.config = config
615
+ self.layer_idx = layer_idx
616
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
617
+
618
+ self.midblock_start = config.midblock_start
619
+ self.midblock_end = config.midblock_end
620
+ self.ratio = config.midblock_ratio if self.midblock_start <= layer_idx < self.midblock_end else 1.0
621
+
622
+ self.num_attention_heads = int(config.num_attention_heads * self.ratio)
623
+ self.num_key_value_heads = int(config.num_key_value_heads * self.ratio)
624
+
625
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
626
+ self.scaling = self.head_dim**-0.5
627
+ self.attention_dropout = config.attention_dropout
628
+ self.is_causal = True
629
+
630
+ self.q_proj = nn.Linear(
631
+ config.hidden_size, self.num_attention_heads * self.head_dim, bias=config.attention_bias
632
+ )
633
+ self.k_proj = nn.Linear(
634
+ config.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias
635
+ )
636
+ self.v_proj = nn.Linear(
637
+ config.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias
638
+ )
639
+ self.o_proj = nn.Linear(
640
+ self.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
641
+ )
642
+
643
+
644
+ class LlamablMLP(LlamaMLP):
645
+ def __init__(self, config, layer_idx: int):
646
+ super().__init__(config)
647
+ self.config = config
648
+ self.hidden_size = config.hidden_size
649
+
650
+ self.midblock_start = config.midblock_start
651
+ self.midblock_end = config.midblock_end
652
+ self.ratio = 0.5 if self.midblock_start <= layer_idx < self.midblock_end else 1.0
653
+
654
+ self.intermediate_size = int(config.intermediate_size * self.ratio)
655
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
656
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
657
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
658
+ self.act_fn = ACT2FN[config.hidden_act]
659
+
660
+
661
+ class LlamablDecoderLayer(LlamaDecoderLayer):
662
+ def __init__(self, config: LlamaConfig, layer_idx: int):
663
+ super().__init__(config, layer_idx=layer_idx)
664
+
665
+ self.self_attn = LlamablAttention(config=config, layer_idx=layer_idx)
666
+ self.mlp = LlamablMLP(config, layer_idx=layer_idx)
667
+
668
+
669
+ class LLama3blModel(LlamaModel):
670
+ def __init__(self, config):
671
+ super().__init__(config)
672
+ self.midblock_ratio = self.config.midblock_ratio
673
+ self.midblock_start = self.config.midblock_start
674
+ self.midblock_end = self.config.midblock_end
675
+
676
+ self.layers = nn.ModuleList(
677
+ [LlamablDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
678
+ )
679
+ # Initialize weights and apply final processing
680
+ self.post_init()
681
+
682
+
683
+ @auto_docstring(
684
+ custom_intro="""
685
+ Idefics3 model consisting of a SIGLIP vision encoder and Llama3 language decoder
686
+ """
687
+ )
688
+ class Idefics3Model(Idefics3PreTrainedModel):
689
+ def __init__(self, config: Idefics3blConfig):
690
+ super().__init__(config)
691
+ self.padding_idx = self.config.text_config.pad_token_id
692
+ self.vocab_size = self.config.text_config.vocab_size
693
+
694
+ self.vision_model = Idefics3VisionTransformer._from_config(config.vision_config)
695
+ self.connector = Idefics3Connector(config)
696
+ self.text_model = LLama3blModel(config.text_config)
697
+
698
+ self.image_seq_len = int(
699
+ ((config.vision_config.image_size // config.vision_config.patch_size) ** 2) / (config.scale_factor**2)
700
+ )
701
+ self.image_token_id = self.config.image_token_id
702
+
703
+ self._use_flash_attention_2 = config.text_config._attn_implementation == "flash_attention_2"
704
+
705
+ self.post_init()
706
+
707
+ # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2Model.enable_input_require_grads
708
+ def enable_input_require_grads(self):
709
+ """
710
+ Enables the gradients for the input embeddings.
711
+
712
+ This is useful for lora when using gradient checkpointing.
713
+ c.f. https://github.com/huggingface/peft/issues/1402#issuecomment-1913675032
714
+
715
+ Override to set output.requires_grad = True for both the decoder's and vision model's embeddings.
716
+ """
717
+
718
+ def get_lowest_module(module):
719
+ if len(list(module.children())) == 0:
720
+ # If the module has no children, it is a leaf module (e.g., Linear, Conv2d, etc.)
721
+ return module
722
+ else:
723
+ # Recursively call the function on each child module
724
+ return get_lowest_module(list(module.children())[0])
725
+
726
+ def make_inputs_require_grads(module, input, output):
727
+ output.requires_grad_(True)
728
+
729
+ self._text_require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads)
730
+ self._vision_require_grads_hook = get_lowest_module(self.vision_model).register_forward_hook(
731
+ make_inputs_require_grads
732
+ )
733
+
734
+ # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2Model.disable_input_require_grads
735
+ def disable_input_require_grads(self):
736
+ self._text_require_grads_hook.remove()
737
+ self._vision_require_grads_hook.remove()
738
+
739
+ # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2Model.get_input_embeddings
740
+ def get_input_embeddings(self):
741
+ return self.text_model.get_input_embeddings()
742
+
743
+ # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2Model.set_input_embeddings
744
+ def set_input_embeddings(self, value):
745
+ self.text_model.set_input_embeddings(value)
746
+
747
+ def inputs_merger(
748
+ self,
749
+ input_ids: torch.LongTensor,
750
+ inputs_embeds: Optional[torch.Tensor],
751
+ image_hidden_states: Optional[torch.Tensor],
752
+ ):
753
+ """
754
+ This method aims at merging the token embeddings with the image hidden states into one single sequence of vectors that are fed to the transformer LM.
755
+ The merging happens as follows:
756
+ - The text token sequence is: `tok_1 tok_2 tok_3 <fake_token_around_image> <image> <image> ... <image> <fake_token_around_image> tok_4`.
757
+ - We get the image hidden states for the image through the vision encoder and that hidden state, after a pixel shuffle operation, is then projected into the text embedding space.
758
+ We thus have a sequence of image hidden states of size (1, image_seq_len, hidden_dim), where 1 is for batch_size of 1 image and hidden_dim is the hidden_dim of the LM transformer.
759
+ - The merging happens so that we obtain the following sequence: `vector_tok_1 vector_tok_2 vector_tok_3 vector_fake_tok_around_image {sequence of image_seq_len image hidden states} vector_fake_toke_around_image vector_tok_4`. That sequence is fed to the LM.
760
+ - To fit the format of that sequence, `input_ids`, `input_embeds`, `attention_mask` are all 3 adapted to insert the image hidden states.
761
+ """
762
+ special_image_token_mask = input_ids == self.image_token_id
763
+ # Fixes RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.
764
+ new_inputs_embeds = inputs_embeds.clone()
765
+ # Flatten `image_hidden_states` if not flat yet
766
+ image_hidden_states = image_hidden_states.view(-1, image_hidden_states.shape[-1])
767
+ # cast to the dtype of the input_embeds to support quantized models
768
+ image_hidden_states = image_hidden_states.to(inputs_embeds.device, inputs_embeds.dtype)
769
+ new_inputs_embeds[special_image_token_mask] = image_hidden_states
770
+ return new_inputs_embeds
771
+
772
+ def get_image_features(self, pixel_values: torch.FloatTensor, pixel_attention_mask: torch.LongTensor = None):
773
+ """
774
+ Encodes images into continuous embeddings that can be forwarded to the language model.
775
+
776
+ Args:
777
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
778
+ The tensors corresponding to the input images.
779
+ pixel_attention_mask (`torch.LongTensor`, *optional*):
780
+ The attention mask indicating padded regions in the image.
781
+ """
782
+ batch_size, num_images, num_channels, height, width = pixel_values.shape
783
+ pixel_values = pixel_values.to(dtype=self.dtype) # fp16 compatibility
784
+ pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:])
785
+
786
+ # Remove padding images - padding images are full 0.
787
+ nb_values_per_image = pixel_values.shape[1:].numel()
788
+ real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image
789
+ pixel_values = pixel_values[real_images_inds].contiguous()
790
+
791
+ # Handle the vision attention mask
792
+ if pixel_attention_mask is None:
793
+ pixel_attention_mask = torch.ones(
794
+ size=(pixel_values.size(0), pixel_values.size(2), pixel_values.size(3)),
795
+ dtype=torch.bool,
796
+ device=pixel_values.device,
797
+ )
798
+ else:
799
+ # Remove padding images from the mask
800
+ pixel_attention_mask = pixel_attention_mask.view(batch_size * num_images, *pixel_attention_mask.shape[2:])
801
+ pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous()
802
+
803
+ patch_size = self.config.vision_config.patch_size
804
+ patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size)
805
+ patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size)
806
+ patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
807
+
808
+ # Get sequence from the vision encoder
809
+ image_hidden_states = self.vision_model(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask)
810
+ image_hidden_states.last_hidden_state
811
+
812
+ # Modality projection & resampling
813
+ image_hidden_states = self.connector(image_hidden_states.last_hidden_state)
814
+ return image_hidden_states
815
+
816
+ @can_return_tuple
817
+ @auto_docstring(
818
+ custom_intro="""
819
+ Inputs fed to the model can have an arbitrary number of images. To account for this, pixel_values fed to
820
+ the model have image padding -> (batch_size, max_num_images, 3, max_heights, max_widths) where
821
+ max_num_images is the maximum number of images among the batch_size samples in the batch.
822
+ Padding images are not needed beyond padding the pixel_values at the entrance of the model.
823
+ For efficiency, we only pass through the vision_model's forward the real images by
824
+ discarding the padding images i.e. pixel_values of size (image_batch_size, 3, height, width) where
825
+ image_batch_size would be 7 when num_images_per_sample=[1, 3, 1, 2] and max_num_images would be 3.
826
+ """
827
+ )
828
+ def forward(
829
+ self,
830
+ input_ids: Optional[torch.LongTensor] = None,
831
+ attention_mask: Optional[torch.Tensor] = None,
832
+ position_ids: Optional[torch.LongTensor] = None,
833
+ past_key_values: Optional[list[torch.FloatTensor]] = None,
834
+ inputs_embeds: Optional[torch.FloatTensor] = None,
835
+ pixel_values: Optional[torch.FloatTensor] = None,
836
+ pixel_attention_mask: Optional[torch.BoolTensor] = None,
837
+ image_hidden_states: Optional[torch.FloatTensor] = None,
838
+ use_cache: Optional[bool] = None,
839
+ output_attentions: Optional[bool] = None,
840
+ output_hidden_states: Optional[bool] = None,
841
+ cache_position: Optional[torch.LongTensor] = None,
842
+ return_dict: Optional[bool] = None,
843
+ **kwargs: Unpack[FlashAttentionKwargs],
844
+ ) -> Union[tuple, Idefics3BaseModelOutputWithPast]:
845
+ r"""
846
+ pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
847
+ Mask to avoid performing attention on padding pixel indices.
848
+ image_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
849
+ The hidden states of the image encoder after modality projection.
850
+ """
851
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
852
+ output_hidden_states = (
853
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
854
+ )
855
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
856
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
857
+
858
+ if self.training and self.text_model.gradient_checkpointing and use_cache:
859
+ logger.warning_once(
860
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
861
+ )
862
+ use_cache = False
863
+
864
+ # retrieve input_ids and inputs_embeds
865
+ if input_ids is not None:
866
+ batch_size, seq_length = input_ids.shape
867
+ elif inputs_embeds is not None:
868
+ batch_size, seq_length, _ = inputs_embeds.shape
869
+ else:
870
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
871
+
872
+ past_seen_tokens = 0
873
+ if use_cache:
874
+ if past_key_values is None:
875
+ past_key_values = DynamicCache()
876
+ past_seen_tokens = past_key_values.get_seq_length()
877
+
878
+ if inputs_embeds is None:
879
+ inputs_embeds = self.text_model.get_input_embeddings()(input_ids).to(self.device)
880
+
881
+ # START VISUAL INPUTS INTEGRATION
882
+ if pixel_values is not None and image_hidden_states is not None:
883
+ raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time")
884
+ elif pixel_values is not None:
885
+ image_hidden_states = self.get_image_features(pixel_values, pixel_attention_mask)
886
+ elif image_hidden_states is not None:
887
+ image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device)
888
+
889
+ if past_seen_tokens == 0 and input_ids is not None and image_hidden_states is not None:
890
+ # When we generate, we don't want to replace the potential image_token_id that we generated by images
891
+ # that simply don't exist
892
+ inputs_embeds = self.inputs_merger(
893
+ input_ids=input_ids,
894
+ inputs_embeds=inputs_embeds,
895
+ image_hidden_states=image_hidden_states,
896
+ )
897
+
898
+ outputs = self.text_model(
899
+ inputs_embeds=inputs_embeds,
900
+ attention_mask=attention_mask,
901
+ position_ids=position_ids,
902
+ past_key_values=past_key_values,
903
+ use_cache=use_cache,
904
+ output_attentions=output_attentions,
905
+ output_hidden_states=output_hidden_states,
906
+ cache_position=cache_position,
907
+ return_dict=True,
908
+ **kwargs,
909
+ )
910
+
911
+ return Idefics3BaseModelOutputWithPast(
912
+ last_hidden_state=outputs.last_hidden_state,
913
+ past_key_values=outputs.past_key_values,
914
+ hidden_states=outputs.hidden_states,
915
+ attentions=outputs.attentions,
916
+ image_hidden_states=image_hidden_states,
917
+ )
918
+
919
+
920
+ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
921
+
922
+
923
+ @auto_docstring(
924
+ custom_intro="""
925
+ The Idefics3 Model with a language modeling head. It is made up a SigLIP vision encoder, with a language modeling head on top.
926
+ """
927
+ )
928
+ class Idefics3blForConditionalGeneration(Idefics3PreTrainedModel, GenerationMixin):
929
+ _tied_weights_keys = ["lm_head.weight"]
930
+
931
+ # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.__init__ with Idefics2->Idefics3
932
+ def __init__(self, config):
933
+ super().__init__(config)
934
+ self.model = Idefics3Model(config)
935
+ self.image_token_id = self.config.image_token_id
936
+
937
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
938
+ self.vocab_size = config.text_config.vocab_size
939
+
940
+ # Initialize weights and apply final processing
941
+ self.post_init()
942
+
943
+ # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.enable_input_require_grads
944
+ def enable_input_require_grads(self):
945
+ """
946
+ Enables the gradients for the input embeddings. This is useful for fine-tuning adapter weights while keeping
947
+ the model weights fixed.
948
+ """
949
+
950
+ def make_inputs_require_grads(module, input, output):
951
+ output.requires_grad_(True)
952
+
953
+ self._text_require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads)
954
+ self._vision_require_grads_hook = self.model.vision_model.get_input_embeddings().register_forward_hook(
955
+ make_inputs_require_grads
956
+ )
957
+
958
+ # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.disable_input_require_grads
959
+ def disable_input_require_grads(self):
960
+ self._text_require_grads_hook.remove()
961
+ self._vision_require_grads_hook.remove()
962
+
963
+ # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.get_input_embeddings
964
+ def get_input_embeddings(self):
965
+ return self.model.text_model.get_input_embeddings()
966
+
967
+ # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.set_input_embeddings
968
+ def set_input_embeddings(self, value):
969
+ self.model.text_model.set_input_embeddings(value)
970
+
971
+ # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.get_output_embeddings
972
+ def get_output_embeddings(self):
973
+ return self.lm_head
974
+
975
+ # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.set_output_embeddings
976
+ def set_output_embeddings(self, new_embeddings):
977
+ self.lm_head = new_embeddings
978
+
979
+ def get_image_features(self, pixel_values: torch.FloatTensor, pixel_attention_mask: torch.LongTensor = None):
980
+ return self.model.get_image_features(pixel_values=pixel_values, pixel_attention_mask=pixel_attention_mask)
981
+
982
+ @can_return_tuple
983
+ @auto_docstring
984
+ def forward(
985
+ self,
986
+ input_ids: Optional[torch.LongTensor] = None,
987
+ attention_mask: Optional[torch.Tensor] = None,
988
+ position_ids: Optional[torch.LongTensor] = None,
989
+ past_key_values: Optional[list[torch.FloatTensor]] = None,
990
+ inputs_embeds: Optional[torch.FloatTensor] = None,
991
+ pixel_values: Optional[torch.FloatTensor] = None,
992
+ pixel_attention_mask: Optional[torch.BoolTensor] = None,
993
+ image_hidden_states: Optional[torch.FloatTensor] = None,
994
+ labels: Optional[torch.LongTensor] = None,
995
+ use_cache: Optional[bool] = None,
996
+ output_attentions: Optional[bool] = None,
997
+ output_hidden_states: Optional[bool] = None,
998
+ cache_position: Optional[torch.LongTensor] = None,
999
+ return_dict: Optional[bool] = None,
1000
+ logits_to_keep: Union[int, torch.Tensor] = 0,
1001
+ **kwargs: Unpack[KwargsForCausalLM],
1002
+ ) -> Union[tuple, Idefics3CausalLMOutputWithPast]:
1003
+ r"""
1004
+ pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
1005
+ Mask to avoid performing attention on padding pixel indices.
1006
+ image_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
1007
+ The hidden states of the image encoder after modality projection.
1008
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1009
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1010
+ config.vocab_size]` or `model.image_token_id` (where `model` is your instance of `Idefics3ForConditionalGeneration`).
1011
+ Tokens with indices set to `model.image_token_id` are ignored (masked), the loss is only
1012
+ computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1013
+
1014
+ Example:
1015
+
1016
+ ```python
1017
+ >>> import requests
1018
+ >>> import torch
1019
+ >>> from PIL import Image
1020
+ >>> from io import BytesIO
1021
+
1022
+ >>> from transformers import AutoProcessor, AutoModelForVision2Seq
1023
+ >>> from transformers.image_utils import load_image
1024
+
1025
+ >>> # Note that passing the image urls (instead of the actual pil images) to the processor is also possible
1026
+ >>> image1 = load_image("https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg")
1027
+ >>> image2 = load_image("https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg")
1028
+ >>> image3 = load_image("https://cdn.britannica.com/68/170868-050-8DDE8263/Golden-Gate-Bridge-San-Francisco.jpg")
1029
+
1030
+ >>> processor = AutoProcessor.from_pretrained("HuggingFaceM4/Idefics3-8B-Llama3")
1031
+ >>> model = AutoModelForVision2Seq.from_pretrained("HuggingFaceM4/Idefics3-8B-Llama3", torch_dtype=torch.bfloat16, device_map="auto")
1032
+
1033
+ >>> # Create inputs
1034
+ >>> messages = [
1035
+ ... {
1036
+ ... "role": "user",
1037
+ ... "content": [
1038
+ ... {"type": "image"},
1039
+ ... {"type": "text", "text": "In this image, we can see the city of New York, and more specifically the Statue of Liberty."},
1040
+ ... {"type": "image"},
1041
+ ... {"type": "text", "text": "What can we see in this image?"},
1042
+ ... ]
1043
+ ... },
1044
+ ... {
1045
+ ... "role": "user",
1046
+ ... "content": [
1047
+ ... {"type": "image"},
1048
+ ... {"type": "text", "text": "In which city is that bridge located?"},
1049
+ ... ]
1050
+ ... }
1051
+ ... ]
1052
+
1053
+ >>> prompts = [processor.apply_chat_template([message], add_generation_prompt=True) for message in messages]
1054
+ >>> images = [[image1, image2], [image3]]
1055
+ >>> inputs = processor(text=prompts, images=images, padding=True, return_tensors="pt").to(model.device)
1056
+
1057
+ >>> # Generate
1058
+ >>> generated_ids = model.generate(**inputs, max_new_tokens=256)
1059
+ >>> generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
1060
+
1061
+ >>> print(generated_texts[0])
1062
+ Assistant: There are buildings, trees, lights, and water visible in this image.
1063
+
1064
+ >>> print(generated_texts[1])
1065
+ Assistant: The bridge is in San Francisco.
1066
+ ```"""
1067
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1068
+ output_hidden_states = (
1069
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1070
+ )
1071
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1072
+
1073
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1074
+ outputs = self.model(
1075
+ input_ids=input_ids,
1076
+ attention_mask=attention_mask,
1077
+ position_ids=position_ids,
1078
+ past_key_values=past_key_values,
1079
+ inputs_embeds=inputs_embeds,
1080
+ pixel_values=pixel_values,
1081
+ pixel_attention_mask=pixel_attention_mask,
1082
+ image_hidden_states=image_hidden_states,
1083
+ use_cache=use_cache,
1084
+ output_attentions=output_attentions,
1085
+ output_hidden_states=output_hidden_states,
1086
+ cache_position=cache_position,
1087
+ return_dict=True,
1088
+ **kwargs,
1089
+ )
1090
+
1091
+ hidden_states = outputs[0]
1092
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
1093
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
1094
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
1095
+
1096
+ loss = None
1097
+ if labels is not None:
1098
+ loss = self.loss_function(
1099
+ logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
1100
+ )
1101
+
1102
+ return Idefics3CausalLMOutputWithPast(
1103
+ loss=loss,
1104
+ logits=logits,
1105
+ past_key_values=outputs.past_key_values,
1106
+ hidden_states=outputs.hidden_states,
1107
+ attentions=outputs.attentions,
1108
+ image_hidden_states=outputs.image_hidden_states,
1109
+ )
1110
+
1111
+ # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.prepare_inputs_for_generation
1112
+ def prepare_inputs_for_generation(
1113
+ self,
1114
+ input_ids,
1115
+ past_key_values=None,
1116
+ attention_mask=None,
1117
+ inputs_embeds=None,
1118
+ cache_position=None,
1119
+ pixel_values=None,
1120
+ pixel_attention_mask=None,
1121
+ image_hidden_states=None,
1122
+ logits_to_keep=None,
1123
+ **kwargs,
1124
+ ):
1125
+ # Overwritten -- there are mutually exclusive inputs (if the logic to make `image_hidden_states` take
1126
+ # precedence is moved to the model, we can remove this fn)
1127
+
1128
+ model_inputs = super().prepare_inputs_for_generation(
1129
+ input_ids,
1130
+ past_key_values=past_key_values,
1131
+ attention_mask=attention_mask,
1132
+ inputs_embeds=inputs_embeds,
1133
+ cache_position=cache_position,
1134
+ pixel_values=pixel_values,
1135
+ pixel_attention_mask=pixel_attention_mask,
1136
+ image_hidden_states=image_hidden_states,
1137
+ logits_to_keep=logits_to_keep,
1138
+ **kwargs,
1139
+ )
1140
+
1141
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1142
+ # but IDEFICS requires both ids and embeds to be present
1143
+ if inputs_embeds is not None and cache_position[0] == 0:
1144
+ model_inputs["input_ids"] = input_ids
1145
+
1146
+ if image_hidden_states is not None:
1147
+ model_inputs["pixel_values"] = None
1148
+ model_inputs["pixel_attention_mask"] = None
1149
+
1150
+ return model_inputs
1151
+
1152
+ # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration._update_model_kwargs_for_generation
1153
+ def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder, **kwargs):
1154
+ model_kwargs = super()._update_model_kwargs_for_generation(
1155
+ outputs=outputs,
1156
+ model_kwargs=model_kwargs,
1157
+ is_encoder_decoder=is_encoder_decoder,
1158
+ **kwargs,
1159
+ )
1160
+ # Get the precomputed image_hidden_states
1161
+ model_kwargs["image_hidden_states"] = outputs.image_hidden_states
1162
+ return model_kwargs
1163
+
1164
+
1165
+ __all__ = ["Idefics3ForConditionalGeneration", "Idefics3PreTrainedModel", "Idefics3Model", "Idefics3VisionTransformer"]
modeling_llama.py ADDED
@@ -0,0 +1,816 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ from typing import Callable, Optional, Union
21
+
22
+ import torch
23
+ import torch.utils.checkpoint
24
+ from torch import nn
25
+
26
+ from transformers.activations import ACT2FN
27
+ from transformers.cache_utils import Cache, DynamicCache
28
+ from transformers.generation import GenerationMixin
29
+ from transformers.integrations import use_kernel_forward_from_hub
30
+ from transformers.masking_utils import create_causal_mask
31
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
32
+ from transformers.modeling_layers import GradientCheckpointingLayer
33
+ from transformers.modeling_outputs import (
34
+ BaseModelOutputWithPast,
35
+ CausalLMOutputWithPast,
36
+ QuestionAnsweringModelOutput,
37
+ SequenceClassifierOutputWithPast,
38
+ TokenClassifierOutput,
39
+ )
40
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
41
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
42
+ from transformers.processing_utils import Unpack
43
+ from transformers.utils import LossKwargs, auto_docstring, can_return_tuple, logging
44
+ from transformers.models.llama.configuration_llama import LlamaConfig
45
+
46
+
47
+ logger = logging.get_logger(__name__)
48
+
49
+
50
+ @use_kernel_forward_from_hub("RMSNorm")
51
+ class LlamaRMSNorm(nn.Module):
52
+ def __init__(self, hidden_size, eps=1e-6):
53
+ """
54
+ LlamaRMSNorm is equivalent to T5LayerNorm
55
+ """
56
+ super().__init__()
57
+ self.weight = nn.Parameter(torch.ones(hidden_size))
58
+ self.variance_epsilon = eps
59
+
60
+ def forward(self, hidden_states):
61
+ input_dtype = hidden_states.dtype
62
+ hidden_states = hidden_states.to(torch.float32)
63
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
64
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
65
+ return self.weight * hidden_states.to(input_dtype)
66
+
67
+ def extra_repr(self):
68
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
69
+
70
+
71
+ class LlamaRotaryEmbedding(nn.Module):
72
+ def __init__(self, config: LlamaConfig, device=None):
73
+ super().__init__()
74
+ # BC: "rope_type" was originally "type"
75
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
76
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
77
+ else:
78
+ self.rope_type = "default"
79
+ self.max_seq_len_cached = config.max_position_embeddings
80
+ self.original_max_seq_len = config.max_position_embeddings
81
+
82
+ self.config = config
83
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
84
+
85
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
86
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
87
+ self.original_inv_freq = self.inv_freq
88
+
89
+ @torch.no_grad()
90
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
91
+ def forward(self, x, position_ids):
92
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
93
+ position_ids_expanded = position_ids[:, None, :].float()
94
+
95
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
96
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
97
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
98
+ emb = torch.cat((freqs, freqs), dim=-1)
99
+ cos = emb.cos() * self.attention_scaling
100
+ sin = emb.sin() * self.attention_scaling
101
+
102
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
103
+
104
+
105
+ def rotate_half(x):
106
+ """Rotates half the hidden dims of the input."""
107
+ x1 = x[..., : x.shape[-1] // 2]
108
+ x2 = x[..., x.shape[-1] // 2 :]
109
+ return torch.cat((-x2, x1), dim=-1)
110
+
111
+
112
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
113
+ """Applies Rotary Position Embedding to the query and key tensors.
114
+
115
+ Args:
116
+ q (`torch.Tensor`): The query tensor.
117
+ k (`torch.Tensor`): The key tensor.
118
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
119
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
120
+ position_ids (`torch.Tensor`, *optional*):
121
+ Deprecated and unused.
122
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
123
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
124
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
125
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
126
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
127
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
128
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
129
+ Returns:
130
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
131
+ """
132
+ cos = cos.unsqueeze(unsqueeze_dim)
133
+ sin = sin.unsqueeze(unsqueeze_dim)
134
+ q_embed = (q * cos) + (rotate_half(q) * sin)
135
+ k_embed = (k * cos) + (rotate_half(k) * sin)
136
+ return q_embed, k_embed
137
+
138
+
139
+ class LlamaMLP(nn.Module):
140
+ def __init__(self, config):
141
+ super().__init__()
142
+ self.config = config
143
+ self.hidden_size = config.hidden_size
144
+ self.intermediate_size = config.intermediate_size
145
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
146
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
147
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
148
+ self.act_fn = ACT2FN[config.hidden_act]
149
+
150
+ def forward(self, x):
151
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
152
+ return down_proj
153
+
154
+
155
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
156
+ """
157
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
158
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
159
+ """
160
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
161
+ if n_rep == 1:
162
+ return hidden_states
163
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
164
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
165
+
166
+
167
+ def eager_attention_forward(
168
+ module: nn.Module,
169
+ query: torch.Tensor,
170
+ key: torch.Tensor,
171
+ value: torch.Tensor,
172
+ attention_mask: Optional[torch.Tensor],
173
+ scaling: float,
174
+ dropout: float = 0.0,
175
+ **kwargs,
176
+ ):
177
+ key_states = repeat_kv(key, module.num_key_value_groups)
178
+ value_states = repeat_kv(value, module.num_key_value_groups)
179
+
180
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
181
+ if attention_mask is not None:
182
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
183
+ attn_weights = attn_weights + causal_mask
184
+
185
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
186
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
187
+ attn_output = torch.matmul(attn_weights, value_states)
188
+ attn_output = attn_output.transpose(1, 2).contiguous()
189
+
190
+ return attn_output, attn_weights
191
+
192
+
193
+ class LlamaAttention(nn.Module):
194
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
195
+
196
+ def __init__(self, config: LlamaConfig, layer_idx: int):
197
+ super().__init__()
198
+ self.config = config
199
+ self.layer_idx = layer_idx
200
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
201
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
202
+ self.scaling = self.head_dim**-0.5
203
+ self.attention_dropout = config.attention_dropout
204
+ self.is_causal = True
205
+
206
+ self.q_proj = nn.Linear(
207
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
208
+ )
209
+ self.k_proj = nn.Linear(
210
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
211
+ )
212
+ self.v_proj = nn.Linear(
213
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
214
+ )
215
+ self.o_proj = nn.Linear(
216
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
217
+ )
218
+
219
+ def forward(
220
+ self,
221
+ hidden_states: torch.Tensor,
222
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
223
+ attention_mask: Optional[torch.Tensor],
224
+ past_key_value: Optional[Cache] = None,
225
+ cache_position: Optional[torch.LongTensor] = None,
226
+ **kwargs: Unpack[FlashAttentionKwargs],
227
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
228
+ input_shape = hidden_states.shape[:-1]
229
+ hidden_shape = (*input_shape, -1, self.head_dim)
230
+
231
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
232
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
233
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
234
+
235
+ cos, sin = position_embeddings
236
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
237
+
238
+ if past_key_value is not None:
239
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
240
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
241
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
242
+
243
+ attention_interface: Callable = eager_attention_forward
244
+ if self.config._attn_implementation != "eager":
245
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
246
+
247
+ attn_output, attn_weights = attention_interface(
248
+ self,
249
+ query_states,
250
+ key_states,
251
+ value_states,
252
+ attention_mask,
253
+ dropout=0.0 if not self.training else self.attention_dropout,
254
+ scaling=self.scaling,
255
+ **kwargs,
256
+ )
257
+
258
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
259
+ attn_output = self.o_proj(attn_output)
260
+ return attn_output, attn_weights
261
+
262
+
263
+ class LlamaDecoderLayer(GradientCheckpointingLayer):
264
+ def __init__(self, config: LlamaConfig, layer_idx: int):
265
+ super().__init__()
266
+ self.hidden_size = config.hidden_size
267
+
268
+ self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx)
269
+
270
+ self.mlp = LlamaMLP(config)
271
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
272
+ self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
273
+
274
+ def forward(
275
+ self,
276
+ hidden_states: torch.Tensor,
277
+ attention_mask: Optional[torch.Tensor] = None,
278
+ position_ids: Optional[torch.LongTensor] = None,
279
+ past_key_value: Optional[Cache] = None,
280
+ output_attentions: Optional[bool] = False,
281
+ use_cache: Optional[bool] = False,
282
+ cache_position: Optional[torch.LongTensor] = None,
283
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
284
+ **kwargs: Unpack[FlashAttentionKwargs],
285
+ ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
286
+ residual = hidden_states
287
+ hidden_states = self.input_layernorm(hidden_states)
288
+
289
+ # Self Attention
290
+ hidden_states, self_attn_weights = self.self_attn(
291
+ hidden_states=hidden_states,
292
+ attention_mask=attention_mask,
293
+ position_ids=position_ids,
294
+ past_key_value=past_key_value,
295
+ output_attentions=output_attentions,
296
+ use_cache=use_cache,
297
+ cache_position=cache_position,
298
+ position_embeddings=position_embeddings,
299
+ **kwargs,
300
+ )
301
+ hidden_states = residual + hidden_states
302
+
303
+ # Fully Connected
304
+ residual = hidden_states
305
+ hidden_states = self.post_attention_layernorm(hidden_states)
306
+ hidden_states = self.mlp(hidden_states)
307
+ hidden_states = residual + hidden_states
308
+
309
+ outputs = (hidden_states,)
310
+ if output_attentions:
311
+ outputs += (self_attn_weights,)
312
+
313
+ return outputs
314
+
315
+
316
+ @auto_docstring
317
+ class LlamaPreTrainedModel(PreTrainedModel):
318
+ config_class = LlamaConfig
319
+ base_model_prefix = "model"
320
+ supports_gradient_checkpointing = True
321
+ _no_split_modules = ["LlamaDecoderLayer"]
322
+ _skip_keys_device_placement = ["past_key_values"]
323
+ _supports_flash_attn_3 = True
324
+ _supports_flash_attn_2 = True
325
+ _supports_sdpa = True
326
+ _supports_flex_attn = True
327
+ _supports_cache_class = True
328
+ _supports_quantized_cache = True
329
+ _supports_static_cache = True
330
+ _supports_attention_backend = True
331
+
332
+ def _init_weights(self, module):
333
+ std = self.config.initializer_range
334
+ if isinstance(module, nn.Linear):
335
+ module.weight.data.normal_(mean=0.0, std=std)
336
+ if module.bias is not None:
337
+ module.bias.data.zero_()
338
+ elif isinstance(module, nn.Embedding):
339
+ module.weight.data.normal_(mean=0.0, std=std)
340
+ if module.padding_idx is not None:
341
+ module.weight.data[module.padding_idx].zero_()
342
+ elif isinstance(module, LlamaRMSNorm):
343
+ module.weight.data.fill_(1.0)
344
+
345
+
346
+ @auto_docstring
347
+ class LlamaModel(LlamaPreTrainedModel):
348
+ def __init__(self, config: LlamaConfig):
349
+ super().__init__(config)
350
+ self.padding_idx = config.pad_token_id
351
+ self.vocab_size = config.vocab_size
352
+
353
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
354
+ self.layers = nn.ModuleList(
355
+ [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
356
+ )
357
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
358
+ self.rotary_emb = LlamaRotaryEmbedding(config=config)
359
+ self.gradient_checkpointing = False
360
+
361
+ # Initialize weights and apply final processing
362
+ self.post_init()
363
+
364
+ def get_input_embeddings(self):
365
+ return self.embed_tokens
366
+
367
+ def set_input_embeddings(self, value):
368
+ self.embed_tokens = value
369
+
370
+ @can_return_tuple
371
+ @auto_docstring
372
+ def forward(
373
+ self,
374
+ input_ids: Optional[torch.LongTensor] = None,
375
+ attention_mask: Optional[torch.Tensor] = None,
376
+ position_ids: Optional[torch.LongTensor] = None,
377
+ past_key_values: Optional[Cache] = None,
378
+ inputs_embeds: Optional[torch.FloatTensor] = None,
379
+ use_cache: Optional[bool] = None,
380
+ output_attentions: Optional[bool] = None,
381
+ output_hidden_states: Optional[bool] = None,
382
+ cache_position: Optional[torch.LongTensor] = None,
383
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
384
+ ) -> BaseModelOutputWithPast:
385
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
386
+ output_hidden_states = (
387
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
388
+ )
389
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
390
+
391
+ if (input_ids is None) ^ (inputs_embeds is not None):
392
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
393
+
394
+ if self.gradient_checkpointing and self.training and use_cache:
395
+ logger.warning_once(
396
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
397
+ )
398
+ use_cache = False
399
+
400
+ # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
401
+ if not isinstance(past_key_values, (type(None), Cache)):
402
+ raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
403
+
404
+ if inputs_embeds is None:
405
+ inputs_embeds = self.embed_tokens(input_ids)
406
+
407
+ if use_cache and past_key_values is None:
408
+ past_key_values = DynamicCache()
409
+
410
+ if cache_position is None:
411
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
412
+ cache_position = torch.arange(
413
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
414
+ )
415
+
416
+ if position_ids is None:
417
+ position_ids = cache_position.unsqueeze(0)
418
+
419
+ causal_mask = create_causal_mask(
420
+ config=self.config,
421
+ input_embeds=inputs_embeds,
422
+ attention_mask=attention_mask,
423
+ cache_position=cache_position,
424
+ past_key_values=past_key_values,
425
+ position_ids=position_ids,
426
+ )
427
+
428
+ hidden_states = inputs_embeds
429
+
430
+ # create position embeddings to be shared across the decoder layers
431
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
432
+
433
+ # decoder layers
434
+ all_hidden_states = () if output_hidden_states else None
435
+ all_self_attns = () if output_attentions else None
436
+
437
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
438
+ if output_hidden_states:
439
+ all_hidden_states += (hidden_states,)
440
+
441
+ layer_outputs = decoder_layer(
442
+ hidden_states,
443
+ attention_mask=causal_mask,
444
+ position_ids=position_ids,
445
+ past_key_value=past_key_values,
446
+ output_attentions=output_attentions,
447
+ use_cache=use_cache,
448
+ cache_position=cache_position,
449
+ position_embeddings=position_embeddings,
450
+ **flash_attn_kwargs,
451
+ )
452
+
453
+ hidden_states = layer_outputs[0]
454
+
455
+ if output_attentions:
456
+ all_self_attns += (layer_outputs[1],)
457
+
458
+ hidden_states = self.norm(hidden_states)
459
+
460
+ # add hidden states from the last decoder layer
461
+ if output_hidden_states:
462
+ all_hidden_states += (hidden_states,)
463
+
464
+ return BaseModelOutputWithPast(
465
+ last_hidden_state=hidden_states,
466
+ past_key_values=past_key_values if use_cache else None,
467
+ hidden_states=all_hidden_states,
468
+ attentions=all_self_attns,
469
+ )
470
+
471
+
472
+ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
473
+
474
+
475
+ @auto_docstring
476
+ class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
477
+ _tied_weights_keys = ["lm_head.weight"]
478
+ _tp_plan = {"lm_head": "colwise_rep"}
479
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
480
+
481
+ def __init__(self, config):
482
+ super().__init__(config)
483
+ self.model = LlamaModel(config)
484
+ self.vocab_size = config.vocab_size
485
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
486
+
487
+ # Initialize weights and apply final processing
488
+ self.post_init()
489
+
490
+ def get_input_embeddings(self):
491
+ return self.model.embed_tokens
492
+
493
+ def set_input_embeddings(self, value):
494
+ self.model.embed_tokens = value
495
+
496
+ def get_output_embeddings(self):
497
+ return self.lm_head
498
+
499
+ def set_output_embeddings(self, new_embeddings):
500
+ self.lm_head = new_embeddings
501
+
502
+ def set_decoder(self, decoder):
503
+ self.model = decoder
504
+
505
+ def get_decoder(self):
506
+ return self.model
507
+
508
+ @can_return_tuple
509
+ @auto_docstring
510
+ def forward(
511
+ self,
512
+ input_ids: Optional[torch.LongTensor] = None,
513
+ attention_mask: Optional[torch.Tensor] = None,
514
+ position_ids: Optional[torch.LongTensor] = None,
515
+ past_key_values: Optional[Cache] = None,
516
+ inputs_embeds: Optional[torch.FloatTensor] = None,
517
+ labels: Optional[torch.LongTensor] = None,
518
+ use_cache: Optional[bool] = None,
519
+ output_attentions: Optional[bool] = None,
520
+ output_hidden_states: Optional[bool] = None,
521
+ cache_position: Optional[torch.LongTensor] = None,
522
+ logits_to_keep: Union[int, torch.Tensor] = 0,
523
+ **kwargs: Unpack[KwargsForCausalLM],
524
+ ) -> CausalLMOutputWithPast:
525
+ r"""
526
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
527
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
528
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
529
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
530
+
531
+ Example:
532
+
533
+ ```python
534
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
535
+
536
+ >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
537
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
538
+
539
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
540
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
541
+
542
+ >>> # Generate
543
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
544
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
545
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
546
+ ```"""
547
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
548
+ output_hidden_states = (
549
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
550
+ )
551
+
552
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
553
+ outputs: BaseModelOutputWithPast = self.model(
554
+ input_ids=input_ids,
555
+ attention_mask=attention_mask,
556
+ position_ids=position_ids,
557
+ past_key_values=past_key_values,
558
+ inputs_embeds=inputs_embeds,
559
+ use_cache=use_cache,
560
+ output_attentions=output_attentions,
561
+ output_hidden_states=output_hidden_states,
562
+ cache_position=cache_position,
563
+ **kwargs,
564
+ )
565
+
566
+ hidden_states = outputs.last_hidden_state
567
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
568
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
569
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
570
+
571
+ loss = None
572
+ if labels is not None:
573
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
574
+
575
+ return CausalLMOutputWithPast(
576
+ loss=loss,
577
+ logits=logits,
578
+ past_key_values=outputs.past_key_values,
579
+ hidden_states=outputs.hidden_states,
580
+ attentions=outputs.attentions,
581
+ )
582
+
583
+
584
+ @auto_docstring(
585
+ custom_intro="""
586
+ The LLaMa Model transformer with a sequence classification head on top (linear layer).
587
+
588
+ [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
589
+ (e.g. GPT-2) do.
590
+
591
+ Since it does classification on the last token, it requires to know the position of the last token. If a
592
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
593
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
594
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
595
+ each row of the batch).
596
+ """
597
+ )
598
+ class LlamaForSequenceClassification(LlamaPreTrainedModel):
599
+ def __init__(self, config):
600
+ super().__init__(config)
601
+ self.num_labels = config.num_labels
602
+ self.model = LlamaModel(config)
603
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
604
+
605
+ # Initialize weights and apply final processing
606
+ self.post_init()
607
+
608
+ def get_input_embeddings(self):
609
+ return self.model.embed_tokens
610
+
611
+ def set_input_embeddings(self, value):
612
+ self.model.embed_tokens = value
613
+
614
+ @can_return_tuple
615
+ @auto_docstring
616
+ def forward(
617
+ self,
618
+ input_ids: Optional[torch.LongTensor] = None,
619
+ attention_mask: Optional[torch.Tensor] = None,
620
+ position_ids: Optional[torch.LongTensor] = None,
621
+ past_key_values: Optional[Cache] = None,
622
+ inputs_embeds: Optional[torch.FloatTensor] = None,
623
+ labels: Optional[torch.LongTensor] = None,
624
+ use_cache: Optional[bool] = None,
625
+ output_attentions: Optional[bool] = None,
626
+ output_hidden_states: Optional[bool] = None,
627
+ ) -> SequenceClassifierOutputWithPast:
628
+ r"""
629
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
630
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
631
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
632
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
633
+ """
634
+
635
+ transformer_outputs: BaseModelOutputWithPast = self.model(
636
+ input_ids,
637
+ attention_mask=attention_mask,
638
+ position_ids=position_ids,
639
+ past_key_values=past_key_values,
640
+ inputs_embeds=inputs_embeds,
641
+ use_cache=use_cache,
642
+ output_attentions=output_attentions,
643
+ output_hidden_states=output_hidden_states,
644
+ )
645
+ hidden_states = transformer_outputs.last_hidden_state
646
+ logits = self.score(hidden_states)
647
+
648
+ if input_ids is not None:
649
+ batch_size = input_ids.shape[0]
650
+ else:
651
+ batch_size = inputs_embeds.shape[0]
652
+
653
+ if self.config.pad_token_id is None and batch_size != 1:
654
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
655
+ if self.config.pad_token_id is None:
656
+ last_non_pad_token = -1
657
+ elif input_ids is not None:
658
+ # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
659
+ non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
660
+ token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
661
+ last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
662
+ else:
663
+ last_non_pad_token = -1
664
+ logger.warning_once(
665
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
666
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
667
+ )
668
+
669
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
670
+
671
+ loss = None
672
+ if labels is not None:
673
+ loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
674
+
675
+ return SequenceClassifierOutputWithPast(
676
+ loss=loss,
677
+ logits=pooled_logits,
678
+ past_key_values=transformer_outputs.past_key_values,
679
+ hidden_states=transformer_outputs.hidden_states,
680
+ attentions=transformer_outputs.attentions,
681
+ )
682
+
683
+
684
+ @auto_docstring
685
+ class LlamaForQuestionAnswering(LlamaPreTrainedModel):
686
+ base_model_prefix = "transformer"
687
+
688
+ # Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Llama
689
+ def __init__(self, config):
690
+ super().__init__(config)
691
+ self.transformer = LlamaModel(config)
692
+ self.qa_outputs = nn.Linear(config.hidden_size, 2)
693
+
694
+ # Initialize weights and apply final processing
695
+ self.post_init()
696
+
697
+ def get_input_embeddings(self):
698
+ return self.transformer.embed_tokens
699
+
700
+ def set_input_embeddings(self, value):
701
+ self.transformer.embed_tokens = value
702
+
703
+ @can_return_tuple
704
+ @auto_docstring
705
+ def forward(
706
+ self,
707
+ input_ids: Optional[torch.LongTensor] = None,
708
+ attention_mask: Optional[torch.Tensor] = None,
709
+ position_ids: Optional[torch.LongTensor] = None,
710
+ past_key_values: Optional[Cache] = None,
711
+ inputs_embeds: Optional[torch.FloatTensor] = None,
712
+ start_positions: Optional[torch.LongTensor] = None,
713
+ end_positions: Optional[torch.LongTensor] = None,
714
+ output_attentions: Optional[bool] = None,
715
+ output_hidden_states: Optional[bool] = None,
716
+ **kwargs,
717
+ ) -> QuestionAnsweringModelOutput:
718
+ outputs: BaseModelOutputWithPast = self.transformer(
719
+ input_ids,
720
+ attention_mask=attention_mask,
721
+ position_ids=position_ids,
722
+ past_key_values=past_key_values,
723
+ inputs_embeds=inputs_embeds,
724
+ output_attentions=output_attentions,
725
+ output_hidden_states=output_hidden_states,
726
+ )
727
+
728
+ sequence_output = outputs.last_hidden_state
729
+
730
+ logits = self.qa_outputs(sequence_output)
731
+ start_logits, end_logits = logits.split(1, dim=-1)
732
+ start_logits = start_logits.squeeze(-1).contiguous()
733
+ end_logits = end_logits.squeeze(-1).contiguous()
734
+
735
+ loss = None
736
+ if start_positions is not None and end_positions is not None:
737
+ loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)
738
+
739
+ return QuestionAnsweringModelOutput(
740
+ loss=loss,
741
+ start_logits=start_logits,
742
+ end_logits=end_logits,
743
+ hidden_states=outputs.hidden_states,
744
+ attentions=outputs.attentions,
745
+ )
746
+
747
+
748
+ @auto_docstring
749
+ class LlamaForTokenClassification(LlamaPreTrainedModel):
750
+ def __init__(self, config):
751
+ super().__init__(config)
752
+ self.num_labels = config.num_labels
753
+ self.model = LlamaModel(config)
754
+ if getattr(config, "classifier_dropout", None) is not None:
755
+ classifier_dropout = config.classifier_dropout
756
+ elif getattr(config, "hidden_dropout", None) is not None:
757
+ classifier_dropout = config.hidden_dropout
758
+ else:
759
+ classifier_dropout = 0.1
760
+ self.dropout = nn.Dropout(classifier_dropout)
761
+ self.score = nn.Linear(config.hidden_size, config.num_labels)
762
+
763
+ # Initialize weights and apply final processing
764
+ self.post_init()
765
+
766
+ def get_input_embeddings(self):
767
+ return self.model.embed_tokens
768
+
769
+ def set_input_embeddings(self, value):
770
+ self.model.embed_tokens = value
771
+
772
+ @can_return_tuple
773
+ @auto_docstring
774
+ def forward(
775
+ self,
776
+ input_ids: Optional[torch.LongTensor] = None,
777
+ attention_mask: Optional[torch.Tensor] = None,
778
+ position_ids: Optional[torch.LongTensor] = None,
779
+ past_key_values: Optional[Cache] = None,
780
+ inputs_embeds: Optional[torch.FloatTensor] = None,
781
+ labels: Optional[torch.LongTensor] = None,
782
+ use_cache: Optional[bool] = None,
783
+ output_attentions: Optional[bool] = None,
784
+ output_hidden_states: Optional[bool] = None,
785
+ ) -> TokenClassifierOutput:
786
+ r"""
787
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
788
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
789
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
790
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
791
+ """
792
+
793
+ outputs: BaseModelOutputWithPast = self.model(
794
+ input_ids,
795
+ attention_mask=attention_mask,
796
+ position_ids=position_ids,
797
+ past_key_values=past_key_values,
798
+ inputs_embeds=inputs_embeds,
799
+ use_cache=use_cache,
800
+ output_attentions=output_attentions,
801
+ output_hidden_states=output_hidden_states,
802
+ )
803
+ sequence_output = outputs.last_hidden_state
804
+ sequence_output = self.dropout(sequence_output)
805
+ logits = self.score(sequence_output)
806
+
807
+ loss = None
808
+ if labels is not None:
809
+ loss = self.loss_function(logits, labels, self.config)
810
+
811
+ return TokenClassifierOutput(
812
+ loss=loss,
813
+ logits=logits,
814
+ hidden_states=outputs.hidden_states,
815
+ attentions=outputs.attentions,
816
+ )