JinghuiLuAstronaut commited on
Commit
49cd5bc
·
verified ·
1 Parent(s): f394074

Add files using upload-large-folder tool

Browse files
Files changed (20) hide show
  1. LTA_openwebtext_dualt/logs/fullycoupled_outwd0p5_8gpu/lta_owt_gpt2cached_len1024_fullycoupled_rmsnorm_nobias_adamw_wd0p1_outwd0p5_nanogpt_tf32_ddit768x12_gbs512_8gpu_1m_20260514_215642.log +0 -0
  2. LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/cohere/modeling_cohere.py +530 -0
  3. LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/granitemoeshared/__init__.py +27 -0
  4. LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/granitemoeshared/configuration_granitemoeshared.py +95 -0
  5. LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/granitemoeshared/modeling_granitemoeshared.py +800 -0
  6. LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/granitemoeshared/modular_granitemoeshared.py +154 -0
  7. LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/instructblip/__init__.py +28 -0
  8. LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/instructblip/configuration_instructblip.py +186 -0
  9. LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/instructblip/modeling_instructblip.py +1405 -0
  10. LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/instructblip/processing_instructblip.py +123 -0
  11. LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/mllama/__init__.py +30 -0
  12. LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/mobilevit/modeling_mobilevit.py +963 -0
  13. LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/speecht5/configuration_speecht5.py +279 -0
  14. LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/speecht5/modeling_speecht5.py +0 -0
  15. LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/speecht5/number_normalizer.py +191 -0
  16. LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/speecht5/tokenization_speecht5.py +166 -0
  17. LTA_openwebtext_dualt/mini_owt_logdirichlet/runs/owt_t5_elftokenized_full_len1024_C1_to_1024_pow1_d768_l12_h12_gbs512_2x8gpu_50ep_lr4e3_ema0p9999_elfopt_not5_bottleneck128_unfixed_norm_stateprobadd_selfcond_ce_fast_20260610_020108/step_053000.pt +3 -0
  18. LTA_openwebtext_dualt/mini_owt_logdirichlet/runs/owt_t5_elftokenized_full_len1024_C1_to_1024_pow1_d768_l12_h12_gbs512_2x8gpu_50ep_lr4e3_ema0p9999_elfopt_not5_bottleneck128_unfixed_norm_stateprobadd_selfcond_ce_fast_20260610_020108/step_163000.pt +3 -0
  19. LTA_openwebtext_dualt/mini_owt_logdirichlet/runs/owt_t5_elftokenized_full_len1024_C1_to_1024_pow1_d768_l12_h12_gbs512_2x8gpu_50ep_lr4e3_ema0p9999_elfopt_not5_bottleneck128_unfixed_norm_stateprobadd_selfcond_ce_fast_20260610_020108/step_172000.pt +3 -0
  20. LTA_openwebtext_dualt/mini_owt_logdirichlet/runs/owt_t5_elftokenized_full_len1024_C1_to_1024_pow1_d768_l12_h12_gbs512_2x8gpu_50ep_lr4e3_ema0p9999_elfopt_not5_bottleneck128_unfixed_norm_stateprobadd_selfcond_ce_fast_20260610_020108/step_182000.pt +3 -0
LTA_openwebtext_dualt/logs/fullycoupled_outwd0p5_8gpu/lta_owt_gpt2cached_len1024_fullycoupled_rmsnorm_nobias_adamw_wd0p1_outwd0p5_nanogpt_tf32_ddit768x12_gbs512_8gpu_1m_20260514_215642.log ADDED
The diff for this file is too large to render. See raw diff
 
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/cohere/modeling_cohere.py ADDED
@@ -0,0 +1,530 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/cohere/modular_cohere.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_cohere.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # Copyright 2024 Cohere team. All rights reserved.
8
+ #
9
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
10
+ # and OPT implementations in this library. It has been modified from its
11
+ # original forms to accommodate minor architectural differences compared
12
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
13
+ #
14
+ # Licensed under the Apache License, Version 2.0 (the "License");
15
+ # you may not use this file except in compliance with the License.
16
+ # You may obtain a copy of the License at
17
+ #
18
+ # http://www.apache.org/licenses/LICENSE-2.0
19
+ #
20
+ # Unless required by applicable law or agreed to in writing, software
21
+ # distributed under the License is distributed on an "AS IS" BASIS,
22
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
23
+ # See the License for the specific language governing permissions and
24
+ # limitations under the License.
25
+
26
+ # This file is based on the LLama model definition file in transformers
27
+
28
+
29
+ from collections.abc import Callable
30
+ from typing import Optional
31
+
32
+ import torch
33
+ from torch import nn
34
+
35
+ from ...activations import ACT2FN
36
+ from ...cache_utils import Cache, DynamicCache
37
+ from ...generation import GenerationMixin
38
+ from ...integrations import use_kernelized_func
39
+ from ...masking_utils import create_causal_mask
40
+ from ...modeling_flash_attention_utils import FlashAttentionKwargs
41
+ from ...modeling_layers import GradientCheckpointingLayer
42
+ from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
43
+ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
44
+ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
45
+ from ...processing_utils import Unpack
46
+ from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
47
+ from ...utils.generic import maybe_autocast, merge_with_config_defaults
48
+ from ...utils.output_capturing import capture_outputs
49
+ from .configuration_cohere import CohereConfig
50
+
51
+
52
+ class CohereLayerNorm(nn.Module):
53
+ def __init__(self, hidden_size=None, eps=1e-5, bias=False):
54
+ """The hidden size can be a tuple or an int. The tuple is used for QKNorm to normalize across head_dim"""
55
+ super().__init__()
56
+ self.weight = nn.Parameter(torch.ones(hidden_size))
57
+ self.variance_epsilon = eps
58
+
59
+ def forward(self, hidden_states):
60
+ input_dtype = hidden_states.dtype
61
+ hidden_states = hidden_states.to(torch.float32)
62
+ mean = hidden_states.mean(-1, keepdim=True)
63
+ variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
64
+ hidden_states = (hidden_states - mean) * torch.rsqrt(variance + self.variance_epsilon)
65
+ hidden_states = self.weight.to(torch.float32) * hidden_states
66
+ return hidden_states.to(input_dtype)
67
+
68
+
69
+ class CohereRotaryEmbedding(nn.Module):
70
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
71
+
72
+ def __init__(self, config: CohereConfig, device=None):
73
+ super().__init__()
74
+ self.max_seq_len_cached = config.max_position_embeddings
75
+ self.original_max_seq_len = config.max_position_embeddings
76
+
77
+ self.config = config
78
+
79
+ self.rope_type = self.config.rope_parameters["rope_type"]
80
+ rope_init_fn: Callable = self.compute_default_rope_parameters
81
+ if self.rope_type != "default":
82
+ rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
83
+ inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
84
+
85
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
86
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
87
+
88
+ @staticmethod
89
+ def compute_default_rope_parameters(
90
+ config: CohereConfig | None = None,
91
+ device: Optional["torch.device"] = None,
92
+ seq_len: int | None = None,
93
+ ) -> tuple["torch.Tensor", float]:
94
+ """
95
+ Computes the inverse frequencies according to the original RoPE implementation
96
+ Args:
97
+ config ([`~transformers.PreTrainedConfig`]):
98
+ The model configuration.
99
+ device (`torch.device`):
100
+ The device to use for initialization of the inverse frequencies.
101
+ seq_len (`int`, *optional*):
102
+ The current sequence length. Unused for this type of RoPE.
103
+ Returns:
104
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
105
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
106
+ """
107
+ base = config.rope_parameters["rope_theta"]
108
+ dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
109
+
110
+ attention_factor = 1.0 # Unused in this type of RoPE
111
+
112
+ # Compute the inverse frequencies
113
+ inv_freq = 1.0 / (
114
+ base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
115
+ )
116
+ return inv_freq, attention_factor
117
+
118
+ @torch.no_grad()
119
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
120
+ def forward(self, x, position_ids):
121
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
122
+ position_ids_expanded = position_ids[:, None, :].float()
123
+
124
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
125
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
126
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
127
+ emb = torch.repeat_interleave(freqs, 2, dim=-1) # diff from Llama: we interleave() instead of cat()
128
+ cos = emb.cos() * self.attention_scaling
129
+ sin = emb.sin() * self.attention_scaling
130
+
131
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
132
+
133
+
134
+ class CohereMLP(nn.Module):
135
+ def __init__(self, config):
136
+ super().__init__()
137
+ self.config = config
138
+ self.hidden_size = config.hidden_size
139
+ self.intermediate_size = config.intermediate_size
140
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
141
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
142
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
143
+ self.act_fn = ACT2FN[config.hidden_act]
144
+
145
+ def forward(self, x):
146
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
147
+ return down_proj
148
+
149
+
150
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
151
+ """
152
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
153
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
154
+ """
155
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
156
+ if n_rep == 1:
157
+ return hidden_states
158
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
159
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
160
+
161
+
162
+ def eager_attention_forward(
163
+ module: nn.Module,
164
+ query: torch.Tensor,
165
+ key: torch.Tensor,
166
+ value: torch.Tensor,
167
+ attention_mask: torch.Tensor | None,
168
+ scaling: float,
169
+ dropout: float = 0.0,
170
+ **kwargs: Unpack[TransformersKwargs],
171
+ ):
172
+ key_states = repeat_kv(key, module.num_key_value_groups)
173
+ value_states = repeat_kv(value, module.num_key_value_groups)
174
+
175
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
176
+ if attention_mask is not None:
177
+ attn_weights = attn_weights + attention_mask
178
+
179
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
180
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
181
+ attn_output = torch.matmul(attn_weights, value_states)
182
+ attn_output = attn_output.transpose(1, 2).contiguous()
183
+
184
+ return attn_output, attn_weights
185
+
186
+
187
+ def rotate_half(x):
188
+ # Split and rotate. Note that this function is different from e.g. Llama.
189
+ x1 = x[..., ::2]
190
+ x2 = x[..., 1::2]
191
+ rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2)
192
+ return rot_x
193
+
194
+
195
+ def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
196
+ """Applies Rotary Position Embedding to the query and key tensors.
197
+
198
+ Args:
199
+ q (`torch.Tensor`): The query tensor.
200
+ k (`torch.Tensor`): The key tensor.
201
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
202
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
203
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
204
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
205
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
206
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
207
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
208
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
209
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
210
+ Returns:
211
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
212
+ """
213
+ dtype = q.dtype
214
+ q = q.float()
215
+ k = k.float()
216
+ cos = cos.unsqueeze(unsqueeze_dim)
217
+ sin = sin.unsqueeze(unsqueeze_dim)
218
+ q_embed = (q * cos) + (rotate_half(q) * sin)
219
+ k_embed = (k * cos) + (rotate_half(k) * sin)
220
+ return q_embed.to(dtype=dtype), k_embed.to(dtype=dtype)
221
+
222
+
223
+ @use_kernelized_func(apply_rotary_pos_emb)
224
+ class CohereAttention(nn.Module):
225
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
226
+
227
+ def __init__(self, config: CohereConfig, layer_idx: int | None = None):
228
+ super().__init__()
229
+ self.config = config
230
+ self.layer_idx = layer_idx
231
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
232
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
233
+ self.scaling = self.head_dim**-0.5
234
+ self.attention_dropout = config.attention_dropout
235
+ self.is_causal = True
236
+
237
+ self.q_proj = nn.Linear(
238
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
239
+ )
240
+ self.k_proj = nn.Linear(
241
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
242
+ )
243
+ self.v_proj = nn.Linear(
244
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
245
+ )
246
+ self.o_proj = nn.Linear(
247
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
248
+ )
249
+ self.use_qk_norm = config.use_qk_norm
250
+ if self.use_qk_norm:
251
+ # When sharding the model using Tensor Parallelism, need to be careful to use n_local_heads
252
+ self.q_norm = CohereLayerNorm(
253
+ hidden_size=(config.num_attention_heads, self.head_dim), eps=config.layer_norm_eps
254
+ )
255
+ self.k_norm = CohereLayerNorm(
256
+ hidden_size=(config.num_key_value_heads, self.head_dim), eps=config.layer_norm_eps
257
+ )
258
+
259
+ def forward(
260
+ self,
261
+ hidden_states: torch.Tensor,
262
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
263
+ attention_mask: torch.Tensor | None,
264
+ past_key_values: Cache | None = None,
265
+ **kwargs: Unpack[FlashAttentionKwargs],
266
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
267
+ input_shape = hidden_states.shape[:-1]
268
+ hidden_shape = (*input_shape, -1, self.head_dim)
269
+
270
+ query_states = self.q_proj(hidden_states).view(hidden_shape)
271
+ key_states = self.k_proj(hidden_states).view(hidden_shape)
272
+ value_states = self.v_proj(hidden_states).view(hidden_shape)
273
+
274
+ if self.use_qk_norm: # main diff from Llama
275
+ query_states = self.q_norm(query_states)
276
+ key_states = self.k_norm(key_states)
277
+
278
+ query_states = query_states.transpose(1, 2)
279
+ key_states = key_states.transpose(1, 2)
280
+ value_states = value_states.transpose(1, 2)
281
+
282
+ cos, sin = position_embeddings
283
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
284
+
285
+ if past_key_values is not None:
286
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
287
+
288
+ attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
289
+ self.config._attn_implementation, eager_attention_forward
290
+ )
291
+
292
+ attn_output, attn_weights = attention_interface(
293
+ self,
294
+ query_states,
295
+ key_states,
296
+ value_states,
297
+ attention_mask,
298
+ dropout=0.0 if not self.training else self.attention_dropout,
299
+ scaling=self.scaling,
300
+ **kwargs,
301
+ )
302
+
303
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
304
+ attn_output = self.o_proj(attn_output)
305
+ return attn_output, attn_weights
306
+
307
+
308
+ class CohereDecoderLayer(GradientCheckpointingLayer):
309
+ def __init__(self, config: CohereConfig, layer_idx: int):
310
+ super().__init__()
311
+ self.hidden_size = config.hidden_size
312
+ self.self_attn = CohereAttention(config=config, layer_idx=layer_idx)
313
+ self.mlp = CohereMLP(config)
314
+ self.input_layernorm = CohereLayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps)
315
+
316
+ def forward(
317
+ self,
318
+ hidden_states: torch.Tensor,
319
+ attention_mask: torch.Tensor | None = None,
320
+ position_ids: torch.LongTensor | None = None,
321
+ past_key_values: Cache | None = None,
322
+ use_cache: bool | None = False,
323
+ position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
324
+ **kwargs: Unpack[FlashAttentionKwargs],
325
+ ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
326
+ """
327
+ Args:
328
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
329
+ attention_mask (`torch.FloatTensor`, *optional*):
330
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
331
+ query_sequence_length, key_sequence_length)` if default attention is used.
332
+ past_key_values (`Cache`, *optional*): cached past key and value projection states
333
+ output_attentions (`bool`, *optional*):
334
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
335
+ returned tensors for more detail.
336
+ use_cache (`bool`, *optional*):
337
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
338
+ (see `past_key_values`).
339
+ position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
340
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
341
+ with `head_dim` being the embedding dimension of each attention head.
342
+ """
343
+ residual = hidden_states
344
+ hidden_states = self.input_layernorm(hidden_states)
345
+
346
+ hidden_states_attention, _ = self.self_attn(
347
+ hidden_states=hidden_states,
348
+ attention_mask=attention_mask,
349
+ position_ids=position_ids,
350
+ past_key_values=past_key_values,
351
+ use_cache=use_cache,
352
+ position_embeddings=position_embeddings,
353
+ **kwargs,
354
+ )
355
+
356
+ hidden_states_mlp = self.mlp(hidden_states)
357
+ hidden_states = residual + hidden_states_attention + hidden_states_mlp
358
+ return hidden_states
359
+
360
+
361
+ @auto_docstring
362
+ class CoherePreTrainedModel(PreTrainedModel):
363
+ config: CohereConfig
364
+ base_model_prefix = "model"
365
+ supports_gradient_checkpointing = True
366
+ _no_split_modules = ["CohereDecoderLayer"]
367
+ _skip_keys_device_placement = ["past_key_values"]
368
+ _supports_flash_attn = True
369
+ _supports_sdpa = True
370
+ _supports_flex_attn = True
371
+
372
+ _can_compile_fullgraph = True
373
+ _supports_attention_backend = True
374
+ _can_record_outputs = {
375
+ "hidden_states": CohereDecoderLayer,
376
+ "attentions": CohereAttention,
377
+ }
378
+
379
+
380
+ @auto_docstring
381
+ class CohereModel(CoherePreTrainedModel):
382
+ def __init__(self, config: CohereConfig):
383
+ super().__init__(config)
384
+ self.padding_idx = config.pad_token_id
385
+ self.vocab_size = config.vocab_size
386
+
387
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
388
+ self.layers = nn.ModuleList(
389
+ [CohereDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
390
+ )
391
+ self.norm = CohereLayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps)
392
+ self.rotary_emb = CohereRotaryEmbedding(config=config)
393
+ self.gradient_checkpointing = False
394
+
395
+ # Initialize weights and apply final processing
396
+ self.post_init()
397
+
398
+ @merge_with_config_defaults
399
+ @capture_outputs
400
+ @auto_docstring
401
+ def forward(
402
+ self,
403
+ input_ids: torch.LongTensor | None = None,
404
+ attention_mask: torch.Tensor | None = None,
405
+ position_ids: torch.LongTensor | None = None,
406
+ past_key_values: Cache | None = None,
407
+ inputs_embeds: torch.FloatTensor | None = None,
408
+ use_cache: bool | None = None,
409
+ **kwargs: Unpack[TransformersKwargs],
410
+ ) -> BaseModelOutputWithPast:
411
+ if (input_ids is None) ^ (inputs_embeds is not None):
412
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
413
+
414
+ if inputs_embeds is None:
415
+ inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
416
+
417
+ if use_cache and past_key_values is None:
418
+ past_key_values = DynamicCache(config=self.config)
419
+
420
+ if position_ids is None:
421
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
422
+ position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
423
+ position_ids = position_ids.unsqueeze(0)
424
+
425
+ causal_mask = create_causal_mask(
426
+ config=self.config,
427
+ inputs_embeds=inputs_embeds,
428
+ attention_mask=attention_mask,
429
+ past_key_values=past_key_values,
430
+ position_ids=position_ids,
431
+ )
432
+
433
+ hidden_states = inputs_embeds
434
+ position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
435
+
436
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
437
+ hidden_states = decoder_layer(
438
+ hidden_states,
439
+ attention_mask=causal_mask,
440
+ position_embeddings=position_embeddings,
441
+ position_ids=position_ids,
442
+ past_key_values=past_key_values,
443
+ use_cache=use_cache,
444
+ **kwargs,
445
+ )
446
+
447
+ hidden_states = self.norm(hidden_states)
448
+ return BaseModelOutputWithPast(
449
+ last_hidden_state=hidden_states,
450
+ past_key_values=past_key_values,
451
+ )
452
+
453
+
454
+ @auto_docstring
455
+ class CohereForCausalLM(CoherePreTrainedModel, GenerationMixin):
456
+ _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
457
+ _tp_plan = {"lm_head": "colwise_gather_output"}
458
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
459
+
460
+ def __init__(self, config):
461
+ super().__init__(config)
462
+ self.model = CohereModel(config)
463
+ self.vocab_size = config.vocab_size
464
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
465
+ self.logit_scale = config.logit_scale
466
+ self.tie_word_embeddings = config.tie_word_embeddings
467
+
468
+ # Initialize weights and apply final processing
469
+ self.post_init()
470
+
471
+ @can_return_tuple
472
+ @auto_docstring
473
+ def forward(
474
+ self,
475
+ input_ids: torch.LongTensor | None = None,
476
+ attention_mask: torch.Tensor | None = None,
477
+ position_ids: torch.LongTensor | None = None,
478
+ past_key_values: Cache | None = None,
479
+ inputs_embeds: torch.FloatTensor | None = None,
480
+ labels: torch.LongTensor | None = None,
481
+ use_cache: bool | None = None,
482
+ logits_to_keep: int | torch.Tensor = 0,
483
+ **kwargs: Unpack[TransformersKwargs],
484
+ ) -> CausalLMOutputWithPast:
485
+ r"""
486
+ Example:
487
+
488
+ ```python
489
+ >> from transformers import AutoTokenizer, CohereForCausalLM
490
+
491
+ >> model = CohereForCausalLM.from_pretrained("CohereForAI/c4ai-command-r-v01")
492
+ >> tokenizer = AutoTokenizer.from_pretrained("CohereForAI/c4ai-command-r-v01")
493
+
494
+ >> prompt = "Hey, are you conscious? Can you talk to me?"
495
+ >> inputs = tokenizer(prompt, return_tensors="pt")
496
+
497
+ >> # Generate
498
+ >> generate_ids = model.generate(inputs.input_ids, max_length=30)
499
+ >> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
500
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
501
+ ```"""
502
+ outputs: BaseModelOutputWithPast = self.model(
503
+ input_ids=input_ids,
504
+ attention_mask=attention_mask,
505
+ position_ids=position_ids,
506
+ past_key_values=past_key_values,
507
+ inputs_embeds=inputs_embeds,
508
+ use_cache=use_cache,
509
+ **kwargs,
510
+ )
511
+
512
+ hidden_states = outputs.last_hidden_state
513
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
514
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
515
+ logits = logits * self.logit_scale # main diff from Llama
516
+
517
+ loss = None
518
+ if labels is not None:
519
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
520
+
521
+ return CausalLMOutputWithPast(
522
+ loss=loss,
523
+ logits=logits,
524
+ past_key_values=outputs.past_key_values,
525
+ hidden_states=outputs.hidden_states,
526
+ attentions=outputs.attentions,
527
+ )
528
+
529
+
530
+ __all__ = ["CohereForCausalLM", "CohereModel", "CoherePreTrainedModel"]
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/granitemoeshared/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_granitemoeshared import *
22
+ from .modeling_granitemoeshared import *
23
+ else:
24
+ import sys
25
+
26
+ _file = globals()["__file__"]
27
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/granitemoeshared/configuration_granitemoeshared.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
4
+ # and OPT implementations in this library. It has been modified from its
5
+ # original forms to accommodate minor architectural differences compared
6
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
7
+ #
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+ """GraniteMoeShared model configuration"""
20
+
21
+ from huggingface_hub.dataclasses import strict
22
+
23
+ from ...configuration_utils import PreTrainedConfig
24
+ from ...modeling_rope_utils import RopeParameters
25
+ from ...utils import auto_docstring
26
+
27
+
28
+ @auto_docstring(checkpoint="ibm-granite/granite-speech-3.2-8b")
29
+ @strict
30
+ class GraniteMoeSharedConfig(PreTrainedConfig):
31
+ r"""
32
+ embedding_multiplier (`float`, *optional*, defaults to 1.0):
33
+ embedding multiplier
34
+ logits_scaling (`float`, *optional*, defaults to 1.0):
35
+ divisor for output logits
36
+ residual_multiplier (`float`, *optional*, defaults to 1.0):
37
+ residual multiplier
38
+ attention_multiplier (`float`, *optional*, defaults to 1.0):
39
+ attention multiplier
40
+ shared_intermediate_size (`int`, *optional*, defaults to 1024):
41
+ intermediate size for shared experts.
42
+
43
+ ```python
44
+ >>> from transformers import GraniteMoeSharedModel, GraniteMoeSharedConfig
45
+
46
+ >>> # Initializing a GraniteMoeShared granitemoe-3b style configuration
47
+ >>> configuration = GraniteMoeSharedConfig()
48
+
49
+ >>> # Initializing a model from the granitemoe-7b style configuration
50
+ >>> model = GraniteMoeSharedModel(configuration)
51
+
52
+ >>> # Accessing the model configuration
53
+ >>> configuration = model.config
54
+ ```
55
+ """
56
+
57
+ model_type = "granitemoeshared"
58
+ keys_to_ignore_at_inference = ["past_key_values"]
59
+
60
+ vocab_size: int = 32000
61
+ hidden_size: int = 4096
62
+ intermediate_size: int = 11008
63
+ num_hidden_layers: int = 32
64
+ num_attention_heads: int = 32
65
+ num_key_value_heads: int | None = None
66
+ hidden_act: str = "silu"
67
+ max_position_embeddings: int = 2048
68
+ initializer_range: float = 0.02
69
+ rms_norm_eps: float = 1e-6
70
+ use_cache: bool = True
71
+ pad_token_id: int | None = None
72
+ bos_token_id: int | None = 1
73
+ eos_token_id: int | list[int] | None = 2
74
+ tie_word_embeddings: bool = False
75
+ rope_parameters: RopeParameters | dict | None = None
76
+ attention_bias: bool = False
77
+ attention_dropout: float | int | None = 0.0
78
+ embedding_multiplier: float | int | None = 1.0
79
+ logits_scaling: float | int | None = 1.0
80
+ residual_multiplier: float | int | None = 1.0
81
+ attention_multiplier: float | int | None = 1.0
82
+ num_local_experts: int | None = 8
83
+ num_experts_per_tok: int | None = 2
84
+ output_router_logits: bool | None = False
85
+ router_aux_loss_coef: float | None = 0.001
86
+ shared_intermediate_size: int = 0
87
+
88
+ def __post_init__(self, **kwargs):
89
+ if self.num_key_value_heads is None:
90
+ self.num_key_value_heads = self.num_attention_heads
91
+
92
+ super().__post_init__(**kwargs)
93
+
94
+
95
+ __all__ = ["GraniteMoeSharedConfig"]
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/granitemoeshared/modeling_granitemoeshared.py ADDED
@@ -0,0 +1,800 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/granitemoeshared/modular_granitemoeshared.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_granitemoeshared.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved.
8
+ #
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+ from collections.abc import Callable
22
+ from typing import Optional, TypedDict
23
+
24
+ import torch
25
+ from torch import nn
26
+ from torch.nn import functional as F
27
+
28
+ from ... import initialization as init
29
+ from ...activations import ACT2FN
30
+ from ...cache_utils import Cache, DynamicCache
31
+ from ...generation import GenerationMixin
32
+ from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
33
+ from ...masking_utils import create_causal_mask
34
+ from ...modeling_layers import GradientCheckpointingLayer
35
+ from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
36
+ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
37
+ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
38
+ from ...processing_utils import Unpack
39
+ from ...utils import TransformersKwargs, auto_docstring
40
+ from ...utils.generic import can_return_tuple, maybe_autocast, merge_with_config_defaults
41
+ from ...utils.output_capturing import capture_outputs
42
+ from .configuration_granitemoeshared import GraniteMoeSharedConfig
43
+
44
+
45
+ class GraniteFlashAttentionKwargs(TypedDict, total=False):
46
+ """
47
+ Keyword arguments for advanced Flash Attention, causal-conv1d, and mamba_ssm kernel usage.
48
+ Use cases include padding-free training and fewer `torch.compile` graph breaks.
49
+
50
+ cu_seq_lens_q (`torch.LongTensor`):
51
+ Gets cumulative sequence length for query state.
52
+ cu_seq_lens_k (`torch.LongTensor`):
53
+ Gets cumulative sequence length for key state.
54
+ max_length_q (`int`):
55
+ Maximum sequence length for query state.
56
+ max_length_k (`int`):
57
+ Maximum sequence length for key state.
58
+ seq_idx (`torch.IntTensor):
59
+ Index of each packed sequence.
60
+ """
61
+
62
+ cu_seq_lens_q: torch.LongTensor
63
+ cu_seq_lens_k: torch.LongTensor
64
+ max_length_q: int
65
+ max_length_k: int
66
+ seq_idx: torch.IntTensor
67
+
68
+
69
+ class GraniteMoeSharedMLP(nn.Module):
70
+ """
71
+ MLP layer for shared experts
72
+
73
+ Args:
74
+ config:
75
+ Configuration object with model hyperparameters.
76
+ """
77
+
78
+ def __init__(self, config: GraniteMoeSharedConfig):
79
+ super().__init__()
80
+
81
+ self.input_size = config.hidden_size
82
+ self.hidden_size = config.shared_intermediate_size
83
+ self.activation = ACT2FN[config.hidden_act]
84
+ self.input_linear = nn.Linear(self.input_size, self.hidden_size * 2, bias=False)
85
+ self.output_linear = nn.Linear(self.hidden_size, self.input_size, bias=False)
86
+
87
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
88
+ hidden_states = self.input_linear(hidden_states)
89
+ chunked_hidden_states = hidden_states.chunk(2, dim=-1)
90
+ hidden_states = self.activation(chunked_hidden_states[0]) * chunked_hidden_states[1]
91
+ hidden_states = self.output_linear(hidden_states)
92
+ return hidden_states
93
+
94
+
95
+ @use_kernel_forward_from_hub("RMSNorm")
96
+ class GraniteMoeSharedRMSNorm(nn.Module):
97
+ def __init__(self, hidden_size, eps: float = 1e-6) -> None:
98
+ """
99
+ GraniteMoeSharedRMSNorm is equivalent to T5LayerNorm
100
+ """
101
+ super().__init__()
102
+ self.weight = nn.Parameter(torch.ones(hidden_size))
103
+ self.variance_epsilon = eps
104
+
105
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
106
+ input_dtype = hidden_states.dtype
107
+ hidden_states = hidden_states.to(torch.float32)
108
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
109
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
110
+ return self.weight * hidden_states.to(input_dtype)
111
+
112
+ def extra_repr(self):
113
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
114
+
115
+
116
+ class GraniteMoeSharedParallelExperts(nn.Module):
117
+ def __init__(self, num_experts: int, input_size: int, output_size: int) -> None:
118
+ """
119
+ Initialize the GraniteMoeSharedParallelExperts module.
120
+ The experts weights are stored in [num_experts, output_size, input_size] format. Such that it's compatible with
121
+ many MoE libraries, such as [Megablock](https://github.com/databricks/megablocks) and
122
+ [ScatterMoE](https://github.com/shawntan/scattermoe), as well as the
123
+ [MoE kernel](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/fused_moe/fused_moe.py)
124
+ used in vllm.
125
+
126
+ Args:
127
+ num_experts (int):
128
+ Number of experts.
129
+ input_size (int):
130
+ Size of the input.
131
+ output_size (int):
132
+ Size of the output.
133
+ """
134
+ super().__init__()
135
+ self.weight = nn.Parameter(torch.empty(num_experts, output_size, input_size))
136
+ self.num_experts = num_experts
137
+ self.input_size = input_size
138
+ self.output_size = output_size
139
+
140
+ def forward(self, inputs, expert_size):
141
+ """
142
+ Forward pass of the GraniteMoeSharedParallelExperts module.
143
+
144
+ Args:
145
+ inputs (Tensor):
146
+ Input tensor.
147
+ expert_size:
148
+ Expert size information.
149
+
150
+ Returns:
151
+ Tensor: Output tensor.
152
+ """
153
+ input_list = inputs.split(expert_size, dim=0)
154
+ output_list = []
155
+ for i in range(self.num_experts):
156
+ output_list.append(F.linear(input_list[i], self.weight[i]))
157
+ results = torch.cat(output_list, dim=0)
158
+ return results
159
+
160
+
161
+ class GraniteMoeSharedTopKGating(nn.Module):
162
+ def __init__(self, input_size: int, num_experts: int, top_k: int):
163
+ """
164
+ Initialize the top-k gating mechanism.
165
+
166
+ Args:
167
+ input_size (`int`):
168
+ Size of the input.
169
+ num_experts (`int`):
170
+ Number of experts.
171
+ top_k (`int`):
172
+ Number of top experts to select.
173
+ """
174
+ super().__init__()
175
+
176
+ self.num_experts = num_experts
177
+ self.input_size = input_size
178
+ self.top_k = top_k
179
+
180
+ self.layer = nn.Linear(input_size, num_experts, bias=False)
181
+
182
+ def forward(self, hidden_states):
183
+ # compute the top_k routing decision
184
+ logits = self.layer(hidden_states).float() # [batch_size x seq_len, num_experts]
185
+ top_k_logits, top_k_indices = logits.topk(self.top_k, dim=1) # [num_tokens, top_k]
186
+ top_k_gates = torch.softmax(top_k_logits, dim=1).type_as(hidden_states) # [num_tokens, top_k]
187
+
188
+ # compute number of input given to each expert
189
+ zeros = torch.zeros(
190
+ [top_k_gates.size(0), self.num_experts], dtype=top_k_gates.dtype, device=top_k_gates.device
191
+ ) # [num_tokens, num_experts]
192
+ gates = zeros.scatter(1, top_k_indices, 1) # [num_tokens, num_experts]
193
+ expert_size = gates.long().sum(0) # [num_experts,]
194
+ # (This cause torch.compile to fail with `torch._dynamo.exc.Unsupported: Backend compiler failed with a fake tensor exception at`)
195
+ # (and `DataDependentOutputException`)
196
+ expert_size = expert_size.tolist()
197
+
198
+ # sort and group input tokens according to expert assignment
199
+ top_k_experts = top_k_indices.flatten() # [num_tokens * top_k]
200
+ _, index_sorted_experts = top_k_experts.sort(0) # [num_tokens * top_k]
201
+ batch_index = index_sorted_experts.div(self.top_k, rounding_mode="trunc") # [num_tokens * top_k]
202
+
203
+ # gather the gate values for grouped input tokens
204
+ top_k_gates = top_k_gates.flatten() # [num_tokens * top_k]
205
+ batch_gates = top_k_gates[index_sorted_experts] # [num_tokens * top_k]
206
+
207
+ return index_sorted_experts, batch_index, batch_gates, expert_size, logits
208
+
209
+
210
+ class GraniteMoeSharedMoE(nn.Module):
211
+ """
212
+ A Sparsely gated mixture of experts layer with 1-layer Feed-Forward networks as experts.
213
+
214
+ Args:
215
+ config:
216
+ Configuration object with model hyperparameters.
217
+ """
218
+
219
+ def __init__(self, config: GraniteMoeSharedConfig):
220
+ super().__init__()
221
+
222
+ self.input_size = config.hidden_size
223
+ self.hidden_size = config.intermediate_size
224
+ self.activation = ACT2FN[config.hidden_act]
225
+ self.input_linear = GraniteMoeSharedParallelExperts(
226
+ config.num_local_experts, self.input_size, self.hidden_size * 2
227
+ )
228
+ self.output_linear = GraniteMoeSharedParallelExperts(
229
+ config.num_local_experts, self.hidden_size, self.input_size
230
+ )
231
+
232
+ self.router = GraniteMoeSharedTopKGating(
233
+ input_size=self.input_size,
234
+ num_experts=config.num_local_experts,
235
+ top_k=config.num_experts_per_tok,
236
+ )
237
+
238
+ def forward(self, layer_input):
239
+ bsz, length, emb_size = layer_input.size()
240
+ layer_input = layer_input.reshape(-1, emb_size)
241
+ _, batch_index, batch_gates, expert_size, _ = self.router(layer_input)
242
+
243
+ expert_inputs = layer_input[batch_index]
244
+ hidden_states = self.input_linear(expert_inputs, expert_size)
245
+ chunked_hidden_states = hidden_states.chunk(2, dim=-1)
246
+ hidden_states = self.activation(chunked_hidden_states[0]) * chunked_hidden_states[1]
247
+ expert_outputs = self.output_linear(hidden_states, expert_size)
248
+
249
+ expert_outputs = expert_outputs * batch_gates[:, None]
250
+
251
+ zeros = torch.zeros((bsz * length, self.input_size), dtype=expert_outputs.dtype, device=expert_outputs.device)
252
+ layer_output = zeros.index_add(0, batch_index, expert_outputs)
253
+ layer_output = layer_output.view(bsz, length, self.input_size)
254
+ return layer_output
255
+
256
+
257
+ def rotate_half(x):
258
+ """Rotates half the hidden dims of the input."""
259
+ x1 = x[..., : x.shape[-1] // 2]
260
+ x2 = x[..., x.shape[-1] // 2 :]
261
+ return torch.cat((-x2, x1), dim=-1)
262
+
263
+
264
+ @use_kernel_func_from_hub("rotary_pos_emb")
265
+ def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
266
+ """Applies Rotary Position Embedding to the query and key tensors.
267
+
268
+ Args:
269
+ q (`torch.Tensor`): The query tensor.
270
+ k (`torch.Tensor`): The key tensor.
271
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
272
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
273
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
274
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
275
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
276
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
277
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
278
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
279
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
280
+ Returns:
281
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
282
+ """
283
+ cos = cos.unsqueeze(unsqueeze_dim)
284
+ sin = sin.unsqueeze(unsqueeze_dim)
285
+ q_embed = (q * cos) + (rotate_half(q) * sin)
286
+ k_embed = (k * cos) + (rotate_half(k) * sin)
287
+ return q_embed, k_embed
288
+
289
+
290
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
291
+ """
292
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
293
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
294
+ """
295
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
296
+ if n_rep == 1:
297
+ return hidden_states
298
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
299
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
300
+
301
+
302
+ def eager_attention_forward(
303
+ module: nn.Module,
304
+ query: torch.Tensor,
305
+ key: torch.Tensor,
306
+ value: torch.Tensor,
307
+ attention_mask: torch.Tensor | None,
308
+ scaling: float,
309
+ dropout: float = 0.0,
310
+ **kwargs: Unpack[TransformersKwargs],
311
+ ):
312
+ key_states = repeat_kv(key, module.num_key_value_groups)
313
+ value_states = repeat_kv(value, module.num_key_value_groups)
314
+
315
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
316
+ if attention_mask is not None:
317
+ attn_weights = attn_weights + attention_mask
318
+
319
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
320
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
321
+ attn_output = torch.matmul(attn_weights, value_states)
322
+ attn_output = attn_output.transpose(1, 2).contiguous()
323
+
324
+ return attn_output, attn_weights
325
+
326
+
327
+ @use_kernelized_func(apply_rotary_pos_emb)
328
+ class GraniteMoeSharedAttention(nn.Module):
329
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
330
+
331
+ def __init__(self, config: GraniteMoeSharedConfig, layer_idx: int):
332
+ super().__init__()
333
+ self.config = config
334
+ self.layer_idx = layer_idx
335
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
336
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
337
+ self.scaling = config.attention_multiplier # Only diff with llama
338
+ self.attention_dropout = config.attention_dropout
339
+ self.is_causal = True
340
+
341
+ self.q_proj = nn.Linear(
342
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
343
+ )
344
+ self.k_proj = nn.Linear(
345
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
346
+ )
347
+ self.v_proj = nn.Linear(
348
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
349
+ )
350
+ self.o_proj = nn.Linear(
351
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
352
+ )
353
+
354
+ def forward(
355
+ self,
356
+ hidden_states: torch.Tensor,
357
+ position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
358
+ attention_mask: torch.Tensor | None = None,
359
+ past_key_values: Cache | None = None,
360
+ **kwargs: Unpack[TransformersKwargs],
361
+ ) -> tuple[torch.Tensor, torch.Tensor]:
362
+ input_shape = hidden_states.shape[:-1]
363
+ hidden_shape = (*input_shape, -1, self.head_dim)
364
+
365
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
366
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
367
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
368
+
369
+ cos, sin = position_embeddings
370
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
371
+
372
+ if past_key_values is not None:
373
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
374
+
375
+ attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
376
+ self.config._attn_implementation, eager_attention_forward
377
+ )
378
+
379
+ attn_output, attn_weights = attention_interface(
380
+ self,
381
+ query_states,
382
+ key_states,
383
+ value_states,
384
+ attention_mask,
385
+ dropout=0.0 if not self.training else self.attention_dropout,
386
+ scaling=self.scaling,
387
+ **kwargs,
388
+ )
389
+
390
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
391
+ attn_output = self.o_proj(attn_output)
392
+ return attn_output, attn_weights
393
+
394
+
395
+ class GraniteMoeSharedDecoderLayer(GradientCheckpointingLayer):
396
+ def __init__(self, config: GraniteMoeSharedConfig, layer_idx: int):
397
+ super().__init__()
398
+ self.hidden_size = config.hidden_size
399
+ self.self_attn = GraniteMoeSharedAttention(config=config, layer_idx=layer_idx)
400
+ self.input_layernorm = GraniteMoeSharedRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
401
+ self.post_attention_layernorm = GraniteMoeSharedRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
402
+ self.block_sparse_moe = GraniteMoeSharedMoE(config)
403
+ self.residual_multiplier = config.residual_multiplier # Only diff with mixtral!
404
+ self.shared_mlp = None if config.shared_intermediate_size == 0 else GraniteMoeSharedMLP(config)
405
+
406
+ def forward(
407
+ self,
408
+ hidden_states: torch.Tensor,
409
+ attention_mask: torch.Tensor | None = None,
410
+ position_ids: torch.LongTensor | None = None,
411
+ past_key_values: Cache | None = None,
412
+ output_attentions: bool | None = False,
413
+ use_cache: bool | None = False,
414
+ position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
415
+ **kwargs: Unpack[GraniteFlashAttentionKwargs],
416
+ ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
417
+ residual = hidden_states
418
+ hidden_states = self.input_layernorm(hidden_states)
419
+
420
+ # Self Attention
421
+ hidden_states, _ = self.self_attn(
422
+ hidden_states=hidden_states,
423
+ attention_mask=attention_mask,
424
+ position_ids=position_ids,
425
+ past_key_values=past_key_values,
426
+ output_attentions=output_attentions,
427
+ use_cache=use_cache,
428
+ position_embeddings=position_embeddings,
429
+ **kwargs,
430
+ )
431
+
432
+ hidden_states = residual + hidden_states * self.residual_multiplier
433
+
434
+ residual = hidden_states
435
+ hidden_states = self.post_attention_layernorm(hidden_states)
436
+ moe_hidden_states = self.block_sparse_moe(hidden_states)
437
+
438
+ if self.shared_mlp is None:
439
+ hidden_states = moe_hidden_states
440
+ else:
441
+ hidden_states = moe_hidden_states + self.shared_mlp(hidden_states)
442
+ hidden_states = residual + hidden_states * self.residual_multiplier
443
+ return hidden_states
444
+
445
+
446
+ @auto_docstring
447
+ class GraniteMoeSharedPreTrainedModel(PreTrainedModel):
448
+ config: GraniteMoeSharedConfig
449
+ base_model_prefix = "model"
450
+ supports_gradient_checkpointing = True
451
+ _no_split_modules = ["GraniteMoeSharedDecoderLayer"]
452
+ _skip_keys_device_placement = ["past_key_values"]
453
+ _supports_flash_attn = True
454
+ _supports_sdpa = True
455
+ _supports_flex_attn = True
456
+ _can_compile_fullgraph = False # TopK gating fails fullgraph compilation at "expert_size = expert_size.tolist()"
457
+ _supports_attention_backend = True
458
+ _can_record_outputs = {
459
+ "hidden_states": GraniteMoeSharedDecoderLayer,
460
+ "attentions": GraniteMoeSharedAttention,
461
+ }
462
+
463
+ @torch.no_grad()
464
+ def _init_weights(self, module):
465
+ super()._init_weights(module)
466
+ if isinstance(module, GraniteMoeSharedParallelExperts):
467
+ init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
468
+
469
+
470
+ class GraniteMoeSharedRotaryEmbedding(nn.Module):
471
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
472
+
473
+ def __init__(self, config: GraniteMoeSharedConfig, device=None):
474
+ super().__init__()
475
+ self.max_seq_len_cached = config.max_position_embeddings
476
+ self.original_max_seq_len = config.max_position_embeddings
477
+
478
+ self.config = config
479
+
480
+ self.rope_type = self.config.rope_parameters["rope_type"]
481
+ rope_init_fn: Callable = self.compute_default_rope_parameters
482
+ if self.rope_type != "default":
483
+ rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
484
+ inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
485
+
486
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
487
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
488
+
489
+ @staticmethod
490
+ def compute_default_rope_parameters(
491
+ config: GraniteMoeSharedConfig | None = None,
492
+ device: Optional["torch.device"] = None,
493
+ seq_len: int | None = None,
494
+ ) -> tuple["torch.Tensor", float]:
495
+ """
496
+ Computes the inverse frequencies according to the original RoPE implementation
497
+ Args:
498
+ config ([`~transformers.PreTrainedConfig`]):
499
+ The model configuration.
500
+ device (`torch.device`):
501
+ The device to use for initialization of the inverse frequencies.
502
+ seq_len (`int`, *optional*):
503
+ The current sequence length. Unused for this type of RoPE.
504
+ Returns:
505
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
506
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
507
+ """
508
+ base = config.rope_parameters["rope_theta"]
509
+ dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
510
+
511
+ attention_factor = 1.0 # Unused in this type of RoPE
512
+
513
+ # Compute the inverse frequencies
514
+ inv_freq = 1.0 / (
515
+ base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
516
+ )
517
+ return inv_freq, attention_factor
518
+
519
+ @torch.no_grad()
520
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
521
+ def forward(self, x, position_ids):
522
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
523
+ position_ids_expanded = position_ids[:, None, :].float()
524
+
525
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
526
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
527
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
528
+ emb = torch.cat((freqs, freqs), dim=-1)
529
+ cos = emb.cos() * self.attention_scaling
530
+ sin = emb.sin() * self.attention_scaling
531
+
532
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
533
+
534
+
535
+ @auto_docstring
536
+ class GraniteMoeSharedModel(GraniteMoeSharedPreTrainedModel):
537
+ def __init__(self, config: GraniteMoeSharedConfig):
538
+ super().__init__(config)
539
+ self.padding_idx = config.pad_token_id
540
+ self.vocab_size = config.vocab_size
541
+
542
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
543
+ self.layers = nn.ModuleList(
544
+ [GraniteMoeSharedDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
545
+ )
546
+ self.norm = GraniteMoeSharedRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
547
+ self.rotary_emb = GraniteMoeSharedRotaryEmbedding(config=config)
548
+ self.gradient_checkpointing = False
549
+ self.embedding_multiplier = config.embedding_multiplier
550
+
551
+ # Initialize weights and apply final processing
552
+ self.post_init()
553
+
554
+ @merge_with_config_defaults
555
+ @capture_outputs
556
+ @auto_docstring
557
+ def forward(
558
+ self,
559
+ input_ids: torch.LongTensor | None = None,
560
+ attention_mask: torch.Tensor | None = None,
561
+ position_ids: torch.LongTensor | None = None,
562
+ past_key_values: Cache | None = None,
563
+ inputs_embeds: torch.FloatTensor | None = None,
564
+ use_cache: bool | None = None,
565
+ **kwargs: Unpack[TransformersKwargs],
566
+ ) -> MoeModelOutputWithPast:
567
+ if (input_ids is None) ^ (inputs_embeds is not None):
568
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
569
+
570
+ if use_cache and past_key_values is None:
571
+ past_key_values = DynamicCache(config=self.config)
572
+
573
+ if inputs_embeds is None:
574
+ inputs_embeds = self.embed_tokens(input_ids)
575
+
576
+ if position_ids is None:
577
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
578
+ position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
579
+ position_ids = position_ids.unsqueeze(0)
580
+
581
+ causal_mask = create_causal_mask( # ONLY DIFF WITH MIXTRAL: NO SLIDING
582
+ config=self.config,
583
+ inputs_embeds=inputs_embeds,
584
+ attention_mask=attention_mask,
585
+ past_key_values=past_key_values,
586
+ position_ids=position_ids,
587
+ )
588
+ inputs_embeds = inputs_embeds * self.embedding_multiplier
589
+ hidden_states = inputs_embeds
590
+
591
+ # create position embeddings to be shared across the decoder layers
592
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
593
+
594
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
595
+ hidden_states = decoder_layer(
596
+ hidden_states,
597
+ position_embeddings=position_embeddings,
598
+ attention_mask=causal_mask,
599
+ position_ids=position_ids,
600
+ past_key_values=past_key_values,
601
+ use_cache=use_cache,
602
+ **kwargs,
603
+ )
604
+
605
+ hidden_states = self.norm(hidden_states)
606
+
607
+ return MoeModelOutputWithPast( # only diff with Mistral is the output type, we need MoE
608
+ last_hidden_state=hidden_states,
609
+ past_key_values=past_key_values,
610
+ )
611
+
612
+
613
+ def load_balancing_loss_func(
614
+ gate_logits: torch.Tensor | tuple[torch.Tensor] | None,
615
+ num_experts: int | None = None,
616
+ top_k=2,
617
+ attention_mask: torch.Tensor | None = None,
618
+ ) -> torch.Tensor | int:
619
+ r"""
620
+ Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
621
+
622
+ See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss
623
+ function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
624
+ experts is too unbalanced.
625
+
626
+ Args:
627
+ gate_logits:
628
+ Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
629
+ shape [batch_size X sequence_length, num_experts].
630
+ num_experts:
631
+ Number of experts
632
+ top_k:
633
+ The number of experts to route per-token, can be also interpreted as the `top-k` routing
634
+ parameter.
635
+ attention_mask (`torch.Tensor`, *optional*):
636
+ The attention_mask used in forward function
637
+ shape [batch_size X sequence_length] if not None.
638
+
639
+ Returns:
640
+ The auxiliary loss.
641
+ """
642
+ if gate_logits is None or not isinstance(gate_logits, tuple):
643
+ return 0
644
+
645
+ if isinstance(gate_logits, tuple):
646
+ compute_device = gate_logits[0].device
647
+ concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
648
+
649
+ routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
650
+
651
+ _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
652
+
653
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
654
+
655
+ if attention_mask is None:
656
+ # Compute the percentage of tokens routed to each experts
657
+ tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
658
+
659
+ # Compute the average probability of routing to these experts
660
+ router_prob_per_expert = torch.mean(routing_weights, dim=0)
661
+ else:
662
+ batch_size, sequence_length = attention_mask.shape
663
+ num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
664
+
665
+ # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
666
+ expert_attention_mask = (
667
+ attention_mask[None, :, :, None, None]
668
+ .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
669
+ .reshape(-1, top_k, num_experts)
670
+ .to(compute_device)
671
+ )
672
+
673
+ # Compute the percentage of tokens routed to each experts
674
+ tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
675
+ expert_attention_mask, dim=0
676
+ )
677
+
678
+ # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
679
+ router_per_expert_attention_mask = (
680
+ attention_mask[None, :, :, None]
681
+ .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
682
+ .reshape(-1, num_experts)
683
+ .to(compute_device)
684
+ )
685
+
686
+ # Compute the average probability of routing to these experts
687
+ router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
688
+ router_per_expert_attention_mask, dim=0
689
+ )
690
+
691
+ overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
692
+ return overall_loss * num_experts
693
+
694
+
695
+ @auto_docstring
696
+ class GraniteMoeSharedForCausalLM(GraniteMoeSharedPreTrainedModel, GenerationMixin):
697
+ _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
698
+ _tp_plan = {"lm_head": "colwise_gather_output"}
699
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
700
+
701
+ def __init__(self, config: GraniteMoeSharedConfig):
702
+ super().__init__(config)
703
+ self.model = GraniteMoeSharedModel(config)
704
+ self.vocab_size = config.vocab_size
705
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
706
+ self.router_aux_loss_coef = config.router_aux_loss_coef
707
+ self.num_experts = config.num_local_experts
708
+ self.num_experts_per_tok = config.num_experts_per_tok
709
+ self.logits_scaling = config.logits_scaling
710
+
711
+ # Initialize weights and apply final processing
712
+ self.post_init()
713
+
714
+ @auto_docstring
715
+ @can_return_tuple
716
+ def forward(
717
+ self,
718
+ input_ids: torch.LongTensor | None = None,
719
+ attention_mask: torch.Tensor | None = None,
720
+ position_ids: torch.LongTensor | None = None,
721
+ past_key_values: Cache | None = None,
722
+ inputs_embeds: torch.FloatTensor | None = None,
723
+ labels: torch.LongTensor | None = None,
724
+ output_router_logits: bool | None = None,
725
+ logits_to_keep: int | torch.Tensor = 0,
726
+ **kwargs,
727
+ ) -> tuple | MoeCausalLMOutputWithPast:
728
+ r"""
729
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
730
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
731
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
732
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
733
+
734
+ Example:
735
+
736
+ ```python
737
+ >>> from transformers import AutoTokenizer, GraniteMoeSharedForCausalLM
738
+
739
+ >>> model = GraniteMoeSharedForCausalLM.from_pretrained("ibm/PowerMoE-3b")
740
+ >>> tokenizer = AutoTokenizer.from_pretrained("ibm/PowerMoE-3b")
741
+
742
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
743
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
744
+
745
+ >>> # Generate
746
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
747
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
748
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
749
+ ```"""
750
+ output_router_logits = (
751
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
752
+ )
753
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
754
+ outputs = self.model(
755
+ input_ids=input_ids,
756
+ attention_mask=attention_mask,
757
+ position_ids=position_ids,
758
+ past_key_values=past_key_values,
759
+ inputs_embeds=inputs_embeds,
760
+ **kwargs,
761
+ )
762
+
763
+ # Only compute necessary logits
764
+ hidden_states = outputs.last_hidden_state
765
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
766
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
767
+ logits = logits / self.config.logits_scaling
768
+
769
+ loss = None
770
+ if labels is not None:
771
+ # Flatten the tokens
772
+ loss = self.loss_function(
773
+ logits,
774
+ labels,
775
+ vocab_size=self.config.vocab_size,
776
+ **kwargs,
777
+ )
778
+
779
+ aux_loss = None
780
+ if output_router_logits:
781
+ aux_loss = load_balancing_loss_func(
782
+ outputs.router_logits,
783
+ self.num_experts,
784
+ self.num_experts_per_tok,
785
+ attention_mask,
786
+ )
787
+ if labels is not None:
788
+ loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
789
+ return MoeCausalLMOutputWithPast(
790
+ loss=loss,
791
+ aux_loss=aux_loss,
792
+ logits=logits,
793
+ past_key_values=outputs.past_key_values,
794
+ hidden_states=outputs.hidden_states,
795
+ attentions=outputs.attentions,
796
+ router_logits=outputs.router_logits,
797
+ )
798
+
799
+
800
+ __all__ = ["GraniteMoeSharedForCausalLM", "GraniteMoeSharedModel", "GraniteMoeSharedPreTrainedModel"]
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/granitemoeshared/modular_granitemoeshared.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from typing import TypedDict
16
+
17
+ import torch
18
+ from torch import nn
19
+
20
+ from ...activations import ACT2FN
21
+ from ...cache_utils import Cache
22
+ from ...processing_utils import Unpack
23
+ from ...utils import logging
24
+ from ..granitemoe.modeling_granitemoe import (
25
+ GraniteMoeDecoderLayer,
26
+ GraniteMoeForCausalLM,
27
+ GraniteMoeModel,
28
+ GraniteMoePreTrainedModel,
29
+ )
30
+ from .configuration_granitemoeshared import GraniteMoeSharedConfig
31
+
32
+
33
+ logger = logging.get_logger(__name__)
34
+
35
+
36
+ class GraniteFlashAttentionKwargs(TypedDict, total=False):
37
+ """
38
+ Keyword arguments for advanced Flash Attention, causal-conv1d, and mamba_ssm kernel usage.
39
+ Use cases include padding-free training and fewer `torch.compile` graph breaks.
40
+
41
+ cu_seq_lens_q (`torch.LongTensor`):
42
+ Gets cumulative sequence length for query state.
43
+ cu_seq_lens_k (`torch.LongTensor`):
44
+ Gets cumulative sequence length for key state.
45
+ max_length_q (`int`):
46
+ Maximum sequence length for query state.
47
+ max_length_k (`int`):
48
+ Maximum sequence length for key state.
49
+ seq_idx (`torch.IntTensor):
50
+ Index of each packed sequence.
51
+ """
52
+
53
+ cu_seq_lens_q: torch.LongTensor
54
+ cu_seq_lens_k: torch.LongTensor
55
+ max_length_q: int
56
+ max_length_k: int
57
+ seq_idx: torch.IntTensor
58
+
59
+
60
+ class GraniteMoeSharedMLP(nn.Module):
61
+ """
62
+ MLP layer for shared experts
63
+
64
+ Args:
65
+ config:
66
+ Configuration object with model hyperparameters.
67
+ """
68
+
69
+ def __init__(self, config: GraniteMoeSharedConfig):
70
+ super().__init__()
71
+
72
+ self.input_size = config.hidden_size
73
+ self.hidden_size = config.shared_intermediate_size
74
+ self.activation = ACT2FN[config.hidden_act]
75
+ self.input_linear = nn.Linear(self.input_size, self.hidden_size * 2, bias=False)
76
+ self.output_linear = nn.Linear(self.hidden_size, self.input_size, bias=False)
77
+
78
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
79
+ hidden_states = self.input_linear(hidden_states)
80
+ chunked_hidden_states = hidden_states.chunk(2, dim=-1)
81
+ hidden_states = self.activation(chunked_hidden_states[0]) * chunked_hidden_states[1]
82
+ hidden_states = self.output_linear(hidden_states)
83
+ return hidden_states
84
+
85
+
86
+ class GraniteMoeSharedDecoderLayer(GraniteMoeDecoderLayer):
87
+ def __init__(self, config: GraniteMoeSharedConfig, layer_idx: int):
88
+ super().__init__(config, layer_idx)
89
+ self.shared_mlp = None if config.shared_intermediate_size == 0 else GraniteMoeSharedMLP(config)
90
+
91
+ def forward(
92
+ self,
93
+ hidden_states: torch.Tensor,
94
+ attention_mask: torch.Tensor | None = None,
95
+ position_ids: torch.LongTensor | None = None,
96
+ past_key_values: Cache | None = None,
97
+ output_attentions: bool | None = False,
98
+ use_cache: bool | None = False,
99
+ position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
100
+ **kwargs: Unpack[GraniteFlashAttentionKwargs],
101
+ ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
102
+ residual = hidden_states
103
+ hidden_states = self.input_layernorm(hidden_states)
104
+
105
+ # Self Attention
106
+ hidden_states, _ = self.self_attn(
107
+ hidden_states=hidden_states,
108
+ attention_mask=attention_mask,
109
+ position_ids=position_ids,
110
+ past_key_values=past_key_values,
111
+ output_attentions=output_attentions,
112
+ use_cache=use_cache,
113
+ position_embeddings=position_embeddings,
114
+ **kwargs,
115
+ )
116
+
117
+ hidden_states = residual + hidden_states * self.residual_multiplier
118
+
119
+ residual = hidden_states
120
+ hidden_states = self.post_attention_layernorm(hidden_states)
121
+ moe_hidden_states = self.block_sparse_moe(hidden_states)
122
+
123
+ if self.shared_mlp is None:
124
+ hidden_states = moe_hidden_states
125
+ else:
126
+ hidden_states = moe_hidden_states + self.shared_mlp(hidden_states)
127
+ hidden_states = residual + hidden_states * self.residual_multiplier
128
+ return hidden_states
129
+
130
+
131
+ class GraniteMoeSharedPreTrainedModel(GraniteMoePreTrainedModel):
132
+ config: GraniteMoeSharedConfig
133
+ _no_split_modules = ["GraniteMoeSharedDecoderLayer"]
134
+
135
+
136
+ class GraniteMoeSharedModel(GraniteMoeModel):
137
+ def __init__(self, config: GraniteMoeSharedConfig):
138
+ super().__init__(config)
139
+ self.layers = nn.ModuleList(
140
+ [GraniteMoeSharedDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
141
+ )
142
+
143
+
144
+ class GraniteMoeSharedForCausalLM(GraniteMoeForCausalLM):
145
+ _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
146
+
147
+ def __init__(self, config: GraniteMoeSharedConfig):
148
+ super().__init__(config)
149
+ self.model = GraniteMoeSharedModel(config)
150
+ # Initialize weights and apply final processing
151
+ self.post_init()
152
+
153
+
154
+ __all__ = ["GraniteMoeSharedForCausalLM", "GraniteMoeSharedModel", "GraniteMoeSharedPreTrainedModel"]
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/instructblip/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_instructblip import *
22
+ from .modeling_instructblip import *
23
+ from .processing_instructblip import *
24
+ else:
25
+ import sys
26
+
27
+ _file = globals()["__file__"]
28
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/instructblip/configuration_instructblip.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """InstructBLIP model configuration"""
15
+
16
+ from huggingface_hub.dataclasses import strict
17
+
18
+ from ...configuration_utils import PreTrainedConfig
19
+ from ...models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
20
+ from ...utils import auto_docstring, logging
21
+ from ..auto import CONFIG_MAPPING, AutoConfig
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+
27
+ @auto_docstring(checkpoint="Salesforce/instructblip-flan-t5-xl")
28
+ @strict
29
+ class InstructBlipVisionConfig(PreTrainedConfig):
30
+ r"""
31
+ Example:
32
+
33
+ ```python
34
+ >>> from transformers import InstructBlipVisionConfig, InstructBlipVisionModel
35
+
36
+ >>> # Initializing a InstructBlipVisionConfig with Salesforce/instructblip-flan-t5-xl style configuration
37
+ >>> configuration = InstructBlipVisionConfig()
38
+
39
+ >>> # Initializing a InstructBlipVisionModel (with random weights) from the Salesforce/instructblip-flan-t5-xl style configuration
40
+ >>> model = InstructBlipVisionModel(configuration)
41
+
42
+ >>> # Accessing the model configuration
43
+ >>> configuration = model.config
44
+ ```"""
45
+
46
+ model_type = "instructblip_vision_model"
47
+ base_config_key = "vision_config"
48
+
49
+ hidden_size: int = 1408
50
+ intermediate_size: int = 6144
51
+ num_hidden_layers: int = 39
52
+ num_attention_heads: int = 16
53
+ image_size: int | list[int] | tuple[int, int] = 224
54
+ patch_size: int | list[int] | tuple[int, int] = 14
55
+ hidden_act: str = "gelu"
56
+ layer_norm_eps: float = 1e-6
57
+ attention_dropout: float | int = 0.0
58
+ initializer_range: float = 1e-10
59
+ qkv_bias: bool = True
60
+
61
+
62
+ @auto_docstring(checkpoint="Salesforce/instructblip-flan-t5-xl")
63
+ @strict
64
+ class InstructBlipQFormerConfig(PreTrainedConfig):
65
+ r"""
66
+ cross_attention_frequency (`int`, *optional*, defaults to 2):
67
+ The frequency of adding cross-attention to the Transformer layers.
68
+ encoder_hidden_size (`int`, *optional*, defaults to 1408):
69
+ The hidden size of the hidden states for cross-attention.
70
+
71
+ Examples:
72
+
73
+ ```python
74
+ >>> from transformers import InstructBlipQFormerConfig, InstructBlipQFormerModel
75
+
76
+ >>> # Initializing a InstructBLIP Salesforce/instructblip-flan-t5-xl style configuration
77
+ >>> configuration = InstructBlipQFormerConfig()
78
+
79
+ >>> # Initializing a model (with random weights) from the Salesforce/instructblip-flan-t5-xl style configuration
80
+ >>> model = InstructBlipQFormerModel(configuration)
81
+ >>> # Accessing the model configuration
82
+ >>> configuration = model.config
83
+ ```"""
84
+
85
+ model_type = "instructblip_qformer"
86
+ base_config_key = "qformer_config"
87
+
88
+ vocab_size: int = 30522
89
+ hidden_size: int = 768
90
+ num_hidden_layers: int = 12
91
+ num_attention_heads: int = 12
92
+ intermediate_size: int = 3072
93
+ hidden_act: str = "gelu"
94
+ hidden_dropout_prob: float | int = 0.1
95
+ attention_probs_dropout_prob: float | int = 0.1
96
+ max_position_embeddings: int = 512
97
+ initializer_range: float = 0.02
98
+ layer_norm_eps: float = 1e-12
99
+ pad_token_id: int | None = 0
100
+ cross_attention_frequency: int = 2
101
+ encoder_hidden_size: int = 1408
102
+
103
+
104
+ @auto_docstring(checkpoint="Salesforce/instructblip-flan-t5-xl")
105
+ @strict
106
+ class InstructBlipConfig(PreTrainedConfig):
107
+ r"""
108
+ qformer_config (`dict`, *optional*):
109
+ Dictionary of configuration options used to initialize [`InstructBlipQFormerConfig`].
110
+ num_query_tokens (`int`, *optional*, defaults to 32):
111
+ The number of query tokens passed through the Transformer.
112
+
113
+ Example:
114
+
115
+ ```python
116
+ >>> from transformers import (
117
+ ... InstructBlipVisionConfig,
118
+ ... InstructBlipQFormerConfig,
119
+ ... OPTConfig,
120
+ ... InstructBlipConfig,
121
+ ... InstructBlipForConditionalGeneration,
122
+ ... )
123
+
124
+ >>> # Initializing a InstructBlipConfig with Salesforce/instructblip-flan-t5-xl style configuration
125
+ >>> configuration = InstructBlipConfig()
126
+
127
+ >>> # Initializing a InstructBlipForConditionalGeneration (with random weights) from the Salesforce/instructblip-flan-t5-xl style configuration
128
+ >>> model = InstructBlipForConditionalGeneration(configuration)
129
+
130
+ >>> # Accessing the model configuration
131
+ >>> configuration = model.config
132
+
133
+ >>> # We can also initialize a InstructBlipConfig from a InstructBlipVisionConfig, InstructBlipQFormerConfig and any PreTrainedConfig
134
+
135
+ >>> # Initializing InstructBLIP vision, InstructBLIP Q-Former and language model configurations
136
+ >>> vision_config = InstructBlipVisionConfig()
137
+ >>> qformer_config = InstructBlipQFormerConfig()
138
+ >>> text_config = OPTConfig()
139
+
140
+ >>> config = InstructBlipConfig(vision_config=vision_config, qformer_config=qformer_config, text_config=text_config)
141
+ ```"""
142
+
143
+ model_type = "instructblip"
144
+ attribute_map = {
145
+ "image_token_id": "image_token_index",
146
+ }
147
+ sub_configs = {
148
+ "text_config": AutoConfig,
149
+ "qformer_config": InstructBlipQFormerConfig,
150
+ "vision_config": InstructBlipVisionConfig,
151
+ }
152
+
153
+ vision_config: dict | PreTrainedConfig | None = None
154
+ qformer_config: dict | PreTrainedConfig | None = None
155
+ text_config: dict | PreTrainedConfig | None = None
156
+ num_query_tokens: int = 32
157
+ image_token_index: int | None = None
158
+ initializer_factor: float = 1.0
159
+ initializer_range: float = 0.02
160
+
161
+ def __post_init__(self, **kwargs):
162
+ if self.text_config is None:
163
+ self.text_config = CONFIG_MAPPING["opt"]()
164
+ logger.info("text_config is None. Initializing the text config with default values (`OPTConfig`).")
165
+ elif isinstance(self.text_config, dict):
166
+ text_model_type = self.text_config.get("model_type", "opt")
167
+ self.text_config = CONFIG_MAPPING[text_model_type](**self.text_config)
168
+
169
+ if self.qformer_config is None:
170
+ self.qformer_config = InstructBlipQFormerConfig()
171
+ logger.info("qformer_config is None. Initializing the InstructBlipQFormerConfig with default values.")
172
+ elif isinstance(self.qformer_config, dict):
173
+ self.qformer_config = InstructBlipQFormerConfig(**self.qformer_config)
174
+
175
+ if self.vision_config is None:
176
+ self.vision_config = InstructBlipVisionConfig()
177
+ logger.info("`vision_config` is `None`. initializing the `InstructBlipVisionConfig` with default values.")
178
+ elif isinstance(self.vision_config, dict):
179
+ self.vision_config = InstructBlipVisionConfig(**self.vision_config)
180
+
181
+ self.qformer_config.encoder_hidden_size = self.vision_config.hidden_size
182
+ self.use_decoder_only_language_model = self.text_config.model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
183
+ super().__post_init__(**kwargs)
184
+
185
+
186
+ __all__ = ["InstructBlipConfig", "InstructBlipQFormerConfig", "InstructBlipVisionConfig"]
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/instructblip/modeling_instructblip.py ADDED
@@ -0,0 +1,1405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The Salesforce Authors and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """PyTorch InstructBLIP model."""
15
+
16
+ import math
17
+ from collections.abc import Callable
18
+ from dataclasses import dataclass
19
+ from typing import Any
20
+
21
+ import torch
22
+ from torch import nn
23
+
24
+ from ... import initialization as init
25
+ from ...activations import ACT2FN
26
+ from ...generation import GenerationMixin
27
+ from ...masking_utils import create_bidirectional_mask
28
+ from ...modeling_flash_attention_utils import FlashAttentionKwargs
29
+ from ...modeling_layers import GradientCheckpointingLayer
30
+ from ...modeling_outputs import (
31
+ BaseModelOutput,
32
+ BaseModelOutputWithPastAndCrossAttentions,
33
+ BaseModelOutputWithPooling,
34
+ BaseModelOutputWithPoolingAndCrossAttentions,
35
+ CausalLMOutputWithPast,
36
+ Seq2SeqLMOutput,
37
+ )
38
+ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
39
+ from ...processing_utils import Unpack
40
+ from ...pytorch_utils import apply_chunking_to_forward
41
+ from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_int
42
+ from ...utils.generic import merge_with_config_defaults
43
+ from ...utils.output_capturing import OutputRecorder, capture_outputs
44
+ from ..auto import AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM
45
+ from .configuration_instructblip import InstructBlipConfig, InstructBlipQFormerConfig, InstructBlipVisionConfig
46
+
47
+
48
+ logger = logging.get_logger(__name__)
49
+
50
+
51
+ @auto_docstring
52
+ @dataclass
53
+ class BaseModelOutputWithVisionQformerOutputs(BaseModelOutputWithPooling):
54
+ r"""
55
+ vision_outputs (`BaseModelOutputWithPooling`):
56
+ Outputs of the vision encoder.
57
+ qformer_outputs (`BaseModelOutputWithPoolingAndCrossAttentions`):
58
+ Outputs of the Q-Former (Querying Transformer).
59
+ """
60
+
61
+ vision_outputs: BaseModelOutputWithPooling | None = None
62
+ qformer_outputs: BaseModelOutputWithPoolingAndCrossAttentions | None = None
63
+
64
+
65
+ @auto_docstring(
66
+ custom_intro="""
67
+ Class defining the outputs of [`InstructBlipForConditionalGeneration`].
68
+ """
69
+ )
70
+ @dataclass
71
+ # Copied from transformers.models.blip_2.modeling_blip_2.Blip2ForConditionalGenerationModelOutput with Blip2->InstructBlip
72
+ class InstructBlipForConditionalGenerationModelOutput(ModelOutput):
73
+ r"""
74
+ loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
75
+ Language modeling loss from the language model.
76
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
77
+ Prediction scores of the language modeling head of the language model.
78
+ vision_outputs (`BaseModelOutputWithPooling`):
79
+ Outputs of the vision encoder.
80
+ qformer_outputs (`BaseModelOutputWithPoolingAndCrossAttentions`):
81
+ Outputs of the Q-Former (Querying Transformer).
82
+ language_model_outputs (`CausalLMOutputWithPast` or `Seq2SeqLMOutput`):
83
+ Outputs of the language model.
84
+ """
85
+
86
+ loss: tuple[torch.FloatTensor] | None = None
87
+ logits: tuple[torch.FloatTensor] | None = None
88
+ vision_outputs: BaseModelOutputWithPooling | None = None
89
+ qformer_outputs: BaseModelOutputWithPoolingAndCrossAttentions | None = None
90
+ language_model_outputs: CausalLMOutputWithPast | Seq2SeqLMOutput | None = None
91
+
92
+ def to_tuple(self) -> tuple[Any]:
93
+ return tuple(
94
+ self[k]
95
+ if k not in ["vision_outputs", "qformer_outputs", "language_model_outputs"]
96
+ else getattr(self, k).to_tuple()
97
+ for k in self.keys()
98
+ )
99
+
100
+
101
+ # Copied from transformers.models.blip.modeling_blip.BlipVisionEmbeddings with Blip->InstructBlip
102
+ class InstructBlipVisionEmbeddings(nn.Module):
103
+ def __init__(self, config: InstructBlipVisionConfig):
104
+ super().__init__()
105
+ self.config = config
106
+ self.embed_dim = config.hidden_size
107
+ self.image_size = config.image_size
108
+ self.patch_size = config.patch_size
109
+
110
+ self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim))
111
+
112
+ self.patch_embedding = nn.Conv2d(
113
+ in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size
114
+ )
115
+
116
+ self.num_patches = (self.image_size // self.patch_size) ** 2
117
+ self.num_positions = self.num_patches + 1
118
+
119
+ self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
120
+
121
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
122
+ """
123
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
124
+ images. This method is also adapted to support torch.jit tracing.
125
+
126
+ Adapted from:
127
+ - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
128
+ - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
129
+ """
130
+
131
+ num_patches = embeddings.shape[1] - 1
132
+ num_positions = self.position_embedding.shape[1] - 1
133
+
134
+ # always interpolate when tracing to ensure the exported model works for dynamic input shapes
135
+ if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
136
+ return self.position_embedding
137
+
138
+ class_pos_embed = self.position_embedding[:, :1]
139
+ patch_pos_embed = self.position_embedding[:, 1:]
140
+
141
+ dim = embeddings.shape[-1]
142
+
143
+ new_height = height // self.patch_size
144
+ new_width = width // self.patch_size
145
+
146
+ sqrt_num_positions = torch_int(num_positions**0.5)
147
+ patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
148
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
149
+
150
+ patch_pos_embed = nn.functional.interpolate(
151
+ patch_pos_embed,
152
+ size=(new_height, new_width),
153
+ mode="bicubic",
154
+ align_corners=False,
155
+ )
156
+
157
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
158
+
159
+ return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
160
+
161
+ def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
162
+ batch_size, _, height, width = pixel_values.shape
163
+ target_dtype = self.patch_embedding.weight.dtype
164
+ patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
165
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
166
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
167
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
168
+ if interpolate_pos_encoding:
169
+ position_embedding = self.interpolate_pos_encoding(embeddings, height, width)
170
+ else:
171
+ position_embedding = self.position_embedding
172
+ embeddings = embeddings + position_embedding[:, : embeddings.size(1), :].to(target_dtype)
173
+ return embeddings
174
+
175
+
176
+ # Adapted from transformers.models.siglip.modeling_siglip.eager_attention_forward -> InstructBLIP doesn't cast attn weights to fp32
177
+ def eager_attention_forward(
178
+ module: nn.Module,
179
+ query: torch.Tensor,
180
+ key: torch.Tensor,
181
+ value: torch.Tensor,
182
+ attention_mask: torch.Tensor | None,
183
+ scaling: float,
184
+ dropout: float = 0.0,
185
+ **kwargs,
186
+ ):
187
+ attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
188
+ if attention_mask is not None:
189
+ attn_weights = attn_weights + attention_mask
190
+
191
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
192
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
193
+
194
+ attn_output = torch.matmul(attn_weights, value)
195
+ attn_output = attn_output.transpose(1, 2).contiguous()
196
+
197
+ return attn_output, attn_weights
198
+
199
+
200
+ # Copied from transformers.models.blip_2.modeling_blip_2.Blip2Attention with Blip2->InstructBlip
201
+ class InstructBlipAttention(nn.Module):
202
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
203
+
204
+ def __init__(self, config):
205
+ super().__init__()
206
+ self.config = config
207
+ self.embed_dim = config.hidden_size
208
+ self.num_heads = config.num_attention_heads
209
+ self.head_dim = self.embed_dim // self.num_heads
210
+ if self.head_dim * self.num_heads != self.embed_dim:
211
+ raise ValueError(
212
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
213
+ f" {self.num_heads})."
214
+ )
215
+ self.scale = self.head_dim**-0.5
216
+ self.is_causal = False
217
+ self.attention_dropout = config.attention_dropout
218
+
219
+ # small tweak here compared to CLIP, no bias here
220
+ self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=False)
221
+
222
+ if config.qkv_bias:
223
+ q_bias = nn.Parameter(torch.zeros(self.embed_dim))
224
+ v_bias = nn.Parameter(torch.zeros(self.embed_dim))
225
+ else:
226
+ q_bias = None
227
+ v_bias = None
228
+
229
+ if q_bias is not None:
230
+ qkv_bias = torch.cat((q_bias, torch.zeros_like(v_bias, requires_grad=False), v_bias))
231
+ self.qkv.bias = nn.Parameter(qkv_bias)
232
+
233
+ self.projection = nn.Linear(self.embed_dim, self.embed_dim)
234
+
235
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
236
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
237
+
238
+ def forward(
239
+ self,
240
+ hidden_states: torch.Tensor,
241
+ **kwargs,
242
+ ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
243
+ """Input shape: Batch x Time x Channel"""
244
+
245
+ bsz, tgt_len, embed_dim = hidden_states.size()
246
+
247
+ mixed_qkv = self.qkv(hidden_states)
248
+
249
+ mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, embed_dim // self.num_heads).permute(
250
+ 2, 0, 3, 1, 4
251
+ )
252
+ query_states, key_states, value_states = mixed_qkv[0], mixed_qkv[1], mixed_qkv[2]
253
+
254
+ attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
255
+ self.config._attn_implementation, eager_attention_forward
256
+ )
257
+
258
+ attn_output, attn_weights = attention_interface(
259
+ self,
260
+ query_states,
261
+ key_states,
262
+ value_states,
263
+ attention_mask=None,
264
+ dropout=0.0 if not self.training else self.attention_dropout,
265
+ scaling=self.scale,
266
+ **kwargs,
267
+ )
268
+
269
+ attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
270
+ attn_output = self.projection(attn_output)
271
+
272
+ return attn_output, attn_weights
273
+
274
+
275
+ # Copied from transformers.models.blip.modeling_blip.BlipMLP
276
+ class InstructBlipMLP(nn.Module):
277
+ def __init__(self, config):
278
+ super().__init__()
279
+ self.config = config
280
+ self.activation_fn = ACT2FN[config.hidden_act]
281
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
282
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
283
+
284
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
285
+ hidden_states = self.fc1(hidden_states)
286
+ hidden_states = self.activation_fn(hidden_states)
287
+ hidden_states = self.fc2(hidden_states)
288
+ return hidden_states
289
+
290
+
291
+ # Copied from transformers.models.blip.modeling_blip.BlipEncoderLayer with Blip->InstructBlip
292
+ class InstructBlipEncoderLayer(GradientCheckpointingLayer):
293
+ def __init__(self, config: InstructBlipConfig):
294
+ super().__init__()
295
+ self.embed_dim = config.hidden_size
296
+ self.self_attn = InstructBlipAttention(config)
297
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
298
+ self.mlp = InstructBlipMLP(config)
299
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
300
+
301
+ @auto_docstring
302
+ def forward(
303
+ self,
304
+ hidden_states: torch.Tensor,
305
+ **kwargs: Unpack[TransformersKwargs],
306
+ ) -> torch.FloatTensor:
307
+ residual = hidden_states
308
+
309
+ hidden_states = self.layer_norm1(hidden_states)
310
+ hidden_states, _ = self.self_attn(
311
+ hidden_states=hidden_states,
312
+ **kwargs,
313
+ )
314
+ hidden_states = hidden_states + residual
315
+ residual = hidden_states
316
+ hidden_states = self.layer_norm2(hidden_states)
317
+ hidden_states = self.mlp(hidden_states)
318
+
319
+ hidden_states = hidden_states + residual
320
+
321
+ return hidden_states
322
+
323
+
324
+ @auto_docstring
325
+ class InstructBlipPreTrainedModel(PreTrainedModel):
326
+ config: InstructBlipConfig
327
+ base_model_prefix = "blip"
328
+ input_modalities = ("image", "text")
329
+ supports_gradient_checkpointing = True
330
+ _supports_attention_backend = True
331
+ _supports_flash_attn = True
332
+ _supports_sdpa = True
333
+ _supports_flex_attn = True
334
+
335
+ _can_compile_fullgraph = True
336
+
337
+ _no_split_modules = [
338
+ "InstructBlipQFormerEmbeddings",
339
+ "InstructBlipAttention",
340
+ "InstructBlipQFormerMultiHeadAttention",
341
+ "InstructBlipQFormerSelfOutput",
342
+ ]
343
+
344
+ @torch.no_grad()
345
+ def _init_weights(self, module):
346
+ """Initialize the weights"""
347
+ super()._init_weights(module)
348
+ factor = self.config.initializer_range
349
+ if isinstance(module, InstructBlipVisionEmbeddings):
350
+ init.trunc_normal_(module.position_embedding, mean=0.0, std=factor)
351
+ init.trunc_normal_(module.class_embedding, mean=0.0, std=factor)
352
+ elif isinstance(module, (InstructBlipForConditionalGeneration, InstructBlipModel)):
353
+ init.zeros_(module.query_tokens)
354
+ elif isinstance(module, InstructBlipQFormerEmbeddings):
355
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
356
+
357
+
358
+ # Copied from transformers.models.blip.modeling_blip.BlipEncoder with Blip->InstructBlip
359
+ class InstructBlipEncoder(nn.Module):
360
+ """
361
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
362
+ [`InstructBlipEncoderLayer`].
363
+
364
+ Args:
365
+ config (`InstructBlipConfig`):
366
+ The corresponding vision configuration for the `InstructBlipEncoder`.
367
+ """
368
+
369
+ def __init__(self, config: InstructBlipConfig):
370
+ super().__init__()
371
+ self.config = config
372
+ self.layers = nn.ModuleList([InstructBlipEncoderLayer(config) for _ in range(config.num_hidden_layers)])
373
+ self.gradient_checkpointing = False
374
+
375
+ @auto_docstring
376
+ def forward(
377
+ self,
378
+ inputs_embeds,
379
+ **kwargs: Unpack[TransformersKwargs],
380
+ ) -> tuple | BaseModelOutput:
381
+ hidden_states = inputs_embeds
382
+ for encoder_layer in self.layers:
383
+ hidden_states = encoder_layer(
384
+ hidden_states,
385
+ **kwargs,
386
+ )
387
+
388
+ return BaseModelOutput(last_hidden_state=hidden_states)
389
+
390
+
391
+ class InstructBlipVisionModel(InstructBlipPreTrainedModel):
392
+ main_input_name = "pixel_values"
393
+ input_modalities = ("image",)
394
+ config: InstructBlipVisionConfig
395
+ _can_record_outputs = {
396
+ "hidden_states": InstructBlipEncoderLayer,
397
+ "attentions": InstructBlipAttention,
398
+ }
399
+
400
+ def __init__(self, config: InstructBlipVisionConfig):
401
+ super().__init__(config)
402
+ self.config = config
403
+ embed_dim = config.hidden_size
404
+
405
+ self.embeddings = InstructBlipVisionEmbeddings(config)
406
+ self.encoder = InstructBlipEncoder(config)
407
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
408
+
409
+ self.post_init()
410
+
411
+ @merge_with_config_defaults
412
+ @capture_outputs(tie_last_hidden_states=False)
413
+ @auto_docstring
414
+ def forward(
415
+ self,
416
+ pixel_values: torch.FloatTensor | None = None,
417
+ interpolate_pos_encoding: bool = False,
418
+ **kwargs: Unpack[TransformersKwargs],
419
+ ) -> tuple | BaseModelOutputWithPooling:
420
+ if pixel_values is None:
421
+ raise ValueError("You have to specify pixel_values")
422
+
423
+ hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
424
+
425
+ encoder_outputs: BaseModelOutput = self.encoder(
426
+ inputs_embeds=hidden_states,
427
+ **kwargs,
428
+ )
429
+
430
+ last_hidden_state = encoder_outputs.last_hidden_state
431
+ last_hidden_state = self.post_layernorm(last_hidden_state)
432
+
433
+ pooled_output = last_hidden_state[:, 0, :]
434
+ pooled_output = self.post_layernorm(pooled_output)
435
+
436
+ return BaseModelOutputWithPooling(
437
+ last_hidden_state=last_hidden_state,
438
+ pooler_output=pooled_output,
439
+ )
440
+
441
+ def get_input_embeddings(self):
442
+ return self.embeddings
443
+
444
+
445
+ class InstructBlipQFormerMultiHeadAttention(nn.Module):
446
+ def __init__(self, config, is_cross_attention=False):
447
+ super().__init__()
448
+ self.config = config
449
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
450
+ raise ValueError(
451
+ "The hidden size (%d) is not a multiple of the number of attention heads (%d)"
452
+ % (config.hidden_size, config.num_attention_heads)
453
+ )
454
+
455
+ self.num_attention_heads = config.num_attention_heads
456
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
457
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
458
+
459
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
460
+ if is_cross_attention:
461
+ self.key = nn.Linear(config.encoder_hidden_size, self.all_head_size)
462
+ self.value = nn.Linear(config.encoder_hidden_size, self.all_head_size)
463
+ else:
464
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
465
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
466
+
467
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
468
+ self.save_attention = False
469
+
470
+ def save_attn_gradients(self, attn_gradients):
471
+ self.attn_gradients = attn_gradients
472
+
473
+ def get_attn_gradients(self):
474
+ return self.attn_gradients
475
+
476
+ def save_attention_map(self, attention_map):
477
+ self.attention_map = attention_map
478
+
479
+ def get_attention_map(self):
480
+ return self.attention_map
481
+
482
+ def transpose_for_scores(self, x):
483
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
484
+ x = x.view(*new_x_shape)
485
+ return x.permute(0, 2, 1, 3)
486
+
487
+ def forward(
488
+ self,
489
+ hidden_states,
490
+ attention_mask=None,
491
+ encoder_hidden_states=None,
492
+ encoder_attention_mask=None,
493
+ **kwargs: Unpack[TransformersKwargs],
494
+ ):
495
+ # If this is instantiated as a cross-attention module, the keys
496
+ # and values come from an encoder; the attention mask needs to be
497
+ # such that the encoder's padding tokens are not attended to.
498
+ is_cross_attention = encoder_hidden_states is not None
499
+
500
+ if is_cross_attention:
501
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
502
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
503
+ attention_mask = encoder_attention_mask
504
+ else:
505
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
506
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
507
+
508
+ mixed_query_layer = self.query(hidden_states)
509
+
510
+ query_layer = self.transpose_for_scores(mixed_query_layer)
511
+
512
+ # Take the dot product between "query" and "key" to get the raw attention scores.
513
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
514
+
515
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
516
+ attention_scores_dtype = attention_scores.dtype
517
+
518
+ if attention_mask is not None:
519
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
520
+ attention_scores = attention_scores + attention_mask
521
+
522
+ # Normalize the attention scores to probabilities.
523
+ attention_probs = nn.Softmax(dim=-1)(attention_scores).to(attention_scores_dtype)
524
+
525
+ if is_cross_attention and self.save_attention:
526
+ self.save_attention_map(attention_probs)
527
+ attention_probs.register_hook(self.save_attn_gradients)
528
+
529
+ # This is actually dropping out entire tokens to attend to, which might
530
+ # seem a bit unusual, but is taken from the original Transformer paper.
531
+ attention_probs_dropped = self.dropout(attention_probs)
532
+
533
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
534
+
535
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
536
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
537
+ context_layer = context_layer.view(*new_context_layer_shape)
538
+
539
+ return context_layer, attention_probs
540
+
541
+
542
+ # Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->InstructBlipQFormer
543
+ class InstructBlipQFormerSelfOutput(nn.Module):
544
+ def __init__(self, config):
545
+ super().__init__()
546
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
547
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
548
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
549
+
550
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
551
+ hidden_states = self.dense(hidden_states)
552
+ hidden_states = self.dropout(hidden_states)
553
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
554
+ return hidden_states
555
+
556
+
557
+ # Copied from transformers.models.blip_2.modeling_blip_2.Blip2QFormerAttention with Blip2->InstructBlip
558
+ class InstructBlipQFormerAttention(nn.Module):
559
+ def __init__(self, config, is_cross_attention=False):
560
+ super().__init__()
561
+ self.attention = InstructBlipQFormerMultiHeadAttention(config, is_cross_attention)
562
+ self.output = InstructBlipQFormerSelfOutput(config)
563
+
564
+ def forward(
565
+ self,
566
+ hidden_states: torch.Tensor,
567
+ attention_mask: torch.FloatTensor | None = None,
568
+ encoder_hidden_states: torch.FloatTensor | None = None,
569
+ encoder_attention_mask: torch.FloatTensor | None = None,
570
+ **kwargs: Unpack[TransformersKwargs],
571
+ ) -> torch.Tensor:
572
+ attn_output, _ = self.attention(
573
+ hidden_states=hidden_states,
574
+ attention_mask=attention_mask,
575
+ encoder_hidden_states=encoder_hidden_states,
576
+ encoder_attention_mask=encoder_attention_mask,
577
+ **kwargs,
578
+ )
579
+ attention_output = self.output(attn_output, hidden_states)
580
+ return attention_output
581
+
582
+
583
+ # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->InstructBlipQFormer
584
+ class InstructBlipQFormerIntermediate(nn.Module):
585
+ def __init__(self, config):
586
+ super().__init__()
587
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
588
+ if isinstance(config.hidden_act, str):
589
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
590
+ else:
591
+ self.intermediate_act_fn = config.hidden_act
592
+
593
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
594
+ hidden_states = self.dense(hidden_states)
595
+ hidden_states = self.intermediate_act_fn(hidden_states)
596
+ return hidden_states
597
+
598
+
599
+ # Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->InstructBlipQFormer
600
+ class InstructBlipQFormerOutput(nn.Module):
601
+ def __init__(self, config):
602
+ super().__init__()
603
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
604
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
605
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
606
+
607
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
608
+ hidden_states = self.dense(hidden_states)
609
+ hidden_states = self.dropout(hidden_states)
610
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
611
+ return hidden_states
612
+
613
+
614
+ class InstructBlipQFormerLayer(GradientCheckpointingLayer):
615
+ def __init__(self, config, layer_idx):
616
+ super().__init__()
617
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
618
+ self.seq_len_dim = 1
619
+ self.attention = InstructBlipQFormerAttention(config)
620
+
621
+ self.layer_idx = layer_idx
622
+
623
+ if layer_idx % config.cross_attention_frequency == 0:
624
+ self.crossattention = InstructBlipQFormerAttention(config, is_cross_attention=True)
625
+ self.has_cross_attention = True
626
+ else:
627
+ self.has_cross_attention = False
628
+
629
+ self.intermediate = InstructBlipQFormerIntermediate(config)
630
+ self.output = InstructBlipQFormerOutput(config)
631
+
632
+ self.intermediate_query = InstructBlipQFormerIntermediate(config)
633
+ self.output_query = InstructBlipQFormerOutput(config)
634
+
635
+ def forward(
636
+ self,
637
+ hidden_states,
638
+ attention_mask=None,
639
+ encoder_hidden_states=None,
640
+ encoder_attention_mask=None,
641
+ query_length=0,
642
+ **kwargs: Unpack[TransformersKwargs],
643
+ ):
644
+ attention_output = self.attention(
645
+ hidden_states,
646
+ attention_mask=attention_mask,
647
+ **kwargs,
648
+ )
649
+
650
+ if query_length > 0:
651
+ query_attention_output = attention_output[:, :query_length, :]
652
+
653
+ if self.has_cross_attention:
654
+ if encoder_hidden_states is None:
655
+ raise ValueError("encoder_hidden_states must be given for cross-attention layers")
656
+ query_attention_output = self.crossattention(
657
+ query_attention_output,
658
+ attention_mask=attention_mask,
659
+ encoder_hidden_states=encoder_hidden_states,
660
+ encoder_attention_mask=encoder_attention_mask,
661
+ **kwargs,
662
+ )
663
+
664
+ layer_output = apply_chunking_to_forward(
665
+ self.feed_forward_chunk_query,
666
+ self.chunk_size_feed_forward,
667
+ self.seq_len_dim,
668
+ query_attention_output,
669
+ )
670
+
671
+ if attention_output.shape[1] > query_length:
672
+ layer_output_text = apply_chunking_to_forward(
673
+ self.feed_forward_chunk,
674
+ self.chunk_size_feed_forward,
675
+ self.seq_len_dim,
676
+ attention_output[:, query_length:, :],
677
+ ).to(layer_output.device)
678
+ layer_output = torch.cat([layer_output, layer_output_text], dim=1)
679
+ else:
680
+ layer_output = apply_chunking_to_forward(
681
+ self.feed_forward_chunk,
682
+ self.chunk_size_feed_forward,
683
+ self.seq_len_dim,
684
+ attention_output,
685
+ )
686
+ return layer_output
687
+
688
+ def feed_forward_chunk(self, attention_output):
689
+ intermediate_output = self.intermediate(attention_output)
690
+ layer_output = self.output(intermediate_output, attention_output)
691
+ return layer_output
692
+
693
+ def feed_forward_chunk_query(self, attention_output):
694
+ intermediate_output = self.intermediate_query(attention_output)
695
+ layer_output = self.output_query(intermediate_output, attention_output)
696
+ return layer_output
697
+
698
+
699
+ # Copied from transformers.models.blip_2.modeling_blip_2.Blip2QFormerEncoder with Blip2->InstructBlip
700
+ class InstructBlipQFormerEncoder(nn.Module):
701
+ def __init__(self, config):
702
+ super().__init__()
703
+ self.config = config
704
+ self.layer = nn.ModuleList(
705
+ [InstructBlipQFormerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
706
+ )
707
+ self.gradient_checkpointing = False
708
+
709
+ @can_return_tuple
710
+ def forward(
711
+ self,
712
+ hidden_states,
713
+ attention_mask=None,
714
+ encoder_hidden_states=None,
715
+ encoder_attention_mask=None,
716
+ query_length=0,
717
+ **kwargs: Unpack[TransformersKwargs],
718
+ ):
719
+ for i in range(self.config.num_hidden_layers):
720
+ layer_module = self.layer[i]
721
+
722
+ hidden_states = layer_module(
723
+ hidden_states,
724
+ attention_mask,
725
+ encoder_hidden_states, # as a positional argument for gradient checkpointing
726
+ encoder_attention_mask=encoder_attention_mask,
727
+ query_length=query_length,
728
+ **kwargs,
729
+ )
730
+
731
+ return BaseModelOutputWithPastAndCrossAttentions(
732
+ last_hidden_state=hidden_states,
733
+ )
734
+
735
+
736
+ class InstructBlipQFormerEmbeddings(nn.Module):
737
+ """Construct the embeddings from word and position embeddings."""
738
+
739
+ def __init__(self, config):
740
+ super().__init__()
741
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
742
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
743
+
744
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
745
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
746
+
747
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
748
+ self.register_buffer(
749
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
750
+ )
751
+
752
+ self.config = config
753
+
754
+ def forward(
755
+ self,
756
+ input_ids=None,
757
+ position_ids=None,
758
+ query_embeds=None,
759
+ past_key_values_length=0,
760
+ ):
761
+ if input_ids is not None:
762
+ seq_length = input_ids.size()[1]
763
+ else:
764
+ seq_length = 0
765
+
766
+ if position_ids is None:
767
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length].clone()
768
+
769
+ if input_ids is not None:
770
+ embeddings = self.word_embeddings(input_ids)
771
+
772
+ position_embeddings = self.position_embeddings(position_ids.to(embeddings.device))
773
+ embeddings = embeddings + position_embeddings
774
+
775
+ if query_embeds is not None:
776
+ embeddings = torch.cat((query_embeds, embeddings), dim=1)
777
+ else:
778
+ embeddings = query_embeds
779
+
780
+ embeddings = embeddings.to(self.layernorm.weight.dtype)
781
+ embeddings = self.layernorm(embeddings)
782
+ embeddings = self.dropout(embeddings)
783
+ return embeddings
784
+
785
+
786
+ class InstructBlipQFormerModel(InstructBlipPreTrainedModel):
787
+ """
788
+ Querying Transformer (Q-Former), used in InstructBLIP. Slightly modified from BLIP-2 as it also takes the
789
+ instruction as input.
790
+ """
791
+
792
+ _supports_attention_backend = False # adds position on attn weights before last matmul
793
+ _supports_flash_attn = False
794
+ _supports_sdpa = False
795
+ _supports_flex_attn = False
796
+
797
+ _can_record_outputs = {
798
+ "hidden_states": InstructBlipQFormerLayer,
799
+ "attentions": [
800
+ OutputRecorder(InstructBlipQFormerMultiHeadAttention, index=1, layer_name=".attention"),
801
+ ],
802
+ "cross_attentions": [
803
+ OutputRecorder(InstructBlipQFormerMultiHeadAttention, index=1, layer_name=".crossattention"),
804
+ ],
805
+ }
806
+
807
+ def __init__(self, config: InstructBlipQFormerConfig):
808
+ super().__init__(config)
809
+ self.config = config
810
+
811
+ self.embeddings = InstructBlipQFormerEmbeddings(config)
812
+
813
+ self.encoder = InstructBlipQFormerEncoder(config)
814
+
815
+ self.post_init()
816
+
817
+ def get_input_embeddings(self):
818
+ return self.embeddings.word_embeddings
819
+
820
+ def set_input_embeddings(self, value):
821
+ self.embeddings.word_embeddings = value
822
+
823
+ @merge_with_config_defaults
824
+ @capture_outputs
825
+ @auto_docstring
826
+ def forward(
827
+ self,
828
+ input_ids: torch.LongTensor,
829
+ attention_mask: torch.FloatTensor | None = None,
830
+ position_ids: torch.LongTensor | None = None,
831
+ query_embeds: torch.Tensor | None = None,
832
+ encoder_hidden_states: torch.FloatTensor | None = None,
833
+ encoder_attention_mask: torch.FloatTensor | None = None,
834
+ **kwargs: Unpack[TransformersKwargs],
835
+ ) -> tuple[torch.FloatTensor] | BaseModelOutputWithPoolingAndCrossAttentions:
836
+ r"""
837
+ query_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
838
+ Hidden states to be used in the attention computation. If cross-attention,
839
+ will be used for the query (i.e., key and value will use the encoder_hidden_states).
840
+ """
841
+ if input_ids is None and query_embeds is None:
842
+ raise ValueError("You have to specify query_embeds when input_ids is None")
843
+
844
+ query_length = query_embeds.shape[1] if query_embeds is not None else 0
845
+
846
+ embedding_output = self.embeddings(
847
+ input_ids=input_ids,
848
+ position_ids=position_ids,
849
+ query_embeds=query_embeds,
850
+ )
851
+
852
+ attention_mask = create_bidirectional_mask(
853
+ config=self.config,
854
+ inputs_embeds=embedding_output,
855
+ attention_mask=attention_mask,
856
+ )
857
+
858
+ if encoder_attention_mask is not None:
859
+ encoder_attention_mask = create_bidirectional_mask(
860
+ config=self.config,
861
+ inputs_embeds=embedding_output,
862
+ attention_mask=encoder_attention_mask,
863
+ encoder_hidden_states=encoder_hidden_states,
864
+ )
865
+
866
+ encoder_outputs: BaseModelOutput = self.encoder(
867
+ embedding_output,
868
+ attention_mask=attention_mask,
869
+ encoder_hidden_states=encoder_hidden_states,
870
+ encoder_attention_mask=encoder_attention_mask,
871
+ query_length=query_length,
872
+ **kwargs,
873
+ )
874
+ sequence_output = encoder_outputs.last_hidden_state
875
+ pooled_output = sequence_output[:, 0, :]
876
+
877
+ return BaseModelOutputWithPoolingAndCrossAttentions(
878
+ last_hidden_state=sequence_output,
879
+ pooler_output=pooled_output,
880
+ )
881
+
882
+
883
+ @auto_docstring(
884
+ custom_intro="""
885
+ InstructBLIP base Model consisting of language model, qformer and vision encoder.
886
+ """
887
+ )
888
+ class InstructBlipModel(InstructBlipPreTrainedModel):
889
+ main_input_name = "pixel_values"
890
+ _keep_in_fp32_modules = ["query_tokens"] # TODO @ArthurZucker I don't know why this is required for FP8
891
+
892
+ def __init__(self, config: InstructBlipConfig):
893
+ super().__init__(config)
894
+
895
+ self.vision_model = InstructBlipVisionModel(config.vision_config)
896
+ self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size))
897
+ self.qformer = InstructBlipQFormerModel(config.qformer_config)
898
+
899
+ self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)
900
+ self.language_model = AutoModel.from_config(config.text_config)
901
+
902
+ # Initialize weights and apply final processing
903
+ self.post_init()
904
+
905
+ def _preprocess_accelerate(self):
906
+ r"""
907
+ Some pre-processing hacks to make the model `accelerate` compatible. Check
908
+ https://github.com/huggingface/transformers/pull/21707 for more details.
909
+ """
910
+ hf_device_map = self.hf_device_map
911
+
912
+ if len(hf_device_map) > 1 and "language_model" not in hf_device_map and torch.cuda.device_count() > 1:
913
+ # warn users about unexpected behavior when using multi-GPU + InstructBLIP + `accelerate`.
914
+ logger.warning(
915
+ "The `language_model` is not in the `hf_device_map` dictionary and you are running your script"
916
+ " in a multi-GPU environment. this may lead to unexpected behavior when using `accelerate`."
917
+ " Please pass a `device_map` that contains `language_model` to remove this warning."
918
+ " Please refer to https://github.com/huggingface/blog/blob/main/accelerate-large-models.md for"
919
+ " more details on creating a `device_map` for large models.",
920
+ )
921
+
922
+ if hasattr(self.language_model, "_hf_hook"):
923
+ self.language_model._hf_hook.io_same_device = True # For `generate` compatibility
924
+
925
+ def get_placeholder_mask(self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor):
926
+ """
927
+ Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`.
928
+ """
929
+ if input_ids is None:
930
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
931
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
932
+ )
933
+ special_image_mask = special_image_mask.all(-1)
934
+ else:
935
+ special_image_mask = input_ids == self.config.image_token_id
936
+
937
+ special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
938
+ return special_image_mask
939
+
940
+ @can_return_tuple
941
+ @auto_docstring
942
+ def forward(
943
+ self,
944
+ pixel_values: torch.FloatTensor,
945
+ qformer_input_ids: torch.FloatTensor,
946
+ qformer_attention_mask: torch.LongTensor | None = None,
947
+ input_ids: torch.FloatTensor | None = None,
948
+ attention_mask: torch.LongTensor | None = None,
949
+ decoder_input_ids: torch.LongTensor | None = None,
950
+ decoder_attention_mask: torch.LongTensor | None = None,
951
+ inputs_embeds: torch.Tensor | None = None,
952
+ interpolate_pos_encoding: bool = False,
953
+ **kwargs: Unpack[FlashAttentionKwargs],
954
+ ) -> tuple | InstructBlipForConditionalGenerationModelOutput:
955
+ r"""
956
+ qformer_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
957
+ Indices of input sequence tokens in the vocabulary of the Q-Former. Input tokens can optionally be provided
958
+ to serve as text prompt, which the Q-Former model will encode.
959
+
960
+ Indices can be obtained using [`InstructBlipProcessor`]. See [`InstructBlipProcessor.__call__`] for
961
+ details.
962
+
963
+ [What are input IDs?](../glossary#input-ids)
964
+ qformer_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
965
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
966
+
967
+ - 1 for tokens that are **not masked**,
968
+ - 0 for tokens that are **masked**.
969
+
970
+ [What are attention masks?](../glossary#attention-mask)
971
+ decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
972
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
973
+ be used by default.
974
+
975
+ Only relevant in case an encoder-decoder language model (like T5) is used.
976
+ """
977
+
978
+ # step 1: forward the images through the vision encoder,
979
+ # to get image embeddings of shape (batch_size, seq_len, hidden_size)
980
+ vision_outputs = self.vision_model(
981
+ pixel_values=pixel_values,
982
+ interpolate_pos_encoding=interpolate_pos_encoding,
983
+ **kwargs,
984
+ )
985
+ image_embeds = vision_outputs[0]
986
+
987
+ # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
988
+ image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
989
+
990
+ # difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former
991
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
992
+ query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
993
+ if qformer_attention_mask is None:
994
+ qformer_attention_mask = torch.ones_like(qformer_input_ids)
995
+ qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
996
+ query_outputs = self.qformer(
997
+ input_ids=qformer_input_ids,
998
+ attention_mask=qformer_attention_mask,
999
+ query_embeds=query_tokens,
1000
+ encoder_hidden_states=image_embeds,
1001
+ encoder_attention_mask=image_attention_mask,
1002
+ **kwargs,
1003
+ )
1004
+ query_output = query_outputs[0][:, : query_tokens.size(1), :]
1005
+
1006
+ if inputs_embeds is None:
1007
+ inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
1008
+ if attention_mask is None:
1009
+ attention_mask = torch.ones_like(input_ids)
1010
+
1011
+ # step 3: use the language model, conditioned on the query outputs and the prompt
1012
+ language_model_inputs = self.language_projection(query_output)
1013
+ language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
1014
+ special_image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds)
1015
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
1016
+
1017
+ if self.config.use_decoder_only_language_model:
1018
+ outputs = self.language_model(
1019
+ inputs_embeds=inputs_embeds,
1020
+ attention_mask=attention_mask,
1021
+ **kwargs,
1022
+ )
1023
+ else:
1024
+ outputs = self.language_model(
1025
+ inputs_embeds=inputs_embeds,
1026
+ attention_mask=attention_mask,
1027
+ decoder_input_ids=decoder_input_ids,
1028
+ decoder_attention_mask=decoder_attention_mask,
1029
+ **kwargs,
1030
+ )
1031
+
1032
+ return InstructBlipForConditionalGenerationModelOutput(
1033
+ vision_outputs=vision_outputs,
1034
+ qformer_outputs=query_outputs,
1035
+ language_model_outputs=outputs,
1036
+ )
1037
+
1038
+
1039
+ @auto_docstring(
1040
+ custom_intro="""
1041
+ InstructBLIP Model for generating text given an image and an optional text prompt. The model consists of a vision
1042
+ encoder, Querying Transformer (Q-Former) and a language model.
1043
+
1044
+ One can optionally pass `input_ids` to the model, which serve as a text prompt, to make the language model continue
1045
+ the prompt. Otherwise, the language model starts generating text from the [BOS] (beginning-of-sequence) token.
1046
+ """
1047
+ )
1048
+ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, GenerationMixin):
1049
+ config: InstructBlipConfig
1050
+ main_input_name = "pixel_values"
1051
+
1052
+ _can_compile_fullgraph = True
1053
+ _keep_in_fp32_modules = ["query_tokens"] # TODO @ArthurZucker I don't know why this is required for FP8
1054
+
1055
+ def __init__(self, config: InstructBlipConfig):
1056
+ super().__init__(config)
1057
+
1058
+ self.vision_model = InstructBlipVisionModel._from_config(config.vision_config)
1059
+
1060
+ self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size))
1061
+ self.qformer = InstructBlipQFormerModel._from_config(config.qformer_config)
1062
+
1063
+ self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)
1064
+
1065
+ if config.use_decoder_only_language_model:
1066
+ language_model = AutoModelForCausalLM.from_config(config.text_config)
1067
+ else:
1068
+ language_model = AutoModelForSeq2SeqLM.from_config(config.text_config)
1069
+
1070
+ self.language_model = language_model
1071
+
1072
+ # Initialize weights and apply final processing
1073
+ self.post_init()
1074
+
1075
+ def set_output_embeddings(self, new_embeddings):
1076
+ self.language_model.set_output_embeddings(new_embeddings)
1077
+
1078
+ def get_output_embeddings(self) -> nn.Module:
1079
+ return self.language_model.get_output_embeddings()
1080
+
1081
+ def get_encoder(self, modality=None):
1082
+ if modality is None:
1083
+ return self.language_model.get_encoder()
1084
+ else:
1085
+ return super().get_encoder(modality=modality)
1086
+
1087
+ def get_decoder(self):
1088
+ return self.language_model.get_decoder()
1089
+
1090
+ # Copied from transformers.models.instructblip.modeling_instructblip.InstructBlipModel._preprocess_accelerate
1091
+ def _preprocess_accelerate(self):
1092
+ r"""
1093
+ Some pre-processing hacks to make the model `accelerate` compatible. Check
1094
+ https://github.com/huggingface/transformers/pull/21707 for more details.
1095
+ """
1096
+ hf_device_map = self.hf_device_map
1097
+
1098
+ if len(hf_device_map) > 1 and "language_model" not in hf_device_map and torch.cuda.device_count() > 1:
1099
+ # warn users about unexpected behavior when using multi-GPU + InstructBLIP + `accelerate`.
1100
+ logger.warning(
1101
+ "The `language_model` is not in the `hf_device_map` dictionary and you are running your script"
1102
+ " in a multi-GPU environment. this may lead to unexpected behavior when using `accelerate`."
1103
+ " Please pass a `device_map` that contains `language_model` to remove this warning."
1104
+ " Please refer to https://github.com/huggingface/blog/blob/main/accelerate-large-models.md for"
1105
+ " more details on creating a `device_map` for large models.",
1106
+ )
1107
+
1108
+ if hasattr(self.language_model, "_hf_hook"):
1109
+ self.language_model._hf_hook.io_same_device = True # For `generate` compatibility
1110
+
1111
+ @can_return_tuple
1112
+ @auto_docstring
1113
+ def get_image_features(
1114
+ self,
1115
+ pixel_values: torch.FloatTensor,
1116
+ qformer_input_ids: torch.LongTensor,
1117
+ qformer_attention_mask: torch.LongTensor | None = None,
1118
+ interpolate_pos_encoding: bool | None = False,
1119
+ **kwargs: Unpack[TransformersKwargs],
1120
+ ) -> tuple | BaseModelOutputWithVisionQformerOutputs:
1121
+ r"""
1122
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
1123
+ The tensors corresponding to the input images.
1124
+ qformer_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1125
+ Indices of input sequence tokens in the vocabulary of the Q-Former. Input tokens can optionally be provided
1126
+ to serve as text prompt, which the Q-Former model will encode.
1127
+
1128
+ Indices can be obtained using [`InstructBlipProcessor`]. See [`InstructBlipProcessor.__call__`] for
1129
+ details.
1130
+
1131
+ [What are input IDs?](../glossary#input-ids)
1132
+ qformer_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1133
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1134
+
1135
+ - 1 for tokens that are **not masked**,
1136
+ - 0 for tokens that are **masked**.
1137
+
1138
+ [What are attention masks?](../glossary#attention-mask)
1139
+ """
1140
+ # step 1: forward the images through the vision encoder,
1141
+ # to get image embeddings of shape (batch_size, seq_len, hidden_size)
1142
+ vision_outputs: BaseModelOutputWithPooling = self.vision_model(
1143
+ pixel_values=pixel_values,
1144
+ interpolate_pos_encoding=interpolate_pos_encoding,
1145
+ return_dict=True,
1146
+ **kwargs,
1147
+ )
1148
+ vision_outputs = BaseModelOutputWithVisionQformerOutputs(**vision_outputs, vision_outputs=vision_outputs)
1149
+ image_embeds = vision_outputs[0]
1150
+
1151
+ # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
1152
+ image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
1153
+
1154
+ # difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former
1155
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
1156
+ query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
1157
+ if qformer_attention_mask is None:
1158
+ qformer_attention_mask = torch.ones_like(qformer_input_ids)
1159
+ qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1)
1160
+ qformer_outputs = self.qformer(
1161
+ input_ids=qformer_input_ids,
1162
+ attention_mask=qformer_attention_mask,
1163
+ query_embeds=query_tokens,
1164
+ encoder_hidden_states=image_embeds,
1165
+ encoder_attention_mask=image_attention_mask,
1166
+ return_dict=True,
1167
+ **kwargs,
1168
+ )
1169
+ vision_outputs.qformer_outputs = qformer_outputs
1170
+ query_output = qformer_outputs[0][:, : query_tokens.size(1), :]
1171
+
1172
+ # step 3: use the language model, conditioned on the query outputs and the prompt
1173
+ image_features = self.language_projection(query_output)
1174
+ vision_outputs.pooler_output = image_features
1175
+
1176
+ return vision_outputs
1177
+
1178
+ def get_placeholder_mask(self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor):
1179
+ """
1180
+ Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`.
1181
+ """
1182
+ if input_ids is None:
1183
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
1184
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
1185
+ )
1186
+ special_image_mask = special_image_mask.all(-1)
1187
+ else:
1188
+ special_image_mask = input_ids == self.config.image_token_id
1189
+
1190
+ special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
1191
+ return special_image_mask
1192
+
1193
+ @can_return_tuple
1194
+ @auto_docstring
1195
+ def forward(
1196
+ self,
1197
+ pixel_values: torch.FloatTensor,
1198
+ qformer_input_ids: torch.FloatTensor,
1199
+ qformer_attention_mask: torch.LongTensor | None = None,
1200
+ input_ids: torch.FloatTensor | None = None,
1201
+ attention_mask: torch.LongTensor | None = None,
1202
+ decoder_input_ids: torch.LongTensor | None = None,
1203
+ decoder_attention_mask: torch.LongTensor | None = None,
1204
+ inputs_embeds: torch.FloatTensor | None = None,
1205
+ labels: torch.LongTensor | None = None,
1206
+ interpolate_pos_encoding: bool = False,
1207
+ **kwargs: Unpack[TransformersKwargs],
1208
+ ) -> tuple | InstructBlipForConditionalGenerationModelOutput:
1209
+ r"""
1210
+ qformer_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1211
+ Indices of input sequence tokens in the vocabulary of the Q-Former. Input tokens can optionally be provided
1212
+ to serve as text prompt, which the Q-Former model will encode.
1213
+
1214
+ Indices can be obtained using [`InstructBlipProcessor`]. See [`InstructBlipProcessor.__call__`] for
1215
+ details.
1216
+
1217
+ [What are input IDs?](../glossary#input-ids)
1218
+ qformer_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1219
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1220
+
1221
+ - 1 for tokens that are **not masked**,
1222
+ - 0 for tokens that are **masked**.
1223
+
1224
+ [What are attention masks?](../glossary#attention-mask)
1225
+ decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
1226
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
1227
+ be used by default.
1228
+
1229
+ Only relevant in case an encoder-decoder language model (like T5) is used.
1230
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1231
+ Labels for computing the language modeling loss. Indices should be in `[-100, 0, ..., config.vocab_size -
1232
+ 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
1233
+ config.vocab_size]`
1234
+
1235
+ Examples:
1236
+
1237
+ ```python
1238
+ >>> from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration
1239
+ >>> import torch
1240
+ >>> from PIL import Image
1241
+ >>> import httpx
1242
+ >>> from io import BytesIO
1243
+
1244
+ >>> model = InstructBlipForConditionalGeneration.from_pretrained("Salesforce/instructblip-vicuna-7b")
1245
+ >>> processor = InstructBlipProcessor.from_pretrained("Salesforce/instructblip-vicuna-7b")
1246
+
1247
+ >>> device = "cuda" if torch.cuda.is_available() else "cpu"
1248
+ >>> model.to(device) # doctest: +IGNORE_RESULT
1249
+
1250
+ >>> url = "https://raw.githubusercontent.com/salesforce/LAVIS/main/docs/_static/Confusing-Pictures.jpg"
1251
+ >>> with httpx.stream("GET", url) as response:
1252
+ ... image = Image.open(BytesIO(response.read())).convert("RGB")
1253
+ >>> prompt = "What is unusual about this image?"
1254
+ >>> inputs = processor(images=image, text=prompt, return_tensors="pt").to(device)
1255
+
1256
+ >>> outputs = model.generate(
1257
+ ... **inputs,
1258
+ ... do_sample=False,
1259
+ ... num_beams=5,
1260
+ ... max_length=256,
1261
+ ... min_length=1,
1262
+ ... top_p=0.9,
1263
+ ... repetition_penalty=1.5,
1264
+ ... length_penalty=1.0,
1265
+ ... temperature=1,
1266
+ ... )
1267
+ >>> generated_text = processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()
1268
+ >>> print(generated_text)
1269
+ The unusual aspect of this image is that a man is ironing clothes on the back of a yellow SUV, which is parked in the middle of a busy city street. This is an unconventional approach to ironing clothes, as it requires the man to balance himself and his ironing equipment on top of the vehicle while navigating through traffic. Additionally, the presence of taxis and other vehicles in the scene further emphasizes the unusual nature of this situation.
1270
+ ```"""
1271
+
1272
+ image_features: BaseModelOutputWithVisionQformerOutputs = self.get_image_features(
1273
+ pixel_values,
1274
+ qformer_input_ids=qformer_input_ids,
1275
+ qformer_attention_mask=qformer_attention_mask,
1276
+ interpolate_pos_encoding=interpolate_pos_encoding,
1277
+ return_dict=True,
1278
+ )
1279
+ language_model_inputs = image_features.pooler_output
1280
+ qformer_outputs = image_features.qformer_outputs
1281
+ vision_outputs = image_features.vision_outputs
1282
+
1283
+ if inputs_embeds is None:
1284
+ inputs_embeds = self.get_input_embeddings()(input_ids)
1285
+
1286
+ if attention_mask is None:
1287
+ attention_mask = torch.ones_like(input_ids)
1288
+
1289
+ language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
1290
+ special_image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds)
1291
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
1292
+
1293
+ if self.config.use_decoder_only_language_model:
1294
+ outputs = self.language_model(
1295
+ inputs_embeds=inputs_embeds,
1296
+ attention_mask=attention_mask,
1297
+ **kwargs,
1298
+ )
1299
+ logits = outputs[0]
1300
+ loss = None
1301
+ if labels is not None:
1302
+ loss = self.loss_function(
1303
+ logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
1304
+ )
1305
+
1306
+ else:
1307
+ kwargs["return_dict"] = True
1308
+ outputs = self.language_model(
1309
+ inputs_embeds=inputs_embeds,
1310
+ attention_mask=attention_mask,
1311
+ decoder_input_ids=decoder_input_ids,
1312
+ decoder_attention_mask=decoder_attention_mask,
1313
+ labels=labels,
1314
+ **kwargs,
1315
+ )
1316
+ loss = outputs.loss
1317
+ logits = outputs.logits
1318
+
1319
+ return InstructBlipForConditionalGenerationModelOutput(
1320
+ loss=loss,
1321
+ logits=logits,
1322
+ vision_outputs=vision_outputs,
1323
+ qformer_outputs=qformer_outputs,
1324
+ language_model_outputs=outputs,
1325
+ )
1326
+
1327
+ @torch.no_grad()
1328
+ def generate(
1329
+ self,
1330
+ pixel_values: torch.FloatTensor,
1331
+ qformer_input_ids: torch.LongTensor | None = None,
1332
+ qformer_attention_mask: torch.LongTensor | None = None,
1333
+ input_ids: torch.LongTensor | None = None,
1334
+ attention_mask: torch.LongTensor | None = None,
1335
+ inputs_embeds: torch.FloatTensor | None = None,
1336
+ interpolate_pos_encoding: bool = False,
1337
+ **generate_kwargs,
1338
+ ) -> torch.LongTensor:
1339
+ """
1340
+ Overrides `generate` function to be able to use the model as a conditional generator.
1341
+
1342
+ Args:
1343
+ pixel_values (`torch.FloatTensor` of shape (batch_size, num_channels, height, width)):
1344
+ Input images to be processed.
1345
+ qformer_input_ids (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
1346
+ The sequence used as a prompt to be fed to the Q-Former module.
1347
+ qformer_attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
1348
+ Mask to avoid performing attention on padding token indices.
1349
+ input_ids (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
1350
+ The sequence used as a prompt for the generation.
1351
+ attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
1352
+ Mask to avoid performing attention on padding token indices.
1353
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
1354
+ Embedded representation of the inputs. Should be float, not int tokens.
1355
+ interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
1356
+ Whether to interpolate the positional encoding of the image embeddings.
1357
+
1358
+ Returns:
1359
+ captions (list): A list of strings of length batch_size * num_captions.
1360
+ """
1361
+ if hasattr(self, "hf_device_map"):
1362
+ # preprocess for `accelerate`
1363
+ self._preprocess_accelerate()
1364
+
1365
+ batch_size = pixel_values.shape[0]
1366
+ image_features: BaseModelOutputWithVisionQformerOutputs = self.get_image_features(
1367
+ pixel_values,
1368
+ qformer_input_ids=qformer_input_ids,
1369
+ qformer_attention_mask=qformer_attention_mask,
1370
+ interpolate_pos_encoding=interpolate_pos_encoding,
1371
+ return_dict=True,
1372
+ )
1373
+ language_model_inputs = image_features.pooler_output
1374
+
1375
+ if inputs_embeds is None:
1376
+ if input_ids is None:
1377
+ image_tokens = [self.config.image_token_index] * self.config.num_query_tokens
1378
+ start_tokens = image_tokens + [self.config.text_config.bos_token_id]
1379
+ input_ids = torch.tensor([start_tokens], dtype=torch.long, device=pixel_values.device)
1380
+ input_ids = input_ids.repeat(batch_size, 1)
1381
+ inputs_embeds = self.get_input_embeddings()(input_ids)
1382
+
1383
+ if attention_mask is None:
1384
+ attention_mask = torch.ones_like(input_ids)
1385
+
1386
+ language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
1387
+ special_image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds)
1388
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
1389
+
1390
+ inputs = {"inputs_embeds": inputs_embeds, "attention_mask": attention_mask}
1391
+ if not self.language_model.config.is_encoder_decoder:
1392
+ inputs["input_ids"] = input_ids
1393
+
1394
+ outputs = self.language_model.generate(**inputs, **generate_kwargs)
1395
+
1396
+ return outputs
1397
+
1398
+
1399
+ __all__ = [
1400
+ "InstructBlipQFormerModel",
1401
+ "InstructBlipPreTrainedModel",
1402
+ "InstructBlipModel",
1403
+ "InstructBlipForConditionalGeneration",
1404
+ "InstructBlipVisionModel",
1405
+ ]
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/instructblip/processing_instructblip.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ Processor class for InstructBLIP. Largely copy of Blip2Processor with addition of a tokenizer for the Q-Former.
16
+ """
17
+
18
+ from ...image_processing_utils import BatchFeature
19
+ from ...image_utils import ImageInput
20
+ from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
21
+ from ...tokenization_utils_base import AddedToken, PreTokenizedInput, TextInput
22
+ from ...utils import auto_docstring, logging
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+
28
+ class InstructBlipProcessorKwargs(ProcessingKwargs, total=False):
29
+ _defaults = {
30
+ "text_kwargs": {
31
+ "add_special_tokens": True,
32
+ "padding": False,
33
+ "stride": 0,
34
+ "return_overflowing_tokens": False,
35
+ "return_special_tokens_mask": False,
36
+ "return_offsets_mapping": False,
37
+ "return_token_type_ids": False,
38
+ "return_length": False,
39
+ "verbose": True,
40
+ },
41
+ }
42
+
43
+
44
+ @auto_docstring
45
+ class InstructBlipProcessor(ProcessorMixin):
46
+ def __init__(self, image_processor, tokenizer, qformer_tokenizer, num_query_tokens=None, **kwargs):
47
+ r"""
48
+ qformer_tokenizer (`AutoTokenizer`):
49
+ An instance of ['PreTrainedTokenizer`]. The Q-Former tokenizer is a required input.
50
+ num_query_tokens (`int`, *optional*):
51
+ "
52
+ Number of tokens used by the Qformer as queries, should be same as in model's config.
53
+ """
54
+ if not hasattr(tokenizer, "image_token"):
55
+ self.image_token = AddedToken("<image>", normalized=False, special=True)
56
+ tokenizer.add_tokens([self.image_token], special_tokens=True)
57
+ else:
58
+ self.image_token = tokenizer.image_token
59
+ self.num_query_tokens = num_query_tokens
60
+
61
+ super().__init__(image_processor, tokenizer, qformer_tokenizer)
62
+
63
+ @auto_docstring
64
+ def __call__(
65
+ self,
66
+ images: ImageInput | None = None,
67
+ text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] = None,
68
+ **kwargs: Unpack[InstructBlipProcessorKwargs],
69
+ ) -> BatchFeature:
70
+ if images is None and text is None:
71
+ raise ValueError("You have to specify at least images or text.")
72
+
73
+ output_kwargs = self._merge_kwargs(
74
+ InstructBlipProcessorKwargs,
75
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
76
+ **kwargs,
77
+ )
78
+
79
+ return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
80
+ encoding = {}
81
+ if text is not None:
82
+ if isinstance(text, str):
83
+ text = [text]
84
+ elif not isinstance(text, list) and not isinstance(text[0], str):
85
+ raise ValueError("Invalid input text. Please provide a string, or a list of strings")
86
+
87
+ qformer_text_encoding = self.qformer_tokenizer(text, **output_kwargs["text_kwargs"])
88
+ encoding["qformer_input_ids"] = qformer_text_encoding.pop("input_ids")
89
+ encoding["qformer_attention_mask"] = qformer_text_encoding.pop("attention_mask")
90
+
91
+ # We need this hacky manipulation because BLIP expects image tokens to be at the beginning even before BOS token
92
+ if output_kwargs["text_kwargs"].get("max_length") is not None:
93
+ output_kwargs["text_kwargs"]["max_length"] -= self.num_query_tokens
94
+ text_encoding = self.tokenizer(text, **output_kwargs["text_kwargs"])
95
+
96
+ if images is not None:
97
+ # Image tokens should not be padded/truncated or prepended with special BOS token
98
+ image_tokens = self.image_token.content * self.num_query_tokens
99
+ output_kwargs["text_kwargs"]["add_special_tokens"] = False
100
+ output_kwargs["text_kwargs"]["padding"] = False
101
+ output_kwargs["text_kwargs"]["truncation"] = False
102
+ image_text_encoding = self.tokenizer(image_tokens, **output_kwargs["text_kwargs"])
103
+ for k in text_encoding:
104
+ text_encoding[k] = [image_text_encoding[k] + sample for sample in text_encoding[k]]
105
+ encoding.update(text_encoding)
106
+
107
+ if images is not None:
108
+ image_encoding = self.image_processor(images, **output_kwargs["images_kwargs"])
109
+ encoding.update(image_encoding)
110
+
111
+ # Cast to desired return tensors type
112
+ encoding = BatchFeature(encoding, tensor_type=return_tensors)
113
+ return encoding
114
+
115
+ @property
116
+ def model_input_names(self):
117
+ tokenizer_input_names = self.tokenizer.model_input_names
118
+ image_processor_input_names = self.image_processor.model_input_names
119
+ qformer_input_names = ["qformer_input_ids", "qformer_attention_mask"]
120
+ return tokenizer_input_names + image_processor_input_names + qformer_input_names
121
+
122
+
123
+ __all__ = ["InstructBlipProcessor"]
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/mllama/__init__.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import _LazyModule
17
+ from ...utils.import_utils import define_import_structure
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from .configuration_mllama import *
22
+ from .image_processing_mllama import *
23
+ from .image_processing_pil_mllama import *
24
+ from .modeling_mllama import *
25
+ from .processing_mllama import *
26
+ else:
27
+ import sys
28
+
29
+ _file = globals()["__file__"]
30
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/mobilevit/modeling_mobilevit.py ADDED
@@ -0,0 +1,963 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Apple Inc. and The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # Original license: https://github.com/apple/ml-cvnets/blob/main/LICENSE
16
+ """PyTorch MobileViT model."""
17
+
18
+ import math
19
+
20
+ import torch
21
+ from torch import nn
22
+ from torch.nn import CrossEntropyLoss
23
+
24
+ from ... import initialization as init
25
+ from ...activations import ACT2FN
26
+ from ...modeling_layers import GradientCheckpointingLayer
27
+ from ...modeling_outputs import (
28
+ BaseModelOutputWithNoAttention,
29
+ BaseModelOutputWithPoolingAndNoAttention,
30
+ ImageClassifierOutputWithNoAttention,
31
+ SemanticSegmenterOutput,
32
+ )
33
+ from ...modeling_utils import PreTrainedModel
34
+ from ...utils import auto_docstring, logging, torch_int
35
+ from .configuration_mobilevit import MobileViTConfig
36
+
37
+
38
+ logger = logging.get_logger(__name__)
39
+
40
+
41
+ def make_divisible(value: int, divisor: int = 8, min_value: int | None = None) -> int:
42
+ """
43
+ Ensure that all layers have a channel count that is divisible by `divisor`.
44
+ """
45
+ if min_value is None:
46
+ min_value = divisor
47
+ new_value = max(min_value, int(value + divisor / 2) // divisor * divisor)
48
+ # Make sure that round down does not go down by more than 10%.
49
+ if new_value < 0.9 * value:
50
+ new_value += divisor
51
+ return int(new_value)
52
+
53
+
54
+ class MobileViTConvLayer(nn.Module):
55
+ def __init__(
56
+ self,
57
+ config: MobileViTConfig,
58
+ in_channels: int,
59
+ out_channels: int,
60
+ kernel_size: int,
61
+ stride: int = 1,
62
+ groups: int = 1,
63
+ bias: bool = False,
64
+ dilation: int = 1,
65
+ use_normalization: bool = True,
66
+ use_activation: bool | str = True,
67
+ ) -> None:
68
+ super().__init__()
69
+ padding = int((kernel_size - 1) / 2) * dilation
70
+
71
+ if in_channels % groups != 0:
72
+ raise ValueError(f"Input channels ({in_channels}) are not divisible by {groups} groups.")
73
+ if out_channels % groups != 0:
74
+ raise ValueError(f"Output channels ({out_channels}) are not divisible by {groups} groups.")
75
+
76
+ self.convolution = nn.Conv2d(
77
+ in_channels=in_channels,
78
+ out_channels=out_channels,
79
+ kernel_size=kernel_size,
80
+ stride=stride,
81
+ padding=padding,
82
+ dilation=dilation,
83
+ groups=groups,
84
+ bias=bias,
85
+ padding_mode="zeros",
86
+ )
87
+
88
+ if use_normalization:
89
+ self.normalization = nn.BatchNorm2d(
90
+ num_features=out_channels,
91
+ eps=1e-5,
92
+ momentum=0.1,
93
+ affine=True,
94
+ track_running_stats=True,
95
+ )
96
+ else:
97
+ self.normalization = None
98
+
99
+ if use_activation:
100
+ if isinstance(use_activation, str):
101
+ self.activation = ACT2FN[use_activation]
102
+ elif isinstance(config.hidden_act, str):
103
+ self.activation = ACT2FN[config.hidden_act]
104
+ else:
105
+ self.activation = config.hidden_act
106
+ else:
107
+ self.activation = None
108
+
109
+ def forward(self, features: torch.Tensor) -> torch.Tensor:
110
+ features = self.convolution(features)
111
+ if self.normalization is not None:
112
+ features = self.normalization(features)
113
+ if self.activation is not None:
114
+ features = self.activation(features)
115
+ return features
116
+
117
+
118
+ class MobileViTInvertedResidual(nn.Module):
119
+ """
120
+ Inverted residual block (MobileNetv2): https://huggingface.co/papers/1801.04381
121
+ """
122
+
123
+ def __init__(
124
+ self, config: MobileViTConfig, in_channels: int, out_channels: int, stride: int, dilation: int = 1
125
+ ) -> None:
126
+ super().__init__()
127
+ expanded_channels = make_divisible(int(round(in_channels * config.expand_ratio)), 8)
128
+
129
+ if stride not in [1, 2]:
130
+ raise ValueError(f"Invalid stride {stride}.")
131
+
132
+ self.use_residual = (stride == 1) and (in_channels == out_channels)
133
+
134
+ self.expand_1x1 = MobileViTConvLayer(
135
+ config, in_channels=in_channels, out_channels=expanded_channels, kernel_size=1
136
+ )
137
+
138
+ self.conv_3x3 = MobileViTConvLayer(
139
+ config,
140
+ in_channels=expanded_channels,
141
+ out_channels=expanded_channels,
142
+ kernel_size=3,
143
+ stride=stride,
144
+ groups=expanded_channels,
145
+ dilation=dilation,
146
+ )
147
+
148
+ self.reduce_1x1 = MobileViTConvLayer(
149
+ config,
150
+ in_channels=expanded_channels,
151
+ out_channels=out_channels,
152
+ kernel_size=1,
153
+ use_activation=False,
154
+ )
155
+
156
+ def forward(self, features: torch.Tensor) -> torch.Tensor:
157
+ residual = features
158
+
159
+ features = self.expand_1x1(features)
160
+ features = self.conv_3x3(features)
161
+ features = self.reduce_1x1(features)
162
+
163
+ return residual + features if self.use_residual else features
164
+
165
+
166
+ class MobileViTMobileNetLayer(nn.Module):
167
+ def __init__(
168
+ self, config: MobileViTConfig, in_channels: int, out_channels: int, stride: int = 1, num_stages: int = 1
169
+ ) -> None:
170
+ super().__init__()
171
+
172
+ self.layer = nn.ModuleList()
173
+ for i in range(num_stages):
174
+ layer = MobileViTInvertedResidual(
175
+ config,
176
+ in_channels=in_channels,
177
+ out_channels=out_channels,
178
+ stride=stride if i == 0 else 1,
179
+ )
180
+ self.layer.append(layer)
181
+ in_channels = out_channels
182
+
183
+ def forward(self, features: torch.Tensor) -> torch.Tensor:
184
+ for layer_module in self.layer:
185
+ features = layer_module(features)
186
+ return features
187
+
188
+
189
+ class MobileViTSelfAttention(nn.Module):
190
+ def __init__(self, config: MobileViTConfig, hidden_size: int) -> None:
191
+ super().__init__()
192
+
193
+ if hidden_size % config.num_attention_heads != 0:
194
+ raise ValueError(
195
+ f"The hidden size {hidden_size} is not a multiple of the number of attention "
196
+ f"heads {config.num_attention_heads}."
197
+ )
198
+
199
+ self.num_attention_heads = config.num_attention_heads
200
+ self.attention_head_size = int(hidden_size / config.num_attention_heads)
201
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
202
+
203
+ self.query = nn.Linear(hidden_size, self.all_head_size, bias=config.qkv_bias)
204
+ self.key = nn.Linear(hidden_size, self.all_head_size, bias=config.qkv_bias)
205
+ self.value = nn.Linear(hidden_size, self.all_head_size, bias=config.qkv_bias)
206
+
207
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
208
+
209
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
210
+ input_shape = hidden_states.shape[:-1]
211
+ hidden_shape = (*input_shape, -1, self.attention_head_size)
212
+ query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
213
+ key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
214
+ value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
215
+
216
+ # Take the dot product between "query" and "key" to get the raw attention scores.
217
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
218
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
219
+
220
+ # Normalize the attention scores to probabilities.
221
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
222
+
223
+ # This is actually dropping out entire tokens to attend to, which might
224
+ # seem a bit unusual, but is taken from the original Transformer paper.
225
+ attention_probs = self.dropout(attention_probs)
226
+
227
+ context_layer = torch.matmul(attention_probs, value_layer)
228
+
229
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
230
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
231
+ context_layer = context_layer.view(*new_context_layer_shape)
232
+ return context_layer
233
+
234
+
235
+ class MobileViTSelfOutput(nn.Module):
236
+ def __init__(self, config: MobileViTConfig, hidden_size: int) -> None:
237
+ super().__init__()
238
+ self.dense = nn.Linear(hidden_size, hidden_size)
239
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
240
+
241
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
242
+ hidden_states = self.dense(hidden_states)
243
+ hidden_states = self.dropout(hidden_states)
244
+ return hidden_states
245
+
246
+
247
+ class MobileViTAttention(nn.Module):
248
+ def __init__(self, config: MobileViTConfig, hidden_size: int) -> None:
249
+ super().__init__()
250
+ self.attention = MobileViTSelfAttention(config, hidden_size)
251
+ self.output = MobileViTSelfOutput(config, hidden_size)
252
+
253
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
254
+ self_outputs = self.attention(hidden_states)
255
+ attention_output = self.output(self_outputs)
256
+ return attention_output
257
+
258
+
259
+ class MobileViTIntermediate(nn.Module):
260
+ def __init__(self, config: MobileViTConfig, hidden_size: int, intermediate_size: int) -> None:
261
+ super().__init__()
262
+ self.dense = nn.Linear(hidden_size, intermediate_size)
263
+ if isinstance(config.hidden_act, str):
264
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
265
+ else:
266
+ self.intermediate_act_fn = config.hidden_act
267
+
268
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
269
+ hidden_states = self.dense(hidden_states)
270
+ hidden_states = self.intermediate_act_fn(hidden_states)
271
+ return hidden_states
272
+
273
+
274
+ class MobileViTOutput(nn.Module):
275
+ def __init__(self, config: MobileViTConfig, hidden_size: int, intermediate_size: int) -> None:
276
+ super().__init__()
277
+ self.dense = nn.Linear(intermediate_size, hidden_size)
278
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
279
+
280
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
281
+ hidden_states = self.dense(hidden_states)
282
+ hidden_states = self.dropout(hidden_states)
283
+ hidden_states = hidden_states + input_tensor
284
+ return hidden_states
285
+
286
+
287
+ class MobileViTTransformerLayer(nn.Module):
288
+ def __init__(self, config: MobileViTConfig, hidden_size: int, intermediate_size: int) -> None:
289
+ super().__init__()
290
+ self.attention = MobileViTAttention(config, hidden_size)
291
+ self.intermediate = MobileViTIntermediate(config, hidden_size, intermediate_size)
292
+ self.output = MobileViTOutput(config, hidden_size, intermediate_size)
293
+ self.layernorm_before = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
294
+ self.layernorm_after = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
295
+
296
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
297
+ attention_output = self.attention(self.layernorm_before(hidden_states))
298
+ hidden_states = attention_output + hidden_states
299
+
300
+ layer_output = self.layernorm_after(hidden_states)
301
+ layer_output = self.intermediate(layer_output)
302
+ layer_output = self.output(layer_output, hidden_states)
303
+ return layer_output
304
+
305
+
306
+ class MobileViTTransformer(nn.Module):
307
+ def __init__(self, config: MobileViTConfig, hidden_size: int, num_stages: int) -> None:
308
+ super().__init__()
309
+
310
+ self.layer = nn.ModuleList()
311
+ for _ in range(num_stages):
312
+ transformer_layer = MobileViTTransformerLayer(
313
+ config,
314
+ hidden_size=hidden_size,
315
+ intermediate_size=int(hidden_size * config.mlp_ratio),
316
+ )
317
+ self.layer.append(transformer_layer)
318
+
319
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
320
+ for layer_module in self.layer:
321
+ hidden_states = layer_module(hidden_states)
322
+ return hidden_states
323
+
324
+
325
+ class MobileViTLayer(GradientCheckpointingLayer):
326
+ """
327
+ MobileViT block: https://huggingface.co/papers/2110.02178
328
+ """
329
+
330
+ def __init__(
331
+ self,
332
+ config: MobileViTConfig,
333
+ in_channels: int,
334
+ out_channels: int,
335
+ stride: int,
336
+ hidden_size: int,
337
+ num_stages: int,
338
+ dilation: int = 1,
339
+ ) -> None:
340
+ super().__init__()
341
+ self.patch_width = config.patch_size
342
+ self.patch_height = config.patch_size
343
+
344
+ if stride == 2:
345
+ self.downsampling_layer = MobileViTInvertedResidual(
346
+ config,
347
+ in_channels=in_channels,
348
+ out_channels=out_channels,
349
+ stride=stride if dilation == 1 else 1,
350
+ dilation=dilation // 2 if dilation > 1 else 1,
351
+ )
352
+ in_channels = out_channels
353
+ else:
354
+ self.downsampling_layer = None
355
+
356
+ self.conv_kxk = MobileViTConvLayer(
357
+ config,
358
+ in_channels=in_channels,
359
+ out_channels=in_channels,
360
+ kernel_size=config.conv_kernel_size,
361
+ )
362
+
363
+ self.conv_1x1 = MobileViTConvLayer(
364
+ config,
365
+ in_channels=in_channels,
366
+ out_channels=hidden_size,
367
+ kernel_size=1,
368
+ use_normalization=False,
369
+ use_activation=False,
370
+ )
371
+
372
+ self.transformer = MobileViTTransformer(
373
+ config,
374
+ hidden_size=hidden_size,
375
+ num_stages=num_stages,
376
+ )
377
+
378
+ self.layernorm = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
379
+
380
+ self.conv_projection = MobileViTConvLayer(
381
+ config, in_channels=hidden_size, out_channels=in_channels, kernel_size=1
382
+ )
383
+
384
+ self.fusion = MobileViTConvLayer(
385
+ config, in_channels=2 * in_channels, out_channels=in_channels, kernel_size=config.conv_kernel_size
386
+ )
387
+
388
+ def unfolding(self, features: torch.Tensor) -> tuple[torch.Tensor, dict]:
389
+ patch_width, patch_height = self.patch_width, self.patch_height
390
+ patch_area = int(patch_width * patch_height)
391
+
392
+ batch_size, channels, orig_height, orig_width = features.shape
393
+
394
+ new_height = (
395
+ torch_int(torch.ceil(orig_height / patch_height) * patch_height)
396
+ if torch.jit.is_tracing()
397
+ else int(math.ceil(orig_height / patch_height) * patch_height)
398
+ )
399
+ new_width = (
400
+ torch_int(torch.ceil(orig_width / patch_width) * patch_width)
401
+ if torch.jit.is_tracing()
402
+ else int(math.ceil(orig_width / patch_width) * patch_width)
403
+ )
404
+
405
+ interpolate = False
406
+ if new_width != orig_width or new_height != orig_height:
407
+ # Note: Padding can be done, but then it needs to be handled in attention function.
408
+ features = nn.functional.interpolate(
409
+ features, size=(new_height, new_width), mode="bilinear", align_corners=False
410
+ )
411
+ interpolate = True
412
+
413
+ # number of patches along width and height
414
+ num_patch_width = new_width // patch_width
415
+ num_patch_height = new_height // patch_height
416
+ num_patches = num_patch_height * num_patch_width
417
+
418
+ # convert from shape (batch_size, channels, orig_height, orig_width)
419
+ # to the shape (batch_size * patch_area, num_patches, channels)
420
+ patches = features.reshape(
421
+ batch_size * channels * num_patch_height, patch_height, num_patch_width, patch_width
422
+ )
423
+ patches = patches.transpose(1, 2)
424
+ patches = patches.reshape(batch_size, channels, num_patches, patch_area)
425
+ patches = patches.transpose(1, 3)
426
+ patches = patches.reshape(batch_size * patch_area, num_patches, -1)
427
+
428
+ info_dict = {
429
+ "orig_size": (orig_height, orig_width),
430
+ "batch_size": batch_size,
431
+ "channels": channels,
432
+ "interpolate": interpolate,
433
+ "num_patches": num_patches,
434
+ "num_patches_width": num_patch_width,
435
+ "num_patches_height": num_patch_height,
436
+ }
437
+ return patches, info_dict
438
+
439
+ def folding(self, patches: torch.Tensor, info_dict: dict) -> torch.Tensor:
440
+ patch_width, patch_height = self.patch_width, self.patch_height
441
+ patch_area = int(patch_width * patch_height)
442
+
443
+ batch_size = info_dict["batch_size"]
444
+ channels = info_dict["channels"]
445
+ num_patches = info_dict["num_patches"]
446
+ num_patch_height = info_dict["num_patches_height"]
447
+ num_patch_width = info_dict["num_patches_width"]
448
+
449
+ # convert from shape (batch_size * patch_area, num_patches, channels)
450
+ # back to shape (batch_size, channels, orig_height, orig_width)
451
+ features = patches.contiguous().view(batch_size, patch_area, num_patches, -1)
452
+ features = features.transpose(1, 3)
453
+ features = features.reshape(
454
+ batch_size * channels * num_patch_height, num_patch_width, patch_height, patch_width
455
+ )
456
+ features = features.transpose(1, 2)
457
+ features = features.reshape(
458
+ batch_size, channels, num_patch_height * patch_height, num_patch_width * patch_width
459
+ )
460
+
461
+ if info_dict["interpolate"]:
462
+ features = nn.functional.interpolate(
463
+ features, size=info_dict["orig_size"], mode="bilinear", align_corners=False
464
+ )
465
+
466
+ return features
467
+
468
+ def forward(self, features: torch.Tensor) -> torch.Tensor:
469
+ # reduce spatial dimensions if needed
470
+ if self.downsampling_layer:
471
+ features = self.downsampling_layer(features)
472
+
473
+ residual = features
474
+
475
+ # local representation
476
+ features = self.conv_kxk(features)
477
+ features = self.conv_1x1(features)
478
+
479
+ # convert feature map to patches
480
+ patches, info_dict = self.unfolding(features)
481
+
482
+ # learn global representations
483
+ patches = self.transformer(patches)
484
+ patches = self.layernorm(patches)
485
+
486
+ # convert patches back to feature maps
487
+ features = self.folding(patches, info_dict)
488
+
489
+ features = self.conv_projection(features)
490
+ features = self.fusion(torch.cat((residual, features), dim=1))
491
+ return features
492
+
493
+
494
+ class MobileViTEncoder(nn.Module):
495
+ def __init__(self, config: MobileViTConfig) -> None:
496
+ super().__init__()
497
+ self.config = config
498
+
499
+ self.layer = nn.ModuleList()
500
+ self.gradient_checkpointing = False
501
+
502
+ # segmentation architectures like DeepLab and PSPNet modify the strides
503
+ # of the classification backbones
504
+ dilate_layer_4 = dilate_layer_5 = False
505
+ if config.output_stride == 8:
506
+ dilate_layer_4 = True
507
+ dilate_layer_5 = True
508
+ elif config.output_stride == 16:
509
+ dilate_layer_5 = True
510
+
511
+ dilation = 1
512
+
513
+ layer_1 = MobileViTMobileNetLayer(
514
+ config,
515
+ in_channels=config.neck_hidden_sizes[0],
516
+ out_channels=config.neck_hidden_sizes[1],
517
+ stride=1,
518
+ num_stages=1,
519
+ )
520
+ self.layer.append(layer_1)
521
+
522
+ layer_2 = MobileViTMobileNetLayer(
523
+ config,
524
+ in_channels=config.neck_hidden_sizes[1],
525
+ out_channels=config.neck_hidden_sizes[2],
526
+ stride=2,
527
+ num_stages=3,
528
+ )
529
+ self.layer.append(layer_2)
530
+
531
+ layer_3 = MobileViTLayer(
532
+ config,
533
+ in_channels=config.neck_hidden_sizes[2],
534
+ out_channels=config.neck_hidden_sizes[3],
535
+ stride=2,
536
+ hidden_size=config.hidden_sizes[0],
537
+ num_stages=2,
538
+ )
539
+ self.layer.append(layer_3)
540
+
541
+ if dilate_layer_4:
542
+ dilation *= 2
543
+
544
+ layer_4 = MobileViTLayer(
545
+ config,
546
+ in_channels=config.neck_hidden_sizes[3],
547
+ out_channels=config.neck_hidden_sizes[4],
548
+ stride=2,
549
+ hidden_size=config.hidden_sizes[1],
550
+ num_stages=4,
551
+ dilation=dilation,
552
+ )
553
+ self.layer.append(layer_4)
554
+
555
+ if dilate_layer_5:
556
+ dilation *= 2
557
+
558
+ layer_5 = MobileViTLayer(
559
+ config,
560
+ in_channels=config.neck_hidden_sizes[4],
561
+ out_channels=config.neck_hidden_sizes[5],
562
+ stride=2,
563
+ hidden_size=config.hidden_sizes[2],
564
+ num_stages=3,
565
+ dilation=dilation,
566
+ )
567
+ self.layer.append(layer_5)
568
+
569
+ def forward(
570
+ self,
571
+ hidden_states: torch.Tensor,
572
+ output_hidden_states: bool = False,
573
+ return_dict: bool = True,
574
+ ) -> tuple | BaseModelOutputWithNoAttention:
575
+ all_hidden_states = () if output_hidden_states else None
576
+
577
+ for i, layer_module in enumerate(self.layer):
578
+ hidden_states = layer_module(hidden_states)
579
+
580
+ if output_hidden_states:
581
+ all_hidden_states = all_hidden_states + (hidden_states,)
582
+
583
+ if not return_dict:
584
+ return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)
585
+
586
+ return BaseModelOutputWithNoAttention(last_hidden_state=hidden_states, hidden_states=all_hidden_states)
587
+
588
+
589
+ @auto_docstring
590
+ class MobileViTPreTrainedModel(PreTrainedModel):
591
+ config: MobileViTConfig
592
+ base_model_prefix = "mobilevit"
593
+ main_input_name = "pixel_values"
594
+ input_modalities = ("image",)
595
+ supports_gradient_checkpointing = True
596
+ _no_split_modules = ["MobileViTLayer"]
597
+
598
+ @torch.no_grad()
599
+ def _init_weights(self, module: nn.Module) -> None:
600
+ """Initialize the weights"""
601
+ if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
602
+ init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
603
+ if module.bias is not None:
604
+ init.zeros_(module.bias)
605
+ if getattr(module, "running_mean", None) is not None:
606
+ init.zeros_(module.running_mean)
607
+ init.ones_(module.running_var)
608
+ init.zeros_(module.num_batches_tracked)
609
+ elif isinstance(module, nn.LayerNorm):
610
+ init.zeros_(module.bias)
611
+ init.ones_(module.weight)
612
+
613
+
614
+ @auto_docstring
615
+ class MobileViTModel(MobileViTPreTrainedModel):
616
+ def __init__(self, config: MobileViTConfig, expand_output: bool = True):
617
+ r"""
618
+ expand_output (`bool`, *optional*, defaults to `True`):
619
+ Whether to expand the output of the model using a 1x1 convolution. If `True`, the model will apply an additional
620
+ 1x1 convolution to expand the output channels from `config.neck_hidden_sizes[5]` to `config.neck_hidden_sizes[6]`.
621
+ """
622
+ super().__init__(config)
623
+ self.config = config
624
+ self.expand_output = expand_output
625
+
626
+ self.conv_stem = MobileViTConvLayer(
627
+ config,
628
+ in_channels=config.num_channels,
629
+ out_channels=config.neck_hidden_sizes[0],
630
+ kernel_size=3,
631
+ stride=2,
632
+ )
633
+
634
+ self.encoder = MobileViTEncoder(config)
635
+
636
+ if self.expand_output:
637
+ self.conv_1x1_exp = MobileViTConvLayer(
638
+ config,
639
+ in_channels=config.neck_hidden_sizes[5],
640
+ out_channels=config.neck_hidden_sizes[6],
641
+ kernel_size=1,
642
+ )
643
+
644
+ # Initialize weights and apply final processing
645
+ self.post_init()
646
+
647
+ @auto_docstring
648
+ def forward(
649
+ self,
650
+ pixel_values: torch.Tensor | None = None,
651
+ output_hidden_states: bool | None = None,
652
+ return_dict: bool | None = None,
653
+ **kwargs,
654
+ ) -> tuple | BaseModelOutputWithPoolingAndNoAttention:
655
+ output_hidden_states = (
656
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
657
+ )
658
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
659
+
660
+ if pixel_values is None:
661
+ raise ValueError("You have to specify pixel_values")
662
+
663
+ embedding_output = self.conv_stem(pixel_values)
664
+
665
+ encoder_outputs = self.encoder(
666
+ embedding_output,
667
+ output_hidden_states=output_hidden_states,
668
+ return_dict=return_dict,
669
+ )
670
+
671
+ if self.expand_output:
672
+ last_hidden_state = self.conv_1x1_exp(encoder_outputs[0])
673
+
674
+ # global average pooling: (batch_size, channels, height, width) -> (batch_size, channels)
675
+ pooled_output = torch.mean(last_hidden_state, dim=[-2, -1], keepdim=False)
676
+ else:
677
+ last_hidden_state = encoder_outputs[0]
678
+ pooled_output = None
679
+
680
+ if not return_dict:
681
+ output = (last_hidden_state, pooled_output) if pooled_output is not None else (last_hidden_state,)
682
+ return output + encoder_outputs[1:]
683
+
684
+ return BaseModelOutputWithPoolingAndNoAttention(
685
+ last_hidden_state=last_hidden_state,
686
+ pooler_output=pooled_output,
687
+ hidden_states=encoder_outputs.hidden_states,
688
+ )
689
+
690
+
691
+ @auto_docstring(
692
+ custom_intro="""
693
+ MobileViT model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
694
+ ImageNet.
695
+ """
696
+ )
697
+ class MobileViTForImageClassification(MobileViTPreTrainedModel):
698
+ def __init__(self, config: MobileViTConfig) -> None:
699
+ super().__init__(config)
700
+
701
+ self.num_labels = config.num_labels
702
+ self.mobilevit = MobileViTModel(config)
703
+
704
+ # Classifier head
705
+ self.dropout = nn.Dropout(config.classifier_dropout_prob, inplace=True)
706
+ self.classifier = (
707
+ nn.Linear(config.neck_hidden_sizes[-1], config.num_labels) if config.num_labels > 0 else nn.Identity()
708
+ )
709
+
710
+ # Initialize weights and apply final processing
711
+ self.post_init()
712
+
713
+ @auto_docstring
714
+ def forward(
715
+ self,
716
+ pixel_values: torch.Tensor | None = None,
717
+ output_hidden_states: bool | None = None,
718
+ labels: torch.Tensor | None = None,
719
+ return_dict: bool | None = None,
720
+ **kwargs,
721
+ ) -> tuple | ImageClassifierOutputWithNoAttention:
722
+ r"""
723
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
724
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
725
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss). If
726
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
727
+ """
728
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
729
+
730
+ outputs = self.mobilevit(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict)
731
+
732
+ pooled_output = outputs.pooler_output if return_dict else outputs[1]
733
+
734
+ logits = self.classifier(self.dropout(pooled_output))
735
+
736
+ loss = None
737
+ if labels is not None:
738
+ loss = self.loss_function(labels, logits, self.config)
739
+
740
+ if not return_dict:
741
+ output = (logits,) + outputs[2:]
742
+ return ((loss,) + output) if loss is not None else output
743
+
744
+ return ImageClassifierOutputWithNoAttention(
745
+ loss=loss,
746
+ logits=logits,
747
+ hidden_states=outputs.hidden_states,
748
+ )
749
+
750
+
751
+ class MobileViTASPPPooling(nn.Module):
752
+ def __init__(self, config: MobileViTConfig, in_channels: int, out_channels: int) -> None:
753
+ super().__init__()
754
+
755
+ self.global_pool = nn.AdaptiveAvgPool2d(output_size=1)
756
+
757
+ self.conv_1x1 = MobileViTConvLayer(
758
+ config,
759
+ in_channels=in_channels,
760
+ out_channels=out_channels,
761
+ kernel_size=1,
762
+ stride=1,
763
+ use_normalization=True,
764
+ use_activation="relu",
765
+ )
766
+
767
+ def forward(self, features: torch.Tensor) -> torch.Tensor:
768
+ spatial_size = features.shape[-2:]
769
+ features = self.global_pool(features)
770
+ features = self.conv_1x1(features)
771
+ features = nn.functional.interpolate(features, size=spatial_size, mode="bilinear", align_corners=False)
772
+ return features
773
+
774
+
775
+ class MobileViTASPP(nn.Module):
776
+ """
777
+ ASPP module defined in DeepLab papers: https://huggingface.co/papers/1606.00915, https://huggingface.co/papers/1706.05587
778
+ """
779
+
780
+ def __init__(self, config: MobileViTConfig) -> None:
781
+ super().__init__()
782
+
783
+ in_channels = config.neck_hidden_sizes[-2]
784
+ out_channels = config.aspp_out_channels
785
+
786
+ if len(config.atrous_rates) != 3:
787
+ raise ValueError("Expected 3 values for atrous_rates")
788
+
789
+ self.convs = nn.ModuleList()
790
+
791
+ in_projection = MobileViTConvLayer(
792
+ config,
793
+ in_channels=in_channels,
794
+ out_channels=out_channels,
795
+ kernel_size=1,
796
+ use_activation="relu",
797
+ )
798
+ self.convs.append(in_projection)
799
+
800
+ self.convs.extend(
801
+ [
802
+ MobileViTConvLayer(
803
+ config,
804
+ in_channels=in_channels,
805
+ out_channels=out_channels,
806
+ kernel_size=3,
807
+ dilation=rate,
808
+ use_activation="relu",
809
+ )
810
+ for rate in config.atrous_rates
811
+ ]
812
+ )
813
+
814
+ pool_layer = MobileViTASPPPooling(config, in_channels, out_channels)
815
+ self.convs.append(pool_layer)
816
+
817
+ self.project = MobileViTConvLayer(
818
+ config, in_channels=5 * out_channels, out_channels=out_channels, kernel_size=1, use_activation="relu"
819
+ )
820
+
821
+ self.dropout = nn.Dropout(p=config.aspp_dropout_prob)
822
+
823
+ def forward(self, features: torch.Tensor) -> torch.Tensor:
824
+ pyramid = []
825
+ for conv in self.convs:
826
+ pyramid.append(conv(features))
827
+ pyramid = torch.cat(pyramid, dim=1)
828
+
829
+ pooled_features = self.project(pyramid)
830
+ pooled_features = self.dropout(pooled_features)
831
+ return pooled_features
832
+
833
+
834
+ class MobileViTDeepLabV3(nn.Module):
835
+ """
836
+ DeepLabv3 architecture: https://huggingface.co/papers/1706.05587
837
+ """
838
+
839
+ def __init__(self, config: MobileViTConfig) -> None:
840
+ super().__init__()
841
+ self.aspp = MobileViTASPP(config)
842
+
843
+ self.dropout = nn.Dropout2d(config.classifier_dropout_prob)
844
+
845
+ self.classifier = MobileViTConvLayer(
846
+ config,
847
+ in_channels=config.aspp_out_channels,
848
+ out_channels=config.num_labels,
849
+ kernel_size=1,
850
+ use_normalization=False,
851
+ use_activation=False,
852
+ bias=True,
853
+ )
854
+
855
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
856
+ features = self.aspp(hidden_states[-1])
857
+ features = self.dropout(features)
858
+ features = self.classifier(features)
859
+ return features
860
+
861
+
862
+ @auto_docstring(
863
+ custom_intro="""
864
+ MobileViT model with a semantic segmentation head on top, e.g. for Pascal VOC.
865
+ """
866
+ )
867
+ class MobileViTForSemanticSegmentation(MobileViTPreTrainedModel):
868
+ def __init__(self, config: MobileViTConfig) -> None:
869
+ super().__init__(config)
870
+
871
+ self.num_labels = config.num_labels
872
+ self.mobilevit = MobileViTModel(config, expand_output=False)
873
+ self.segmentation_head = MobileViTDeepLabV3(config)
874
+
875
+ # Initialize weights and apply final processing
876
+ self.post_init()
877
+
878
+ @auto_docstring
879
+ def forward(
880
+ self,
881
+ pixel_values: torch.Tensor | None = None,
882
+ labels: torch.Tensor | None = None,
883
+ output_hidden_states: bool | None = None,
884
+ return_dict: bool | None = None,
885
+ **kwargs,
886
+ ) -> tuple | SemanticSegmenterOutput:
887
+ r"""
888
+ labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
889
+ Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
890
+ config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy).
891
+
892
+ Examples:
893
+
894
+ ```python
895
+ >>> import httpx
896
+ >>> from io import BytesIO
897
+ >>> import torch
898
+ >>> from PIL import Image
899
+ >>> from transformers import AutoImageProcessor, MobileViTForSemanticSegmentation
900
+
901
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
902
+ >>> with httpx.stream("GET", url) as response:
903
+ ... image = Image.open(BytesIO(response.read()))
904
+
905
+ >>> image_processor = AutoImageProcessor.from_pretrained("apple/deeplabv3-mobilevit-small")
906
+ >>> model = MobileViTForSemanticSegmentation.from_pretrained("apple/deeplabv3-mobilevit-small")
907
+
908
+ >>> inputs = image_processor(images=image, return_tensors="pt")
909
+
910
+ >>> with torch.no_grad():
911
+ ... outputs = model(**inputs)
912
+
913
+ >>> # logits are of shape (batch_size, num_labels, height, width)
914
+ >>> logits = outputs.logits
915
+ ```"""
916
+ output_hidden_states = (
917
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
918
+ )
919
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
920
+
921
+ if labels is not None and self.config.num_labels == 1:
922
+ raise ValueError("The number of labels should be greater than one")
923
+
924
+ outputs = self.mobilevit(
925
+ pixel_values,
926
+ output_hidden_states=True, # we need the intermediate hidden states
927
+ return_dict=return_dict,
928
+ )
929
+
930
+ encoder_hidden_states = outputs.hidden_states if return_dict else outputs[1]
931
+
932
+ logits = self.segmentation_head(encoder_hidden_states)
933
+
934
+ loss = None
935
+ if labels is not None:
936
+ # upsample logits to the images' original size
937
+ upsampled_logits = nn.functional.interpolate(
938
+ logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
939
+ )
940
+ loss_fct = CrossEntropyLoss(ignore_index=self.config.semantic_loss_ignore_index)
941
+ loss = loss_fct(upsampled_logits, labels)
942
+
943
+ if not return_dict:
944
+ if output_hidden_states:
945
+ output = (logits,) + outputs[1:]
946
+ else:
947
+ output = (logits,) + outputs[2:]
948
+ return ((loss,) + output) if loss is not None else output
949
+
950
+ return SemanticSegmenterOutput(
951
+ loss=loss,
952
+ logits=logits,
953
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
954
+ attentions=None,
955
+ )
956
+
957
+
958
+ __all__ = [
959
+ "MobileViTForImageClassification",
960
+ "MobileViTForSemanticSegmentation",
961
+ "MobileViTModel",
962
+ "MobileViTPreTrainedModel",
963
+ ]
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/speecht5/configuration_speecht5.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The Fairseq Authors, Microsoft Research, and the HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """SpeechT5 model configuration"""
15
+
16
+ import functools
17
+ import operator
18
+
19
+ from huggingface_hub.dataclasses import strict
20
+
21
+ from ...configuration_utils import PreTrainedConfig
22
+ from ...utils import auto_docstring
23
+
24
+
25
+ @auto_docstring(checkpoint="microsoft/speecht5_asr")
26
+ @strict
27
+ class SpeechT5Config(PreTrainedConfig):
28
+ r"""
29
+ positional_dropout (`float`, *optional*, defaults to 0.1):
30
+ The dropout probability for the text position encoding layers.
31
+ feat_extract_norm (`str`, *optional*, defaults to `"group"`):
32
+ The norm to be applied to 1D convolutional layers in the speech encoder pre-net. One of `"group"` for group
33
+ normalization of only the first 1D convolutional layer or `"layer"` for layer normalization of all 1D
34
+ convolutional layers.
35
+ feat_proj_dropout (`float`, *optional*, defaults to 0.0):
36
+ The dropout probability for output of the speech encoder pre-net.
37
+ feat_extract_activation (`str, `optional`, defaults to `"gelu"`):
38
+ The non-linear activation function (function or string) in the 1D convolutional layers of the feature
39
+ extractor. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` are supported.
40
+ conv_dim (`tuple[int]` or `list[int]`, *optional*, defaults to `(512, 512, 512, 512, 512, 512, 512)`):
41
+ A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the
42
+ speech encoder pre-net. The length of *conv_dim* defines the number of 1D convolutional layers.
43
+ conv_stride (`tuple[int]` or `list[int]`, *optional*, defaults to `(5, 2, 2, 2, 2, 2, 2)`):
44
+ A tuple of integers defining the stride of each 1D convolutional layer in the speech encoder pre-net. The
45
+ length of *conv_stride* defines the number of convolutional layers and has to match the length of
46
+ *conv_dim*.
47
+ conv_kernel (`tuple[int]` or `list[int]`, *optional*, defaults to `(10, 3, 3, 3, 3, 3, 3)`):
48
+ A tuple of integers defining the kernel size of each 1D convolutional layer in the speech encoder pre-net.
49
+ The length of *conv_kernel* defines the number of convolutional layers and has to match the length of
50
+ *conv_dim*.
51
+ conv_bias (`bool`, *optional*, defaults to `False`):
52
+ Whether the 1D convolutional layers have a bias.
53
+ num_conv_pos_embeddings (`int`, *optional*, defaults to 128):
54
+ Number of convolutional positional embeddings. Defines the kernel size of 1D convolutional positional
55
+ embeddings layer.
56
+ num_conv_pos_embedding_groups (`int`, *optional*, defaults to 16):
57
+ Number of groups of 1D convolutional positional embeddings layer.
58
+ apply_spec_augment (`bool`, *optional*, defaults to `True`):
59
+ Whether to apply *SpecAugment* data augmentation to the outputs of the speech encoder pre-net. For
60
+ reference see [SpecAugment: A Simple Data Augmentation Method for Automatic Speech
61
+ Recognition](https://huggingface.co/papers/1904.08779).
62
+ mask_time_prob (`float`, *optional*, defaults to 0.05):
63
+ Percentage (between 0 and 1) of all feature vectors along the time axis which will be masked. The masking
64
+ procedure generates ''mask_time_prob*len(time_axis)/mask_time_length'' independent masks over the axis. If
65
+ reasoning from the probability of each feature vector to be chosen as the start of the vector span to be
66
+ masked, *mask_time_prob* should be `prob_vector_start*mask_time_length`. Note that overlap may decrease the
67
+ actual percentage of masked vectors. This is only relevant if `apply_spec_augment is True`.
68
+ mask_time_length (`int`, *optional*, defaults to 10):
69
+ Length of vector span along the time axis.
70
+ mask_time_min_masks (`int`, *optional*, defaults to 2),:
71
+ The minimum number of masks of length `mask_feature_length` generated along the time axis, each time step,
72
+ irrespectively of `mask_feature_prob`. Only relevant if ''mask_time_prob*len(time_axis)/mask_time_length <
73
+ mask_time_min_masks''
74
+ mask_feature_prob (`float`, *optional*, defaults to 0.0):
75
+ Percentage (between 0 and 1) of all feature vectors along the feature axis which will be masked. The
76
+ masking procedure generates ''mask_feature_prob*len(feature_axis)/mask_time_length'' independent masks over
77
+ the axis. If reasoning from the probability of each feature vector to be chosen as the start of the vector
78
+ span to be masked, *mask_feature_prob* should be `prob_vector_start*mask_feature_length`. Note that overlap
79
+ may decrease the actual percentage of masked vectors. This is only relevant if `apply_spec_augment is
80
+ True`.
81
+ mask_feature_length (`int`, *optional*, defaults to 10):
82
+ Length of vector span along the feature axis.
83
+ mask_feature_min_masks (`int`, *optional*, defaults to 0),:
84
+ The minimum number of masks of length `mask_feature_length` generated along the feature axis, each time
85
+ step, irrespectively of `mask_feature_prob`. Only relevant if
86
+ ''mask_feature_prob*len(feature_axis)/mask_feature_length < mask_feature_min_masks''
87
+ num_mel_bins (`int`, *optional*, defaults to 80):
88
+ Number of mel features used per input features. Used by the speech decoder pre-net. Should correspond to
89
+ the value used in the [`SpeechT5Processor`] class.
90
+ speech_decoder_prenet_layers (`int`, *optional*, defaults to 2):
91
+ Number of layers in the speech decoder pre-net.
92
+ speech_decoder_prenet_units (`int`, *optional*, defaults to 256):
93
+ Dimensionality of the layers in the speech decoder pre-net.
94
+ speech_decoder_prenet_dropout (`float`, *optional*, defaults to 0.5):
95
+ The dropout probability for the speech decoder pre-net layers.
96
+ speaker_embedding_dim (`int`, *optional*, defaults to 512):
97
+ Dimensionality of the *XVector* embedding vectors.
98
+ speech_decoder_postnet_layers (`int`, *optional*, defaults to 5):
99
+ Number of layers in the speech decoder post-net.
100
+ speech_decoder_postnet_units (`int`, *optional*, defaults to 256):
101
+ Dimensionality of the layers in the speech decoder post-net.
102
+ speech_decoder_postnet_kernel (`int`, *optional*, defaults to 5):
103
+ Number of convolutional filter channels in the speech decoder post-net.
104
+ speech_decoder_postnet_dropout (`float`, *optional*, defaults to 0.5):
105
+ The dropout probability for the speech decoder post-net layers.
106
+ reduction_factor (`int`, *optional*, defaults to 2):
107
+ Spectrogram length reduction factor for the speech decoder inputs.
108
+ max_speech_positions (`int`, *optional*, defaults to 4000):
109
+ The maximum sequence length of speech features that this model might ever be used with.
110
+ max_text_positions (`int`, *optional*, defaults to 450):
111
+ The maximum sequence length of text features that this model might ever be used with.
112
+ encoder_max_relative_position (`int`, *optional*, defaults to 160):
113
+ Maximum distance for relative position embedding in the encoder.
114
+ use_guided_attention_loss (`bool`, *optional*, defaults to `True`):
115
+ Whether to apply guided attention loss while training the TTS model.
116
+ guided_attention_loss_num_heads (`int`, *optional*, defaults to 2):
117
+ Number of attention heads the guided attention loss will be applied to. Use -1 to apply this loss to all
118
+ attention heads.
119
+ guided_attention_loss_sigma (`float`, *optional*, defaults to 0.4):
120
+ Standard deviation for guided attention loss.
121
+ guided_attention_loss_scale (`float`, *optional*, defaults to 10.0):
122
+ Scaling coefficient for guided attention loss (also known as lambda).
123
+
124
+ Example:
125
+
126
+ ```python
127
+ >>> from transformers import SpeechT5Model, SpeechT5Config
128
+
129
+ >>> # Initializing a "microsoft/speecht5_asr" style configuration
130
+ >>> configuration = SpeechT5Config()
131
+
132
+ >>> # Initializing a model (with random weights) from the "microsoft/speecht5_asr" style configuration
133
+ >>> model = SpeechT5Model(configuration)
134
+
135
+ >>> # Accessing the model configuration
136
+ >>> configuration = model.config
137
+ ```"""
138
+
139
+ model_type = "speecht5"
140
+ attribute_map = {"num_attention_heads": "encoder_attention_heads", "num_hidden_layers": "encoder_layers"}
141
+
142
+ vocab_size: int = 81
143
+ hidden_size: int = 768
144
+ encoder_layers: int = 12
145
+ encoder_attention_heads: int = 12
146
+ encoder_ffn_dim: int = 3072
147
+ encoder_layerdrop: float | int = 0.1
148
+ decoder_layers: int = 6
149
+ decoder_ffn_dim: int = 3072
150
+ decoder_attention_heads: int = 12
151
+ decoder_layerdrop: float | int = 0.1
152
+ hidden_act: str = "gelu"
153
+ positional_dropout: float | int = 0.1
154
+ hidden_dropout: float | int = 0.1
155
+ attention_dropout: float | int = 0.1
156
+ activation_dropout: float | int = 0.1
157
+ initializer_range: float = 0.02
158
+ layer_norm_eps: float = 1e-5
159
+ scale_embedding: bool = False
160
+ feat_extract_norm: str = "group"
161
+ feat_proj_dropout: float | int = 0.0
162
+ feat_extract_activation: str = "gelu"
163
+ conv_dim: list[int] | tuple[int, ...] = (512, 512, 512, 512, 512, 512, 512)
164
+ conv_stride: list[int] | tuple[int, ...] = (5, 2, 2, 2, 2, 2, 2)
165
+ conv_kernel: list[int] | tuple[int, ...] = (10, 3, 3, 3, 3, 2, 2)
166
+ conv_bias: bool = False
167
+ num_conv_pos_embeddings: int = 128
168
+ num_conv_pos_embedding_groups: int = 16
169
+ apply_spec_augment: bool = True
170
+ mask_time_prob: float | int = 0.05
171
+ mask_time_length: int = 10
172
+ mask_time_min_masks: int = 2
173
+ mask_feature_prob: float | int = 0.0
174
+ mask_feature_length: int = 10
175
+ mask_feature_min_masks: int = 0
176
+ pad_token_id: int | None = 1
177
+ bos_token_id: int | None = 0
178
+ eos_token_id: int | list[int] | None = 2
179
+ decoder_start_token_id: int | None = 2
180
+ num_mel_bins: int = 80
181
+ speech_decoder_prenet_layers: int = 2
182
+ speech_decoder_prenet_units: int = 256
183
+ speech_decoder_prenet_dropout: float | int = 0.5
184
+ speaker_embedding_dim: int = 512
185
+ speech_decoder_postnet_layers: int = 5
186
+ speech_decoder_postnet_units: int = 256
187
+ speech_decoder_postnet_kernel: int = 5
188
+ speech_decoder_postnet_dropout: float | int = 0.5
189
+ reduction_factor: int = 2
190
+ max_speech_positions: int = 4000
191
+ max_text_positions: int = 450
192
+ encoder_max_relative_position: int = 160
193
+ use_guided_attention_loss: bool = True
194
+ guided_attention_loss_num_heads: int = 2
195
+ guided_attention_loss_sigma: float = 0.4
196
+ guided_attention_loss_scale: float = 10.0
197
+ use_cache: bool = True
198
+ is_encoder_decoder: bool = True
199
+ tie_word_embeddings: bool = True
200
+
201
+ def __post_init__(self, **kwargs):
202
+ self.num_feat_extract_layers = len(self.conv_dim)
203
+ super().__post_init__(**kwargs)
204
+
205
+ def validate_architecture(self):
206
+ """Part of `@strict`-powered validation. Validates the architecture of the config."""
207
+ if (
208
+ (len(self.conv_stride) != self.num_feat_extract_layers)
209
+ or (len(self.conv_kernel) != self.num_feat_extract_layers)
210
+ or (len(self.conv_dim) != self.num_feat_extract_layers)
211
+ ):
212
+ raise ValueError(
213
+ "Configuration for convolutional layers is incorrect. It is required that `len(config.conv_dim)` =="
214
+ " `len(config.conv_stride)` == `len(config.conv_kernel)`, but is `len(config.conv_dim) ="
215
+ f" {len(self.conv_dim)}`, `len(config.conv_stride) = {len(self.conv_stride)}`,"
216
+ f" `len(config.conv_kernel) = {len(self.conv_kernel)}`."
217
+ )
218
+
219
+ def inputs_to_logits_ratio(self):
220
+ return functools.reduce(operator.mul, self.conv_stride, 1)
221
+
222
+
223
+ @auto_docstring(checkpoint="microsoft/speecht5_asr")
224
+ @strict
225
+ class SpeechT5HifiGanConfig(PreTrainedConfig):
226
+ r"""
227
+ model_in_dim (`int`, *optional*, defaults to 80):
228
+ The number of frequency bins in the input log-mel spectrogram.
229
+ upsample_initial_channel (`int`, *optional*, defaults to 512):
230
+ The number of input channels into the upsampling network.
231
+ upsample_rates (`tuple[int]` or `list[int]`, *optional*, defaults to `[4, 4, 4, 4]`):
232
+ A tuple of integers defining the stride of each 1D convolutional layer in the upsampling network. The
233
+ length of *upsample_rates* defines the number of convolutional layers and has to match the length of
234
+ *upsample_kernel_sizes*.
235
+ upsample_kernel_sizes (`tuple[int]` or `list[int]`, *optional*, defaults to `[8, 8, 8, 8]`):
236
+ A tuple of integers defining the kernel size of each 1D convolutional layer in the upsampling network. The
237
+ length of *upsample_kernel_sizes* defines the number of convolutional layers and has to match the length of
238
+ *upsample_rates*.
239
+ resblock_kernel_sizes (`tuple[int]` or `list[int]`, *optional*, defaults to `[3, 7, 11]`):
240
+ A tuple of integers defining the kernel sizes of the 1D convolutional layers in the multi-receptive field
241
+ fusion (MRF) module.
242
+ resblock_dilation_sizes (`tuple[tuple[int]]` or `list[list[int]]`, *optional*, defaults to `[[1, 3, 5], [1, 3, 5], [1, 3, 5]]`):
243
+ A nested tuple of integers defining the dilation rates of the dilated 1D convolutional layers in the
244
+ multi-receptive field fusion (MRF) module.
245
+ leaky_relu_slope (`float`, *optional*, defaults to 0.1):
246
+ The angle of the negative slope used by the leaky ReLU activation.
247
+ normalize_before (`bool`, *optional*, defaults to `True`):
248
+ Whether or not to normalize the spectrogram before vocoding using the vocoder's learned mean and variance.
249
+
250
+ Example:
251
+
252
+ ```python
253
+ >>> from transformers import SpeechT5HifiGan, SpeechT5HifiGanConfig
254
+
255
+ >>> # Initializing a "microsoft/speecht5_hifigan" style configuration
256
+ >>> configuration = SpeechT5HifiGanConfig()
257
+
258
+ >>> # Initializing a model (with random weights) from the "microsoft/speecht5_hifigan" style configuration
259
+ >>> model = SpeechT5HifiGan(configuration)
260
+
261
+ >>> # Accessing the model configuration
262
+ >>> configuration = model.config
263
+ ```"""
264
+
265
+ model_type = "speecht5_hifigan"
266
+
267
+ model_in_dim: int = 80
268
+ sampling_rate: int = 16000
269
+ upsample_initial_channel: int = 512
270
+ upsample_rates: list[int] | tuple[int, ...] = (4, 4, 4, 4)
271
+ upsample_kernel_sizes: list[int] | tuple[int, ...] = (8, 8, 8, 8)
272
+ resblock_kernel_sizes: list[int] | tuple[int, ...] = (3, 7, 11)
273
+ resblock_dilation_sizes: list | tuple = ((1, 3, 5), (1, 3, 5), (1, 3, 5))
274
+ initializer_range: float = 0.01
275
+ leaky_relu_slope: float = 0.1
276
+ normalize_before: bool = True
277
+
278
+
279
+ __all__ = ["SpeechT5Config", "SpeechT5HifiGanConfig"]
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/speecht5/modeling_speecht5.py ADDED
The diff for this file is too large to render. See raw diff
 
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/speecht5/number_normalizer.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The Fairseq Authors, Microsoft Research, and the HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Number Normalizer class for SpeechT5."""
15
+
16
+ import re
17
+
18
+
19
+ class EnglishNumberNormalizer:
20
+ def __init__(self):
21
+ self.ones = ["", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
22
+ self.teens = [
23
+ "",
24
+ "eleven",
25
+ "twelve",
26
+ "thirteen",
27
+ "fourteen",
28
+ "fifteen",
29
+ "sixteen",
30
+ "seventeen",
31
+ "eighteen",
32
+ "nineteen",
33
+ ]
34
+ self.tens = ["", "ten", "twenty", "thirty", "forty", "fifty", "sixty", "seventy", "eighty", "ninety"]
35
+ self.thousands = [
36
+ "",
37
+ "thousand",
38
+ "million",
39
+ "billion",
40
+ "trillion",
41
+ "quadrillion",
42
+ "quintillion",
43
+ "sextillion",
44
+ "septillion",
45
+ "octillion",
46
+ "nonillion",
47
+ "decillion",
48
+ ]
49
+
50
+ # Define a dictionary to map currency symbols to their names
51
+ # Top most traded currencies according to
52
+ # https://en.wikipedia.org/wiki/Template:Most_traded_currencies
53
+ self.currency_symbols = {
54
+ "$": " dollars",
55
+ "€": " euros",
56
+ "£": " pounds",
57
+ "¢": " cents",
58
+ "¥": " japanese yen",
59
+ "﷼": " saudi riyal",
60
+ "₹": " indian rupees",
61
+ "₽": " russian rubles",
62
+ "฿": " thai baht",
63
+ "₺": " turkish liras",
64
+ "₴": " ukrainian hryvnia",
65
+ "₣": " swiss francs",
66
+ "₡": " costa rican colon",
67
+ "₱": " philippine peso",
68
+ "₪": " israeli shekels",
69
+ "₮": " mongolian tögrög",
70
+ "₩": " south korean won",
71
+ "₦": " nigerian naira",
72
+ "₫": " vietnamese Đồng",
73
+ }
74
+
75
+ def spell_number(self, num):
76
+ if num == 0:
77
+ return "zero"
78
+
79
+ parts = []
80
+ for i in range(0, len(self.thousands)):
81
+ if num % 1000 != 0:
82
+ part = ""
83
+ hundreds = num % 1000 // 100
84
+ tens_units = num % 100
85
+
86
+ if hundreds > 0:
87
+ part += self.ones[hundreds] + " hundred"
88
+ if tens_units > 0:
89
+ part += " and "
90
+
91
+ if tens_units > 10 and tens_units < 20:
92
+ part += self.teens[tens_units - 10]
93
+ else:
94
+ tens_digit = self.tens[tens_units // 10]
95
+ ones_digit = self.ones[tens_units % 10]
96
+ if tens_digit:
97
+ part += tens_digit
98
+ if ones_digit:
99
+ if tens_digit:
100
+ part += " "
101
+ part += ones_digit
102
+
103
+ parts.append(part)
104
+
105
+ num //= 1000
106
+
107
+ return " ".join(reversed(parts))
108
+
109
+ def convert(self, number):
110
+ """
111
+ Converts an individual number passed in string form to spelt-out form
112
+ """
113
+ if "." in number:
114
+ integer_part, decimal_part = number.split(".")
115
+ else:
116
+ integer_part, decimal_part = number, "00"
117
+
118
+ # Extract currency symbol if present
119
+ currency_symbol = ""
120
+ for symbol, name in self.currency_symbols.items():
121
+ if integer_part.startswith(symbol):
122
+ currency_symbol = name
123
+ integer_part = integer_part[len(symbol) :]
124
+ break
125
+
126
+ if integer_part.startswith("-"):
127
+ if integer_part[1:].startswith(symbol):
128
+ currency_symbol = name
129
+ integer_part = "-" + integer_part[len(symbol) + 1 :]
130
+ break
131
+
132
+ # Extract 'minus' prefix for negative numbers
133
+ minus_prefix = ""
134
+ if integer_part.startswith("-"):
135
+ minus_prefix = "minus "
136
+ integer_part = integer_part[1:]
137
+ elif integer_part.startswith("minus"):
138
+ minus_prefix = "minus "
139
+ integer_part = integer_part[len("minus") :]
140
+
141
+ percent_suffix = ""
142
+ if "%" in integer_part or "%" in decimal_part:
143
+ percent_suffix = " percent"
144
+ integer_part = integer_part.replace("%", "")
145
+ decimal_part = decimal_part.replace("%", "")
146
+
147
+ integer_part = integer_part.zfill(3 * ((len(integer_part) - 1) // 3 + 1))
148
+
149
+ parts = []
150
+ for i in range(0, len(integer_part), 3):
151
+ chunk = int(integer_part[i : i + 3])
152
+ if chunk > 0:
153
+ part = self.spell_number(chunk)
154
+ unit = self.thousands[len(integer_part[i:]) // 3 - 1]
155
+ if unit:
156
+ part += " " + unit
157
+ parts.append(part)
158
+
159
+ spelled_integer = " ".join(parts)
160
+
161
+ # Format the spelt-out number based on conditions, such as:
162
+ # If it has decimal parts, currency symbol, minus prefix, etc
163
+ if decimal_part == "00":
164
+ return (
165
+ f"{minus_prefix}{spelled_integer}{percent_suffix}{currency_symbol}"
166
+ if minus_prefix or currency_symbol
167
+ else f"{spelled_integer}{percent_suffix}"
168
+ )
169
+ else:
170
+ spelled_decimal = " ".join([self.spell_number(int(digit)) for digit in decimal_part])
171
+ return (
172
+ f"{minus_prefix}{spelled_integer} point {spelled_decimal}{percent_suffix}{currency_symbol}"
173
+ if minus_prefix or currency_symbol
174
+ else f"{minus_prefix}{spelled_integer} point {spelled_decimal}{percent_suffix}"
175
+ )
176
+
177
+ def __call__(self, text):
178
+ """
179
+ Convert numbers / number-like quantities in a string to their spelt-out counterparts
180
+ """
181
+ # Form part of the pattern for all currency symbols
182
+ pattern = r"(?<!\w)(-?\$?\€?\£?\¢?\¥?\₹?\₽?\฿?\₺?\₴?\₣?\₡?\₱?\₪?\₮?\₩?\₦?\₫?\﷼?\d+(?:\.\d{1,2})?%?)(?!\w)"
183
+
184
+ # Find and replace commas in numbers (15,000 -> 15000, etc)
185
+ text = re.sub(r"(\d+,\d+)", lambda match: match.group(1).replace(",", ""), text)
186
+
187
+ # Use regex to find and replace numbers in the text
188
+ converted_text = re.sub(pattern, lambda match: self.convert(match.group(1)), text)
189
+ converted_text = re.sub(" +", " ", converted_text)
190
+
191
+ return converted_text
LTA_openwebtext_dualt/mini_owt_logdirichlet/.venv_qwen35_uv/lib/python3.12/site-packages/transformers/models/speecht5/tokenization_speecht5.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The Facebook Inc. and The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Tokenization class for SpeechT5."""
15
+
16
+ from typing import Any
17
+
18
+ from ...tokenization_utils_sentencepiece import SentencePieceBackend
19
+ from ...utils import logging
20
+ from ...utils.import_utils import requires
21
+ from .number_normalizer import EnglishNumberNormalizer
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+ VOCAB_FILES_NAMES = {"vocab_file": "spm_char.model"}
27
+
28
+
29
+ @requires(backends=("sentencepiece",))
30
+ class SpeechT5Tokenizer(SentencePieceBackend):
31
+ """
32
+ Construct a SpeechT5 tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece).
33
+
34
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
35
+ this superclass for more information regarding those methods.
36
+
37
+ Args:
38
+ vocab_file (`str`):
39
+ [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
40
+ contains the vocabulary necessary to instantiate a tokenizer.
41
+ bos_token (`str`, *optional*, defaults to `"<s>"`):
42
+ The begin of sequence token.
43
+ eos_token (`str`, *optional*, defaults to `"</s>"`):
44
+ The end of sequence token.
45
+ unk_token (`str`, *optional*, defaults to `"<unk>"`):
46
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
47
+ token instead.
48
+ pad_token (`str`, *optional*, defaults to `"<pad>"`):
49
+ The token used for padding, for example when batching sequences of different lengths.
50
+ normalize (`bool`, *optional*, defaults to `False`):
51
+ Whether to convert numeric quantities in the text to their spelt-out english counterparts.
52
+ sp_model_kwargs (`dict`, *optional*):
53
+ Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
54
+ SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
55
+ to set:
56
+
57
+ - `enable_sampling`: Enable subword regularization.
58
+ - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
59
+
60
+ - `nbest_size = {0,1}`: No sampling is performed.
61
+ - `nbest_size > 1`: samples from the nbest_size results.
62
+ - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
63
+ using forward-filtering-and-backward-sampling algorithm.
64
+
65
+ - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
66
+ BPE-dropout.
67
+
68
+ Attributes:
69
+ sp_model (`SentencePieceProcessor`):
70
+ The *SentencePiece* processor that is used for every conversion (string, tokens and IDs).
71
+ """
72
+
73
+ vocab_files_names = VOCAB_FILES_NAMES
74
+ model_input_names = ["input_ids", "attention_mask"]
75
+ is_fast = False
76
+
77
+ def __init__(
78
+ self,
79
+ vocab_file,
80
+ bos_token="<s>",
81
+ eos_token="</s>",
82
+ unk_token="<unk>",
83
+ pad_token="<pad>",
84
+ normalize=False,
85
+ sp_model_kwargs: dict[str, Any] | None = None,
86
+ **kwargs,
87
+ ) -> None:
88
+ self.normalize = normalize
89
+ self._normalizer = None
90
+
91
+ # Prepare sp_model_kwargs for parent class
92
+ if sp_model_kwargs is not None:
93
+ kwargs["sp_model_kwargs"] = sp_model_kwargs
94
+
95
+ # Call parent init (which will load sp_model)
96
+ super().__init__(
97
+ vocab_file=vocab_file,
98
+ bos_token=bos_token,
99
+ eos_token=eos_token,
100
+ unk_token=unk_token,
101
+ pad_token=pad_token,
102
+ normalize=normalize,
103
+ **kwargs,
104
+ )
105
+
106
+ def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs):
107
+ normalize = kwargs.pop("normalize", self.normalize)
108
+ if is_split_into_words:
109
+ text = " " + text
110
+ if normalize:
111
+ text = self.normalizer(text)
112
+ return (text, kwargs)
113
+
114
+ @property
115
+ def normalizer(self):
116
+ if self._normalizer is None:
117
+ self._normalizer = EnglishNumberNormalizer()
118
+ return self._normalizer
119
+
120
+ @normalizer.setter
121
+ def normalizer(self, value):
122
+ self._normalizer = value
123
+
124
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> list[int]:
125
+ """Build model inputs from a sequence by appending eos_token_id."""
126
+ if token_ids_1 is None:
127
+ return token_ids_0 + [self.eos_token_id]
128
+ # We don't expect to process pairs, but leave the pair logic for API consistency
129
+ return token_ids_0 + token_ids_1 + [self.eos_token_id]
130
+
131
+ def get_special_tokens_mask(
132
+ self, token_ids_0: list[int], token_ids_1: list[int] | None = None, already_has_special_tokens: bool = False
133
+ ) -> list[int]:
134
+ if already_has_special_tokens:
135
+ return super().get_special_tokens_mask(
136
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
137
+ )
138
+
139
+ suffix_ones = [1]
140
+ if token_ids_1 is None:
141
+ return ([0] * len(token_ids_0)) + suffix_ones
142
+ return ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones
143
+
144
+ def create_token_type_ids_from_sequences(
145
+ self, token_ids_0: list[int], token_ids_1: list[int] | None = None
146
+ ) -> list[int]:
147
+ """
148
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. SpeechT5 does not
149
+ make use of token type ids, therefore a list of zeros is returned.
150
+
151
+ Args:
152
+ token_ids_0 (`list[int]`):
153
+ List of IDs.
154
+ token_ids_1 (`list[int]`, *optional*):
155
+ Optional second list of IDs for sequence pairs.
156
+
157
+ Returns:
158
+ `list[int]`: List of zeros.
159
+ """
160
+ eos = [self.eos_token_id]
161
+ if token_ids_1 is None:
162
+ return len(token_ids_0 + eos) * [0]
163
+ return len(token_ids_0 + token_ids_1 + eos) * [0]
164
+
165
+
166
+ __all__ = ["SpeechT5Tokenizer"]
LTA_openwebtext_dualt/mini_owt_logdirichlet/runs/owt_t5_elftokenized_full_len1024_C1_to_1024_pow1_d768_l12_h12_gbs512_2x8gpu_50ep_lr4e3_ema0p9999_elfopt_not5_bottleneck128_unfixed_norm_stateprobadd_selfcond_ce_fast_20260610_020108/step_053000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c3a0ce15a3f8e0441fca84965ff658de402bc494bd94b53730430287ab2ab2df
3
+ size 927700322
LTA_openwebtext_dualt/mini_owt_logdirichlet/runs/owt_t5_elftokenized_full_len1024_C1_to_1024_pow1_d768_l12_h12_gbs512_2x8gpu_50ep_lr4e3_ema0p9999_elfopt_not5_bottleneck128_unfixed_norm_stateprobadd_selfcond_ce_fast_20260610_020108/step_163000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c731b03de97e220aab47f37b3d6c191aa23591d0c4f973a4b38e93730358b2bf
3
+ size 927700322
LTA_openwebtext_dualt/mini_owt_logdirichlet/runs/owt_t5_elftokenized_full_len1024_C1_to_1024_pow1_d768_l12_h12_gbs512_2x8gpu_50ep_lr4e3_ema0p9999_elfopt_not5_bottleneck128_unfixed_norm_stateprobadd_selfcond_ce_fast_20260610_020108/step_172000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aa1869d751dd1db17309c432203e4d0978a22ab1fbe9065646c5d04cfe9baa67
3
+ size 927700322
LTA_openwebtext_dualt/mini_owt_logdirichlet/runs/owt_t5_elftokenized_full_len1024_C1_to_1024_pow1_d768_l12_h12_gbs512_2x8gpu_50ep_lr4e3_ema0p9999_elfopt_not5_bottleneck128_unfixed_norm_stateprobadd_selfcond_ce_fast_20260610_020108/step_182000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2ca8affe97a9e4ab92c98e52693f27b329e99dd9122eb9b7672ab56618aaf840
3
+ size 927700322